RHI: Add Compute Pipeline abstraction with D3D12 and OpenGL support
- Add SetComputeShader/GetComputeShader/HasComputeShader to RHIPipelineState - Add m_computePipelineState for D3D12 compute PSO - Add m_computeProgram/m_computeShader for OpenGL - Fix OpenGLCommandList::DispatchCompute bug (was ignoring x,y,z params) - Fix OpenGLShader::GetID usage in OpenGLPipelineState - Mark Priority 8 as completed in RHI_Design_Issues.md
This commit is contained in:
@@ -31,6 +31,7 @@ public:
|
||||
void SetTopology(uint32_t topologyType) override;
|
||||
void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override;
|
||||
void SetSampleCount(uint32_t count) override;
|
||||
void SetComputeShader(RHIShader* shader) override;
|
||||
|
||||
// State query
|
||||
const RasterizerDesc& GetRasterizerState() const override { return m_rasterizerDesc; }
|
||||
@@ -38,6 +39,8 @@ public:
|
||||
const DepthStencilStateDesc& GetDepthStencilState() const override { return m_depthStencilDesc; }
|
||||
const InputLayoutDesc& GetInputLayout() const override { return m_inputLayoutDesc; }
|
||||
PipelineStateHash GetHash() const override;
|
||||
RHIShader* GetComputeShader() const override { return m_computeShader; }
|
||||
bool HasComputeShader() const override { return m_csBytecode.pShaderBytecode != nullptr && m_csBytecode.BytecodeLength > 0; }
|
||||
|
||||
// Finalization
|
||||
bool IsFinalized() const override { return m_finalized; }
|
||||
@@ -45,12 +48,14 @@ public:
|
||||
|
||||
// Shader Bytecode (set by CommandList when binding)
|
||||
void SetShaderBytecodes(const D3D12_SHADER_BYTECODE& vs, const D3D12_SHADER_BYTECODE& ps, const D3D12_SHADER_BYTECODE& gs = {});
|
||||
void SetComputeShaderBytecodes(const D3D12_SHADER_BYTECODE& cs);
|
||||
|
||||
// Lifecycle
|
||||
void Shutdown() override;
|
||||
ID3D12PipelineState* GetPipelineState() const { return m_pipelineState.Get(); }
|
||||
ID3D12PipelineState* GetComputePipelineState() const { return m_computePipelineState.Get(); }
|
||||
void* GetNativeHandle() override { return m_pipelineState.Get(); }
|
||||
PipelineType GetType() const override { return PipelineType::Graphics; }
|
||||
PipelineType GetType() const override { return HasComputeShader() ? PipelineType::Compute : PipelineType::Graphics; }
|
||||
|
||||
void Bind() override;
|
||||
void Unbind() override;
|
||||
@@ -71,6 +76,7 @@ public:
|
||||
|
||||
private:
|
||||
bool CreateD3D12PSO();
|
||||
bool CreateD3D12ComputePSO();
|
||||
|
||||
ID3D12Device* m_device;
|
||||
bool m_finalized = false;
|
||||
@@ -91,10 +97,13 @@ private:
|
||||
D3D12_SHADER_BYTECODE m_vsBytecode = {};
|
||||
D3D12_SHADER_BYTECODE m_psBytecode = {};
|
||||
D3D12_SHADER_BYTECODE m_gsBytecode = {};
|
||||
D3D12_SHADER_BYTECODE m_csBytecode = {};
|
||||
class RHIShader* m_computeShader = nullptr;
|
||||
ID3D12RootSignature* m_rootSignature = nullptr;
|
||||
|
||||
// D3D12 resources
|
||||
ComPtr<ID3D12PipelineState> m_pipelineState;
|
||||
ComPtr<ID3D12PipelineState> m_computePipelineState;
|
||||
std::vector<D3D12_INPUT_ELEMENT_DESC> m_inputElements;
|
||||
};
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ public:
|
||||
|
||||
void Dispatch(uint32_t x, uint32_t y, uint32_t z) override;
|
||||
void DispatchIndirect(unsigned int buffer, size_t offset);
|
||||
void DispatchCompute(unsigned int x, unsigned int y, unsigned int z, unsigned int groupX, unsigned int groupY, unsigned int groupZ);
|
||||
void DispatchCompute(unsigned int x, unsigned int y, unsigned int z);
|
||||
|
||||
void MemoryBarrier(unsigned int barriers);
|
||||
void TextureBarrier();
|
||||
|
||||
@@ -79,6 +79,7 @@ public:
|
||||
void SetTopology(uint32_t topologyType) override;
|
||||
void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override;
|
||||
void SetSampleCount(uint32_t count) override;
|
||||
void SetComputeShader(RHIShader* shader) override;
|
||||
|
||||
// State query
|
||||
const RasterizerDesc& GetRasterizerState() const override { return m_rasterizerDesc; }
|
||||
@@ -86,6 +87,8 @@ public:
|
||||
const DepthStencilStateDesc& GetDepthStencilState() const override { return m_depthStencilDesc; }
|
||||
const InputLayoutDesc& GetInputLayout() const override { return m_inputLayoutDesc; }
|
||||
PipelineStateHash GetHash() const override;
|
||||
RHIShader* GetComputeShader() const override { return m_computeShader; }
|
||||
bool HasComputeShader() const override { return m_computeProgram != 0; }
|
||||
|
||||
// Finalization (OpenGL doesn't need it)
|
||||
bool IsFinalized() const override { return true; }
|
||||
@@ -94,7 +97,7 @@ public:
|
||||
// Lifecycle
|
||||
void Shutdown() override;
|
||||
void* GetNativeHandle() override { return reinterpret_cast<void*>(static_cast<uintptr_t>(m_program)); }
|
||||
PipelineType GetType() const override { return PipelineType::Graphics; }
|
||||
PipelineType GetType() const override { return HasComputeShader() ? PipelineType::Compute : PipelineType::Graphics; }
|
||||
|
||||
void Bind() override;
|
||||
void Unbind() override;
|
||||
@@ -130,6 +133,8 @@ private:
|
||||
DepthStencilStateDesc m_depthStencilDesc;
|
||||
uint32_t m_topologyType = 0;
|
||||
unsigned int m_program = 0;
|
||||
unsigned int m_computeProgram = 0;
|
||||
class RHIShader* m_computeShader = nullptr;
|
||||
bool m_programAttached = false;
|
||||
|
||||
// OpenGL specific state
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
namespace XCEngine {
|
||||
namespace RHI {
|
||||
|
||||
class RHIShader;
|
||||
|
||||
class RHIPipelineState {
|
||||
public:
|
||||
virtual ~RHIPipelineState() = default;
|
||||
@@ -18,6 +20,7 @@ public:
|
||||
virtual void SetTopology(uint32_t topologyType) = 0; // PrimitiveTopologyType
|
||||
virtual void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) = 0;
|
||||
virtual void SetSampleCount(uint32_t count) = 0;
|
||||
virtual void SetComputeShader(RHIShader* shader) = 0;
|
||||
|
||||
// State query
|
||||
virtual const RasterizerDesc& GetRasterizerState() const = 0;
|
||||
@@ -25,6 +28,8 @@ public:
|
||||
virtual const DepthStencilStateDesc& GetDepthStencilState() const = 0;
|
||||
virtual const InputLayoutDesc& GetInputLayout() const = 0;
|
||||
virtual PipelineStateHash GetHash() const = 0;
|
||||
virtual RHIShader* GetComputeShader() const = 0;
|
||||
virtual bool HasComputeShader() const = 0;
|
||||
|
||||
// Finalization (D3D12/Vulkan creates real PSO)
|
||||
virtual bool IsFinalized() const = 0;
|
||||
|
||||
Reference in New Issue
Block a user