diff --git a/engine/include/XCEngine/RHI/D3D12/D3D12PipelineState.h b/engine/include/XCEngine/RHI/D3D12/D3D12PipelineState.h index b8596b1e..63278813 100644 --- a/engine/include/XCEngine/RHI/D3D12/D3D12PipelineState.h +++ b/engine/include/XCEngine/RHI/D3D12/D3D12PipelineState.h @@ -32,6 +32,7 @@ public: void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override; void SetSampleCount(uint32_t count) override; void SetComputeShader(RHIShader* shader) override; + void SetRootSignature(ID3D12RootSignature* rootSignature); // State query const RasterizerDesc& GetRasterizerState() const override { return m_rasterizerDesc; } @@ -54,6 +55,7 @@ public: void Shutdown() override; ID3D12PipelineState* GetPipelineState() const { return m_pipelineState.Get(); } ID3D12PipelineState* GetComputePipelineState() const { return m_computePipelineState.Get(); } + ID3D12RootSignature* GetRootSignature() const { return m_rootSignature.Get(); } void* GetNativeHandle() override { return m_pipelineState.Get(); } PipelineType GetType() const override { return HasComputeShader() ? PipelineType::Compute : PipelineType::Graphics; } @@ -99,7 +101,7 @@ private: D3D12_SHADER_BYTECODE m_gsBytecode = {}; D3D12_SHADER_BYTECODE m_csBytecode = {}; class RHIShader* m_computeShader = nullptr; - ID3D12RootSignature* m_rootSignature = nullptr; + ComPtr m_rootSignature; // D3D12 resources ComPtr m_pipelineState; diff --git a/engine/include/XCEngine/RHI/OpenGL/OpenGLPipelineState.h b/engine/include/XCEngine/RHI/OpenGL/OpenGLPipelineState.h index 3455f1cc..3fcee3a8 100644 --- a/engine/include/XCEngine/RHI/OpenGL/OpenGLPipelineState.h +++ b/engine/include/XCEngine/RHI/OpenGL/OpenGLPipelineState.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include "../RHIPipelineState.h" @@ -121,6 +122,7 @@ public: void Clear(unsigned int buffers); void AttachShader(unsigned int program); void DetachShader(); + void SetOwnedGraphicsShader(std::unique_ptr shader); const OpenGLDepthStencilState& GetOpenGLDepthStencilState() const; const OpenGLBlendState& GetOpenGLBlendState() const; @@ -135,6 +137,7 @@ private: unsigned int m_program = 0; unsigned int m_computeProgram = 0; class RHIShader* m_computeShader = nullptr; + std::unique_ptr m_graphicsShader; bool m_programAttached = false; // OpenGL specific state diff --git a/engine/include/XCEngine/RHI/RHITypes.h b/engine/include/XCEngine/RHI/RHITypes.h index 0cd2f930..d5e6adc7 100644 --- a/engine/include/XCEngine/RHI/RHITypes.h +++ b/engine/include/XCEngine/RHI/RHITypes.h @@ -245,17 +245,17 @@ struct RootSignatureDesc { // ========== Pipeline State Structures (Unity SRP style) ========== struct StencilOpDesc { - uint32_t failOp = 0; // StencilOp - uint32_t passOp = 0; // StencilOp - uint32_t depthFailOp = 0; // StencilOp - uint32_t func = 0; // ComparisonFunc + uint32_t failOp = static_cast(StencilOp::Keep); + uint32_t passOp = static_cast(StencilOp::Keep); + uint32_t depthFailOp = static_cast(StencilOp::Keep); + uint32_t func = static_cast(ComparisonFunc::Always); }; struct DepthStencilStateDesc { bool depthTestEnable = true; bool depthWriteEnable = true; bool depthBoundsEnable = false; - uint32_t depthFunc = 0; // ComparisonFunc + uint32_t depthFunc = static_cast(ComparisonFunc::Less); bool stencilEnable = false; uint8_t stencilReadMask = 0xFF; uint8_t stencilWriteMask = 0xFF; @@ -265,20 +265,20 @@ struct DepthStencilStateDesc { struct BlendDesc { bool blendEnable = false; - uint32_t srcBlend = 0; // BlendFactor - uint32_t dstBlend = 0; // BlendFactor - uint32_t srcBlendAlpha = 0; // BlendFactor - uint32_t dstBlendAlpha = 0; // BlendFactor - uint32_t blendOp = 0; // BlendOp - uint32_t blendOpAlpha = 0; // BlendOp + uint32_t srcBlend = static_cast(BlendFactor::One); + uint32_t dstBlend = static_cast(BlendFactor::Zero); + uint32_t srcBlendAlpha = static_cast(BlendFactor::One); + uint32_t dstBlendAlpha = static_cast(BlendFactor::Zero); + uint32_t blendOp = static_cast(BlendOp::Add); + uint32_t blendOpAlpha = static_cast(BlendOp::Add); uint8_t colorWriteMask = 0xF; float blendFactor[4] = {1.0f, 1.0f, 1.0f, 1.0f}; }; struct RasterizerDesc { - uint32_t fillMode = 0; // FillMode (default: Solid) - uint32_t cullMode = 0; // CullMode (default: Back) - uint32_t frontFace = 0; // FrontFace (default: CounterClockwise) + uint32_t fillMode = static_cast(FillMode::Solid); + uint32_t cullMode = static_cast(CullMode::Back); + uint32_t frontFace = static_cast(FrontFace::CounterClockwise); bool depthClipEnable = true; bool scissorTestEnable = false; bool multisampleEnable = false; @@ -307,12 +307,15 @@ struct PipelineStateHash { }; struct GraphicsPipelineDesc { + ShaderCompileDesc vertexShader; + ShaderCompileDesc fragmentShader; + ShaderCompileDesc geometryShader; InputLayoutDesc inputLayout; RasterizerDesc rasterizerState; BlendDesc blendState; DepthStencilStateDesc depthStencilState; - uint32_t topologyType = 0; // PrimitiveTopologyType + uint32_t topologyType = static_cast(PrimitiveTopologyType::Triangle); uint32_t renderTargetCount = 1; uint32_t renderTargetFormats[8] = { 0 }; // Format uint32_t depthStencilFormat = 0; // Format diff --git a/engine/src/RHI/D3D12/D3D12CommandList.cpp b/engine/src/RHI/D3D12/D3D12CommandList.cpp index 882bcecf..2ff67e2f 100644 --- a/engine/src/RHI/D3D12/D3D12CommandList.cpp +++ b/engine/src/RHI/D3D12/D3D12CommandList.cpp @@ -127,6 +127,10 @@ void D3D12CommandList::SetPipelineState(RHIPipelineState* pso) { D3D12PipelineState* d3d12Pso = static_cast(pso); SetPipelineStateInternal(d3d12Pso->GetComputePipelineState()); } else { + D3D12PipelineState* d3d12Pso = static_cast(pso); + if (d3d12Pso->GetRootSignature() != nullptr) { + SetRootSignature(d3d12Pso->GetRootSignature()); + } SetPipelineStateInternal(static_cast(pso->GetNativeHandle())); } } diff --git a/engine/src/RHI/D3D12/D3D12Device.cpp b/engine/src/RHI/D3D12/D3D12Device.cpp index 2b8a06d8..d9c53740 100644 --- a/engine/src/RHI/D3D12/D3D12Device.cpp +++ b/engine/src/RHI/D3D12/D3D12Device.cpp @@ -39,6 +39,27 @@ std::string NarrowAscii(const std::wstring& value) { return result; } +bool HasShaderPayload(const ShaderCompileDesc& desc) { + return !desc.source.empty() || !desc.fileName.empty(); +} + +bool CompileD3D12Shader(const ShaderCompileDesc& desc, D3D12Shader& shader) { + const std::string entryPoint = NarrowAscii(desc.entryPoint); + const std::string profile = NarrowAscii(desc.profile); + const char* entryPointPtr = entryPoint.empty() ? nullptr : entryPoint.c_str(); + const char* profilePtr = profile.empty() ? nullptr : profile.c_str(); + + if (!desc.source.empty()) { + return shader.Compile(desc.source.data(), desc.source.size(), entryPointPtr, profilePtr); + } + + if (!desc.fileName.empty()) { + return shader.CompileFromFile(desc.fileName.c_str(), entryPointPtr, profilePtr); + } + + return false; +} + } // namespace D3D12Device::D3D12Device() @@ -489,6 +510,55 @@ RHIPipelineState* D3D12Device::CreatePipelineState(const GraphicsPipelineDesc& d pso->SetTopology(desc.topologyType); pso->SetRenderTargetFormats(desc.renderTargetCount, desc.renderTargetFormats, desc.depthStencilFormat); pso->SetSampleCount(desc.sampleCount); + + const bool hasVertexShader = HasShaderPayload(desc.vertexShader); + const bool hasFragmentShader = HasShaderPayload(desc.fragmentShader); + const bool hasGeometryShader = HasShaderPayload(desc.geometryShader); + if (!hasVertexShader && !hasFragmentShader && !hasGeometryShader) { + return pso; + } + + if (!hasVertexShader || !hasFragmentShader) { + delete pso; + return nullptr; + } + + auto* rootSignature = CreateRootSignature({}); + if (rootSignature == nullptr) { + delete pso; + return nullptr; + } + + pso->SetRootSignature(rootSignature->GetRootSignature()); + + D3D12Shader vertexShader; + D3D12Shader fragmentShader; + D3D12Shader geometryShader; + const bool vertexCompiled = CompileD3D12Shader(desc.vertexShader, vertexShader); + const bool fragmentCompiled = CompileD3D12Shader(desc.fragmentShader, fragmentShader); + const bool geometryCompiled = !hasGeometryShader || CompileD3D12Shader(desc.geometryShader, geometryShader); + + if (!vertexCompiled || !fragmentCompiled || !geometryCompiled) { + rootSignature->Shutdown(); + delete rootSignature; + delete pso; + return nullptr; + } + + pso->SetShaderBytecodes( + vertexShader.GetD3D12Bytecode(), + fragmentShader.GetD3D12Bytecode(), + hasGeometryShader ? geometryShader.GetD3D12Bytecode() : D3D12_SHADER_BYTECODE{}); + pso->EnsureValid(); + + rootSignature->Shutdown(); + delete rootSignature; + + if (!pso->IsValid()) { + delete pso; + return nullptr; + } + return pso; } diff --git a/engine/src/RHI/D3D12/D3D12PipelineState.cpp b/engine/src/RHI/D3D12/D3D12PipelineState.cpp index eeef8382..e1fcc791 100644 --- a/engine/src/RHI/D3D12/D3D12PipelineState.cpp +++ b/engine/src/RHI/D3D12/D3D12PipelineState.cpp @@ -127,6 +127,10 @@ void D3D12PipelineState::SetComputeShader(RHIShader* shader) { m_computeShader = shader; } +void D3D12PipelineState::SetRootSignature(ID3D12RootSignature* rootSignature) { + m_rootSignature = rootSignature; +} + void D3D12PipelineState::SetComputeShaderBytecodes(const D3D12_SHADER_BYTECODE& cs) { m_csBytecode = cs; } @@ -146,7 +150,7 @@ bool D3D12PipelineState::CreateD3D12PSO() { } D3D12_GRAPHICS_PIPELINE_STATE_DESC desc = {}; - desc.pRootSignature = m_rootSignature; + desc.pRootSignature = m_rootSignature.Get(); desc.VS = m_vsBytecode; desc.PS = m_psBytecode; desc.GS = m_gsBytecode; @@ -154,36 +158,36 @@ bool D3D12PipelineState::CreateD3D12PSO() { desc.InputLayout.NumElements = static_cast(m_inputElements.size()); desc.InputLayout.pInputElementDescs = m_inputElements.data(); - desc.RasterizerState.FillMode = static_cast(m_rasterizerDesc.fillMode); - desc.RasterizerState.CullMode = static_cast(m_rasterizerDesc.cullMode); - desc.RasterizerState.FrontCounterClockwise = (m_rasterizerDesc.frontFace != 0); + desc.RasterizerState.FillMode = ToD3D12(static_cast(m_rasterizerDesc.fillMode)); + desc.RasterizerState.CullMode = ToD3D12(static_cast(m_rasterizerDesc.cullMode)); + desc.RasterizerState.FrontCounterClockwise = static_cast(m_rasterizerDesc.frontFace) == FrontFace::CounterClockwise; desc.RasterizerState.DepthClipEnable = m_rasterizerDesc.depthClipEnable; desc.RasterizerState.MultisampleEnable = m_rasterizerDesc.multisampleEnable; desc.RasterizerState.AntialiasedLineEnable = m_rasterizerDesc.antialiasedLineEnable; desc.BlendState.RenderTarget[0].BlendEnable = m_blendDesc.blendEnable; - desc.BlendState.RenderTarget[0].SrcBlend = static_cast(m_blendDesc.srcBlend); - desc.BlendState.RenderTarget[0].DestBlend = static_cast(m_blendDesc.dstBlend); - desc.BlendState.RenderTarget[0].BlendOp = static_cast(m_blendDesc.blendOp); - desc.BlendState.RenderTarget[0].SrcBlendAlpha = static_cast(m_blendDesc.srcBlendAlpha); - desc.BlendState.RenderTarget[0].DestBlendAlpha = static_cast(m_blendDesc.dstBlendAlpha); - desc.BlendState.RenderTarget[0].BlendOpAlpha = static_cast(m_blendDesc.blendOpAlpha); + desc.BlendState.RenderTarget[0].SrcBlend = ToD3D12(static_cast(m_blendDesc.srcBlend)); + desc.BlendState.RenderTarget[0].DestBlend = ToD3D12(static_cast(m_blendDesc.dstBlend)); + desc.BlendState.RenderTarget[0].BlendOp = ToD3D12(static_cast(m_blendDesc.blendOp)); + desc.BlendState.RenderTarget[0].SrcBlendAlpha = ToD3D12(static_cast(m_blendDesc.srcBlendAlpha)); + desc.BlendState.RenderTarget[0].DestBlendAlpha = ToD3D12(static_cast(m_blendDesc.dstBlendAlpha)); + desc.BlendState.RenderTarget[0].BlendOpAlpha = ToD3D12(static_cast(m_blendDesc.blendOpAlpha)); desc.BlendState.RenderTarget[0].RenderTargetWriteMask = m_blendDesc.colorWriteMask; desc.DepthStencilState.DepthEnable = m_depthStencilDesc.depthTestEnable; desc.DepthStencilState.DepthWriteMask = m_depthStencilDesc.depthWriteEnable ? D3D12_DEPTH_WRITE_MASK_ALL : D3D12_DEPTH_WRITE_MASK_ZERO; - desc.DepthStencilState.DepthFunc = static_cast(m_depthStencilDesc.depthFunc); + desc.DepthStencilState.DepthFunc = ToD3D12(static_cast(m_depthStencilDesc.depthFunc)); desc.DepthStencilState.StencilEnable = m_depthStencilDesc.stencilEnable; desc.DepthStencilState.StencilReadMask = m_depthStencilDesc.stencilReadMask; desc.DepthStencilState.StencilWriteMask = m_depthStencilDesc.stencilWriteMask; - desc.DepthStencilState.FrontFace.StencilFailOp = static_cast(m_depthStencilDesc.front.failOp); - desc.DepthStencilState.FrontFace.StencilPassOp = static_cast(m_depthStencilDesc.front.passOp); - desc.DepthStencilState.FrontFace.StencilDepthFailOp = static_cast(m_depthStencilDesc.front.depthFailOp); - desc.DepthStencilState.FrontFace.StencilFunc = static_cast(m_depthStencilDesc.front.func); - desc.DepthStencilState.BackFace.StencilFailOp = static_cast(m_depthStencilDesc.back.failOp); - desc.DepthStencilState.BackFace.StencilPassOp = static_cast(m_depthStencilDesc.back.passOp); - desc.DepthStencilState.BackFace.StencilDepthFailOp = static_cast(m_depthStencilDesc.back.depthFailOp); - desc.DepthStencilState.BackFace.StencilFunc = static_cast(m_depthStencilDesc.back.func); + desc.DepthStencilState.FrontFace.StencilFailOp = ToD3D12(static_cast(m_depthStencilDesc.front.failOp)); + desc.DepthStencilState.FrontFace.StencilPassOp = ToD3D12(static_cast(m_depthStencilDesc.front.passOp)); + desc.DepthStencilState.FrontFace.StencilDepthFailOp = ToD3D12(static_cast(m_depthStencilDesc.front.depthFailOp)); + desc.DepthStencilState.FrontFace.StencilFunc = ToD3D12(static_cast(m_depthStencilDesc.front.func)); + desc.DepthStencilState.BackFace.StencilFailOp = ToD3D12(static_cast(m_depthStencilDesc.back.failOp)); + desc.DepthStencilState.BackFace.StencilPassOp = ToD3D12(static_cast(m_depthStencilDesc.back.passOp)); + desc.DepthStencilState.BackFace.StencilDepthFailOp = ToD3D12(static_cast(m_depthStencilDesc.back.depthFailOp)); + desc.DepthStencilState.BackFace.StencilFunc = ToD3D12(static_cast(m_depthStencilDesc.back.func)); desc.NumRenderTargets = m_renderTargetCount; for (uint32_t i = 0; i < m_renderTargetCount && i < 8; ++i) { @@ -210,7 +214,7 @@ bool D3D12PipelineState::CreateD3D12ComputePSO() { } D3D12_COMPUTE_PIPELINE_STATE_DESC desc = {}; - desc.pRootSignature = m_rootSignature; + desc.pRootSignature = m_rootSignature.Get(); desc.CS = m_csBytecode; HRESULT hr = m_device->CreateComputePipelineState(&desc, IID_PPV_ARGS(&m_computePipelineState)); @@ -225,6 +229,7 @@ bool D3D12PipelineState::CreateD3D12ComputePSO() { void D3D12PipelineState::Shutdown() { m_pipelineState.Reset(); m_computePipelineState.Reset(); + m_rootSignature.Reset(); m_finalized = false; } @@ -260,4 +265,4 @@ D3D12_INPUT_ELEMENT_DESC D3D12PipelineState::CreateInputElement( } } // namespace RHI -} // namespace XCEngine \ No newline at end of file +} // namespace XCEngine diff --git a/engine/src/RHI/OpenGL/OpenGLCommandList.cpp b/engine/src/RHI/OpenGL/OpenGLCommandList.cpp index 7e506c1a..149fad12 100644 --- a/engine/src/RHI/OpenGL/OpenGLCommandList.cpp +++ b/engine/src/RHI/OpenGL/OpenGLCommandList.cpp @@ -598,7 +598,7 @@ void OpenGLCommandList::ClearDepthStencil(RHIResourceView* depthStencil, float d void OpenGLCommandList::SetPipelineState(RHIPipelineState* pipelineState) { if (pipelineState) { - UseShader(reinterpret_cast(pipelineState->GetNativeHandle())); + pipelineState->Bind(); } } diff --git a/engine/src/RHI/OpenGL/OpenGLDevice.cpp b/engine/src/RHI/OpenGL/OpenGLDevice.cpp index 36a9902e..f6a8a7ba 100644 --- a/engine/src/RHI/OpenGL/OpenGLDevice.cpp +++ b/engine/src/RHI/OpenGL/OpenGLDevice.cpp @@ -41,6 +41,27 @@ static PFNWGLCREATECONTEXTATTRIBSARBPROC wglCreateContextAttribsARB = nullptr; namespace XCEngine { namespace RHI { +namespace { + +std::string NarrowAscii(const std::wstring& value) { + std::string result; + result.reserve(value.size()); + for (wchar_t ch : value) { + result.push_back(static_cast(ch)); + } + return result; +} + +bool HasShaderPayload(const ShaderCompileDesc& desc) { + return !desc.source.empty() || !desc.fileName.empty(); +} + +std::string SourceToString(const ShaderCompileDesc& desc) { + return std::string(reinterpret_cast(desc.source.data()), desc.source.size()); +} + +} // namespace + OpenGLDevice::OpenGLDevice() : m_hwnd(nullptr) , m_hdc(nullptr) @@ -405,6 +426,52 @@ RHIPipelineState* OpenGLDevice::CreatePipelineState(const GraphicsPipelineDesc& pso->SetTopology(desc.topologyType); pso->SetRenderTargetFormats(desc.renderTargetCount, desc.renderTargetFormats, desc.depthStencilFormat); pso->SetSampleCount(desc.sampleCount); + + const bool hasVertexShader = HasShaderPayload(desc.vertexShader); + const bool hasFragmentShader = HasShaderPayload(desc.fragmentShader); + const bool hasGeometryShader = HasShaderPayload(desc.geometryShader); + if (!hasVertexShader && !hasFragmentShader && !hasGeometryShader) { + return pso; + } + + if (!hasVertexShader || !hasFragmentShader) { + delete pso; + return nullptr; + } + + if (!MakeContextCurrent()) { + delete pso; + return nullptr; + } + + auto graphicsShader = std::make_unique(); + bool compiled = false; + if (!desc.vertexShader.source.empty() && !desc.fragmentShader.source.empty()) { + const std::string vertexSource = SourceToString(desc.vertexShader); + const std::string fragmentSource = SourceToString(desc.fragmentShader); + if (hasGeometryShader && !desc.geometryShader.source.empty()) { + const std::string geometrySource = SourceToString(desc.geometryShader); + compiled = graphicsShader->Compile(vertexSource.c_str(), fragmentSource.c_str(), geometrySource.c_str()); + } else { + compiled = graphicsShader->Compile(vertexSource.c_str(), fragmentSource.c_str()); + } + } else if (!desc.vertexShader.fileName.empty() && !desc.fragmentShader.fileName.empty()) { + const std::string vertexPath = NarrowAscii(desc.vertexShader.fileName); + const std::string fragmentPath = NarrowAscii(desc.fragmentShader.fileName); + if (hasGeometryShader && !desc.geometryShader.fileName.empty()) { + const std::string geometryPath = NarrowAscii(desc.geometryShader.fileName); + compiled = graphicsShader->CompileFromFile(vertexPath.c_str(), fragmentPath.c_str(), geometryPath.c_str()); + } else { + compiled = graphicsShader->CompileFromFile(vertexPath.c_str(), fragmentPath.c_str()); + } + } + + if (!compiled) { + delete pso; + return nullptr; + } + + pso->SetOwnedGraphicsShader(std::move(graphicsShader)); return pso; } diff --git a/engine/src/RHI/OpenGL/OpenGLPipelineState.cpp b/engine/src/RHI/OpenGL/OpenGLPipelineState.cpp index 4204304e..2899ad3c 100644 --- a/engine/src/RHI/OpenGL/OpenGLPipelineState.cpp +++ b/engine/src/RHI/OpenGL/OpenGLPipelineState.cpp @@ -1,6 +1,7 @@ #include "XCEngine/RHI/OpenGL/OpenGLPipelineState.h" #include "XCEngine/RHI/OpenGL/OpenGLShader.h" #include "XCEngine/RHI/OpenGL/OpenGLEnums.h" +#include #include namespace XCEngine { @@ -19,14 +20,42 @@ void OpenGLPipelineState::SetInputLayout(const InputLayoutDesc& layout) { void OpenGLPipelineState::SetRasterizerState(const RasterizerDesc& state) { m_rasterizerDesc = state; + m_glRasterizerState.cullFaceEnable = static_cast(state.cullMode) != CullMode::None; + m_glRasterizerState.cullFace = static_cast(state.cullMode); + m_glRasterizerState.frontFace = static_cast(state.frontFace); + m_glRasterizerState.polygonMode = static_cast(state.fillMode); + m_glRasterizerState.depthClipEnable = state.depthClipEnable; + m_glRasterizerState.scissorTestEnable = state.scissorTestEnable; + m_glRasterizerState.multisampleEnable = state.multisampleEnable; + m_glRasterizerState.polygonOffsetFactor = state.slopeScaledDepthBias; + m_glRasterizerState.polygonOffsetUnits = state.depthBiasClamp; } void OpenGLPipelineState::SetBlendState(const BlendDesc& state) { m_blendDesc = state; + m_glBlendState.blendEnable = state.blendEnable; + m_glBlendState.srcBlend = static_cast(state.srcBlend); + m_glBlendState.dstBlend = static_cast(state.dstBlend); + m_glBlendState.srcBlendAlpha = static_cast(state.srcBlendAlpha); + m_glBlendState.dstBlendAlpha = static_cast(state.dstBlendAlpha); + m_glBlendState.blendOp = static_cast(state.blendOp); + m_glBlendState.blendOpAlpha = static_cast(state.blendOpAlpha); + m_glBlendState.colorWriteMask = state.colorWriteMask; + std::memcpy(m_glBlendState.blendFactor, state.blendFactor, sizeof(state.blendFactor)); } void OpenGLPipelineState::SetDepthStencilState(const DepthStencilStateDesc& state) { m_depthStencilDesc = state; + m_glDepthStencilState.depthTestEnable = state.depthTestEnable; + m_glDepthStencilState.depthWriteEnable = state.depthWriteEnable; + m_glDepthStencilState.depthFunc = static_cast(state.depthFunc); + m_glDepthStencilState.stencilEnable = state.stencilEnable; + m_glDepthStencilState.stencilReadMask = state.stencilReadMask; + m_glDepthStencilState.stencilWriteMask = state.stencilWriteMask; + m_glDepthStencilState.stencilFunc = static_cast(state.front.func); + m_glDepthStencilState.stencilFailOp = static_cast(state.front.failOp); + m_glDepthStencilState.stencilDepthFailOp = static_cast(state.front.depthFailOp); + m_glDepthStencilState.stencilDepthPassOp = static_cast(state.front.passOp); } void OpenGLPipelineState::SetTopology(uint32_t topologyType) { @@ -55,6 +84,7 @@ PipelineStateHash OpenGLPipelineState::GetHash() const { } void OpenGLPipelineState::Shutdown() { + m_graphicsShader.reset(); m_program = 0; m_computeProgram = 0; m_computeShader = nullptr; @@ -64,7 +94,7 @@ void OpenGLPipelineState::Shutdown() { void OpenGLPipelineState::Bind() { if (HasComputeShader()) { glUseProgram(m_computeProgram); - } else if (m_programAttached) { + } else if (m_programAttached && m_program != 0) { glUseProgram(m_program); } Apply(); @@ -213,6 +243,15 @@ void OpenGLPipelineState::DetachShader() { glUseProgram(0); } +void OpenGLPipelineState::SetOwnedGraphicsShader(std::unique_ptr shader) { + m_graphicsShader = std::move(shader); + if (m_graphicsShader) { + SetProgram(m_graphicsShader->GetID()); + } else { + DetachShader(); + } +} + const OpenGLDepthStencilState& OpenGLPipelineState::GetOpenGLDepthStencilState() const { return m_glDepthStencilState; } diff --git a/tests/RHI/unit/test_pipeline_state.cpp b/tests/RHI/unit/test_pipeline_state.cpp index 9eb6ad04..7e90fb54 100644 --- a/tests/RHI/unit/test_pipeline_state.cpp +++ b/tests/RHI/unit/test_pipeline_state.cpp @@ -169,4 +169,82 @@ TEST_P(RHITestFixture, PipelineState_GetType) { pso->Shutdown(); delete pso; -} \ No newline at end of file +} + +TEST_P(RHITestFixture, PipelineState_Create_GraphicsShadersFromDesc) { + GraphicsPipelineDesc desc = {}; + desc.topologyType = static_cast(PrimitiveTopologyType::Triangle); + desc.renderTargetFormats[0] = static_cast(Format::R8G8B8A8_UNorm); + desc.depthStencilFormat = static_cast(Format::Unknown); + + InputElementDesc position = {}; + position.semanticName = "POSITION"; + position.semanticIndex = 0; + position.format = static_cast(Format::R32G32B32A32_Float); + position.inputSlot = 0; + position.alignedByteOffset = 0; + desc.inputLayout.elements.push_back(position); + + if (GetBackendType() == RHIType::D3D12) { + const char* hlslSource = R"( +struct VSInput { + float4 position : POSITION; +}; + +struct PSInput { + float4 position : SV_POSITION; +}; + +PSInput MainVS(VSInput input) { + PSInput output; + output.position = input.position; + return output; +} + +float4 MainPS(PSInput input) : SV_TARGET { + return float4(1.0f, 0.0f, 0.0f, 1.0f); +} +)"; + + desc.vertexShader.source.assign(hlslSource, hlslSource + std::strlen(hlslSource)); + desc.vertexShader.entryPoint = L"MainVS"; + desc.vertexShader.profile = L"vs_5_0"; + desc.vertexShader.sourceLanguage = ShaderLanguage::HLSL; + + desc.fragmentShader.source.assign(hlslSource, hlslSource + std::strlen(hlslSource)); + desc.fragmentShader.entryPoint = L"MainPS"; + desc.fragmentShader.profile = L"ps_5_0"; + desc.fragmentShader.sourceLanguage = ShaderLanguage::HLSL; + } else { + const char* vertexSource = R"(#version 430 +layout(location = 0) in vec4 aPosition; + +void main() { + gl_Position = aPosition; +} +)"; + const char* fragmentSource = R"(#version 430 +layout(location = 0) out vec4 fragColor; + +void main() { + fragColor = vec4(1.0, 0.0, 0.0, 1.0); +} +)"; + + desc.vertexShader.source.assign(vertexSource, vertexSource + std::strlen(vertexSource)); + desc.vertexShader.sourceLanguage = ShaderLanguage::GLSL; + desc.vertexShader.profile = L"vs_4_30"; + + desc.fragmentShader.source.assign(fragmentSource, fragmentSource + std::strlen(fragmentSource)); + desc.fragmentShader.sourceLanguage = ShaderLanguage::GLSL; + desc.fragmentShader.profile = L"fs_4_30"; + } + + RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc); + ASSERT_NE(pso, nullptr); + EXPECT_TRUE(pso->IsValid()); + EXPECT_NE(pso->GetNativeHandle(), nullptr); + + pso->Shutdown(); + delete pso; +}