Add graphics shader support to RHI pipeline states

This commit is contained in:
2026-03-25 23:19:18 +08:00
parent aaf9cce418
commit 1597181458
10 changed files with 311 additions and 40 deletions

View File

@@ -32,6 +32,7 @@ public:
void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override; void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override;
void SetSampleCount(uint32_t count) override; void SetSampleCount(uint32_t count) override;
void SetComputeShader(RHIShader* shader) override; void SetComputeShader(RHIShader* shader) override;
void SetRootSignature(ID3D12RootSignature* rootSignature);
// State query // State query
const RasterizerDesc& GetRasterizerState() const override { return m_rasterizerDesc; } const RasterizerDesc& GetRasterizerState() const override { return m_rasterizerDesc; }
@@ -54,6 +55,7 @@ public:
void Shutdown() override; void Shutdown() override;
ID3D12PipelineState* GetPipelineState() const { return m_pipelineState.Get(); } ID3D12PipelineState* GetPipelineState() const { return m_pipelineState.Get(); }
ID3D12PipelineState* GetComputePipelineState() const { return m_computePipelineState.Get(); } ID3D12PipelineState* GetComputePipelineState() const { return m_computePipelineState.Get(); }
ID3D12RootSignature* GetRootSignature() const { return m_rootSignature.Get(); }
void* GetNativeHandle() override { return m_pipelineState.Get(); } void* GetNativeHandle() override { return m_pipelineState.Get(); }
PipelineType GetType() const override { return HasComputeShader() ? PipelineType::Compute : PipelineType::Graphics; } PipelineType GetType() const override { return HasComputeShader() ? PipelineType::Compute : PipelineType::Graphics; }
@@ -99,7 +101,7 @@ private:
D3D12_SHADER_BYTECODE m_gsBytecode = {}; D3D12_SHADER_BYTECODE m_gsBytecode = {};
D3D12_SHADER_BYTECODE m_csBytecode = {}; D3D12_SHADER_BYTECODE m_csBytecode = {};
class RHIShader* m_computeShader = nullptr; class RHIShader* m_computeShader = nullptr;
ID3D12RootSignature* m_rootSignature = nullptr; ComPtr<ID3D12RootSignature> m_rootSignature;
// D3D12 resources // D3D12 resources
ComPtr<ID3D12PipelineState> m_pipelineState; ComPtr<ID3D12PipelineState> m_pipelineState;

View File

@@ -1,6 +1,7 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <memory>
#include <string> #include <string>
#include "../RHIPipelineState.h" #include "../RHIPipelineState.h"
@@ -121,6 +122,7 @@ public:
void Clear(unsigned int buffers); void Clear(unsigned int buffers);
void AttachShader(unsigned int program); void AttachShader(unsigned int program);
void DetachShader(); void DetachShader();
void SetOwnedGraphicsShader(std::unique_ptr<class OpenGLShader> shader);
const OpenGLDepthStencilState& GetOpenGLDepthStencilState() const; const OpenGLDepthStencilState& GetOpenGLDepthStencilState() const;
const OpenGLBlendState& GetOpenGLBlendState() const; const OpenGLBlendState& GetOpenGLBlendState() const;
@@ -135,6 +137,7 @@ private:
unsigned int m_program = 0; unsigned int m_program = 0;
unsigned int m_computeProgram = 0; unsigned int m_computeProgram = 0;
class RHIShader* m_computeShader = nullptr; class RHIShader* m_computeShader = nullptr;
std::unique_ptr<class OpenGLShader> m_graphicsShader;
bool m_programAttached = false; bool m_programAttached = false;
// OpenGL specific state // OpenGL specific state

View File

