#include "XCEngine/RHI/D3D12/D3D12PipelineState.h" #include "XCEngine/RHI/D3D12/D3D12Shader.h" #include namespace XCEngine { namespace RHI { D3D12PipelineState::D3D12PipelineState(ID3D12Device* device) : m_device(device), m_finalized(false) { m_renderTargetFormats[0] = static_cast(Format::R8G8B8A8_UNorm); m_depthStencilFormat = static_cast(Format::D24_UNorm_S8_UInt); } bool D3D12PipelineState::Initialize(ID3D12Device* device, const D3D12_GRAPHICS_PIPELINE_STATE_DESC& desc) { m_device = device; m_rootSignature = desc.pRootSignature; m_rasterizerDesc.fillMode = static_cast(desc.RasterizerState.FillMode); m_rasterizerDesc.cullMode = static_cast(desc.RasterizerState.CullMode); m_rasterizerDesc.frontFace = desc.RasterizerState.FrontCounterClockwise ? 1 : 0; // 1 = CounterClockwise m_rasterizerDesc.depthClipEnable = desc.RasterizerState.DepthClipEnable != 0; m_blendDesc.blendEnable = desc.BlendState.RenderTarget[0].BlendEnable != 0; m_blendDesc.srcBlend = static_cast(desc.BlendState.RenderTarget[0].SrcBlend); m_blendDesc.dstBlend = static_cast(desc.BlendState.RenderTarget[0].DestBlend); m_blendDesc.blendOp = static_cast(desc.BlendState.RenderTarget[0].BlendOp); m_blendDesc.srcBlendAlpha = static_cast(desc.BlendState.RenderTarget[0].SrcBlendAlpha); m_blendDesc.dstBlendAlpha = static_cast(desc.BlendState.RenderTarget[0].DestBlendAlpha); m_blendDesc.blendOpAlpha = static_cast(desc.BlendState.RenderTarget[0].BlendOpAlpha); m_blendDesc.colorWriteMask = desc.BlendState.RenderTarget[0].RenderTargetWriteMask; m_depthStencilDesc.depthTestEnable = desc.DepthStencilState.DepthEnable != 0; m_depthStencilDesc.depthWriteEnable = (desc.DepthStencilState.DepthWriteMask == D3D12_DEPTH_WRITE_MASK_ALL); m_depthStencilDesc.depthFunc = static_cast(desc.DepthStencilState.DepthFunc); m_depthStencilDesc.stencilEnable = desc.DepthStencilState.StencilEnable != 0; m_depthStencilDesc.stencilReadMask = desc.DepthStencilState.StencilReadMask; m_depthStencilDesc.stencilWriteMask = desc.DepthStencilState.StencilWriteMask; m_topologyType = static_cast(desc.PrimitiveTopologyType); m_renderTargetCount = desc.NumRenderTargets; for (UINT i = 0; i < desc.NumRenderTargets; ++i) { m_renderTargetFormats[i] = static_cast(desc.RTVFormats[i]); } m_depthStencilFormat = static_cast(desc.DSVFormat); m_sampleCount = desc.SampleDesc.Count; // Set shader bytecodes m_vsBytecode = desc.VS; m_psBytecode = desc.PS; m_gsBytecode = desc.GS; // Set input layout m_inputElements.clear(); for (UINT i = 0; i < desc.InputLayout.NumElements; ++i) { m_inputElements.push_back(desc.InputLayout.pInputElementDescs[i]); } return Finalize(); } D3D12PipelineState::~D3D12PipelineState() { Shutdown(); } void D3D12PipelineState::SetInputLayout(const InputLayoutDesc& layout) { m_inputLayoutDesc = layout; m_inputElements.clear(); for (const auto& elem : layout.elements) { D3D12_INPUT_ELEMENT_DESC desc = {}; desc.SemanticName = elem.semanticName.c_str(); desc.SemanticIndex = elem.semanticIndex; desc.Format = ToD3D12(static_cast(elem.format)); desc.InputSlot = elem.inputSlot; desc.AlignedByteOffset = elem.alignedByteOffset; desc.InputSlotClass = D3D12_INPUT_CLASSIFICATION_PER_VERTEX_DATA; desc.InstanceDataStepRate = 0; m_inputElements.push_back(desc); } } void D3D12PipelineState::SetRasterizerState(const RasterizerDesc& state) { m_rasterizerDesc = state; } void D3D12PipelineState::SetBlendState(const BlendDesc& state) { m_blendDesc = state; } void D3D12PipelineState::SetDepthStencilState(const DepthStencilStateDesc& state) { m_depthStencilDesc = state; } void D3D12PipelineState::SetTopology(uint32_t topologyType) { m_topologyType = topologyType; } void D3D12PipelineState::SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) { m_renderTargetCount = count; for (uint32_t i = 0; i < count && i < 8; ++i) { m_renderTargetFormats[i] = formats[i]; } m_depthStencilFormat = depthFormat; } void D3D12PipelineState::SetSampleCount(uint32_t count) { m_sampleCount = count; } PipelineStateHash D3D12PipelineState::GetHash() const { PipelineStateHash hash = {}; hash.blendStateHash = std::hash{}(*reinterpret_cast(&m_blendDesc)); hash.depthStateHash = std::hash{}(*reinterpret_cast(&m_depthStencilDesc)); hash.rasterizerStateHash = std::hash{}(*reinterpret_cast(&m_rasterizerDesc)); hash.topologyHash = m_topologyType; hash.renderTargetHash = m_renderTargetCount | (m_depthStencilFormat << 8); return hash; } void D3D12PipelineState::SetShaderBytecodes(const D3D12_SHADER_BYTECODE& vs, const D3D12_SHADER_BYTECODE& ps, const D3D12_SHADER_BYTECODE& gs) { m_vsBytecode = vs; m_psBytecode = ps; m_gsBytecode = gs; } void D3D12PipelineState::SetComputeShader(RHIShader* shader) { m_computeShader = shader; } void D3D12PipelineState::SetComputeShaderBytecodes(const D3D12_SHADER_BYTECODE& cs) { m_csBytecode = cs; } bool D3D12PipelineState::Finalize() { if (m_finalized) return true; if (HasComputeShader()) { return CreateD3D12ComputePSO(); } return CreateD3D12PSO(); } bool D3D12PipelineState::CreateD3D12PSO() { if (!m_vsBytecode.pShaderBytecode || !m_psBytecode.pShaderBytecode) { return false; } D3D12_GRAPHICS_PIPELINE_STATE_DESC desc = {}; desc.pRootSignature = m_rootSignature; desc.VS = m_vsBytecode; desc.PS = m_psBytecode; desc.GS = m_gsBytecode; desc.InputLayout.NumElements = static_cast(m_inputElements.size()); desc.InputLayout.pInputElementDescs = m_inputElements.data(); desc.RasterizerState.FillMode = static_cast(m_rasterizerDesc.fillMode); desc.RasterizerState.CullMode = static_cast(m_rasterizerDesc.cullMode); desc.RasterizerState.FrontCounterClockwise = (m_rasterizerDesc.frontFace != 0); desc.RasterizerState.DepthClipEnable = m_rasterizerDesc.depthClipEnable; desc.RasterizerState.MultisampleEnable = m_rasterizerDesc.multisampleEnable; desc.RasterizerState.AntialiasedLineEnable = m_rasterizerDesc.antialiasedLineEnable; desc.BlendState.RenderTarget[0].BlendEnable = m_blendDesc.blendEnable; desc.BlendState.RenderTarget[0].SrcBlend = static_cast(m_blendDesc.srcBlend); desc.BlendState.RenderTarget[0].DestBlend = static_cast(m_blendDesc.dstBlend); desc.BlendState.RenderTarget[0].BlendOp = static_cast(m_blendDesc.blendOp); desc.BlendState.RenderTarget[0].SrcBlendAlpha = static_cast(m_blendDesc.srcBlendAlpha); desc.BlendState.RenderTarget[0].DestBlendAlpha = static_cast(m_blendDesc.dstBlendAlpha); desc.BlendState.RenderTarget[0].BlendOpAlpha = static_cast(m_blendDesc.blendOpAlpha); desc.BlendState.RenderTarget[0].RenderTargetWriteMask = m_blendDesc.colorWriteMask; desc.DepthStencilState.DepthEnable = m_depthStencilDesc.depthTestEnable; desc.DepthStencilState.DepthWriteMask = m_depthStencilDesc.depthWriteEnable ? D3D12_DEPTH_WRITE_MASK_ALL : D3D12_DEPTH_WRITE_MASK_ZERO; desc.DepthStencilState.DepthFunc = static_cast(m_depthStencilDesc.depthFunc); desc.DepthStencilState.StencilEnable = m_depthStencilDesc.stencilEnable; desc.DepthStencilState.StencilReadMask = m_depthStencilDesc.stencilReadMask; desc.DepthStencilState.StencilWriteMask = m_depthStencilDesc.stencilWriteMask; desc.DepthStencilState.FrontFace.StencilFailOp = static_cast(m_depthStencilDesc.front.failOp); desc.DepthStencilState.FrontFace.StencilPassOp = static_cast(m_depthStencilDesc.front.passOp); desc.DepthStencilState.FrontFace.StencilDepthFailOp = static_cast(m_depthStencilDesc.front.depthFailOp); desc.DepthStencilState.FrontFace.StencilFunc = static_cast(m_depthStencilDesc.front.func); desc.DepthStencilState.BackFace.StencilFailOp = static_cast(m_depthStencilDesc.back.failOp); desc.DepthStencilState.BackFace.StencilPassOp = static_cast(m_depthStencilDesc.back.passOp); desc.DepthStencilState.BackFace.StencilDepthFailOp = static_cast(m_depthStencilDesc.back.depthFailOp); desc.DepthStencilState.BackFace.StencilFunc = static_cast(m_depthStencilDesc.back.func); desc.NumRenderTargets = m_renderTargetCount; for (uint32_t i = 0; i < m_renderTargetCount && i < 8; ++i) { desc.RTVFormats[i] = static_cast(m_renderTargetFormats[i]); } desc.DSVFormat = static_cast(m_depthStencilFormat); desc.SampleDesc.Count = m_sampleCount; desc.SampleDesc.Quality = 0; desc.SampleMask = 0xffffffff; desc.PrimitiveTopologyType = static_cast(m_topologyType); HRESULT hr = m_device->CreateGraphicsPipelineState(&desc, IID_PPV_ARGS(&m_pipelineState)); if (FAILED(hr)) { return false; } m_finalized = true; return true; } bool D3D12PipelineState::CreateD3D12ComputePSO() { if (!m_csBytecode.pShaderBytecode || !m_csBytecode.BytecodeLength) { return false; } D3D12_COMPUTE_PIPELINE_STATE_DESC desc = {}; desc.pRootSignature = m_rootSignature; desc.CS = m_csBytecode; HRESULT hr = m_device->CreateComputePipelineState(&desc, IID_PPV_ARGS(&m_computePipelineState)); if (FAILED(hr)) { return false; } m_finalized = true; return true; } void D3D12PipelineState::Shutdown() { m_pipelineState.Reset(); m_computePipelineState.Reset(); m_finalized = false; } void D3D12PipelineState::Bind() { } void D3D12PipelineState::Unbind() { } D3D12_INPUT_ELEMENT_DESC D3D12PipelineState::CreateInputElement( const char* semanticName, uint32_t semanticIndex, Format format, uint32_t inputSlot, uint32_t alignedByteOffset) { D3D12_INPUT_ELEMENT_DESC element = {}; element.SemanticName = semanticName; element.SemanticIndex = semanticIndex; element.Format = ToD3D12(format); element.InputSlot = inputSlot; element.AlignedByteOffset = alignedByteOffset; element.InputSlotClass = D3D12_INPUT_CLASSIFICATION_PER_VERTEX_DATA; element.InstanceDataStepRate = 0; return element; } D3D12_INPUT_ELEMENT_DESC D3D12PipelineState::CreateInputElement( const char* semanticName, uint32_t semanticIndex, Format format, uint32_t inputSlot) { return CreateInputElement(semanticName, semanticIndex, format, inputSlot, D3D12_APPEND_ALIGNED_ELEMENT); } } // namespace RHI } // namespace XCEngine