RHI: Add Compute/Dispatch unit tests (P1-7) and fix shader type bugs

Bug fixes:
- D3D12Shader::Compile: Set m_type based on target string (cs_/vs_/ps_/gs_)
- OpenGLShader::Compile: Parse target parameter to determine shader type
- OpenGLShader::CompileCompute: Set m_type = ShaderType::Compute
- D3D12CommandList::SetPipelineState: Use correct PSO handle for Compute

New tests (test_compute.cpp, 8 tests):
- ComputeShader_Compile_ValidShader
- ComputeShader_GetType_ReturnsCompute
- ComputeShader_Shutdown_Invalidates
- PipelineState_SetComputeShader
- PipelineState_HasComputeShader_ReturnsTrue
- PipelineState_GetType_Compute
- PipelineState_EnsureValid_Compute
- CommandList_Dispatch_Basic

Test results: 232/232 passed (D3D12: 116, OpenGL: 116)
This commit is contained in:
2026-03-25 13:52:11 +08:00
parent da1f8cfb58
commit 720dd422d5
5 changed files with 312 additions and 2 deletions

View File

@@ -178,7 +178,13 @@ void D3D12CommandList::AliasBarrier(ID3D12Resource* beforeResource, ID3D12Resour
void D3D12CommandList::SetPipelineState(RHIPipelineState* pso) {
if (!pso) return;
SetPipelineStateInternal(static_cast<ID3D12PipelineState*>(pso->GetNativeHandle()));
if (pso->GetType() == PipelineType::Compute) {
D3D12PipelineState* d3d12Pso = static_cast<D3D12PipelineState*>(pso);
SetPipelineStateInternal(d3d12Pso->GetComputePipelineState());
} else {
SetPipelineStateInternal(static_cast<ID3D12PipelineState*>(pso->GetNativeHandle()));
}
}
void D3D12CommandList::SetPipelineState(ID3D12PipelineState* pso) {

View File

@@ -55,6 +55,16 @@ bool D3D12Shader::Compile(const void* sourceData, size_t sourceSize, const char*
return false;
}
if (strstr(target, "vs_")) {
m_type = ShaderType::Vertex;
} else if (strstr(target, "ps_")) {
m_type = ShaderType::Fragment;
} else if (strstr(target, "gs_")) {
m_type = ShaderType::Geometry;
} else if (strstr(target, "cs_")) {
m_type = ShaderType::Compute;
}
m_uniformsCached = false;
return true;
}

View File

@@ -153,6 +153,7 @@ bool OpenGLShader::CompileCompute(const char* computeSource) {
glDeleteShader(compute);
m_type = ShaderType::Compute;
m_uniformsCached = false;
return true;
@@ -219,7 +220,20 @@ bool OpenGLShader::Compile(const void* sourceData, size_t sourceSize, const char
if (!sourceData || sourceSize == 0) {
return false;
}
return Compile(static_cast<const char*>(sourceData), ShaderType::Fragment);
ShaderType type = ShaderType::Fragment;
if (target) {
if (strstr(target, "vs_")) {
type = ShaderType::Vertex;
} else if (strstr(target, "ps_")) {
type = ShaderType::Fragment;
} else if (strstr(target, "gs_")) {
type = ShaderType::Geometry;
} else if (strstr(target, "cs_")) {
type = ShaderType::Compute;
}
}
return Compile(static_cast<const char*>(sourceData), type);
}
void OpenGLShader::Shutdown() {