diff --git a/engine/include/XCEngine/RHI/D3D12/D3D12PipelineState.h b/engine/include/XCEngine/RHI/D3D12/D3D12PipelineState.h index 63278813..df373a94 100644 --- a/engine/include/XCEngine/RHI/D3D12/D3D12PipelineState.h +++ b/engine/include/XCEngine/RHI/D3D12/D3D12PipelineState.h @@ -79,6 +79,7 @@ public: private: bool CreateD3D12PSO(); bool CreateD3D12ComputePSO(); + bool EnsureDefaultRootSignature(); ID3D12Device* m_device; bool m_finalized = false; diff --git a/engine/src/RHI/D3D12/D3D12CommandList.cpp b/engine/src/RHI/D3D12/D3D12CommandList.cpp index 234d9844..319c6027 100644 --- a/engine/src/RHI/D3D12/D3D12CommandList.cpp +++ b/engine/src/RHI/D3D12/D3D12CommandList.cpp @@ -139,16 +139,32 @@ void D3D12CommandList::AliasBarrier(ID3D12Resource* beforeResource, ID3D12Resour void D3D12CommandList::SetPipelineState(RHIPipelineState* pso) { if (!pso) return; - + + pso->EnsureValid(); + + D3D12PipelineState* d3d12Pso = static_cast(pso); if (pso->GetType() == PipelineType::Compute) { - D3D12PipelineState* d3d12Pso = static_cast(pso); - SetPipelineStateInternal(d3d12Pso->GetComputePipelineState()); + ID3D12RootSignature* computeRootSignature = d3d12Pso->GetRootSignature(); + ID3D12PipelineState* computePipelineState = d3d12Pso->GetComputePipelineState(); + if (computeRootSignature == nullptr || computePipelineState == nullptr) { + return; + } + + m_commandList->SetComputeRootSignature(computeRootSignature); + m_commandList->SetPipelineState(computePipelineState); + m_currentPipelineState = computePipelineState; + m_currentRootSignature = computeRootSignature; } else { - D3D12PipelineState* d3d12Pso = static_cast(pso); if (d3d12Pso->GetRootSignature() != nullptr) { SetRootSignature(d3d12Pso->GetRootSignature()); } - SetPipelineStateInternal(static_cast(pso->GetNativeHandle())); + + ID3D12PipelineState* graphicsPipelineState = static_cast(pso->GetNativeHandle()); + if (graphicsPipelineState == nullptr) { + return; + } + + SetPipelineStateInternal(graphicsPipelineState); } } @@ -484,6 +500,10 @@ void D3D12CommandList::EndRenderPass() { } void D3D12CommandList::SetPipelineStateInternal(ID3D12PipelineState* pso) { + if (pso == nullptr) { + return; + } + m_commandList->SetPipelineState(pso); m_currentPipelineState = pso; if (m_currentRootSignature) { diff --git a/engine/src/RHI/D3D12/D3D12PipelineState.cpp b/engine/src/RHI/D3D12/D3D12PipelineState.cpp index 48dbdbb2..8cb3e572 100644 --- a/engine/src/RHI/D3D12/D3D12PipelineState.cpp +++ b/engine/src/RHI/D3D12/D3D12PipelineState.cpp @@ -126,6 +126,17 @@ void D3D12PipelineState::SetShaderBytecodes(const D3D12_SHADER_BYTECODE& vs, con void D3D12PipelineState::SetComputeShader(RHIShader* shader) { m_computeShader = shader; + m_csBytecode = {}; + m_computePipelineState.Reset(); + m_pipelineState.Reset(); + m_finalized = false; + + if (shader != nullptr) { + auto* d3d12Shader = static_cast(shader); + if (d3d12Shader->IsValid() && d3d12Shader->GetType() == ShaderType::Compute) { + m_csBytecode = d3d12Shader->GetD3D12Bytecode(); + } + } } void D3D12PipelineState::SetRootSignature(ID3D12RootSignature* rootSignature) { @@ -217,6 +228,10 @@ bool D3D12PipelineState::CreateD3D12ComputePSO() { return false; } + if (!EnsureDefaultRootSignature()) { + return false; + } + D3D12_COMPUTE_PIPELINE_STATE_DESC desc = {}; desc.pRootSignature = m_rootSignature.Get(); desc.CS = m_csBytecode; @@ -230,10 +245,46 @@ bool D3D12PipelineState::CreateD3D12ComputePSO() { return true; } +bool D3D12PipelineState::EnsureDefaultRootSignature() { + if (m_rootSignature != nullptr) { + return true; + } + + if (m_device == nullptr) { + return false; + } + + D3D12_ROOT_SIGNATURE_DESC rootSignatureDesc = {}; + rootSignatureDesc.Flags = D3D12_ROOT_SIGNATURE_FLAG_NONE; + + ComPtr serializedSignature; + ComPtr error; + HRESULT hr = D3D12SerializeRootSignature( + &rootSignatureDesc, + D3D_ROOT_SIGNATURE_VERSION_1, + &serializedSignature, + &error); + if (FAILED(hr)) { + return false; + } + + hr = m_device->CreateRootSignature( + 0, + serializedSignature->GetBufferPointer(), + serializedSignature->GetBufferSize(), + IID_PPV_ARGS(&m_rootSignature)); + return SUCCEEDED(hr); +} + void D3D12PipelineState::Shutdown() { m_pipelineState.Reset(); m_computePipelineState.Reset(); m_rootSignature.Reset(); + m_vsBytecode = {}; + m_psBytecode = {}; + m_gsBytecode = {}; + m_csBytecode = {}; + m_computeShader = nullptr; m_finalized = false; } diff --git a/tests/RHI/OpenGL/unit/test_shader.cpp b/tests/RHI/OpenGL/unit/test_shader.cpp index 59cda104..c3b0a32b 100644 --- a/tests/RHI/OpenGL/unit/test_shader.cpp +++ b/tests/RHI/OpenGL/unit/test_shader.cpp @@ -2,6 +2,8 @@ #include "XCEngine/RHI/OpenGL/OpenGLShader.h" #include "XCEngine/RHI/OpenGL/OpenGLCommandList.h" +#include + using namespace XCEngine::RHI; TEST_F(OpenGLTestFixture, Shader_Compile_VertexFragment) { @@ -74,12 +76,15 @@ TEST_F(OpenGLTestFixture, Shader_Compile_InvalidSource) { undefined_symbol; } )"; - + OpenGLShader shader; + testing::internal::CaptureStdout(); bool result = shader.Compile(invalidSource, "void main() { }"); - + const std::string output = testing::internal::GetCapturedStdout(); + EXPECT_FALSE(result); EXPECT_FALSE(shader.IsValid()); - + EXPECT_NE(output.find("ERROR::SHADER_COMPILATION_ERROR"), std::string::npos); + shader.Shutdown(); } diff --git a/tests/RHI/unit/test_compute.cpp b/tests/RHI/unit/test_compute.cpp index c0633e99..2604dd05 100644 --- a/tests/RHI/unit/test_compute.cpp +++ b/tests/RHI/unit/test_compute.cpp @@ -5,23 +5,39 @@ using namespace XCEngine::RHI; -TEST_P(RHITestFixture, ComputeShader_Compile_ValidShader) { +namespace { + +ShaderCompileDesc MakeComputeShaderDesc(RHIType backendType) { ShaderCompileDesc desc = {}; - if (GetBackendType() == RHIType::D3D12) { - desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; + if (backendType == RHIType::D3D12) { + static const char* cs = R"( +[numthreads(1, 1, 1)] +void MainCS(uint3 dispatchThreadId : SV_DispatchThreadID) { +} +)"; + desc.source.assign(cs, cs + std::strlen(cs)); + desc.sourceLanguage = ShaderLanguage::HLSL; desc.entryPoint = L"MainCS"; desc.profile = L"cs_5_0"; } else { - desc.sourceLanguage = ShaderLanguage::GLSL; static const char* cs = R"( #version 460 layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; void main() { } )"; - desc.source.assign(cs, cs + strlen(cs)); + desc.source.assign(cs, cs + std::strlen(cs)); + desc.sourceLanguage = ShaderLanguage::GLSL; } + return desc; +} + +} // namespace + +TEST_P(RHITestFixture, ComputeShader_Compile_ValidShader) { + ShaderCompileDesc desc = MakeComputeShaderDesc(GetBackendType()); + RHIShader* shader = GetDevice()->CreateShader(desc); if (shader != nullptr) { EXPECT_TRUE(shader->IsValid()); @@ -32,21 +48,7 @@ TEST_P(RHITestFixture, ComputeShader_Compile_ValidShader) { } TEST_P(RHITestFixture, ComputeShader_GetType_ReturnsCompute) { - ShaderCompileDesc desc = {}; - if (GetBackendType() == RHIType::D3D12) { - desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; - desc.entryPoint = L"MainCS"; - desc.profile = L"cs_5_0"; - } else { - desc.sourceLanguage = ShaderLanguage::GLSL; - static const char* cs = R"( - #version 460 - layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; - void main() { - } - )"; - desc.source.assign(cs, cs + strlen(cs)); - } + ShaderCompileDesc desc = MakeComputeShaderDesc(GetBackendType()); RHIShader* shader = GetDevice()->CreateShader(desc); if (shader != nullptr) { @@ -57,21 +59,7 @@ TEST_P(RHITestFixture, ComputeShader_GetType_ReturnsCompute) { } TEST_P(RHITestFixture, ComputeShader_Shutdown_Invalidates) { - ShaderCompileDesc desc = {}; - if (GetBackendType() == RHIType::D3D12) { - desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; - desc.entryPoint = L"MainCS"; - desc.profile = L"cs_5_0"; - } else { - desc.sourceLanguage = ShaderLanguage::GLSL; - static const char* cs = R"( - #version 460 - layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; - void main() { - } - )"; - desc.source.assign(cs, cs + strlen(cs)); - } + ShaderCompileDesc desc = MakeComputeShaderDesc(GetBackendType()); RHIShader* shader = GetDevice()->CreateShader(desc); if (shader != nullptr) { @@ -89,21 +77,7 @@ TEST_P(RHITestFixture, PipelineState_SetComputeShader) { return; } - ShaderCompileDesc shaderDesc = {}; - if (GetBackendType() == RHIType::D3D12) { - shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; - shaderDesc.entryPoint = L"MainCS"; - shaderDesc.profile = L"cs_5_0"; - } else { - shaderDesc.sourceLanguage = ShaderLanguage::GLSL; - static const char* cs = R"( - #version 460 - layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; - void main() { - } - )"; - shaderDesc.source.assign(cs, cs + strlen(cs)); - } + ShaderCompileDesc shaderDesc = MakeComputeShaderDesc(GetBackendType()); RHIShader* computeShader = GetDevice()->CreateShader(shaderDesc); if (computeShader != nullptr) { @@ -126,21 +100,7 @@ TEST_P(RHITestFixture, PipelineState_HasComputeShader_ReturnsTrue) { EXPECT_FALSE(pso->HasComputeShader()); - ShaderCompileDesc shaderDesc = {}; - if (GetBackendType() == RHIType::D3D12) { - shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; - shaderDesc.entryPoint = L"MainCS"; - shaderDesc.profile = L"cs_5_0"; - } else { - shaderDesc.sourceLanguage = ShaderLanguage::GLSL; - static const char* cs = R"( - #version 460 - layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; - void main() { - } - )"; - shaderDesc.source.assign(cs, cs + strlen(cs)); - } + ShaderCompileDesc shaderDesc = MakeComputeShaderDesc(GetBackendType()); RHIShader* computeShader = GetDevice()->CreateShader(shaderDesc); if (computeShader != nullptr) { @@ -163,21 +123,7 @@ TEST_P(RHITestFixture, PipelineState_GetType_Compute) { EXPECT_EQ(pso->GetType(), PipelineType::Graphics); - ShaderCompileDesc shaderDesc = {}; - if (GetBackendType() == RHIType::D3D12) { - shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; - shaderDesc.entryPoint = L"MainCS"; - shaderDesc.profile = L"cs_5_0"; - } else { - shaderDesc.sourceLanguage = ShaderLanguage::GLSL; - static const char* cs = R"( - #version 460 - layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; - void main() { - } - )"; - shaderDesc.source.assign(cs, cs + strlen(cs)); - } + ShaderCompileDesc shaderDesc = MakeComputeShaderDesc(GetBackendType()); RHIShader* computeShader = GetDevice()->CreateShader(shaderDesc); if (computeShader != nullptr) { @@ -198,21 +144,7 @@ TEST_P(RHITestFixture, PipelineState_EnsureValid_Compute) { return; } - ShaderCompileDesc shaderDesc = {}; - if (GetBackendType() == RHIType::D3D12) { - shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; - shaderDesc.entryPoint = L"MainCS"; - shaderDesc.profile = L"cs_5_0"; - } else { - shaderDesc.sourceLanguage = ShaderLanguage::GLSL; - static const char* cs = R"( - #version 460 - layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; - void main() { - } - )"; - shaderDesc.source.assign(cs, cs + strlen(cs)); - } + ShaderCompileDesc shaderDesc = MakeComputeShaderDesc(GetBackendType()); RHIShader* computeShader = GetDevice()->CreateShader(shaderDesc); if (computeShader != nullptr) { @@ -243,21 +175,7 @@ TEST_P(RHITestFixture, CommandList_Dispatch_Basic) { RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc); if (pso != nullptr) { - ShaderCompileDesc shaderDesc = {}; - if (GetBackendType() == RHIType::D3D12) { - shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; - shaderDesc.entryPoint = L"MainCS"; - shaderDesc.profile = L"cs_5_0"; - } else { - shaderDesc.sourceLanguage = ShaderLanguage::GLSL; - static const char* cs = R"( - #version 460 - layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; - void main() { - } - )"; - shaderDesc.source.assign(cs, cs + strlen(cs)); - } + ShaderCompileDesc shaderDesc = MakeComputeShaderDesc(GetBackendType()); RHIShader* computeShader = GetDevice()->CreateShader(shaderDesc); if (computeShader != nullptr) {