Files
XCEngine/engine/src/RHI/D3D12/D3D12PipelineState.cpp
ssdfasd 6328ac8d37 RHI: Add Compute Pipeline abstraction with D3D12 and OpenGL support
- Add SetComputeShader/GetComputeShader/HasComputeShader to RHIPipelineState
- Add m_computePipelineState for D3D12 compute PSO
- Add m_computeProgram/m_computeShader for OpenGL
- Fix OpenGLCommandList::DispatchCompute bug (was ignoring x,y,z params)
- Fix OpenGLShader::GetID usage in OpenGLPipelineState
- Mark Priority 8 as completed in RHI_Design_Issues.md
2026-03-25 01:05:03 +08:00

261 lines
11 KiB
C++

#include "XCEngine/RHI/D3D12/D3D12PipelineState.h"
#include "XCEngine/RHI/D3D12/D3D12Shader.h"
#include <cstring>
namespace XCEngine {
namespace RHI {
D3D12PipelineState::D3D12PipelineState(ID3D12Device* device)
: m_device(device), m_finalized(false) {
m_renderTargetFormats[0] = static_cast<uint32_t>(Format::R8G8B8A8_UNorm);
m_depthStencilFormat = static_cast<uint32_t>(Format::D24_UNorm_S8_UInt);
}
bool D3D12PipelineState::Initialize(ID3D12Device* device, const D3D12_GRAPHICS_PIPELINE_STATE_DESC& desc) {
m_device = device;
m_rootSignature = desc.pRootSignature;
m_rasterizerDesc.fillMode = static_cast<uint32_t>(desc.RasterizerState.FillMode);
m_rasterizerDesc.cullMode = static_cast<uint32_t>(desc.RasterizerState.CullMode);
m_rasterizerDesc.frontFace = desc.RasterizerState.FrontCounterClockwise ? 1 : 0; // 1 = CounterClockwise
m_rasterizerDesc.depthClipEnable = desc.RasterizerState.DepthClipEnable != 0;
m_blendDesc.blendEnable = desc.BlendState.RenderTarget[0].BlendEnable != 0;
m_blendDesc.srcBlend = static_cast<uint32_t>(desc.BlendState.RenderTarget[0].SrcBlend);
m_blendDesc.dstBlend = static_cast<uint32_t>(desc.BlendState.RenderTarget[0].DestBlend);
m_blendDesc.blendOp = static_cast<uint32_t>(desc.BlendState.RenderTarget[0].BlendOp);
m_blendDesc.srcBlendAlpha = static_cast<uint32_t>(desc.BlendState.RenderTarget[0].SrcBlendAlpha);
m_blendDesc.dstBlendAlpha = static_cast<uint32_t>(desc.BlendState.RenderTarget[0].DestBlendAlpha);
m_blendDesc.blendOpAlpha = static_cast<uint32_t>(desc.BlendState.RenderTarget[0].BlendOpAlpha);
m_blendDesc.colorWriteMask = desc.BlendState.RenderTarget[0].RenderTargetWriteMask;
m_depthStencilDesc.depthTestEnable = desc.DepthStencilState.DepthEnable != 0;
m_depthStencilDesc.depthWriteEnable = (desc.DepthStencilState.DepthWriteMask == D3D12_DEPTH_WRITE_MASK_ALL);
m_depthStencilDesc.depthFunc = static_cast<uint32_t>(desc.DepthStencilState.DepthFunc);
m_depthStencilDesc.stencilEnable = desc.DepthStencilState.StencilEnable != 0;
m_depthStencilDesc.stencilReadMask = desc.DepthStencilState.StencilReadMask;
m_depthStencilDesc.stencilWriteMask = desc.DepthStencilState.StencilWriteMask;
m_topologyType = static_cast<uint32_t>(desc.PrimitiveTopologyType);
m_renderTargetCount = desc.NumRenderTargets;
for (UINT i = 0; i < desc.NumRenderTargets; ++i) {
m_renderTargetFormats[i] = static_cast<uint32_t>(desc.RTVFormats[i]);
}
m_depthStencilFormat = static_cast<uint32_t>(desc.DSVFormat);
m_sampleCount = desc.SampleDesc.Count;
// Set shader bytecodes
m_vsBytecode = desc.VS;
m_psBytecode = desc.PS;
m_gsBytecode = desc.GS;
// Set input layout
m_inputElements.clear();
for (UINT i = 0; i < desc.InputLayout.NumElements; ++i) {
m_inputElements.push_back(desc.InputLayout.pInputElementDescs[i]);
}
return Finalize();
}
D3D12PipelineState::~D3D12PipelineState() {
Shutdown();
}
void D3D12PipelineState::SetInputLayout(const InputLayoutDesc& layout) {
m_inputLayoutDesc = layout;
m_inputElements.clear();
for (const auto& elem : layout.elements) {
D3D12_INPUT_ELEMENT_DESC desc = {};
desc.SemanticName = elem.semanticName.c_str();
desc.SemanticIndex = elem.semanticIndex;
desc.Format = ToD3D12(static_cast<Format>(elem.format));
desc.InputSlot = elem.inputSlot;
desc.AlignedByteOffset = elem.alignedByteOffset;
desc.InputSlotClass = D3D12_INPUT_CLASSIFICATION_PER_VERTEX_DATA;
desc.InstanceDataStepRate = 0;
m_inputElements.push_back(desc);
}
}
void D3D12PipelineState::SetRasterizerState(const RasterizerDesc& state) {
m_rasterizerDesc = state;
}
void D3D12PipelineState::SetBlendState(const BlendDesc& state) {
m_blendDesc = state;
}
void D3D12PipelineState::SetDepthStencilState(const DepthStencilStateDesc& state) {
m_depthStencilDesc = state;
}
void D3D12PipelineState::SetTopology(uint32_t topologyType) {
m_topologyType = topologyType;
}
void D3D12PipelineState::SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) {
m_renderTargetCount = count;
for (uint32_t i = 0; i < count && i < 8; ++i) {
m_renderTargetFormats[i] = formats[i];
}
m_depthStencilFormat = depthFormat;
}
void D3D12PipelineState::SetSampleCount(uint32_t count) {
m_sampleCount = count;
}
PipelineStateHash D3D12PipelineState::GetHash() const {
PipelineStateHash hash = {};
hash.blendStateHash = std::hash<uint64_t>{}(*reinterpret_cast<const uint64_t*>(&m_blendDesc));
hash.depthStateHash = std::hash<uint64_t>{}(*reinterpret_cast<const uint64_t*>(&m_depthStencilDesc));
hash.rasterizerStateHash = std::hash<uint64_t>{}(*reinterpret_cast<const uint64_t*>(&m_rasterizerDesc));
hash.topologyHash = m_topologyType;
hash.renderTargetHash = m_renderTargetCount | (m_depthStencilFormat << 8);
return hash;
}
void D3D12PipelineState::SetShaderBytecodes(const D3D12_SHADER_BYTECODE& vs, const D3D12_SHADER_BYTECODE& ps, const D3D12_SHADER_BYTECODE& gs) {
m_vsBytecode = vs;
m_psBytecode = ps;
m_gsBytecode = gs;
}
void D3D12PipelineState::SetComputeShader(RHIShader* shader) {
m_computeShader = shader;
}
void D3D12PipelineState::SetComputeShaderBytecodes(const D3D12_SHADER_BYTECODE& cs) {
m_csBytecode = cs;
}
bool D3D12PipelineState::Finalize() {
if (m_finalized) return true;
if (HasComputeShader()) {
return CreateD3D12ComputePSO();
}
return CreateD3D12PSO();
}
bool D3D12PipelineState::CreateD3D12PSO() {
if (!m_vsBytecode.pShaderBytecode || !m_psBytecode.pShaderBytecode) {
return false;
}
D3D12_GRAPHICS_PIPELINE_STATE_DESC desc = {};
desc.pRootSignature = m_rootSignature;
desc.VS = m_vsBytecode;
desc.PS = m_psBytecode;
desc.GS = m_gsBytecode;
desc.InputLayout.NumElements = static_cast<UINT>(m_inputElements.size());
desc.InputLayout.pInputElementDescs = m_inputElements.data();
desc.RasterizerState.FillMode = static_cast<D3D12_FILL_MODE>(m_rasterizerDesc.fillMode);
desc.RasterizerState.CullMode = static_cast<D3D12_CULL_MODE>(m_rasterizerDesc.cullMode);
desc.RasterizerState.FrontCounterClockwise = (m_rasterizerDesc.frontFace != 0);
desc.RasterizerState.DepthClipEnable = m_rasterizerDesc.depthClipEnable;
desc.RasterizerState.MultisampleEnable = m_rasterizerDesc.multisampleEnable;
desc.RasterizerState.AntialiasedLineEnable = m_rasterizerDesc.antialiasedLineEnable;
desc.BlendState.RenderTarget[0].BlendEnable = m_blendDesc.blendEnable;
desc.BlendState.RenderTarget[0].SrcBlend = static_cast<D3D12_BLEND>(m_blendDesc.srcBlend);
desc.BlendState.RenderTarget[0].DestBlend = static_cast<D3D12_BLEND>(m_blendDesc.dstBlend);
desc.BlendState.RenderTarget[0].BlendOp = static_cast<D3D12_BLEND_OP>(m_blendDesc.blendOp);
desc.BlendState.RenderTarget[0].SrcBlendAlpha = static_cast<D3D12_BLEND>(m_blendDesc.srcBlendAlpha);
desc.BlendState.RenderTarget[0].DestBlendAlpha = static_cast<D3D12_BLEND>(m_blendDesc.dstBlendAlpha);
desc.BlendState.RenderTarget[0].BlendOpAlpha = static_cast<D3D12_BLEND_OP>(m_blendDesc.blendOpAlpha);
desc.BlendState.RenderTarget[0].RenderTargetWriteMask = m_blendDesc.colorWriteMask;
desc.DepthStencilState.DepthEnable = m_depthStencilDesc.depthTestEnable;
desc.DepthStencilState.DepthWriteMask = m_depthStencilDesc.depthWriteEnable ? D3D12_DEPTH_WRITE_MASK_ALL : D3D12_DEPTH_WRITE_MASK_ZERO;
desc.DepthStencilState.DepthFunc = static_cast<D3D12_COMPARISON_FUNC>(m_depthStencilDesc.depthFunc);
desc.DepthStencilState.StencilEnable = m_depthStencilDesc.stencilEnable;
desc.DepthStencilState.StencilReadMask = m_depthStencilDesc.stencilReadMask;
desc.DepthStencilState.StencilWriteMask = m_depthStencilDesc.stencilWriteMask;
desc.DepthStencilState.FrontFace.StencilFailOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.front.failOp);
desc.DepthStencilState.FrontFace.StencilPassOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.front.passOp);
desc.DepthStencilState.FrontFace.StencilDepthFailOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.front.depthFailOp);
desc.DepthStencilState.FrontFace.StencilFunc = static_cast<D3D12_COMPARISON_FUNC>(m_depthStencilDesc.front.func);
desc.DepthStencilState.BackFace.StencilFailOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.back.failOp);
desc.DepthStencilState.BackFace.StencilPassOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.back.passOp);
desc.DepthStencilState.BackFace.StencilDepthFailOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.back.depthFailOp);
desc.DepthStencilState.BackFace.StencilFunc = static_cast<D3D12_COMPARISON_FUNC>(m_depthStencilDesc.back.func);
desc.NumRenderTargets = m_renderTargetCount;
for (uint32_t i = 0; i < m_renderTargetCount && i < 8; ++i) {
desc.RTVFormats[i] = static_cast<DXGI_FORMAT>(m_renderTargetFormats[i]);
}
desc.DSVFormat = static_cast<DXGI_FORMAT>(m_depthStencilFormat);
desc.SampleDesc.Count = m_sampleCount;
desc.SampleDesc.Quality = 0;
desc.SampleMask = 0xffffffff;
desc.PrimitiveTopologyType = static_cast<D3D12_PRIMITIVE_TOPOLOGY_TYPE>(m_topologyType);
HRESULT hr = m_device->CreateGraphicsPipelineState(&desc, IID_PPV_ARGS(&m_pipelineState));
if (FAILED(hr)) {
return false;
}
m_finalized = true;
return true;
}
bool D3D12PipelineState::CreateD3D12ComputePSO() {
if (!m_csBytecode.pShaderBytecode || !m_csBytecode.BytecodeLength) {
return false;
}
D3D12_COMPUTE_PIPELINE_STATE_DESC desc = {};
desc.pRootSignature = m_rootSignature;
desc.CS = m_csBytecode;
HRESULT hr = m_device->CreateComputePipelineState(&desc, IID_PPV_ARGS(&m_computePipelineState));
if (FAILED(hr)) {
return false;
}
m_finalized = true;
return true;
}
void D3D12PipelineState::Shutdown() {
m_pipelineState.Reset();
m_computePipelineState.Reset();
m_finalized = false;
}
void D3D12PipelineState::Bind() {
}
void D3D12PipelineState::Unbind() {
}
D3D12_INPUT_ELEMENT_DESC D3D12PipelineState::CreateInputElement(
const char* semanticName,
uint32_t semanticIndex,
Format format,
uint32_t inputSlot,
uint32_t alignedByteOffset) {
D3D12_INPUT_ELEMENT_DESC element = {};
element.SemanticName = semanticName;
element.SemanticIndex = semanticIndex;
element.Format = ToD3D12(format);
element.InputSlot = inputSlot;
element.AlignedByteOffset = alignedByteOffset;
element.InputSlotClass = D3D12_INPUT_CLASSIFICATION_PER_VERTEX_DATA;
element.InstanceDataStepRate = 0;
return element;
}
D3D12_INPUT_ELEMENT_DESC D3D12PipelineState::CreateInputElement(
const char* semanticName,
uint32_t semanticIndex,
Format format,
uint32_t inputSlot) {
return CreateInputElement(semanticName, semanticIndex, format, inputSlot, D3D12_APPEND_ALIGNED_ELEMENT);
}
} // namespace RHI
} // namespace XCEngine