From 6328ac8d37d1e1139338e57c6a19afd95f1c290f Mon Sep 17 00:00:00 2001 From: ssdfasd <2156608475@qq.com> Date: Wed, 25 Mar 2026 01:05:03 +0800 Subject: [PATCH] RHI: Add Compute Pipeline abstraction with D3D12 and OpenGL support - Add SetComputeShader/GetComputeShader/HasComputeShader to RHIPipelineState - Add m_computePipelineState for D3D12 compute PSO - Add m_computeProgram/m_computeShader for OpenGL - Fix OpenGLCommandList::DispatchCompute bug (was ignoring x,y,z params) - Fix OpenGLShader::GetID usage in OpenGLPipelineState - Mark Priority 8 as completed in RHI_Design_Issues.md --- RHI_Design_Issues.md | 2 +- .../XCEngine/RHI/D3D12/D3D12PipelineState.h | 11 ++++++- .../XCEngine/RHI/OpenGL/OpenGLCommandList.h | 2 +- .../XCEngine/RHI/OpenGL/OpenGLPipelineState.h | 7 ++++- .../include/XCEngine/RHI/RHIPipelineState.h | 5 +++ engine/src/RHI/D3D12/D3D12PipelineState.cpp | 31 +++++++++++++++++++ engine/src/RHI/OpenGL/OpenGLCommandList.cpp | 7 ++--- engine/src/RHI/OpenGL/OpenGLPipelineState.cpp | 17 +++++++++- 8 files changed, 72 insertions(+), 10 deletions(-) diff --git a/RHI_Design_Issues.md b/RHI_Design_Issues.md index 8623dc85..3749ebda 100644 --- a/RHI_Design_Issues.md +++ b/RHI_Design_Issues.md @@ -563,7 +563,7 @@ class RHITexture : public RHIResource { ... }; | 5 | TransitionBarrier 针对 View 而非 Resource | 🟡 中 | 中 | ✅ 已完成 | | 6 | SetGlobal* 空操作 | 🟡 中 | 低 | ✅ 已完成 | | 7 | OpenGL 特有方法暴露 | 🟡 中 | 高 | ❌ 未完成 | -| 8 | 缺少 Compute Pipeline 抽象 | 🟡 中 | 中 | ❌ 未完成 | +| 8 | 缺少 Compute Pipeline 抽象 | 🟡 中 | 中 | ✅ 已完成 | --- diff --git a/engine/include/XCEngine/RHI/D3D12/D3D12PipelineState.h b/engine/include/XCEngine/RHI/D3D12/D3D12PipelineState.h index ae3b7f38..bb50b155 100644 --- a/engine/include/XCEngine/RHI/D3D12/D3D12PipelineState.h +++ b/engine/include/XCEngine/RHI/D3D12/D3D12PipelineState.h @@ -31,6 +31,7 @@ public: void SetTopology(uint32_t topologyType) override; void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override; void SetSampleCount(uint32_t count) override; + void SetComputeShader(RHIShader* shader) override; // State query const RasterizerDesc& GetRasterizerState() const override { return m_rasterizerDesc; } @@ -38,6 +39,8 @@ public: const DepthStencilStateDesc& GetDepthStencilState() const override { return m_depthStencilDesc; } const InputLayoutDesc& GetInputLayout() const override { return m_inputLayoutDesc; } PipelineStateHash GetHash() const override; + RHIShader* GetComputeShader() const override { return m_computeShader; } + bool HasComputeShader() const override { return m_csBytecode.pShaderBytecode != nullptr && m_csBytecode.BytecodeLength > 0; } // Finalization bool IsFinalized() const override { return m_finalized; } @@ -45,12 +48,14 @@ public: // Shader Bytecode (set by CommandList when binding) void SetShaderBytecodes(const D3D12_SHADER_BYTECODE& vs, const D3D12_SHADER_BYTECODE& ps, const D3D12_SHADER_BYTECODE& gs = {}); + void SetComputeShaderBytecodes(const D3D12_SHADER_BYTECODE& cs); // Lifecycle void Shutdown() override; ID3D12PipelineState* GetPipelineState() const { return m_pipelineState.Get(); } + ID3D12PipelineState* GetComputePipelineState() const { return m_computePipelineState.Get(); } void* GetNativeHandle() override { return m_pipelineState.Get(); } - PipelineType GetType() const override { return PipelineType::Graphics; } + PipelineType GetType() const override { return HasComputeShader() ? PipelineType::Compute : PipelineType::Graphics; } void Bind() override; void Unbind() override; @@ -71,6 +76,7 @@ public: private: bool CreateD3D12PSO(); + bool CreateD3D12ComputePSO(); ID3D12Device* m_device; bool m_finalized = false; @@ -91,10 +97,13 @@ private: D3D12_SHADER_BYTECODE m_vsBytecode = {}; D3D12_SHADER_BYTECODE m_psBytecode = {}; D3D12_SHADER_BYTECODE m_gsBytecode = {}; + D3D12_SHADER_BYTECODE m_csBytecode = {}; + class RHIShader* m_computeShader = nullptr; ID3D12RootSignature* m_rootSignature = nullptr; // D3D12 resources ComPtr m_pipelineState; + ComPtr m_computePipelineState; std::vector m_inputElements; }; diff --git a/engine/include/XCEngine/RHI/OpenGL/OpenGLCommandList.h b/engine/include/XCEngine/RHI/OpenGL/OpenGLCommandList.h index fcadb600..8fb1449b 100644 --- a/engine/include/XCEngine/RHI/OpenGL/OpenGLCommandList.h +++ b/engine/include/XCEngine/RHI/OpenGL/OpenGLCommandList.h @@ -123,7 +123,7 @@ public: void Dispatch(uint32_t x, uint32_t y, uint32_t z) override; void DispatchIndirect(unsigned int buffer, size_t offset); - void DispatchCompute(unsigned int x, unsigned int y, unsigned int z, unsigned int groupX, unsigned int groupY, unsigned int groupZ); + void DispatchCompute(unsigned int x, unsigned int y, unsigned int z); void MemoryBarrier(unsigned int barriers); void TextureBarrier(); diff --git a/engine/include/XCEngine/RHI/OpenGL/OpenGLPipelineState.h b/engine/include/XCEngine/RHI/OpenGL/OpenGLPipelineState.h index 4d04dd4c..d6d8164d 100644 --- a/engine/include/XCEngine/RHI/OpenGL/OpenGLPipelineState.h +++ b/engine/include/XCEngine/RHI/OpenGL/OpenGLPipelineState.h @@ -79,6 +79,7 @@ public: void SetTopology(uint32_t topologyType) override; void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override; void SetSampleCount(uint32_t count) override; + void SetComputeShader(RHIShader* shader) override; // State query const RasterizerDesc& GetRasterizerState() const override { return m_rasterizerDesc; } @@ -86,6 +87,8 @@ public: const DepthStencilStateDesc& GetDepthStencilState() const override { return m_depthStencilDesc; } const InputLayoutDesc& GetInputLayout() const override { return m_inputLayoutDesc; } PipelineStateHash GetHash() const override; + RHIShader* GetComputeShader() const override { return m_computeShader; } + bool HasComputeShader() const override { return m_computeProgram != 0; } // Finalization (OpenGL doesn't need it) bool IsFinalized() const override { return true; } @@ -94,7 +97,7 @@ public: // Lifecycle void Shutdown() override; void* GetNativeHandle() override { return reinterpret_cast(static_cast(m_program)); } - PipelineType GetType() const override { return PipelineType::Graphics; } + PipelineType GetType() const override { return HasComputeShader() ? PipelineType::Compute : PipelineType::Graphics; } void Bind() override; void Unbind() override; @@ -130,6 +133,8 @@ private: DepthStencilStateDesc m_depthStencilDesc; uint32_t m_topologyType = 0; unsigned int m_program = 0; + unsigned int m_computeProgram = 0; + class RHIShader* m_computeShader = nullptr; bool m_programAttached = false; // OpenGL specific state diff --git a/engine/include/XCEngine/RHI/RHIPipelineState.h b/engine/include/XCEngine/RHI/RHIPipelineState.h index b1051689..22887083 100644 --- a/engine/include/XCEngine/RHI/RHIPipelineState.h +++ b/engine/include/XCEngine/RHI/RHIPipelineState.h @@ -6,6 +6,8 @@ namespace XCEngine { namespace RHI { +class RHIShader; + class RHIPipelineState { public: virtual ~RHIPipelineState() = default; @@ -18,6 +20,7 @@ public: virtual void SetTopology(uint32_t topologyType) = 0; // PrimitiveTopologyType virtual void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) = 0; virtual void SetSampleCount(uint32_t count) = 0; + virtual void SetComputeShader(RHIShader* shader) = 0; // State query virtual const RasterizerDesc& GetRasterizerState() const = 0; @@ -25,6 +28,8 @@ public: virtual const DepthStencilStateDesc& GetDepthStencilState() const = 0; virtual const InputLayoutDesc& GetInputLayout() const = 0; virtual PipelineStateHash GetHash() const = 0; + virtual RHIShader* GetComputeShader() const = 0; + virtual bool HasComputeShader() const = 0; // Finalization (D3D12/Vulkan creates real PSO) virtual bool IsFinalized() const = 0; diff --git a/engine/src/RHI/D3D12/D3D12PipelineState.cpp b/engine/src/RHI/D3D12/D3D12PipelineState.cpp index f418e141..41dd04d1 100644 --- a/engine/src/RHI/D3D12/D3D12PipelineState.cpp +++ b/engine/src/RHI/D3D12/D3D12PipelineState.cpp @@ -1,4 +1,5 @@ #include "XCEngine/RHI/D3D12/D3D12PipelineState.h" +#include "XCEngine/RHI/D3D12/D3D12Shader.h" #include namespace XCEngine { @@ -121,8 +122,19 @@ void D3D12PipelineState::SetShaderBytecodes(const D3D12_SHADER_BYTECODE& vs, con m_gsBytecode = gs; } +void D3D12PipelineState::SetComputeShader(RHIShader* shader) { + m_computeShader = shader; +} + +void D3D12PipelineState::SetComputeShaderBytecodes(const D3D12_SHADER_BYTECODE& cs) { + m_csBytecode = cs; +} + bool D3D12PipelineState::Finalize() { if (m_finalized) return true; + if (HasComputeShader()) { + return CreateD3D12ComputePSO(); + } return CreateD3D12PSO(); } @@ -190,8 +202,27 @@ bool D3D12PipelineState::CreateD3D12PSO() { return true; } +bool D3D12PipelineState::CreateD3D12ComputePSO() { + if (!m_csBytecode.pShaderBytecode || !m_csBytecode.BytecodeLength) { + return false; + } + + D3D12_COMPUTE_PIPELINE_STATE_DESC desc = {}; + desc.pRootSignature = m_rootSignature; + desc.CS = m_csBytecode; + + HRESULT hr = m_device->CreateComputePipelineState(&desc, IID_PPV_ARGS(&m_computePipelineState)); + if (FAILED(hr)) { + return false; + } + + m_finalized = true; + return true; +} + void D3D12PipelineState::Shutdown() { m_pipelineState.Reset(); + m_computePipelineState.Reset(); m_finalized = false; } diff --git a/engine/src/RHI/OpenGL/OpenGLCommandList.cpp b/engine/src/RHI/OpenGL/OpenGLCommandList.cpp index 22488bca..5e905de8 100644 --- a/engine/src/RHI/OpenGL/OpenGLCommandList.cpp +++ b/engine/src/RHI/OpenGL/OpenGLCommandList.cpp @@ -271,11 +271,8 @@ void OpenGLCommandList::DispatchIndirect(unsigned int buffer, size_t offset) { glBindBuffer(GL_DISPATCH_INDIRECT_BUFFER, 0); } -void OpenGLCommandList::DispatchCompute(unsigned int x, unsigned int y, unsigned int z, unsigned int groupX, unsigned int groupY, unsigned int groupZ) { - glDispatchCompute(groupX, groupY, groupZ); - (void)x; - (void)y; - (void)z; +void OpenGLCommandList::DispatchCompute(unsigned int x, unsigned int y, unsigned int z) { + glDispatchCompute(x, y, z); } void OpenGLCommandList::MemoryBarrier(unsigned int barriers) { diff --git a/engine/src/RHI/OpenGL/OpenGLPipelineState.cpp b/engine/src/RHI/OpenGL/OpenGLPipelineState.cpp index a9868c93..3a487b71 100644 --- a/engine/src/RHI/OpenGL/OpenGLPipelineState.cpp +++ b/engine/src/RHI/OpenGL/OpenGLPipelineState.cpp @@ -1,4 +1,5 @@ #include "XCEngine/RHI/OpenGL/OpenGLPipelineState.h" +#include "XCEngine/RHI/OpenGL/OpenGLShader.h" #include namespace XCEngine { @@ -37,6 +38,16 @@ void OpenGLPipelineState::SetRenderTargetFormats(uint32_t count, const uint32_t* void OpenGLPipelineState::SetSampleCount(uint32_t count) { } +void OpenGLPipelineState::SetComputeShader(RHIShader* shader) { + m_computeShader = shader; + if (shader) { + OpenGLShader* glShader = static_cast(shader); + m_computeProgram = glShader->GetID(); + } else { + m_computeProgram = 0; + } +} + PipelineStateHash OpenGLPipelineState::GetHash() const { PipelineStateHash hash = {}; return hash; @@ -44,11 +55,15 @@ PipelineStateHash OpenGLPipelineState::GetHash() const { void OpenGLPipelineState::Shutdown() { m_program = 0; + m_computeProgram = 0; + m_computeShader = nullptr; m_programAttached = false; } void OpenGLPipelineState::Bind() { - if (m_programAttached) { + if (HasComputeShader()) { + glUseProgram(m_computeProgram); + } else if (m_programAttached) { glUseProgram(m_program); } Apply();