120 lines
4.4 KiB
C++
120 lines
4.4 KiB
C++
#pragma once
|
|
|
|
#include <d3d12.h>
|
|
#include <dxgi1_4.h>
|
|
#include <wrl/client.h>
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
#include "../RHIPipelineState.h"
|
|
#include "../RHITypes.h"
|
|
#include "../RHIEnums.h"
|
|
#include "D3D12Enums.h"
|
|
|
|
using Microsoft::WRL::ComPtr;
|
|
|
|
namespace XCEngine {
|
|
namespace RHI {
|
|
|
|
class D3D12PipelineState : public RHIPipelineState {
|
|
public:
|
|
D3D12PipelineState() = default;
|
|
D3D12PipelineState(ID3D12Device* device);
|
|
~D3D12PipelineState() override;
|
|
|
|
bool Initialize(ID3D12Device* device, const D3D12_GRAPHICS_PIPELINE_STATE_DESC& desc);
|
|
|
|
// State configuration (Unity SRP style)
|
|
void SetInputLayout(const InputLayoutDesc& layout) override;
|
|
void SetRasterizerState(const RasterizerDesc& state) override;
|
|
void SetBlendState(const BlendDesc& state) override;
|
|
void SetDepthStencilState(const DepthStencilStateDesc& state) override;
|
|
void SetTopology(uint32_t topologyType) override;
|
|
void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override;
|
|
void SetSampleCount(uint32_t count) override;
|
|
void SetSampleQuality(uint32_t quality) override;
|
|
void SetComputeShader(RHIShader* shader) override;
|
|
void SetOwnedComputeShader(std::unique_ptr<class D3D12Shader> shader);
|
|
void SetRootSignature(ID3D12RootSignature* rootSignature);
|
|
|
|
// State query
|
|
const RasterizerDesc& GetRasterizerState() const override { return m_rasterizerDesc; }
|
|
const BlendDesc& GetBlendState() const override { return m_blendDesc; }
|
|
const DepthStencilStateDesc& GetDepthStencilState() const override { return m_depthStencilDesc; }
|
|
const InputLayoutDesc& GetInputLayout() const override { return m_inputLayoutDesc; }
|
|
PipelineStateHash GetHash() const override;
|
|
RHIShader* GetComputeShader() const override { return m_computeShader; }
|
|
bool HasComputeShader() const override { return m_csBytecode.pShaderBytecode != nullptr && m_csBytecode.BytecodeLength > 0; }
|
|
|
|
// Validation
|
|
bool IsValid() const override { return m_finalized; }
|
|
void EnsureValid() override;
|
|
|
|
// Shader Bytecode (set by CommandList when binding)
|
|
void SetShaderBytecodes(const D3D12_SHADER_BYTECODE& vs, const D3D12_SHADER_BYTECODE& ps, const D3D12_SHADER_BYTECODE& gs = {});
|
|
void SetComputeShaderBytecodes(const D3D12_SHADER_BYTECODE& cs);
|
|
|
|
// Lifecycle
|
|
void Shutdown() override;
|
|
ID3D12PipelineState* GetPipelineState() const { return m_pipelineState.Get(); }
|
|
ID3D12PipelineState* GetComputePipelineState() const { return m_computePipelineState.Get(); }
|
|
ID3D12RootSignature* GetRootSignature() const { return m_rootSignature.Get(); }
|
|
void* GetNativeHandle() override { return m_pipelineState.Get(); }
|
|
PipelineType GetType() const override { return HasComputeShader() ? PipelineType::Compute : PipelineType::Graphics; }
|
|
|
|
void Bind() override;
|
|
void Unbind() override;
|
|
|
|
// Helper functions
|
|
static D3D12_INPUT_ELEMENT_DESC CreateInputElement(
|
|
const char* semanticName,
|
|
uint32_t semanticIndex,
|
|
Format format,
|
|
uint32_t inputSlot,
|
|
uint32_t alignedByteOffset);
|
|
|
|
static D3D12_INPUT_ELEMENT_DESC CreateInputElement(
|
|
const char* semanticName,
|
|
uint32_t semanticIndex,
|
|
Format format,
|
|
uint32_t inputSlot);
|
|
|
|
private:
|
|
bool CreateD3D12PSO();
|
|
bool CreateD3D12ComputePSO();
|
|
bool EnsureDefaultRootSignature();
|
|
|
|
ID3D12Device* m_device;
|
|
bool m_finalized = false;
|
|
|
|
// Stored configuration (Unity SRP style)
|
|
GraphicsPipelineDesc m_desc;
|
|
InputLayoutDesc m_inputLayoutDesc;
|
|
RasterizerDesc m_rasterizerDesc;
|
|
BlendDesc m_blendDesc;
|
|
DepthStencilStateDesc m_depthStencilDesc;
|
|
uint32_t m_topologyType = 0;
|
|
uint32_t m_renderTargetCount = 1;
|
|
uint32_t m_renderTargetFormats[8] = { 0 };
|
|
uint32_t m_depthStencilFormat = 0;
|
|
uint32_t m_sampleCount = 1;
|
|
uint32_t m_sampleQuality = 0;
|
|
|
|
// Shader bytecodes (set externally)
|
|
D3D12_SHADER_BYTECODE m_vsBytecode = {};
|
|
D3D12_SHADER_BYTECODE m_psBytecode = {};
|
|
D3D12_SHADER_BYTECODE m_gsBytecode = {};
|
|
D3D12_SHADER_BYTECODE m_csBytecode = {};
|
|
class RHIShader* m_computeShader = nullptr;
|
|
std::unique_ptr<class D3D12Shader> m_ownedComputeShader;
|
|
ComPtr<ID3D12RootSignature> m_rootSignature;
|
|
|
|
// D3D12 resources
|
|
ComPtr<ID3D12PipelineState> m_pipelineState;
|
|
ComPtr<ID3D12PipelineState> m_computePipelineState;
|
|
std::vector<D3D12_INPUT_ELEMENT_DESC> m_inputElements;
|
|
};
|
|
|
|
} // namespace RHI
|
|
} // namespace XCEngine
|