Files
XCEngine/engine/include/XCEngine/RHI/D3D12/D3D12PipelineState.h

120 lines
4.4 KiB
C++

#pragma once
#include <d3d12.h>
#include <dxgi1_4.h>
#include <wrl/client.h>
#include <memory>
#include <vector>
#include "../RHIPipelineState.h"
#include "../RHITypes.h"
#include "../RHIEnums.h"
#include "D3D12Enums.h"
using Microsoft::WRL::ComPtr;
namespace XCEngine {
namespace RHI {
class D3D12PipelineState : public RHIPipelineState {
public:
D3D12PipelineState() = default;
D3D12PipelineState(ID3D12Device* device);
~D3D12PipelineState() override;
bool Initialize(ID3D12Device* device, const D3D12_GRAPHICS_PIPELINE_STATE_DESC& desc);
// State configuration (Unity SRP style)
void SetInputLayout(const InputLayoutDesc& layout) override;
void SetRasterizerState(const RasterizerDesc& state) override;
void SetBlendState(const BlendDesc& state) override;
void SetDepthStencilState(const DepthStencilStateDesc& state) override;
void SetTopology(uint32_t topologyType) override;
void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override;
void SetSampleCount(uint32_t count) override;
void SetSampleQuality(uint32_t quality) override;
void SetComputeShader(RHIShader* shader) override;
void SetOwnedComputeShader(std::unique_ptr<class D3D12Shader> shader);
void SetRootSignature(ID3D12RootSignature* rootSignature);
// State query
const RasterizerDesc& GetRasterizerState() const override { return m_rasterizerDesc; }
const BlendDesc& GetBlendState() const override { return m_blendDesc; }
const DepthStencilStateDesc& GetDepthStencilState() const override { return m_depthStencilDesc; }
const InputLayoutDesc& GetInputLayout() const override { return m_inputLayoutDesc; }
PipelineStateHash GetHash() const override;
RHIShader* GetComputeShader() const override { return m_computeShader; }
bool HasComputeShader() const override { return m_csBytecode.pShaderBytecode != nullptr && m_csBytecode.BytecodeLength > 0; }
// Validation
bool IsValid() const override { return m_finalized; }
void EnsureValid() override;
// Shader Bytecode (set by CommandList when binding)
void SetShaderBytecodes(const D3D12_SHADER_BYTECODE& vs, const D3D12_SHADER_BYTECODE& ps, const D3D12_SHADER_BYTECODE& gs = {});
void SetComputeShaderBytecodes(const D3D12_SHADER_BYTECODE& cs);
// Lifecycle
void Shutdown() override;
ID3D12PipelineState* GetPipelineState() const { return m_pipelineState.Get(); }
ID3D12PipelineState* GetComputePipelineState() const { return m_computePipelineState.Get(); }
ID3D12RootSignature* GetRootSignature() const { return m_rootSignature.Get(); }
void* GetNativeHandle() override { return m_pipelineState.Get(); }
PipelineType GetType() const override { return HasComputeShader() ? PipelineType::Compute : PipelineType::Graphics; }
void Bind() override;
void Unbind() override;
// Helper functions
static D3D12_INPUT_ELEMENT_DESC CreateInputElement(
const char* semanticName,
uint32_t semanticIndex,
Format format,
uint32_t inputSlot,
uint32_t alignedByteOffset);
static D3D12_INPUT_ELEMENT_DESC CreateInputElement(
const char* semanticName,
uint32_t semanticIndex,
Format format,
uint32_t inputSlot);
private:
bool CreateD3D12PSO();
bool CreateD3D12ComputePSO();
bool EnsureDefaultRootSignature();
ID3D12Device* m_device;
bool m_finalized = false;
// Stored configuration (Unity SRP style)
GraphicsPipelineDesc m_desc;
InputLayoutDesc m_inputLayoutDesc;
RasterizerDesc m_rasterizerDesc;
BlendDesc m_blendDesc;
DepthStencilStateDesc m_depthStencilDesc;
uint32_t m_topologyType = 0;
uint32_t m_renderTargetCount = 1;
uint32_t m_renderTargetFormats[8] = { 0 };
uint32_t m_depthStencilFormat = 0;
uint32_t m_sampleCount = 1;
uint32_t m_sampleQuality = 0;
// Shader bytecodes (set externally)
D3D12_SHADER_BYTECODE m_vsBytecode = {};
D3D12_SHADER_BYTECODE m_psBytecode = {};
D3D12_SHADER_BYTECODE m_gsBytecode = {};
D3D12_SHADER_BYTECODE m_csBytecode = {};
class RHIShader* m_computeShader = nullptr;
std::unique_ptr<class D3D12Shader> m_ownedComputeShader;
ComPtr<ID3D12RootSignature> m_rootSignature;
// D3D12 resources
ComPtr<ID3D12PipelineState> m_pipelineState;
ComPtr<ID3D12PipelineState> m_computePipelineState;
std::vector<D3D12_INPUT_ELEMENT_DESC> m_inputElements;
};
} // namespace RHI
} // namespace XCEngine