RHI: Add Compute Pipeline abstraction with D3D12 and OpenGL support
- Add SetComputeShader/GetComputeShader/HasComputeShader to RHIPipelineState - Add m_computePipelineState for D3D12 compute PSO - Add m_computeProgram/m_computeShader for OpenGL - Fix OpenGLCommandList::DispatchCompute bug (was ignoring x,y,z params) - Fix OpenGLShader::GetID usage in OpenGLPipelineState - Mark Priority 8 as completed in RHI_Design_Issues.md
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
#include "XCEngine/RHI/D3D12/D3D12PipelineState.h"
|
||||
#include "XCEngine/RHI/D3D12/D3D12Shader.h"
|
||||
#include <cstring>
|
||||
|
||||
namespace XCEngine {
|
||||
@@ -121,8 +122,19 @@ void D3D12PipelineState::SetShaderBytecodes(const D3D12_SHADER_BYTECODE& vs, con
|
||||
m_gsBytecode = gs;
|
||||
}
|
||||
|
||||
void D3D12PipelineState::SetComputeShader(RHIShader* shader) {
|
||||
m_computeShader = shader;
|
||||
}
|
||||
|
||||
void D3D12PipelineState::SetComputeShaderBytecodes(const D3D12_SHADER_BYTECODE& cs) {
|
||||
m_csBytecode = cs;
|
||||
}
|
||||
|
||||
bool D3D12PipelineState::Finalize() {
|
||||
if (m_finalized) return true;
|
||||
if (HasComputeShader()) {
|
||||
return CreateD3D12ComputePSO();
|
||||
}
|
||||
return CreateD3D12PSO();
|
||||
}
|
||||
|
||||
@@ -190,8 +202,27 @@ bool D3D12PipelineState::CreateD3D12PSO() {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool D3D12PipelineState::CreateD3D12ComputePSO() {
|
||||
if (!m_csBytecode.pShaderBytecode || !m_csBytecode.BytecodeLength) {
|
||||
return false;
|
||||
}
|
||||
|
||||
D3D12_COMPUTE_PIPELINE_STATE_DESC desc = {};
|
||||
desc.pRootSignature = m_rootSignature;
|
||||
desc.CS = m_csBytecode;
|
||||
|
||||
HRESULT hr = m_device->CreateComputePipelineState(&desc, IID_PPV_ARGS(&m_computePipelineState));
|
||||
if (FAILED(hr)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
m_finalized = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
void D3D12PipelineState::Shutdown() {
|
||||
m_pipelineState.Reset();
|
||||
m_computePipelineState.Reset();
|
||||
m_finalized = false;
|
||||
}
|
||||
|
||||
|
||||
@@ -271,11 +271,8 @@ void OpenGLCommandList::DispatchIndirect(unsigned int buffer, size_t offset) {
|
||||
glBindBuffer(GL_DISPATCH_INDIRECT_BUFFER, 0);
|
||||
}
|
||||
|
||||
void OpenGLCommandList::DispatchCompute(unsigned int x, unsigned int y, unsigned int z, unsigned int groupX, unsigned int groupY, unsigned int groupZ) {
|
||||
glDispatchCompute(groupX, groupY, groupZ);
|
||||
(void)x;
|
||||
(void)y;
|
||||
(void)z;
|
||||
void OpenGLCommandList::DispatchCompute(unsigned int x, unsigned int y, unsigned int z) {
|
||||
glDispatchCompute(x, y, z);
|
||||
}
|
||||
|
||||
void OpenGLCommandList::MemoryBarrier(unsigned int barriers) {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "XCEngine/RHI/OpenGL/OpenGLPipelineState.h"
|
||||
#include "XCEngine/RHI/OpenGL/OpenGLShader.h"
|
||||
#include <glad/glad.h>
|
||||
|
||||
namespace XCEngine {
|
||||
@@ -37,6 +38,16 @@ void OpenGLPipelineState::SetRenderTargetFormats(uint32_t count, const uint32_t*
|
||||
void OpenGLPipelineState::SetSampleCount(uint32_t count) {
|
||||
}
|
||||
|
||||
void OpenGLPipelineState::SetComputeShader(RHIShader* shader) {
|
||||
m_computeShader = shader;
|
||||
if (shader) {
|
||||
OpenGLShader* glShader = static_cast<OpenGLShader*>(shader);
|
||||
m_computeProgram = glShader->GetID();
|
||||
} else {
|
||||
m_computeProgram = 0;
|
||||
}
|
||||
}
|
||||
|
||||
PipelineStateHash OpenGLPipelineState::GetHash() const {
|
||||
PipelineStateHash hash = {};
|
||||
return hash;
|
||||
@@ -44,11 +55,15 @@ PipelineStateHash OpenGLPipelineState::GetHash() const {
|
||||
|
||||
void OpenGLPipelineState::Shutdown() {
|
||||
m_program = 0;
|
||||
m_computeProgram = 0;
|
||||
m_computeShader = nullptr;
|
||||
m_programAttached = false;
|
||||
}
|
||||
|
||||
void OpenGLPipelineState::Bind() {
|
||||
if (m_programAttached) {
|
||||
if (HasComputeShader()) {
|
||||
glUseProgram(m_computeProgram);
|
||||
} else if (m_programAttached) {
|
||||
glUseProgram(m_program);
|
||||
}
|
||||
Apply();
|
||||
|
||||
Reference in New Issue
Block a user