@@ -245,17 +245,17 @@ struct RootSignatureDesc {
// ========== Pipeline State Structures (Unity SRP style) ========== // ========== Pipeline State Structures (Unity SRP style) ==========
struct StencilOpDesc { struct StencilOpDesc {
uint32_t failOp = 0; // StencilOp uint32_t failOp = static_cast<uint32_t>(StencilOp::Keep);
uint32_t passOp = 0; // StencilOp uint32_t passOp = static_cast<uint32_t>(StencilOp::Keep);
uint32_t depthFailOp = 0; // StencilOp uint32_t depthFailOp = static_cast<uint32_t>(StencilOp::Keep);
uint32_t func = 0; // ComparisonFunc uint32_t func = static_cast<uint32_t>(ComparisonFunc::Always);
}; };
struct DepthStencilStateDesc { struct DepthStencilStateDesc {
bool depthTestEnable = true; bool depthTestEnable = true;
bool depthWriteEnable = true; bool depthWriteEnable = true;
bool depthBoundsEnable = false; bool depthBoundsEnable = false;
uint32_t depthFunc = 0; // ComparisonFunc uint32_t depthFunc = static_cast<uint32_t>(ComparisonFunc::Less);
bool stencilEnable = false; bool stencilEnable = false;
uint8_t stencilReadMask = 0xFF; uint8_t stencilReadMask = 0xFF;
uint8_t stencilWriteMask = 0xFF; uint8_t stencilWriteMask = 0xFF;
@@ -265,20 +265,20 @@ struct DepthStencilStateDesc {
struct BlendDesc { struct BlendDesc {
bool blendEnable = false; bool blendEnable = false;
uint32_t srcBlend = 0; // BlendFactor uint32_t srcBlend = static_cast<uint32_t>(BlendFactor::One);
uint32_t dstBlend = 0; // BlendFactor uint32_t dstBlend = static_cast<uint32_t>(BlendFactor::Zero);
uint32_t srcBlendAlpha = 0; // BlendFactor uint32_t srcBlendAlpha = static_cast<uint32_t>(BlendFactor::One);
uint32_t dstBlendAlpha = 0; // BlendFactor uint32_t dstBlendAlpha = static_cast<uint32_t>(BlendFactor::Zero);
uint32_t blendOp = 0; // BlendOp uint32_t blendOp = static_cast<uint32_t>(BlendOp::Add);
uint32_t blendOpAlpha = 0; // BlendOp uint32_t blendOpAlpha = static_cast<uint32_t>(BlendOp::Add);
uint8_t colorWriteMask = 0xF; uint8_t colorWriteMask = 0xF;
float blendFactor[4] = {1.0f, 1.0f, 1.0f, 1.0f}; float blendFactor[4] = {1.0f, 1.0f, 1.0f, 1.0f};
}; };
struct RasterizerDesc { struct RasterizerDesc {
uint32_t fillMode = 0; // FillMode (default: Solid) uint32_t fillMode = static_cast<uint32_t>(FillMode::Solid);
uint32_t cullMode = 0; // CullMode (default: Back) uint32_t cullMode = static_cast<uint32_t>(CullMode::Back);
uint32_t frontFace = 0; // FrontFace (default: CounterClockwise) uint32_t frontFace = static_cast<uint32_t>(FrontFace::CounterClockwise);
bool depthClipEnable = true; bool depthClipEnable = true;
bool scissorTestEnable = false; bool scissorTestEnable = false;
bool multisampleEnable = false; bool multisampleEnable = false;
@@ -307,12 +307,15 @@ struct PipelineStateHash {
}; };
struct GraphicsPipelineDesc { struct GraphicsPipelineDesc {
ShaderCompileDesc vertexShader;
ShaderCompileDesc fragmentShader;
ShaderCompileDesc geometryShader;
InputLayoutDesc inputLayout; InputLayoutDesc inputLayout;
RasterizerDesc rasterizerState; RasterizerDesc rasterizerState;
BlendDesc blendState; BlendDesc blendState;
DepthStencilStateDesc depthStencilState; DepthStencilStateDesc depthStencilState;
uint32_t topologyType = 0; // PrimitiveTopologyType uint32_t topologyType = static_cast<uint32_t>(PrimitiveTopologyType::Triangle);
uint32_t renderTargetCount = 1; uint32_t renderTargetCount = 1;
uint32_t renderTargetFormats[8] = { 0 }; // Format uint32_t renderTargetFormats[8] = { 0 }; // Format
uint32_t depthStencilFormat = 0; // Format uint32_t depthStencilFormat = 0; // Format

View File

@@ -127,6 +127,10 @@ void D3D12CommandList::SetPipelineState(RHIPipelineState* pso) {
D3D12PipelineState* d3d12Pso = static_cast<D3D12PipelineState*>(pso); D3D12PipelineState* d3d12Pso = static_cast<D3D12PipelineState*>(pso);
SetPipelineStateInternal(d3d12Pso->GetComputePipelineState()); SetPipelineStateInternal(d3d12Pso->GetComputePipelineState());
} else { } else {
D3D12PipelineState* d3d12Pso = static_cast<D3D12PipelineState*>(pso);
if (d3d12Pso->GetRootSignature() != nullptr) {
SetRootSignature(d3d12Pso->GetRootSignature());
}
SetPipelineStateInternal(static_cast<ID3D12PipelineState*>(pso->GetNativeHandle())); SetPipelineStateInternal(static_cast<ID3D12PipelineState*>(pso->GetNativeHandle()));
} }
} }

View File

@@ -39,6 +39,27 @@ std::string NarrowAscii(const std::wstring& value) {
return result; return result;
} }
bool HasShaderPayload(const ShaderCompileDesc& desc) {
return !desc.source.empty() || !desc.fileName.empty();
}
bool CompileD3D12Shader(const ShaderCompileDesc& desc, D3D12Shader& shader) {
const std::string entryPoint = NarrowAscii(desc.entryPoint);
const std::string profile = NarrowAscii(desc.profile);
const char* entryPointPtr = entryPoint.empty() ? nullptr : entryPoint.c_str();
const char* profilePtr = profile.empty() ? nullptr : profile.c_str();
if (!desc.source.empty()) {
return shader.Compile(desc.source.data(), desc.source.size(), entryPointPtr, profilePtr);
}
if (!desc.fileName.empty()) {
return shader.CompileFromFile(desc.fileName.c_str(), entryPointPtr, profilePtr);
}
return false;
}
} // namespace } // namespace
D3D12Device::D3D12Device() D3D12Device::D3D12Device()
@@ -489,6 +510,55 @@ RHIPipelineState* D3D12Device::CreatePipelineState(const GraphicsPipelineDesc& d
pso->SetTopology(desc.topologyType); pso->SetTopology(desc.topologyType);
pso->SetRenderTargetFormats(desc.renderTargetCount, desc.renderTargetFormats, desc.depthStencilFormat); pso->SetRenderTargetFormats(desc.renderTargetCount, desc.renderTargetFormats, desc.depthStencilFormat);
pso->SetSampleCount(desc.sampleCount); pso->SetSampleCount(desc.sampleCount);
const bool hasVertexShader = HasShaderPayload(desc.vertexShader);
const bool hasFragmentShader = HasShaderPayload(desc.fragmentShader);
const bool hasGeometryShader = HasShaderPayload(desc.geometryShader);
if (!hasVertexShader && !hasFragmentShader && !hasGeometryShader) {
return pso;
}
if (!hasVertexShader || !hasFragmentShader) {
delete pso;
return nullptr;
}
auto* rootSignature = CreateRootSignature({});
if (rootSignature == nullptr) {
delete pso;
return nullptr;
}
pso->SetRootSignature(rootSignature->GetRootSignature());
D3D12Shader vertexShader;
D3D12Shader fragmentShader;
D3D12Shader geometryShader;
const bool vertexCompiled = CompileD3D12Shader(desc.vertexShader, vertexShader);
const bool fragmentCompiled = CompileD3D12Shader(desc.fragmentShader, fragmentShader);
const bool geometryCompiled = !hasGeometryShader || CompileD3D12Shader(desc.geometryShader, geometryShader);
if (!vertexCompiled || !fragmentCompiled || !geometryCompiled) {
rootSignature->Shutdown();
delete rootSignature;
delete pso;
return nullptr;
}
pso->SetShaderBytecodes(
vertexShader.GetD3D12Bytecode(),
fragmentShader.GetD3D12Bytecode(),
hasGeometryShader ? geometryShader.GetD3D12Bytecode() : D3D12_SHADER_BYTECODE{});
pso->EnsureValid();
rootSignature->Shutdown();
delete rootSignature;
if (!pso->IsValid()) {
delete pso;
return nullptr;
}
return pso; return pso;
} }

View File

@@ -127,6 +127,10 @@ void D3D12PipelineState::SetComputeShader(RHIShader* shader) {
m_computeShader = shader; m_computeShader = shader;
} }
void D3D12PipelineState::SetRootSignature(ID3D12RootSignature* rootSignature) {
m_rootSignature = rootSignature;
}
void D3D12PipelineState::SetComputeShaderBytecodes(const D3D12_SHADER_BYTECODE& cs) { void D3D12PipelineState::SetComputeShaderBytecodes(const D3D12_SHADER_BYTECODE& cs) {
m_csBytecode = cs; m_csBytecode = cs;
} }
@@ -146,7 +150,7 @@ bool D3D12PipelineState::CreateD3D12PSO() {
} }
D3D12_GRAPHICS_PIPELINE_STATE_DESC desc = {}; D3D12_GRAPHICS_PIPELINE_STATE_DESC desc = {};
desc.pRootSignature = m_rootSignature; desc.pRootSignature = m_rootSignature.Get();
desc.VS = m_vsBytecode; desc.VS = m_vsBytecode;
desc.PS = m_psBytecode; desc.PS = m_psBytecode;
desc.GS = m_gsBytecode; desc.GS = m_gsBytecode;
@@ -154,36 +158,36 @@ bool D3D12PipelineState::CreateD3D12PSO() {
desc.InputLayout.NumElements = static_cast<UINT>(m_inputElements.size()); desc.InputLayout.NumElements = static_cast<UINT>(m_inputElements.size());
desc.InputLayout.pInputElementDescs = m_inputElements.data(); desc.InputLayout.pInputElementDescs = m_inputElements.data();
desc.RasterizerState.FillMode = static_cast<D3D12_FILL_MODE>(m_rasterizerDesc.fillMode); desc.RasterizerState.FillMode = ToD3D12(static_cast<FillMode>(m_rasterizerDesc.fillMode));
desc.RasterizerState.CullMode = static_cast<D3D12_CULL_MODE>(m_rasterizerDesc.cullMode); desc.RasterizerState.CullMode = ToD3D12(static_cast<CullMode>(m_rasterizerDesc.cullMode));
desc.RasterizerState.FrontCounterClockwise = (m_rasterizerDesc.frontFace != 0); desc.RasterizerState.FrontCounterClockwise = static_cast<FrontFace>(m_rasterizerDesc.frontFace) == FrontFace::CounterClockwise;
desc.RasterizerState.DepthClipEnable = m_rasterizerDesc.depthClipEnable; desc.RasterizerState.DepthClipEnable = m_rasterizerDesc.depthClipEnable;
desc.RasterizerState.MultisampleEnable = m_rasterizerDesc.multisampleEnable; desc.RasterizerState.MultisampleEnable = m_rasterizerDesc.multisampleEnable;
desc.RasterizerState.AntialiasedLineEnable = m_rasterizerDesc.antialiasedLineEnable; desc.RasterizerState.AntialiasedLineEnable = m_rasterizerDesc.antialiasedLineEnable;
desc.BlendState.RenderTarget[0].BlendEnable = m_blendDesc.blendEnable; desc.BlendState.RenderTarget[0].BlendEnable = m_blendDesc.blendEnable;
desc.BlendState.RenderTarget[0].SrcBlend = static_cast<D3D12_BLEND>(m_blendDesc.srcBlend); desc.BlendState.RenderTarget[0].SrcBlend = ToD3D12(static_cast<BlendFactor>(m_blendDesc.srcBlend));
desc.BlendState.RenderTarget[0].DestBlend = static_cast<D3D12_BLEND>(m_blendDesc.dstBlend); desc.BlendState.RenderTarget[0].DestBlend = ToD3D12(static_cast<BlendFactor>(m_blendDesc.dstBlend));
desc.BlendState.RenderTarget[0].BlendOp = static_cast<D3D12_BLEND_OP>(m_blendDesc.blendOp); desc.BlendState.RenderTarget[0].BlendOp = ToD3D12(static_cast<BlendOp>(m_blendDesc.blendOp));
desc.BlendState.RenderTarget[0].SrcBlendAlpha = static_cast<D3D12_BLEND>(m_blendDesc.srcBlendAlpha); desc.BlendState.RenderTarget[0].SrcBlendAlpha = ToD3D12(static_cast<BlendFactor>(m_blendDesc.srcBlendAlpha));
desc.BlendState.RenderTarget[0].DestBlendAlpha = static_cast<D3D12_BLEND>(m_blendDesc.dstBlendAlpha); desc.BlendState.RenderTarget[0].DestBlendAlpha = ToD3D12(static_cast<BlendFactor>(m_blendDesc.dstBlendAlpha));
desc.BlendState.RenderTarget[0].BlendOpAlpha = static_cast<D3D12_BLEND_OP>(m_blendDesc.blendOpAlpha); desc.BlendState.RenderTarget[0].BlendOpAlpha = ToD3D12(static_cast<BlendOp>(m_blendDesc.blendOpAlpha));
desc.BlendState.RenderTarget[0].RenderTargetWriteMask = m_blendDesc.colorWriteMask; desc.BlendState.RenderTarget[0].RenderTargetWriteMask = m_blendDesc.colorWriteMask;
desc.DepthStencilState.DepthEnable = m_depthStencilDesc.depthTestEnable; desc.DepthStencilState.DepthEnable = m_depthStencilDesc.depthTestEnable;
desc.DepthStencilState.DepthWriteMask = m_depthStencilDesc.depthWriteEnable ? D3D12_DEPTH_WRITE_MASK_ALL : D3D12_DEPTH_WRITE_MASK_ZERO; desc.DepthStencilState.DepthWriteMask = m_depthStencilDesc.depthWriteEnable ? D3D12_DEPTH_WRITE_MASK_ALL : D3D12_DEPTH_WRITE_MASK_ZERO;
desc.DepthStencilState.DepthFunc = static_cast<D3D12_COMPARISON_FUNC>(m_depthStencilDesc.depthFunc); desc.DepthStencilState.DepthFunc = ToD3D12(static_cast<ComparisonFunc>(m_depthStencilDesc.depthFunc));
desc.DepthStencilState.StencilEnable = m_depthStencilDesc.stencilEnable; desc.DepthStencilState.StencilEnable = m_depthStencilDesc.stencilEnable;
desc.DepthStencilState.StencilReadMask = m_depthStencilDesc.stencilReadMask; desc.DepthStencilState.StencilReadMask = m_depthStencilDesc.stencilReadMask;
desc.DepthStencilState.StencilWriteMask = m_depthStencilDesc.stencilWriteMask; desc.DepthStencilState.StencilWriteMask = m_depthStencilDesc.stencilWriteMask;
desc.DepthStencilState.FrontFace.StencilFailOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.front.failOp); desc.DepthStencilState.FrontFace.StencilFailOp = ToD3D12(static_cast<StencilOp>(m_depthStencilDesc.front.failOp));
desc.DepthStencilState.FrontFace.StencilPassOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.front.passOp); desc.DepthStencilState.FrontFace.StencilPassOp = ToD3D12(static_cast<StencilOp>(m_depthStencilDesc.front.passOp));
desc.DepthStencilState.FrontFace.StencilDepthFailOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.front.depthFailOp); desc.DepthStencilState.FrontFace.StencilDepthFailOp = ToD3D12(static_cast<StencilOp>(m_depthStencilDesc.front.depthFailOp));
desc.DepthStencilState.FrontFace.StencilFunc = static_cast<D3D12_COMPARISON_FUNC>(m_depthStencilDesc.front.func); desc.DepthStencilState.FrontFace.StencilFunc = ToD3D12(static_cast<ComparisonFunc>(m_depthStencilDesc.front.func));
desc.DepthStencilState.BackFace.StencilFailOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.back.failOp); desc.DepthStencilState.BackFace.StencilFailOp = ToD3D12(static_cast<StencilOp>(m_depthStencilDesc.back.failOp));
desc.DepthStencilState.BackFace.StencilPassOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.back.passOp); desc.DepthStencilState.BackFace.StencilPassOp = ToD3D12(static_cast<StencilOp>(m_depthStencilDesc.back.passOp));
desc.DepthStencilState.BackFace.StencilDepthFailOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.back.depthFailOp); desc.DepthStencilState.BackFace.StencilDepthFailOp = ToD3D12(static_cast<StencilOp>(m_depthStencilDesc.back.depthFailOp));
desc.DepthStencilState.BackFace.StencilFunc = static_cast<D3D12_COMPARISON_FUNC>(m_depthStencilDesc.back.func); desc.DepthStencilState.BackFace.StencilFunc = ToD3D12(static_cast<ComparisonFunc>(m_depthStencilDesc.back.func));
desc.NumRenderTargets = m_renderTargetCount; desc.NumRenderTargets = m_renderTargetCount;
for (uint32_t i = 0; i < m_renderTargetCount && i < 8; ++i) { for (uint32_t i = 0; i < m_renderTargetCount && i < 8; ++i) {
@@ -210,7 +214,7 @@ bool D3D12PipelineState::CreateD3D12ComputePSO() {
} }
D3D12_COMPUTE_PIPELINE_STATE_DESC desc = {}; D3D12_COMPUTE_PIPELINE_STATE_DESC desc = {};
desc.pRootSignature = m_rootSignature; desc.pRootSignature = m_rootSignature.Get();
desc.CS = m_csBytecode; desc.CS = m_csBytecode;
HRESULT hr = m_device->CreateComputePipelineState(&desc, IID_PPV_ARGS(&m_computePipelineState)); HRESULT hr = m_device->CreateComputePipelineState(&desc, IID_PPV_ARGS(&m_computePipelineState));
@@ -225,6 +229,7 @@ bool D3D12PipelineState::CreateD3D12ComputePSO() {
void D3D12PipelineState::Shutdown() { void D3D12PipelineState::Shutdown() {
m_pipelineState.Reset(); m_pipelineState.Reset();
m_computePipelineState.Reset(); m_computePipelineState.Reset();
m_rootSignature.Reset();
m_finalized = false; m_finalized = false;
} }
@@ -260,4 +265,4 @@ D3D12_INPUT_ELEMENT_DESC D3D12PipelineState::CreateInputElement(
} }
} // namespace RHI } // namespace RHI
} // namespace XCEngine } // namespace XCEngine

