From 720dd422d581e3cd4cff49bcb6cb5778f7382e45 Mon Sep 17 00:00:00 2001 From: ssdfasd <2156608475@qq.com> Date: Wed, 25 Mar 2026 13:52:11 +0800 Subject: [PATCH] RHI: Add Compute/Dispatch unit tests (P1-7) and fix shader type bugs Bug fixes: - D3D12Shader::Compile: Set m_type based on target string (cs_/vs_/ps_/gs_) - OpenGLShader::Compile: Parse target parameter to determine shader type - OpenGLShader::CompileCompute: Set m_type = ShaderType::Compute - D3D12CommandList::SetPipelineState: Use correct PSO handle for Compute New tests (test_compute.cpp, 8 tests): - ComputeShader_Compile_ValidShader - ComputeShader_GetType_ReturnsCompute - ComputeShader_Shutdown_Invalidates - PipelineState_SetComputeShader - PipelineState_HasComputeShader_ReturnsTrue - PipelineState_GetType_Compute - PipelineState_EnsureValid_Compute - CommandList_Dispatch_Basic Test results: 232/232 passed (D3D12: 116, OpenGL: 116) --- engine/src/RHI/D3D12/D3D12CommandList.cpp | 8 +- engine/src/RHI/D3D12/D3D12Shader.cpp | 10 + engine/src/RHI/OpenGL/OpenGLShader.cpp | 16 +- tests/RHI/unit/CMakeLists.txt | 1 + tests/RHI/unit/test_compute.cpp | 279 ++++++++++++++++++++++ 5 files changed, 312 insertions(+), 2 deletions(-) create mode 100644 tests/RHI/unit/test_compute.cpp diff --git a/engine/src/RHI/D3D12/D3D12CommandList.cpp b/engine/src/RHI/D3D12/D3D12CommandList.cpp index e30e4428..cd165701 100644 --- a/engine/src/RHI/D3D12/D3D12CommandList.cpp +++ b/engine/src/RHI/D3D12/D3D12CommandList.cpp @@ -178,7 +178,13 @@ void D3D12CommandList::AliasBarrier(ID3D12Resource* beforeResource, ID3D12Resour void D3D12CommandList::SetPipelineState(RHIPipelineState* pso) { if (!pso) return; - SetPipelineStateInternal(static_cast(pso->GetNativeHandle())); + + if (pso->GetType() == PipelineType::Compute) { + D3D12PipelineState* d3d12Pso = static_cast(pso); + SetPipelineStateInternal(d3d12Pso->GetComputePipelineState()); + } else { + SetPipelineStateInternal(static_cast(pso->GetNativeHandle())); + } } void D3D12CommandList::SetPipelineState(ID3D12PipelineState* pso) { diff --git a/engine/src/RHI/D3D12/D3D12Shader.cpp b/engine/src/RHI/D3D12/D3D12Shader.cpp index 2a3a9217..e2221e0c 100644 --- a/engine/src/RHI/D3D12/D3D12Shader.cpp +++ b/engine/src/RHI/D3D12/D3D12Shader.cpp @@ -55,6 +55,16 @@ bool D3D12Shader::Compile(const void* sourceData, size_t sourceSize, const char* return false; } + if (strstr(target, "vs_")) { + m_type = ShaderType::Vertex; + } else if (strstr(target, "ps_")) { + m_type = ShaderType::Fragment; + } else if (strstr(target, "gs_")) { + m_type = ShaderType::Geometry; + } else if (strstr(target, "cs_")) { + m_type = ShaderType::Compute; + } + m_uniformsCached = false; return true; } diff --git a/engine/src/RHI/OpenGL/OpenGLShader.cpp b/engine/src/RHI/OpenGL/OpenGLShader.cpp index a366fa7b..ac4c52e4 100644 --- a/engine/src/RHI/OpenGL/OpenGLShader.cpp +++ b/engine/src/RHI/OpenGL/OpenGLShader.cpp @@ -153,6 +153,7 @@ bool OpenGLShader::CompileCompute(const char* computeSource) { glDeleteShader(compute); + m_type = ShaderType::Compute; m_uniformsCached = false; return true; @@ -219,7 +220,20 @@ bool OpenGLShader::Compile(const void* sourceData, size_t sourceSize, const char if (!sourceData || sourceSize == 0) { return false; } - return Compile(static_cast(sourceData), ShaderType::Fragment); + + ShaderType type = ShaderType::Fragment; + if (target) { + if (strstr(target, "vs_")) { + type = ShaderType::Vertex; + } else if (strstr(target, "ps_")) { + type = ShaderType::Fragment; + } else if (strstr(target, "gs_")) { + type = ShaderType::Geometry; + } else if (strstr(target, "cs_")) { + type = ShaderType::Compute; + } + } + return Compile(static_cast(sourceData), type); } void OpenGLShader::Shutdown() { diff --git a/tests/RHI/unit/CMakeLists.txt b/tests/RHI/unit/CMakeLists.txt index 94201b7e..c966de4f 100644 --- a/tests/RHI/unit/CMakeLists.txt +++ b/tests/RHI/unit/CMakeLists.txt @@ -19,6 +19,7 @@ set(TEST_SOURCES test_fence.cpp test_sampler.cpp test_descriptor.cpp + test_compute.cpp ${CMAKE_SOURCE_DIR}/tests/opengl/package/src/glad.c ) diff --git a/tests/RHI/unit/test_compute.cpp b/tests/RHI/unit/test_compute.cpp new file mode 100644 index 00000000..590ec6bc --- /dev/null +++ b/tests/RHI/unit/test_compute.cpp @@ -0,0 +1,279 @@ +#include "fixtures/RHITestFixture.h" +#include "XCEngine/RHI/RHIShader.h" +#include "XCEngine/RHI/RHIPipelineState.h" +#include + +using namespace XCEngine::RHI; + +TEST_P(RHITestFixture, ComputeShader_Compile_ValidShader) { + ShaderCompileDesc desc = {}; + if (GetBackendType() == RHIType::D3D12) { + desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; + desc.entryPoint = L"MainCS"; + desc.profile = L"cs_5_0"; + } else { + desc.sourceLanguage = ShaderLanguage::GLSL; + static const char* cs = R"( + #version 460 + layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + void main() { + } + )"; + desc.source.assign(cs, cs + strlen(cs)); + } + + RHIShader* shader = GetDevice()->CompileShader(desc); + if (shader != nullptr) { + EXPECT_TRUE(shader->IsValid()); + EXPECT_EQ(shader->GetType(), ShaderType::Compute); + shader->Shutdown(); + delete shader; + } +} + +TEST_P(RHITestFixture, ComputeShader_GetType_ReturnsCompute) { + ShaderCompileDesc desc = {}; + if (GetBackendType() == RHIType::D3D12) { + desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; + desc.entryPoint = L"MainCS"; + desc.profile = L"cs_5_0"; + } else { + desc.sourceLanguage = ShaderLanguage::GLSL; + static const char* cs = R"( + #version 460 + layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + void main() { + } + )"; + desc.source.assign(cs, cs + strlen(cs)); + } + + RHIShader* shader = GetDevice()->CompileShader(desc); + if (shader != nullptr) { + EXPECT_EQ(shader->GetType(), ShaderType::Compute); + shader->Shutdown(); + delete shader; + } +} + +TEST_P(RHITestFixture, ComputeShader_Shutdown_Invalidates) { + ShaderCompileDesc desc = {}; + if (GetBackendType() == RHIType::D3D12) { + desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; + desc.entryPoint = L"MainCS"; + desc.profile = L"cs_5_0"; + } else { + desc.sourceLanguage = ShaderLanguage::GLSL; + static const char* cs = R"( + #version 460 + layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + void main() { + } + )"; + desc.source.assign(cs, cs + strlen(cs)); + } + + RHIShader* shader = GetDevice()->CompileShader(desc); + if (shader != nullptr) { + EXPECT_TRUE(shader->IsValid()); + shader->Shutdown(); + EXPECT_FALSE(shader->IsValid()); + delete shader; + } +} + +TEST_P(RHITestFixture, PipelineState_SetComputeShader) { + GraphicsPipelineDesc desc = {}; + RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc); + if (pso == nullptr) { + return; + } + + ShaderCompileDesc shaderDesc = {}; + if (GetBackendType() == RHIType::D3D12) { + shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; + shaderDesc.entryPoint = L"MainCS"; + shaderDesc.profile = L"cs_5_0"; + } else { + shaderDesc.sourceLanguage = ShaderLanguage::GLSL; + static const char* cs = R"( + #version 460 + layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + void main() { + } + )"; + shaderDesc.source.assign(cs, cs + strlen(cs)); + } + + RHIShader* computeShader = GetDevice()->CompileShader(shaderDesc); + if (computeShader != nullptr) { + pso->SetComputeShader(computeShader); + EXPECT_TRUE(pso->HasComputeShader()); + computeShader->Shutdown(); + delete computeShader; + } + + pso->Shutdown(); + delete pso; +} + +TEST_P(RHITestFixture, PipelineState_HasComputeShader_ReturnsTrue) { + GraphicsPipelineDesc desc = {}; + RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc); + if (pso == nullptr) { + return; + } + + EXPECT_FALSE(pso->HasComputeShader()); + + ShaderCompileDesc shaderDesc = {}; + if (GetBackendType() == RHIType::D3D12) { + shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; + shaderDesc.entryPoint = L"MainCS"; + shaderDesc.profile = L"cs_5_0"; + } else { + shaderDesc.sourceLanguage = ShaderLanguage::GLSL; + static const char* cs = R"( + #version 460 + layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + void main() { + } + )"; + shaderDesc.source.assign(cs, cs + strlen(cs)); + } + + RHIShader* computeShader = GetDevice()->CompileShader(shaderDesc); + if (computeShader != nullptr) { + pso->SetComputeShader(computeShader); + EXPECT_TRUE(pso->HasComputeShader()); + computeShader->Shutdown(); + delete computeShader; + } + + pso->Shutdown(); + delete pso; +} + +TEST_P(RHITestFixture, PipelineState_GetType_Compute) { + GraphicsPipelineDesc desc = {}; + RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc); + if (pso == nullptr) { + return; + } + + EXPECT_EQ(pso->GetType(), PipelineType::Graphics); + + ShaderCompileDesc shaderDesc = {}; + if (GetBackendType() == RHIType::D3D12) { + shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; + shaderDesc.entryPoint = L"MainCS"; + shaderDesc.profile = L"cs_5_0"; + } else { + shaderDesc.sourceLanguage = ShaderLanguage::GLSL; + static const char* cs = R"( + #version 460 + layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + void main() { + } + )"; + shaderDesc.source.assign(cs, cs + strlen(cs)); + } + + RHIShader* computeShader = GetDevice()->CompileShader(shaderDesc); + if (computeShader != nullptr) { + pso->SetComputeShader(computeShader); + EXPECT_EQ(pso->GetType(), PipelineType::Compute); + computeShader->Shutdown(); + delete computeShader; + } + + pso->Shutdown(); + delete pso; +} + +TEST_P(RHITestFixture, PipelineState_EnsureValid_Compute) { + GraphicsPipelineDesc desc = {}; + RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc); + if (pso == nullptr) { + return; + } + + ShaderCompileDesc shaderDesc = {}; + if (GetBackendType() == RHIType::D3D12) { + shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; + shaderDesc.entryPoint = L"MainCS"; + shaderDesc.profile = L"cs_5_0"; + } else { + shaderDesc.sourceLanguage = ShaderLanguage::GLSL; + static const char* cs = R"( + #version 460 + layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + void main() { + } + )"; + shaderDesc.source.assign(cs, cs + strlen(cs)); + } + + RHIShader* computeShader = GetDevice()->CompileShader(shaderDesc); + if (computeShader != nullptr) { + pso->SetComputeShader(computeShader); + + if (GetBackendType() == RHIType::D3D12) { + EXPECT_FALSE(pso->IsValid()); + pso->EnsureValid(); + } else { + EXPECT_TRUE(pso->IsValid()); + } + + computeShader->Shutdown(); + delete computeShader; + } + + pso->Shutdown(); + delete pso; +} + +TEST_P(RHITestFixture, CommandList_Dispatch_Basic) { + RHICommandList* cmdList = GetDevice()->CreateCommandList({}); + ASSERT_NE(cmdList, nullptr); + + cmdList->Reset(); + + GraphicsPipelineDesc desc = {}; + RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc); + + if (pso != nullptr) { + ShaderCompileDesc shaderDesc = {}; + if (GetBackendType() == RHIType::D3D12) { + shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; + shaderDesc.entryPoint = L"MainCS"; + shaderDesc.profile = L"cs_5_0"; + } else { + shaderDesc.sourceLanguage = ShaderLanguage::GLSL; + static const char* cs = R"( + #version 460 + layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + void main() { + } + )"; + shaderDesc.source.assign(cs, cs + strlen(cs)); + } + + RHIShader* computeShader = GetDevice()->CompileShader(shaderDesc); + if (computeShader != nullptr) { + pso->SetComputeShader(computeShader); + cmdList->SetPipelineState(pso); + cmdList->Dispatch(1, 1, 1); + + computeShader->Shutdown(); + delete computeShader; + } + + pso->Shutdown(); + delete pso; + } + + cmdList->Close(); + cmdList->Shutdown(); + delete cmdList; +} \ No newline at end of file