RHI: Add Compute Pipeline abstraction with D3D12 and OpenGL support
- Add SetComputeShader/GetComputeShader/HasComputeShader to RHIPipelineState - Add m_computePipelineState for D3D12 compute PSO - Add m_computeProgram/m_computeShader for OpenGL - Fix OpenGLCommandList::DispatchCompute bug (was ignoring x,y,z params) - Fix OpenGLShader::GetID usage in OpenGLPipelineState - Mark Priority 8 as completed in RHI_Design_Issues.md
This commit is contained in:
@@ -563,7 +563,7 @@ class RHITexture : public RHIResource { ... };
|
|||||||
| 5 | TransitionBarrier 针对 View 而非 Resource | 🟡 中 | 中 | ✅ 已完成 |
|
| 5 | TransitionBarrier 针对 View 而非 Resource | 🟡 中 | 中 | ✅ 已完成 |
|
||||||
| 6 | SetGlobal* 空操作 | 🟡 中 | 低 | ✅ 已完成 |
|
| 6 | SetGlobal* 空操作 | 🟡 中 | 低 | ✅ 已完成 |
|
||||||
| 7 | OpenGL 特有方法暴露 | 🟡 中 | 高 | ❌ 未完成 |
|
| 7 | OpenGL 特有方法暴露 | 🟡 中 | 高 | ❌ 未完成 |
|
||||||
| 8 | 缺少 Compute Pipeline 抽象 | 🟡 中 | 中 | ❌ 未完成 |
|
| 8 | 缺少 Compute Pipeline 抽象 | 🟡 中 | 中 | ✅ 已完成 |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ public:
|
|||||||
void SetTopology(uint32_t topologyType) override;
|
void SetTopology(uint32_t topologyType) override;
|
||||||
void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override;
|
void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override;
|
||||||
void SetSampleCount(uint32_t count) override;
|
void SetSampleCount(uint32_t count) override;
|
||||||
|
void SetComputeShader(RHIShader* shader) override;
|
||||||
|
|
||||||
// State query
|
// State query
|
||||||
const RasterizerDesc& GetRasterizerState() const override { return m_rasterizerDesc; }
|
const RasterizerDesc& GetRasterizerState() const override { return m_rasterizerDesc; }
|
||||||
@@ -38,6 +39,8 @@ public:
|
|||||||
const DepthStencilStateDesc& GetDepthStencilState() const override { return m_depthStencilDesc; }
|
const DepthStencilStateDesc& GetDepthStencilState() const override { return m_depthStencilDesc; }
|
||||||
const InputLayoutDesc& GetInputLayout() const override { return m_inputLayoutDesc; }
|
const InputLayoutDesc& GetInputLayout() const override { return m_inputLayoutDesc; }
|
||||||
PipelineStateHash GetHash() const override;
|
PipelineStateHash GetHash() const override;
|
||||||
|
RHIShader* GetComputeShader() const override { return m_computeShader; }
|
||||||
|
bool HasComputeShader() const override { return m_csBytecode.pShaderBytecode != nullptr && m_csBytecode.BytecodeLength > 0; }
|
||||||
|
|
||||||
// Finalization
|
// Finalization
|
||||||
bool IsFinalized() const override { return m_finalized; }
|
bool IsFinalized() const override { return m_finalized; }
|
||||||
@@ -45,12 +48,14 @@ public:
|
|||||||
|
|
||||||
// Shader Bytecode (set by CommandList when binding)
|
// Shader Bytecode (set by CommandList when binding)
|
||||||
void SetShaderBytecodes(const D3D12_SHADER_BYTECODE& vs, const D3D12_SHADER_BYTECODE& ps, const D3D12_SHADER_BYTECODE& gs = {});
|
void SetShaderBytecodes(const D3D12_SHADER_BYTECODE& vs, const D3D12_SHADER_BYTECODE& ps, const D3D12_SHADER_BYTECODE& gs = {});
|
||||||
|
void SetComputeShaderBytecodes(const D3D12_SHADER_BYTECODE& cs);
|
||||||
|
|
||||||
// Lifecycle
|
// Lifecycle
|
||||||
void Shutdown() override;
|
void Shutdown() override;
|
||||||
ID3D12PipelineState* GetPipelineState() const { return m_pipelineState.Get(); }
|
ID3D12PipelineState* GetPipelineState() const { return m_pipelineState.Get(); }
|
||||||
|
ID3D12PipelineState* GetComputePipelineState() const { return m_computePipelineState.Get(); }
|
||||||
void* GetNativeHandle() override { return m_pipelineState.Get(); }
|
void* GetNativeHandle() override { return m_pipelineState.Get(); }
|
||||||
PipelineType GetType() const override { return PipelineType::Graphics; }
|
PipelineType GetType() const override { return HasComputeShader() ? PipelineType::Compute : PipelineType::Graphics; }
|
||||||
|
|
||||||
void Bind() override;
|
void Bind() override;
|
||||||
void Unbind() override;
|
void Unbind() override;
|
||||||
@@ -71,6 +76,7 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
bool CreateD3D12PSO();
|
bool CreateD3D12PSO();
|
||||||
|
bool CreateD3D12ComputePSO();
|
||||||
|
|
||||||
ID3D12Device* m_device;
|
ID3D12Device* m_device;
|
||||||
bool m_finalized = false;
|
bool m_finalized = false;
|
||||||
@@ -91,10 +97,13 @@ private:
|
|||||||
D3D12_SHADER_BYTECODE m_vsBytecode = {};
|
D3D12_SHADER_BYTECODE m_vsBytecode = {};
|
||||||
D3D12_SHADER_BYTECODE m_psBytecode = {};
|
D3D12_SHADER_BYTECODE m_psBytecode = {};
|
||||||
D3D12_SHADER_BYTECODE m_gsBytecode = {};
|
D3D12_SHADER_BYTECODE m_gsBytecode = {};
|
||||||
|
D3D12_SHADER_BYTECODE m_csBytecode = {};
|
||||||
|
class RHIShader* m_computeShader = nullptr;
|
||||||
ID3D12RootSignature* m_rootSignature = nullptr;
|
ID3D12RootSignature* m_rootSignature = nullptr;
|
||||||
|
|
||||||
// D3D12 resources
|
// D3D12 resources
|
||||||
ComPtr<ID3D12PipelineState> m_pipelineState;
|
ComPtr<ID3D12PipelineState> m_pipelineState;
|
||||||
|
ComPtr<ID3D12PipelineState> m_computePipelineState;
|
||||||
std::vector<D3D12_INPUT_ELEMENT_DESC> m_inputElements;
|
std::vector<D3D12_INPUT_ELEMENT_DESC> m_inputElements;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ public:
|
|||||||
|
|
||||||
void Dispatch(uint32_t x, uint32_t y, uint32_t z) override;
|
void Dispatch(uint32_t x, uint32_t y, uint32_t z) override;
|
||||||
void DispatchIndirect(unsigned int buffer, size_t offset);
|
void DispatchIndirect(unsigned int buffer, size_t offset);
|
||||||
void DispatchCompute(unsigned int x, unsigned int y, unsigned int z, unsigned int groupX, unsigned int groupY, unsigned int groupZ);
|
void DispatchCompute(unsigned int x, unsigned int y, unsigned int z);
|
||||||
|
|
||||||
void MemoryBarrier(unsigned int barriers);
|
void MemoryBarrier(unsigned int barriers);
|
||||||
void TextureBarrier();
|
void TextureBarrier();
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ public:
|
|||||||
void SetTopology(uint32_t topologyType) override;
|
void SetTopology(uint32_t topologyType) override;
|
||||||
void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override;
|
void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override;
|
||||||
void SetSampleCount(uint32_t count) override;
|
void SetSampleCount(uint32_t count) override;
|
||||||
|
void SetComputeShader(RHIShader* shader) override;
|
||||||
|
|
||||||
// State query
|
// State query
|
||||||
const RasterizerDesc& GetRasterizerState() const override { return m_rasterizerDesc; }
|
const RasterizerDesc& GetRasterizerState() const override { return m_rasterizerDesc; }
|
||||||
@@ -86,6 +87,8 @@ public:
|
|||||||
const DepthStencilStateDesc& GetDepthStencilState() const override { return m_depthStencilDesc; }
|
const DepthStencilStateDesc& GetDepthStencilState() const override { return m_depthStencilDesc; }
|
||||||
const InputLayoutDesc& GetInputLayout() const override { return m_inputLayoutDesc; }
|
const InputLayoutDesc& GetInputLayout() const override { return m_inputLayoutDesc; }
|
||||||
PipelineStateHash GetHash() const override;
|
PipelineStateHash GetHash() const override;
|
||||||
|
RHIShader* GetComputeShader() const override { return m_computeShader; }
|
||||||
|
bool HasComputeShader() const override { return m_computeProgram != 0; }
|
||||||
|
|
||||||
// Finalization (OpenGL doesn't need it)
|
// Finalization (OpenGL doesn't need it)
|
||||||
bool IsFinalized() const override { return true; }
|
bool IsFinalized() const override { return true; }
|
||||||
@@ -94,7 +97,7 @@ public:
|
|||||||
// Lifecycle
|
// Lifecycle
|
||||||
void Shutdown() override;
|
void Shutdown() override;
|
||||||
void* GetNativeHandle() override { return reinterpret_cast<void*>(static_cast<uintptr_t>(m_program)); }
|
void* GetNativeHandle() override { return reinterpret_cast<void*>(static_cast<uintptr_t>(m_program)); }
|
||||||
PipelineType GetType() const override { return PipelineType::Graphics; }
|
PipelineType GetType() const override { return HasComputeShader() ? PipelineType::Compute : PipelineType::Graphics; }
|
||||||
|
|
||||||
void Bind() override;
|
void Bind() override;
|
||||||
void Unbind() override;
|
void Unbind() override;
|
||||||
@@ -130,6 +133,8 @@ private:
|
|||||||
DepthStencilStateDesc m_depthStencilDesc;
|
DepthStencilStateDesc m_depthStencilDesc;
|
||||||
uint32_t m_topologyType = 0;
|
uint32_t m_topologyType = 0;
|
||||||
unsigned int m_program = 0;
|
unsigned int m_program = 0;
|
||||||
|
unsigned int m_computeProgram = 0;
|
||||||
|
class RHIShader* m_computeShader = nullptr;
|
||||||
bool m_programAttached = false;
|
bool m_programAttached = false;
|
||||||
|
|
||||||
// OpenGL specific state
|
// OpenGL specific state
|
||||||
|
|||||||
@@ -6,6 +6,8 @@
|
|||||||
namespace XCEngine {
|
namespace XCEngine {
|
||||||
namespace RHI {
|
namespace RHI {
|
||||||
|
|
||||||
|
class RHIShader;
|
||||||
|
|
||||||
class RHIPipelineState {
|
class RHIPipelineState {
|
||||||
public:
|
public:
|
||||||
virtual ~RHIPipelineState() = default;
|
virtual ~RHIPipelineState() = default;
|
||||||
@@ -18,6 +20,7 @@ public:
|
|||||||
virtual void SetTopology(uint32_t topologyType) = 0; // PrimitiveTopologyType
|
virtual void SetTopology(uint32_t topologyType) = 0; // PrimitiveTopologyType
|
||||||
virtual void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) = 0;
|
virtual void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) = 0;
|
||||||
virtual void SetSampleCount(uint32_t count) = 0;
|
virtual void SetSampleCount(uint32_t count) = 0;
|
||||||
|
virtual void SetComputeShader(RHIShader* shader) = 0;
|
||||||
|
|
||||||
// State query
|
// State query
|
||||||
virtual const RasterizerDesc& GetRasterizerState() const = 0;
|
virtual const RasterizerDesc& GetRasterizerState() const = 0;
|
||||||
@@ -25,6 +28,8 @@ public:
|
|||||||
virtual const DepthStencilStateDesc& GetDepthStencilState() const = 0;
|
virtual const DepthStencilStateDesc& GetDepthStencilState() const = 0;
|
||||||
virtual const InputLayoutDesc& GetInputLayout() const = 0;
|
virtual const InputLayoutDesc& GetInputLayout() const = 0;
|
||||||
virtual PipelineStateHash GetHash() const = 0;
|
virtual PipelineStateHash GetHash() const = 0;
|
||||||
|
virtual RHIShader* GetComputeShader() const = 0;
|
||||||
|
virtual bool HasComputeShader() const = 0;
|
||||||
|
|
||||||
// Finalization (D3D12/Vulkan creates real PSO)
|
// Finalization (D3D12/Vulkan creates real PSO)
|
||||||
virtual bool IsFinalized() const = 0;
|
virtual bool IsFinalized() const = 0;
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#include "XCEngine/RHI/D3D12/D3D12PipelineState.h"
|
#include "XCEngine/RHI/D3D12/D3D12PipelineState.h"
|
||||||
|
#include "XCEngine/RHI/D3D12/D3D12Shader.h"
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
|
||||||
namespace XCEngine {
|
namespace XCEngine {
|
||||||
@@ -121,8 +122,19 @@ void D3D12PipelineState::SetShaderBytecodes(const D3D12_SHADER_BYTECODE& vs, con
|
|||||||
m_gsBytecode = gs;
|
m_gsBytecode = gs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void D3D12PipelineState::SetComputeShader(RHIShader* shader) {
|
||||||
|
m_computeShader = shader;
|
||||||
|
}
|
||||||
|
|
||||||
|
void D3D12PipelineState::SetComputeShaderBytecodes(const D3D12_SHADER_BYTECODE& cs) {
|
||||||
|
m_csBytecode = cs;
|
||||||
|
}
|
||||||
|
|
||||||
bool D3D12PipelineState::Finalize() {
|
bool D3D12PipelineState::Finalize() {
|
||||||
if (m_finalized) return true;
|
if (m_finalized) return true;
|
||||||
|
if (HasComputeShader()) {
|
||||||
|
return CreateD3D12ComputePSO();
|
||||||
|
}
|
||||||
return CreateD3D12PSO();
|
return CreateD3D12PSO();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,8 +202,27 @@ bool D3D12PipelineState::CreateD3D12PSO() {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool D3D12PipelineState::CreateD3D12ComputePSO() {
|
||||||
|
if (!m_csBytecode.pShaderBytecode || !m_csBytecode.BytecodeLength) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
D3D12_COMPUTE_PIPELINE_STATE_DESC desc = {};
|
||||||
|
desc.pRootSignature = m_rootSignature;
|
||||||
|
desc.CS = m_csBytecode;
|
||||||
|
|
||||||
|
HRESULT hr = m_device->CreateComputePipelineState(&desc, IID_PPV_ARGS(&m_computePipelineState));
|
||||||
|
if (FAILED(hr)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
m_finalized = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
void D3D12PipelineState::Shutdown() {
|
void D3D12PipelineState::Shutdown() {
|
||||||
m_pipelineState.Reset();
|
m_pipelineState.Reset();
|
||||||
|
m_computePipelineState.Reset();
|
||||||
m_finalized = false;
|
m_finalized = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -271,11 +271,8 @@ void OpenGLCommandList::DispatchIndirect(unsigned int buffer, size_t offset) {
|
|||||||
glBindBuffer(GL_DISPATCH_INDIRECT_BUFFER, 0);
|
glBindBuffer(GL_DISPATCH_INDIRECT_BUFFER, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
void OpenGLCommandList::DispatchCompute(unsigned int x, unsigned int y, unsigned int z, unsigned int groupX, unsigned int groupY, unsigned int groupZ) {
|
void OpenGLCommandList::DispatchCompute(unsigned int x, unsigned int y, unsigned int z) {
|
||||||
glDispatchCompute(groupX, groupY, groupZ);
|
glDispatchCompute(x, y, z);
|
||||||
(void)x;
|
|
||||||
(void)y;
|
|
||||||
(void)z;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void OpenGLCommandList::MemoryBarrier(unsigned int barriers) {
|
void OpenGLCommandList::MemoryBarrier(unsigned int barriers) {
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#include "XCEngine/RHI/OpenGL/OpenGLPipelineState.h"
|
#include "XCEngine/RHI/OpenGL/OpenGLPipelineState.h"
|
||||||
|
#include "XCEngine/RHI/OpenGL/OpenGLShader.h"
|
||||||
#include <glad/glad.h>
|
#include <glad/glad.h>
|
||||||
|
|
||||||
namespace XCEngine {
|
namespace XCEngine {
|
||||||
@@ -37,6 +38,16 @@ void OpenGLPipelineState::SetRenderTargetFormats(uint32_t count, const uint32_t*
|
|||||||
void OpenGLPipelineState::SetSampleCount(uint32_t count) {
|
void OpenGLPipelineState::SetSampleCount(uint32_t count) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OpenGLPipelineState::SetComputeShader(RHIShader* shader) {
|
||||||
|
m_computeShader = shader;
|
||||||
|
if (shader) {
|
||||||
|
OpenGLShader* glShader = static_cast<OpenGLShader*>(shader);
|
||||||
|
m_computeProgram = glShader->GetID();
|
||||||
|
} else {
|
||||||
|
m_computeProgram = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
PipelineStateHash OpenGLPipelineState::GetHash() const {
|
PipelineStateHash OpenGLPipelineState::GetHash() const {
|
||||||
PipelineStateHash hash = {};
|
PipelineStateHash hash = {};
|
||||||
return hash;
|
return hash;
|
||||||
@@ -44,11 +55,15 @@ PipelineStateHash OpenGLPipelineState::GetHash() const {
|
|||||||
|
|
||||||
void OpenGLPipelineState::Shutdown() {
|
void OpenGLPipelineState::Shutdown() {
|
||||||
m_program = 0;
|
m_program = 0;
|
||||||
|
m_computeProgram = 0;
|
||||||
|
m_computeShader = nullptr;
|
||||||
m_programAttached = false;
|
m_programAttached = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void OpenGLPipelineState::Bind() {
|
void OpenGLPipelineState::Bind() {
|
||||||
if (m_programAttached) {
|
if (HasComputeShader()) {
|
||||||
|
glUseProgram(m_computeProgram);
|
||||||
|
} else if (m_programAttached) {
|
||||||
glUseProgram(m_program);
|
glUseProgram(m_program);
|
||||||
}
|
}
|
||||||
Apply();
|
Apply();
|
||||||
|
|||||||
Reference in New Issue
Block a user