View File

@@ -598,7 +598,7 @@ void OpenGLCommandList::ClearDepthStencil(RHIResourceView* depthStencil, float d
void OpenGLCommandList::SetPipelineState(RHIPipelineState* pipelineState) { void OpenGLCommandList::SetPipelineState(RHIPipelineState* pipelineState) {
if (pipelineState) { if (pipelineState) {
UseShader(reinterpret_cast<uintptr_t>(pipelineState->GetNativeHandle())); pipelineState->Bind();
} }
} }

View File

@@ -41,6 +41,27 @@ static PFNWGLCREATECONTEXTATTRIBSARBPROC wglCreateContextAttribsARB = nullptr;
namespace XCEngine { namespace XCEngine {
namespace RHI { namespace RHI {
namespace {
std::string NarrowAscii(const std::wstring& value) {
std::string result;
result.reserve(value.size());
for (wchar_t ch : value) {
result.push_back(static_cast<char>(ch));
}
return result;
}
bool HasShaderPayload(const ShaderCompileDesc& desc) {
return !desc.source.empty() || !desc.fileName.empty();
}
std::string SourceToString(const ShaderCompileDesc& desc) {
return std::string(reinterpret_cast<const char*>(desc.source.data()), desc.source.size());
}
} // namespace
OpenGLDevice::OpenGLDevice() OpenGLDevice::OpenGLDevice()
: m_hwnd(nullptr) : m_hwnd(nullptr)
, m_hdc(nullptr) , m_hdc(nullptr)
@@ -405,6 +426,52 @@ RHIPipelineState* OpenGLDevice::CreatePipelineState(const GraphicsPipelineDesc&
pso->SetTopology(desc.topologyType); pso->SetTopology(desc.topologyType);
pso->SetRenderTargetFormats(desc.renderTargetCount, desc.renderTargetFormats, desc.depthStencilFormat); pso->SetRenderTargetFormats(desc.renderTargetCount, desc.renderTargetFormats, desc.depthStencilFormat);
pso->SetSampleCount(desc.sampleCount); pso->SetSampleCount(desc.sampleCount);
const bool hasVertexShader = HasShaderPayload(desc.vertexShader);
const bool hasFragmentShader = HasShaderPayload(desc.fragmentShader);
const bool hasGeometryShader = HasShaderPayload(desc.geometryShader);
if (!hasVertexShader && !hasFragmentShader && !hasGeometryShader) {
return pso;
}
if (!hasVertexShader || !hasFragmentShader) {
delete pso;
return nullptr;
}
if (!MakeContextCurrent()) {
delete pso;
return nullptr;
}
auto graphicsShader = std::make_unique<OpenGLShader>();
bool compiled = false;
if (!desc.vertexShader.source.empty() && !desc.fragmentShader.source.empty()) {
const std::string vertexSource = SourceToString(desc.vertexShader);
const std::string fragmentSource = SourceToString(desc.fragmentShader);
if (hasGeometryShader && !desc.geometryShader.source.empty()) {
const std::string geometrySource = SourceToString(desc.geometryShader);
compiled = graphicsShader->Compile(vertexSource.c_str(), fragmentSource.c_str(), geometrySource.c_str());
} else {
compiled = graphicsShader->Compile(vertexSource.c_str(), fragmentSource.c_str());
}
} else if (!desc.vertexShader.fileName.empty() && !desc.fragmentShader.fileName.empty()) {
const std::string vertexPath = NarrowAscii(desc.vertexShader.fileName);
const std::string fragmentPath = NarrowAscii(desc.fragmentShader.fileName);
if (hasGeometryShader && !desc.geometryShader.fileName.empty()) {
const std::string geometryPath = NarrowAscii(desc.geometryShader.fileName);
compiled = graphicsShader->CompileFromFile(vertexPath.c_str(), fragmentPath.c_str(), geometryPath.c_str());
} else {
compiled = graphicsShader->CompileFromFile(vertexPath.c_str(), fragmentPath.c_str());
}
}
if (!compiled) {
delete pso;
return nullptr;
}
pso->SetOwnedGraphicsShader(std::move(graphicsShader));
return pso; return pso;
} }

View File

@@ -1,6 +1,7 @@
#include "XCEngine/RHI/OpenGL/OpenGLPipelineState.h" #include "XCEngine/RHI/OpenGL/OpenGLPipelineState.h"
#include "XCEngine/RHI/OpenGL/OpenGLShader.h" #include "XCEngine/RHI/OpenGL/OpenGLShader.h"
#include "XCEngine/RHI/OpenGL/OpenGLEnums.h" #include "XCEngine/RHI/OpenGL/OpenGLEnums.h"
#include <cstring>
#include <glad/glad.h> #include <glad/glad.h>
namespace XCEngine { namespace XCEngine {
@@ -19,14 +20,42 @@ void OpenGLPipelineState::SetInputLayout(const InputLayoutDesc& layout) {
void OpenGLPipelineState::SetRasterizerState(const RasterizerDesc& state) { void OpenGLPipelineState::SetRasterizerState(const RasterizerDesc& state) {
m_rasterizerDesc = state; m_rasterizerDesc = state;
m_glRasterizerState.cullFaceEnable = static_cast<CullMode>(state.cullMode) != CullMode::None;
m_glRasterizerState.cullFace = static_cast<CullMode>(state.cullMode);
m_glRasterizerState.frontFace = static_cast<FrontFace>(state.frontFace);
m_glRasterizerState.polygonMode = static_cast<FillMode>(state.fillMode);
m_glRasterizerState.depthClipEnable = state.depthClipEnable;
m_glRasterizerState.scissorTestEnable = state.scissorTestEnable;
m_glRasterizerState.multisampleEnable = state.multisampleEnable;
m_glRasterizerState.polygonOffsetFactor = state.slopeScaledDepthBias;
m_glRasterizerState.polygonOffsetUnits = state.depthBiasClamp;
} }
void OpenGLPipelineState::SetBlendState(const BlendDesc& state) { void OpenGLPipelineState::SetBlendState(const BlendDesc& state) {
m_blendDesc = state; m_blendDesc = state;
m_glBlendState.blendEnable = state.blendEnable;
m_glBlendState.srcBlend = static_cast<BlendFactor>(state.srcBlend);
m_glBlendState.dstBlend = static_cast<BlendFactor>(state.dstBlend);
m_glBlendState.srcBlendAlpha = static_cast<BlendFactor>(state.srcBlendAlpha);
m_glBlendState.dstBlendAlpha = static_cast<BlendFactor>(state.dstBlendAlpha);
m_glBlendState.blendOp = static_cast<BlendOp>(state.blendOp);
m_glBlendState.blendOpAlpha = static_cast<BlendOp>(state.blendOpAlpha);
m_glBlendState.colorWriteMask = state.colorWriteMask;
std::memcpy(m_glBlendState.blendFactor, state.blendFactor, sizeof(state.blendFactor));
} }
void OpenGLPipelineState::SetDepthStencilState(const DepthStencilStateDesc& state) { void OpenGLPipelineState::SetDepthStencilState(const DepthStencilStateDesc& state) {
m_depthStencilDesc = state; m_depthStencilDesc = state;
m_glDepthStencilState.depthTestEnable = state.depthTestEnable;
m_glDepthStencilState.depthWriteEnable = state.depthWriteEnable;
m_glDepthStencilState.depthFunc = static_cast<ComparisonFunc>(state.depthFunc);
m_glDepthStencilState.stencilEnable = state.stencilEnable;
m_glDepthStencilState.stencilReadMask = state.stencilReadMask;
m_glDepthStencilState.stencilWriteMask = state.stencilWriteMask;
m_glDepthStencilState.stencilFunc = static_cast<ComparisonFunc>(state.front.func);
m_glDepthStencilState.stencilFailOp = static_cast<StencilOp>(state.front.failOp);
m_glDepthStencilState.stencilDepthFailOp = static_cast<StencilOp>(state.front.depthFailOp);
m_glDepthStencilState.stencilDepthPassOp = static_cast<StencilOp>(state.front.passOp);
} }
void OpenGLPipelineState::SetTopology(uint32_t topologyType) { void OpenGLPipelineState::SetTopology(uint32_t topologyType) {
@@ -55,6 +84,7 @@ PipelineStateHash OpenGLPipelineState::GetHash() const {
} }
void OpenGLPipelineState::Shutdown() { void OpenGLPipelineState::Shutdown() {
m_graphicsShader.reset();
m_program = 0; m_program = 0;
m_computeProgram = 0; m_computeProgram = 0;
m_computeShader = nullptr; m_computeShader = nullptr;
@@ -64,7 +94,7 @@ void OpenGLPipelineState::Shutdown() {
void OpenGLPipelineState::Bind() { void OpenGLPipelineState::Bind() {
if (HasComputeShader()) { if (HasComputeShader()) {
glUseProgram(m_computeProgram); glUseProgram(m_computeProgram);
} else if (m_programAttached) { } else if (m_programAttached && m_program != 0) {
glUseProgram(m_program); glUseProgram(m_program);
} }
Apply(); Apply();
@@ -213,6 +243,15 @@ void OpenGLPipelineState::DetachShader() {
glUseProgram(0); glUseProgram(0);
} }
void OpenGLPipelineState::SetOwnedGraphicsShader(std::unique_ptr<OpenGLShader> shader) {
m_graphicsShader = std::move(shader);
if (m_graphicsShader) {
SetProgram(m_graphicsShader->GetID());
} else {
DetachShader();
}
}
const OpenGLDepthStencilState& OpenGLPipelineState::GetOpenGLDepthStencilState() const { const OpenGLDepthStencilState& OpenGLPipelineState::GetOpenGLDepthStencilState() const {
return m_glDepthStencilState; return m_glDepthStencilState;
} }

View File

@@ -169,4 +169,82 @@ TEST_P(RHITestFixture, PipelineState_GetType) {
pso->Shutdown(); pso->Shutdown();
delete pso; delete pso;
} }
TEST_P(RHITestFixture, PipelineState_Create_GraphicsShadersFromDesc) {
GraphicsPipelineDesc desc = {};
desc.topologyType = static_cast<uint32_t>(PrimitiveTopologyType::Triangle);
desc.renderTargetFormats[0] = static_cast<uint32_t>(Format::R8G8B8A8_UNorm);
desc.depthStencilFormat = static_cast<uint32_t>(Format::Unknown);
InputElementDesc position = {};
position.semanticName = "POSITION";
position.semanticIndex = 0;
position.format = static_cast<uint32_t>(Format::R32G32B32A32_Float);
position.inputSlot = 0;
position.alignedByteOffset = 0;
desc.inputLayout.elements.push_back(position);
if (GetBackendType() == RHIType::D3D12) {
const char* hlslSource = R"(
struct VSInput {
float4 position : POSITION;
};
struct PSInput {
float4 position : SV_POSITION;
};
PSInput MainVS(VSInput input) {
PSInput output;
output.position = input.position;
return output;
}
float4 MainPS(PSInput input) : SV_TARGET {
return float4(1.0f, 0.0f, 0.0f, 1.0f);
}
)";
desc.vertexShader.source.assign(hlslSource, hlslSource + std::strlen(hlslSource));
desc.vertexShader.entryPoint = L"MainVS";
desc.vertexShader.profile = L"vs_5_0";
desc.vertexShader.sourceLanguage = ShaderLanguage::HLSL;
desc.fragmentShader.source.assign(hlslSource, hlslSource + std::strlen(hlslSource));
desc.fragmentShader.entryPoint = L"MainPS";
desc.fragmentShader.profile = L"ps_5_0";
desc.fragmentShader.sourceLanguage = ShaderLanguage::HLSL;
} else {
const char* vertexSource = R"(#version 430
layout(location = 0) in vec4 aPosition;
void main() {
gl_Position = aPosition;
}
)";
const char* fragmentSource = R"(#version 430
layout(location = 0) out vec4 fragColor;
void main() {
fragColor = vec4(1.0, 0.0, 0.0, 1.0);
}
)";
desc.vertexShader.source.assign(vertexSource, vertexSource + std::strlen(vertexSource));
desc.vertexShader.sourceLanguage = ShaderLanguage::GLSL;
desc.vertexShader.profile = L"vs_4_30";
desc.fragmentShader.source.assign(fragmentSource, fragmentSource + std::strlen(fragmentSource));
desc.fragmentShader.sourceLanguage = ShaderLanguage::GLSL;
desc.fragmentShader.profile = L"fs_4_30";
}
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
ASSERT_NE(pso, nullptr);
EXPECT_TRUE(pso->IsValid());
EXPECT_NE(pso->GetNativeHandle(), nullptr);
pso->Shutdown();
delete pso;
}