- Add SetComputeShader/GetComputeShader/HasComputeShader to RHIPipelineState - Add m_computePipelineState for D3D12 compute PSO - Add m_computeProgram/m_computeShader for OpenGL - Fix OpenGLCommandList::DispatchCompute bug (was ignoring x,y,z params) - Fix OpenGLShader::GetID usage in OpenGLPipelineState - Mark Priority 8 as completed in RHI_Design_Issues.md
261 lines
11 KiB
C++
261 lines
11 KiB
C++
#include "XCEngine/RHI/D3D12/D3D12PipelineState.h"
|
|
#include "XCEngine/RHI/D3D12/D3D12Shader.h"
|
|
#include <cstring>
|
|
|
|
namespace XCEngine {
|
|
namespace RHI {
|
|
|
|
D3D12PipelineState::D3D12PipelineState(ID3D12Device* device)
|
|
: m_device(device), m_finalized(false) {
|
|
m_renderTargetFormats[0] = static_cast<uint32_t>(Format::R8G8B8A8_UNorm);
|
|
m_depthStencilFormat = static_cast<uint32_t>(Format::D24_UNorm_S8_UInt);
|
|
}
|
|
|
|
bool D3D12PipelineState::Initialize(ID3D12Device* device, const D3D12_GRAPHICS_PIPELINE_STATE_DESC& desc) {
|
|
m_device = device;
|
|
m_rootSignature = desc.pRootSignature;
|
|
|
|
m_rasterizerDesc.fillMode = static_cast<uint32_t>(desc.RasterizerState.FillMode);
|
|
m_rasterizerDesc.cullMode = static_cast<uint32_t>(desc.RasterizerState.CullMode);
|
|
m_rasterizerDesc.frontFace = desc.RasterizerState.FrontCounterClockwise ? 1 : 0; // 1 = CounterClockwise
|
|
m_rasterizerDesc.depthClipEnable = desc.RasterizerState.DepthClipEnable != 0;
|
|
|
|
m_blendDesc.blendEnable = desc.BlendState.RenderTarget[0].BlendEnable != 0;
|
|
m_blendDesc.srcBlend = static_cast<uint32_t>(desc.BlendState.RenderTarget[0].SrcBlend);
|
|
m_blendDesc.dstBlend = static_cast<uint32_t>(desc.BlendState.RenderTarget[0].DestBlend);
|
|
m_blendDesc.blendOp = static_cast<uint32_t>(desc.BlendState.RenderTarget[0].BlendOp);
|
|
m_blendDesc.srcBlendAlpha = static_cast<uint32_t>(desc.BlendState.RenderTarget[0].SrcBlendAlpha);
|
|
m_blendDesc.dstBlendAlpha = static_cast<uint32_t>(desc.BlendState.RenderTarget[0].DestBlendAlpha);
|
|
m_blendDesc.blendOpAlpha = static_cast<uint32_t>(desc.BlendState.RenderTarget[0].BlendOpAlpha);
|
|
m_blendDesc.colorWriteMask = desc.BlendState.RenderTarget[0].RenderTargetWriteMask;
|
|
|
|
m_depthStencilDesc.depthTestEnable = desc.DepthStencilState.DepthEnable != 0;
|
|
m_depthStencilDesc.depthWriteEnable = (desc.DepthStencilState.DepthWriteMask == D3D12_DEPTH_WRITE_MASK_ALL);
|
|
m_depthStencilDesc.depthFunc = static_cast<uint32_t>(desc.DepthStencilState.DepthFunc);
|
|
m_depthStencilDesc.stencilEnable = desc.DepthStencilState.StencilEnable != 0;
|
|
m_depthStencilDesc.stencilReadMask = desc.DepthStencilState.StencilReadMask;
|
|
m_depthStencilDesc.stencilWriteMask = desc.DepthStencilState.StencilWriteMask;
|
|
|
|
m_topologyType = static_cast<uint32_t>(desc.PrimitiveTopologyType);
|
|
m_renderTargetCount = desc.NumRenderTargets;
|
|
for (UINT i = 0; i < desc.NumRenderTargets; ++i) {
|
|
m_renderTargetFormats[i] = static_cast<uint32_t>(desc.RTVFormats[i]);
|
|
}
|
|
m_depthStencilFormat = static_cast<uint32_t>(desc.DSVFormat);
|
|
m_sampleCount = desc.SampleDesc.Count;
|
|
|
|
// Set shader bytecodes
|
|
m_vsBytecode = desc.VS;
|
|
m_psBytecode = desc.PS;
|
|
m_gsBytecode = desc.GS;
|
|
|
|
// Set input layout
|
|
m_inputElements.clear();
|
|
for (UINT i = 0; i < desc.InputLayout.NumElements; ++i) {
|
|
m_inputElements.push_back(desc.InputLayout.pInputElementDescs[i]);
|
|
}
|
|
|
|
return Finalize();
|
|
}
|
|
|
|
D3D12PipelineState::~D3D12PipelineState() {
|
|
Shutdown();
|
|
}
|
|
|
|
void D3D12PipelineState::SetInputLayout(const InputLayoutDesc& layout) {
|
|
m_inputLayoutDesc = layout;
|
|
m_inputElements.clear();
|
|
for (const auto& elem : layout.elements) {
|
|
D3D12_INPUT_ELEMENT_DESC desc = {};
|
|
desc.SemanticName = elem.semanticName.c_str();
|
|
desc.SemanticIndex = elem.semanticIndex;
|
|
desc.Format = ToD3D12(static_cast<Format>(elem.format));
|
|
desc.InputSlot = elem.inputSlot;
|
|
desc.AlignedByteOffset = elem.alignedByteOffset;
|
|
desc.InputSlotClass = D3D12_INPUT_CLASSIFICATION_PER_VERTEX_DATA;
|
|
desc.InstanceDataStepRate = 0;
|
|
m_inputElements.push_back(desc);
|
|
}
|
|
}
|
|
|
|
void D3D12PipelineState::SetRasterizerState(const RasterizerDesc& state) {
|
|
m_rasterizerDesc = state;
|
|
}
|
|
|
|
void D3D12PipelineState::SetBlendState(const BlendDesc& state) {
|
|
m_blendDesc = state;
|
|
}
|
|
|
|
void D3D12PipelineState::SetDepthStencilState(const DepthStencilStateDesc& state) {
|
|
m_depthStencilDesc = state;
|
|
}
|
|
|
|
void D3D12PipelineState::SetTopology(uint32_t topologyType) {
|
|
m_topologyType = topologyType;
|
|
}
|
|
|
|
void D3D12PipelineState::SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) {
|
|
m_renderTargetCount = count;
|
|
for (uint32_t i = 0; i < count && i < 8; ++i) {
|
|
m_renderTargetFormats[i] = formats[i];
|
|
}
|
|
m_depthStencilFormat = depthFormat;
|
|
}
|
|
|
|
void D3D12PipelineState::SetSampleCount(uint32_t count) {
|
|
m_sampleCount = count;
|
|
}
|
|
|
|
PipelineStateHash D3D12PipelineState::GetHash() const {
|
|
PipelineStateHash hash = {};
|
|
hash.blendStateHash = std::hash<uint64_t>{}(*reinterpret_cast<const uint64_t*>(&m_blendDesc));
|
|
hash.depthStateHash = std::hash<uint64_t>{}(*reinterpret_cast<const uint64_t*>(&m_depthStencilDesc));
|
|
hash.rasterizerStateHash = std::hash<uint64_t>{}(*reinterpret_cast<const uint64_t*>(&m_rasterizerDesc));
|
|
hash.topologyHash = m_topologyType;
|
|
hash.renderTargetHash = m_renderTargetCount | (m_depthStencilFormat << 8);
|
|
return hash;
|
|
}
|
|
|
|
void D3D12PipelineState::SetShaderBytecodes(const D3D12_SHADER_BYTECODE& vs, const D3D12_SHADER_BYTECODE& ps, const D3D12_SHADER_BYTECODE& gs) {
|
|
m_vsBytecode = vs;
|
|
m_psBytecode = ps;
|
|
m_gsBytecode = gs;
|
|
}
|
|
|
|
void D3D12PipelineState::SetComputeShader(RHIShader* shader) {
|
|
m_computeShader = shader;
|
|
}
|
|
|
|
void D3D12PipelineState::SetComputeShaderBytecodes(const D3D12_SHADER_BYTECODE& cs) {
|
|
m_csBytecode = cs;
|
|
}
|
|
|
|
bool D3D12PipelineState::Finalize() {
|
|
if (m_finalized) return true;
|
|
if (HasComputeShader()) {
|
|
return CreateD3D12ComputePSO();
|
|
}
|
|
return CreateD3D12PSO();
|
|
}
|
|
|
|
bool D3D12PipelineState::CreateD3D12PSO() {
|
|
if (!m_vsBytecode.pShaderBytecode || !m_psBytecode.pShaderBytecode) {
|
|
return false;
|
|
}
|
|
|
|
D3D12_GRAPHICS_PIPELINE_STATE_DESC desc = {};
|
|
desc.pRootSignature = m_rootSignature;
|
|
desc.VS = m_vsBytecode;
|
|
desc.PS = m_psBytecode;
|
|
desc.GS = m_gsBytecode;
|
|
|
|
desc.InputLayout.NumElements = static_cast<UINT>(m_inputElements.size());
|
|
desc.InputLayout.pInputElementDescs = m_inputElements.data();
|
|
|
|
desc.RasterizerState.FillMode = static_cast<D3D12_FILL_MODE>(m_rasterizerDesc.fillMode);
|
|
desc.RasterizerState.CullMode = static_cast<D3D12_CULL_MODE>(m_rasterizerDesc.cullMode);
|
|
desc.RasterizerState.FrontCounterClockwise = (m_rasterizerDesc.frontFace != 0);
|
|
desc.RasterizerState.DepthClipEnable = m_rasterizerDesc.depthClipEnable;
|
|
desc.RasterizerState.MultisampleEnable = m_rasterizerDesc.multisampleEnable;
|
|
desc.RasterizerState.AntialiasedLineEnable = m_rasterizerDesc.antialiasedLineEnable;
|
|
|
|
desc.BlendState.RenderTarget[0].BlendEnable = m_blendDesc.blendEnable;
|
|
desc.BlendState.RenderTarget[0].SrcBlend = static_cast<D3D12_BLEND>(m_blendDesc.srcBlend);
|
|
desc.BlendState.RenderTarget[0].DestBlend = static_cast<D3D12_BLEND>(m_blendDesc.dstBlend);
|
|
desc.BlendState.RenderTarget[0].BlendOp = static_cast<D3D12_BLEND_OP>(m_blendDesc.blendOp);
|
|
desc.BlendState.RenderTarget[0].SrcBlendAlpha = static_cast<D3D12_BLEND>(m_blendDesc.srcBlendAlpha);
|
|
desc.BlendState.RenderTarget[0].DestBlendAlpha = static_cast<D3D12_BLEND>(m_blendDesc.dstBlendAlpha);
|
|
desc.BlendState.RenderTarget[0].BlendOpAlpha = static_cast<D3D12_BLEND_OP>(m_blendDesc.blendOpAlpha);
|
|
desc.BlendState.RenderTarget[0].RenderTargetWriteMask = m_blendDesc.colorWriteMask;
|
|
|
|
desc.DepthStencilState.DepthEnable = m_depthStencilDesc.depthTestEnable;
|
|
desc.DepthStencilState.DepthWriteMask = m_depthStencilDesc.depthWriteEnable ? D3D12_DEPTH_WRITE_MASK_ALL : D3D12_DEPTH_WRITE_MASK_ZERO;
|
|
desc.DepthStencilState.DepthFunc = static_cast<D3D12_COMPARISON_FUNC>(m_depthStencilDesc.depthFunc);
|
|
desc.DepthStencilState.StencilEnable = m_depthStencilDesc.stencilEnable;
|
|
desc.DepthStencilState.StencilReadMask = m_depthStencilDesc.stencilReadMask;
|
|
desc.DepthStencilState.StencilWriteMask = m_depthStencilDesc.stencilWriteMask;
|
|
desc.DepthStencilState.FrontFace.StencilFailOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.front.failOp);
|
|
desc.DepthStencilState.FrontFace.StencilPassOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.front.passOp);
|
|
desc.DepthStencilState.FrontFace.StencilDepthFailOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.front.depthFailOp);
|
|
desc.DepthStencilState.FrontFace.StencilFunc = static_cast<D3D12_COMPARISON_FUNC>(m_depthStencilDesc.front.func);
|
|
desc.DepthStencilState.BackFace.StencilFailOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.back.failOp);
|
|
desc.DepthStencilState.BackFace.StencilPassOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.back.passOp);
|
|
desc.DepthStencilState.BackFace.StencilDepthFailOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.back.depthFailOp);
|
|
desc.DepthStencilState.BackFace.StencilFunc = static_cast<D3D12_COMPARISON_FUNC>(m_depthStencilDesc.back.func);
|
|
|
|
desc.NumRenderTargets = m_renderTargetCount;
|
|
for (uint32_t i = 0; i < m_renderTargetCount && i < 8; ++i) {
|
|
desc.RTVFormats[i] = static_cast<DXGI_FORMAT>(m_renderTargetFormats[i]);
|
|
}
|
|
desc.DSVFormat = static_cast<DXGI_FORMAT>(m_depthStencilFormat);
|
|
desc.SampleDesc.Count = m_sampleCount;
|
|
desc.SampleDesc.Quality = 0;
|
|
desc.SampleMask = 0xffffffff;
|
|
desc.PrimitiveTopologyType = static_cast<D3D12_PRIMITIVE_TOPOLOGY_TYPE>(m_topologyType);
|
|
|
|
HRESULT hr = m_device->CreateGraphicsPipelineState(&desc, IID_PPV_ARGS(&m_pipelineState));
|
|
if (FAILED(hr)) {
|
|
return false;
|
|
}
|
|
|
|
m_finalized = true;
|
|
return true;
|
|
}
|
|
|
|
bool D3D12PipelineState::CreateD3D12ComputePSO() {
|
|
if (!m_csBytecode.pShaderBytecode || !m_csBytecode.BytecodeLength) {
|
|
return false;
|
|
}
|
|
|
|
D3D12_COMPUTE_PIPELINE_STATE_DESC desc = {};
|
|
desc.pRootSignature = m_rootSignature;
|
|
desc.CS = m_csBytecode;
|
|
|
|
HRESULT hr = m_device->CreateComputePipelineState(&desc, IID_PPV_ARGS(&m_computePipelineState));
|
|
if (FAILED(hr)) {
|
|
return false;
|
|
}
|
|
|
|
m_finalized = true;
|
|
return true;
|
|
}
|
|
|
|
void D3D12PipelineState::Shutdown() {
|
|
m_pipelineState.Reset();
|
|
m_computePipelineState.Reset();
|
|
m_finalized = false;
|
|
}
|
|
|
|
void D3D12PipelineState::Bind() {
|
|
}
|
|
|
|
void D3D12PipelineState::Unbind() {
|
|
}
|
|
|
|
D3D12_INPUT_ELEMENT_DESC D3D12PipelineState::CreateInputElement(
|
|
const char* semanticName,
|
|
uint32_t semanticIndex,
|
|
Format format,
|
|
uint32_t inputSlot,
|
|
uint32_t alignedByteOffset) {
|
|
D3D12_INPUT_ELEMENT_DESC element = {};
|
|
element.SemanticName = semanticName;
|
|
element.SemanticIndex = semanticIndex;
|
|
element.Format = ToD3D12(format);
|
|
element.InputSlot = inputSlot;
|
|
element.AlignedByteOffset = alignedByteOffset;
|
|
element.InputSlotClass = D3D12_INPUT_CLASSIFICATION_PER_VERTEX_DATA;
|
|
element.InstanceDataStepRate = 0;
|
|
return element;
|
|
}
|
|
|
|
D3D12_INPUT_ELEMENT_DESC D3D12PipelineState::CreateInputElement(
|
|
const char* semanticName,
|
|
uint32_t semanticIndex,
|
|
Format format,
|
|
uint32_t inputSlot) {
|
|
return CreateInputElement(semanticName, semanticIndex, format, inputSlot, D3D12_APPEND_ALIGNED_ELEMENT);
|
|
}
|
|
|
|
} // namespace RHI
|
|
} // namespace XCEngine
|