Add graphics shader support to RHI pipeline states
This commit is contained in:
@@ -127,6 +127,10 @@ void D3D12CommandList::SetPipelineState(RHIPipelineState* pso) {
|
||||
D3D12PipelineState* d3d12Pso = static_cast<D3D12PipelineState*>(pso);
|
||||
SetPipelineStateInternal(d3d12Pso->GetComputePipelineState());
|
||||
} else {
|
||||
D3D12PipelineState* d3d12Pso = static_cast<D3D12PipelineState*>(pso);
|
||||
if (d3d12Pso->GetRootSignature() != nullptr) {
|
||||
SetRootSignature(d3d12Pso->GetRootSignature());
|
||||
}
|
||||
SetPipelineStateInternal(static_cast<ID3D12PipelineState*>(pso->GetNativeHandle()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,6 +39,27 @@ std::string NarrowAscii(const std::wstring& value) {
|
||||
return result;
|
||||
}
|
||||
|
||||
bool HasShaderPayload(const ShaderCompileDesc& desc) {
|
||||
return !desc.source.empty() || !desc.fileName.empty();
|
||||
}
|
||||
|
||||
bool CompileD3D12Shader(const ShaderCompileDesc& desc, D3D12Shader& shader) {
|
||||
const std::string entryPoint = NarrowAscii(desc.entryPoint);
|
||||
const std::string profile = NarrowAscii(desc.profile);
|
||||
const char* entryPointPtr = entryPoint.empty() ? nullptr : entryPoint.c_str();
|
||||
const char* profilePtr = profile.empty() ? nullptr : profile.c_str();
|
||||
|
||||
if (!desc.source.empty()) {
|
||||
return shader.Compile(desc.source.data(), desc.source.size(), entryPointPtr, profilePtr);
|
||||
}
|
||||
|
||||
if (!desc.fileName.empty()) {
|
||||
return shader.CompileFromFile(desc.fileName.c_str(), entryPointPtr, profilePtr);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
D3D12Device::D3D12Device()
|
||||
@@ -489,6 +510,55 @@ RHIPipelineState* D3D12Device::CreatePipelineState(const GraphicsPipelineDesc& d
|
||||
pso->SetTopology(desc.topologyType);
|
||||
pso->SetRenderTargetFormats(desc.renderTargetCount, desc.renderTargetFormats, desc.depthStencilFormat);
|
||||
pso->SetSampleCount(desc.sampleCount);
|
||||
|
||||
const bool hasVertexShader = HasShaderPayload(desc.vertexShader);
|
||||
const bool hasFragmentShader = HasShaderPayload(desc.fragmentShader);
|
||||
const bool hasGeometryShader = HasShaderPayload(desc.geometryShader);
|
||||
if (!hasVertexShader && !hasFragmentShader && !hasGeometryShader) {
|
||||
return pso;
|
||||
}
|
||||
|
||||
if (!hasVertexShader || !hasFragmentShader) {
|
||||
delete pso;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* rootSignature = CreateRootSignature({});
|
||||
if (rootSignature == nullptr) {
|
||||
delete pso;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
pso->SetRootSignature(rootSignature->GetRootSignature());
|
||||
|
||||
D3D12Shader vertexShader;
|
||||
D3D12Shader fragmentShader;
|
||||
D3D12Shader geometryShader;
|
||||
const bool vertexCompiled = CompileD3D12Shader(desc.vertexShader, vertexShader);
|
||||
const bool fragmentCompiled = CompileD3D12Shader(desc.fragmentShader, fragmentShader);
|
||||
const bool geometryCompiled = !hasGeometryShader || CompileD3D12Shader(desc.geometryShader, geometryShader);
|
||||
|
||||
if (!vertexCompiled || !fragmentCompiled || !geometryCompiled) {
|
||||
rootSignature->Shutdown();
|
||||
delete rootSignature;
|
||||
delete pso;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
pso->SetShaderBytecodes(
|
||||
vertexShader.GetD3D12Bytecode(),
|
||||
fragmentShader.GetD3D12Bytecode(),
|
||||
hasGeometryShader ? geometryShader.GetD3D12Bytecode() : D3D12_SHADER_BYTECODE{});
|
||||
pso->EnsureValid();
|
||||
|
||||
rootSignature->Shutdown();
|
||||
delete rootSignature;
|
||||
|
||||
if (!pso->IsValid()) {
|
||||
delete pso;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return pso;
|
||||
}
|
||||
|
||||
|
||||
@@ -127,6 +127,10 @@ void D3D12PipelineState::SetComputeShader(RHIShader* shader) {
|
||||
m_computeShader = shader;
|
||||
}
|
||||
|
||||
void D3D12PipelineState::SetRootSignature(ID3D12RootSignature* rootSignature) {
|
||||
m_rootSignature = rootSignature;
|
||||
}
|
||||
|
||||
void D3D12PipelineState::SetComputeShaderBytecodes(const D3D12_SHADER_BYTECODE& cs) {
|
||||
m_csBytecode = cs;
|
||||
}
|
||||
@@ -146,7 +150,7 @@ bool D3D12PipelineState::CreateD3D12PSO() {
|
||||
}
|
||||
|
||||
D3D12_GRAPHICS_PIPELINE_STATE_DESC desc = {};
|
||||
desc.pRootSignature = m_rootSignature;
|
||||
desc.pRootSignature = m_rootSignature.Get();
|
||||
desc.VS = m_vsBytecode;
|
||||
desc.PS = m_psBytecode;
|
||||
desc.GS = m_gsBytecode;
|
||||
@@ -154,36 +158,36 @@ bool D3D12PipelineState::CreateD3D12PSO() {
|
||||
desc.InputLayout.NumElements = static_cast<UINT>(m_inputElements.size());
|
||||
desc.InputLayout.pInputElementDescs = m_inputElements.data();
|
||||
|
||||
desc.RasterizerState.FillMode = static_cast<D3D12_FILL_MODE>(m_rasterizerDesc.fillMode);
|
||||
desc.RasterizerState.CullMode = static_cast<D3D12_CULL_MODE>(m_rasterizerDesc.cullMode);
|
||||
desc.RasterizerState.FrontCounterClockwise = (m_rasterizerDesc.frontFace != 0);
|
||||
desc.RasterizerState.FillMode = ToD3D12(static_cast<FillMode>(m_rasterizerDesc.fillMode));
|
||||
desc.RasterizerState.CullMode = ToD3D12(static_cast<CullMode>(m_rasterizerDesc.cullMode));
|
||||
desc.RasterizerState.FrontCounterClockwise = static_cast<FrontFace>(m_rasterizerDesc.frontFace) == FrontFace::CounterClockwise;
|
||||
desc.RasterizerState.DepthClipEnable = m_rasterizerDesc.depthClipEnable;
|
||||
desc.RasterizerState.MultisampleEnable = m_rasterizerDesc.multisampleEnable;
|
||||
desc.RasterizerState.AntialiasedLineEnable = m_rasterizerDesc.antialiasedLineEnable;
|
||||
|
||||
desc.BlendState.RenderTarget[0].BlendEnable = m_blendDesc.blendEnable;
|
||||
desc.BlendState.RenderTarget[0].SrcBlend = static_cast<D3D12_BLEND>(m_blendDesc.srcBlend);
|
||||
desc.BlendState.RenderTarget[0].DestBlend = static_cast<D3D12_BLEND>(m_blendDesc.dstBlend);
|
||||
desc.BlendState.RenderTarget[0].BlendOp = static_cast<D3D12_BLEND_OP>(m_blendDesc.blendOp);
|
||||
desc.BlendState.RenderTarget[0].SrcBlendAlpha = static_cast<D3D12_BLEND>(m_blendDesc.srcBlendAlpha);
|
||||
desc.BlendState.RenderTarget[0].DestBlendAlpha = static_cast<D3D12_BLEND>(m_blendDesc.dstBlendAlpha);
|
||||
desc.BlendState.RenderTarget[0].BlendOpAlpha = static_cast<D3D12_BLEND_OP>(m_blendDesc.blendOpAlpha);
|
||||
desc.BlendState.RenderTarget[0].SrcBlend = ToD3D12(static_cast<BlendFactor>(m_blendDesc.srcBlend));
|
||||
desc.BlendState.RenderTarget[0].DestBlend = ToD3D12(static_cast<BlendFactor>(m_blendDesc.dstBlend));
|
||||
desc.BlendState.RenderTarget[0].BlendOp = ToD3D12(static_cast<BlendOp>(m_blendDesc.blendOp));
|
||||
desc.BlendState.RenderTarget[0].SrcBlendAlpha = ToD3D12(static_cast<BlendFactor>(m_blendDesc.srcBlendAlpha));
|
||||
desc.BlendState.RenderTarget[0].DestBlendAlpha = ToD3D12(static_cast<BlendFactor>(m_blendDesc.dstBlendAlpha));
|
||||
desc.BlendState.RenderTarget[0].BlendOpAlpha = ToD3D12(static_cast<BlendOp>(m_blendDesc.blendOpAlpha));
|
||||
desc.BlendState.RenderTarget[0].RenderTargetWriteMask = m_blendDesc.colorWriteMask;
|
||||
|
||||
desc.DepthStencilState.DepthEnable = m_depthStencilDesc.depthTestEnable;
|
||||
desc.DepthStencilState.DepthWriteMask = m_depthStencilDesc.depthWriteEnable ? D3D12_DEPTH_WRITE_MASK_ALL : D3D12_DEPTH_WRITE_MASK_ZERO;
|
||||
desc.DepthStencilState.DepthFunc = static_cast<D3D12_COMPARISON_FUNC>(m_depthStencilDesc.depthFunc);
|
||||
desc.DepthStencilState.DepthFunc = ToD3D12(static_cast<ComparisonFunc>(m_depthStencilDesc.depthFunc));
|
||||
desc.DepthStencilState.StencilEnable = m_depthStencilDesc.stencilEnable;
|
||||
desc.DepthStencilState.StencilReadMask = m_depthStencilDesc.stencilReadMask;
|
||||
desc.DepthStencilState.StencilWriteMask = m_depthStencilDesc.stencilWriteMask;
|
||||
desc.DepthStencilState.FrontFace.StencilFailOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.front.failOp);
|
||||
desc.DepthStencilState.FrontFace.StencilPassOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.front.passOp);
|
||||
desc.DepthStencilState.FrontFace.StencilDepthFailOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.front.depthFailOp);
|
||||
desc.DepthStencilState.FrontFace.StencilFunc = static_cast<D3D12_COMPARISON_FUNC>(m_depthStencilDesc.front.func);
|
||||
desc.DepthStencilState.BackFace.StencilFailOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.back.failOp);
|
||||
desc.DepthStencilState.BackFace.StencilPassOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.back.passOp);
|
||||
desc.DepthStencilState.BackFace.StencilDepthFailOp = static_cast<D3D12_STENCIL_OP>(m_depthStencilDesc.back.depthFailOp);
|
||||
desc.DepthStencilState.BackFace.StencilFunc = static_cast<D3D12_COMPARISON_FUNC>(m_depthStencilDesc.back.func);
|
||||
desc.DepthStencilState.FrontFace.StencilFailOp = ToD3D12(static_cast<StencilOp>(m_depthStencilDesc.front.failOp));
|
||||
desc.DepthStencilState.FrontFace.StencilPassOp = ToD3D12(static_cast<StencilOp>(m_depthStencilDesc.front.passOp));
|
||||
desc.DepthStencilState.FrontFace.StencilDepthFailOp = ToD3D12(static_cast<StencilOp>(m_depthStencilDesc.front.depthFailOp));
|
||||
desc.DepthStencilState.FrontFace.StencilFunc = ToD3D12(static_cast<ComparisonFunc>(m_depthStencilDesc.front.func));
|
||||
desc.DepthStencilState.BackFace.StencilFailOp = ToD3D12(static_cast<StencilOp>(m_depthStencilDesc.back.failOp));
|
||||
desc.DepthStencilState.BackFace.StencilPassOp = ToD3D12(static_cast<StencilOp>(m_depthStencilDesc.back.passOp));
|
||||
desc.DepthStencilState.BackFace.StencilDepthFailOp = ToD3D12(static_cast<StencilOp>(m_depthStencilDesc.back.depthFailOp));
|
||||
desc.DepthStencilState.BackFace.StencilFunc = ToD3D12(static_cast<ComparisonFunc>(m_depthStencilDesc.back.func));
|
||||
|
||||
desc.NumRenderTargets = m_renderTargetCount;
|
||||
for (uint32_t i = 0; i < m_renderTargetCount && i < 8; ++i) {
|
||||
@@ -210,7 +214,7 @@ bool D3D12PipelineState::CreateD3D12ComputePSO() {
|
||||
}
|
||||
|
||||
D3D12_COMPUTE_PIPELINE_STATE_DESC desc = {};
|
||||
desc.pRootSignature = m_rootSignature;
|
||||
desc.pRootSignature = m_rootSignature.Get();
|
||||
desc.CS = m_csBytecode;
|
||||
|
||||
HRESULT hr = m_device->CreateComputePipelineState(&desc, IID_PPV_ARGS(&m_computePipelineState));
|
||||
@@ -225,6 +229,7 @@ bool D3D12PipelineState::CreateD3D12ComputePSO() {
|
||||
void D3D12PipelineState::Shutdown() {
|
||||
m_pipelineState.Reset();
|
||||
m_computePipelineState.Reset();
|
||||
m_rootSignature.Reset();
|
||||
m_finalized = false;
|
||||
}
|
||||
|
||||
@@ -260,4 +265,4 @@ D3D12_INPUT_ELEMENT_DESC D3D12PipelineState::CreateInputElement(
|
||||
}
|
||||
|
||||
} // namespace RHI
|
||||
} // namespace XCEngine
|
||||
} // namespace XCEngine
|
||||
|
||||
@@ -598,7 +598,7 @@ void OpenGLCommandList::ClearDepthStencil(RHIResourceView* depthStencil, float d
|
||||
|
||||
void OpenGLCommandList::SetPipelineState(RHIPipelineState* pipelineState) {
|
||||
if (pipelineState) {
|
||||
UseShader(reinterpret_cast<uintptr_t>(pipelineState->GetNativeHandle()));
|
||||
pipelineState->Bind();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -41,6 +41,27 @@ static PFNWGLCREATECONTEXTATTRIBSARBPROC wglCreateContextAttribsARB = nullptr;
|
||||
namespace XCEngine {
|
||||
namespace RHI {
|
||||
|
||||
namespace {
|
||||
|
||||
std::string NarrowAscii(const std::wstring& value) {
|
||||
std::string result;
|
||||
result.reserve(value.size());
|
||||
for (wchar_t ch : value) {
|
||||
result.push_back(static_cast<char>(ch));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool HasShaderPayload(const ShaderCompileDesc& desc) {
|
||||
return !desc.source.empty() || !desc.fileName.empty();
|
||||
}
|
||||
|
||||
std::string SourceToString(const ShaderCompileDesc& desc) {
|
||||
return std::string(reinterpret_cast<const char*>(desc.source.data()), desc.source.size());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
OpenGLDevice::OpenGLDevice()
|
||||
: m_hwnd(nullptr)
|
||||
, m_hdc(nullptr)
|
||||
@@ -405,6 +426,52 @@ RHIPipelineState* OpenGLDevice::CreatePipelineState(const GraphicsPipelineDesc&
|
||||
pso->SetTopology(desc.topologyType);
|
||||
pso->SetRenderTargetFormats(desc.renderTargetCount, desc.renderTargetFormats, desc.depthStencilFormat);
|
||||
pso->SetSampleCount(desc.sampleCount);
|
||||
|
||||
const bool hasVertexShader = HasShaderPayload(desc.vertexShader);
|
||||
const bool hasFragmentShader = HasShaderPayload(desc.fragmentShader);
|
||||
const bool hasGeometryShader = HasShaderPayload(desc.geometryShader);
|
||||
if (!hasVertexShader && !hasFragmentShader && !hasGeometryShader) {
|
||||
return pso;
|
||||
}
|
||||
|
||||
if (!hasVertexShader || !hasFragmentShader) {
|
||||
delete pso;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!MakeContextCurrent()) {
|
||||
delete pso;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto graphicsShader = std::make_unique<OpenGLShader>();
|
||||
bool compiled = false;
|
||||
if (!desc.vertexShader.source.empty() && !desc.fragmentShader.source.empty()) {
|
||||
const std::string vertexSource = SourceToString(desc.vertexShader);
|
||||
const std::string fragmentSource = SourceToString(desc.fragmentShader);
|
||||
if (hasGeometryShader && !desc.geometryShader.source.empty()) {
|
||||
const std::string geometrySource = SourceToString(desc.geometryShader);
|
||||
compiled = graphicsShader->Compile(vertexSource.c_str(), fragmentSource.c_str(), geometrySource.c_str());
|
||||
} else {
|
||||
compiled = graphicsShader->Compile(vertexSource.c_str(), fragmentSource.c_str());
|
||||
}
|
||||
} else if (!desc.vertexShader.fileName.empty() && !desc.fragmentShader.fileName.empty()) {
|
||||
const std::string vertexPath = NarrowAscii(desc.vertexShader.fileName);
|
||||
const std::string fragmentPath = NarrowAscii(desc.fragmentShader.fileName);
|
||||
if (hasGeometryShader && !desc.geometryShader.fileName.empty()) {
|
||||
const std::string geometryPath = NarrowAscii(desc.geometryShader.fileName);
|
||||
compiled = graphicsShader->CompileFromFile(vertexPath.c_str(), fragmentPath.c_str(), geometryPath.c_str());
|
||||
} else {
|
||||
compiled = graphicsShader->CompileFromFile(vertexPath.c_str(), fragmentPath.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (!compiled) {
|
||||
delete pso;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
pso->SetOwnedGraphicsShader(std::move(graphicsShader));
|
||||
return pso;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "XCEngine/RHI/OpenGL/OpenGLPipelineState.h"
|
||||
#include "XCEngine/RHI/OpenGL/OpenGLShader.h"
|
||||
#include "XCEngine/RHI/OpenGL/OpenGLEnums.h"
|
||||
#include <cstring>
|
||||
#include <glad/glad.h>
|
||||
|
||||
namespace XCEngine {
|
||||
@@ -19,14 +20,42 @@ void OpenGLPipelineState::SetInputLayout(const InputLayoutDesc& layout) {
|
||||
|
||||
void OpenGLPipelineState::SetRasterizerState(const RasterizerDesc& state) {
|
||||
m_rasterizerDesc = state;
|
||||
m_glRasterizerState.cullFaceEnable = static_cast<CullMode>(state.cullMode) != CullMode::None;
|
||||
m_glRasterizerState.cullFace = static_cast<CullMode>(state.cullMode);
|
||||
m_glRasterizerState.frontFace = static_cast<FrontFace>(state.frontFace);
|
||||
m_glRasterizerState.polygonMode = static_cast<FillMode>(state.fillMode);
|
||||
m_glRasterizerState.depthClipEnable = state.depthClipEnable;
|
||||
m_glRasterizerState.scissorTestEnable = state.scissorTestEnable;
|
||||
m_glRasterizerState.multisampleEnable = state.multisampleEnable;
|
||||
m_glRasterizerState.polygonOffsetFactor = state.slopeScaledDepthBias;
|
||||
m_glRasterizerState.polygonOffsetUnits = state.depthBiasClamp;
|
||||
}
|
||||
|
||||
void OpenGLPipelineState::SetBlendState(const BlendDesc& state) {
|
||||
m_blendDesc = state;
|
||||
m_glBlendState.blendEnable = state.blendEnable;
|
||||
m_glBlendState.srcBlend = static_cast<BlendFactor>(state.srcBlend);
|
||||
m_glBlendState.dstBlend = static_cast<BlendFactor>(state.dstBlend);
|
||||
m_glBlendState.srcBlendAlpha = static_cast<BlendFactor>(state.srcBlendAlpha);
|
||||
m_glBlendState.dstBlendAlpha = static_cast<BlendFactor>(state.dstBlendAlpha);
|
||||
m_glBlendState.blendOp = static_cast<BlendOp>(state.blendOp);
|
||||
m_glBlendState.blendOpAlpha = static_cast<BlendOp>(state.blendOpAlpha);
|
||||
m_glBlendState.colorWriteMask = state.colorWriteMask;
|
||||
std::memcpy(m_glBlendState.blendFactor, state.blendFactor, sizeof(state.blendFactor));
|
||||
}
|
||||
|
||||
void OpenGLPipelineState::SetDepthStencilState(const DepthStencilStateDesc& state) {
|
||||
m_depthStencilDesc = state;
|
||||
m_glDepthStencilState.depthTestEnable = state.depthTestEnable;
|
||||
m_glDepthStencilState.depthWriteEnable = state.depthWriteEnable;
|
||||
m_glDepthStencilState.depthFunc = static_cast<ComparisonFunc>(state.depthFunc);
|
||||
m_glDepthStencilState.stencilEnable = state.stencilEnable;
|
||||
m_glDepthStencilState.stencilReadMask = state.stencilReadMask;
|
||||
m_glDepthStencilState.stencilWriteMask = state.stencilWriteMask;
|
||||
m_glDepthStencilState.stencilFunc = static_cast<ComparisonFunc>(state.front.func);
|
||||
m_glDepthStencilState.stencilFailOp = static_cast<StencilOp>(state.front.failOp);
|
||||
m_glDepthStencilState.stencilDepthFailOp = static_cast<StencilOp>(state.front.depthFailOp);
|
||||
m_glDepthStencilState.stencilDepthPassOp = static_cast<StencilOp>(state.front.passOp);
|
||||
}
|
||||
|
||||
void OpenGLPipelineState::SetTopology(uint32_t topologyType) {
|
||||
@@ -55,6 +84,7 @@ PipelineStateHash OpenGLPipelineState::GetHash() const {
|
||||
}
|
||||
|
||||
void OpenGLPipelineState::Shutdown() {
|
||||
m_graphicsShader.reset();
|
||||
m_program = 0;
|
||||
m_computeProgram = 0;
|
||||
m_computeShader = nullptr;
|
||||
@@ -64,7 +94,7 @@ void OpenGLPipelineState::Shutdown() {
|
||||
void OpenGLPipelineState::Bind() {
|
||||
if (HasComputeShader()) {
|
||||
glUseProgram(m_computeProgram);
|
||||
} else if (m_programAttached) {
|
||||
} else if (m_programAttached && m_program != 0) {
|
||||
glUseProgram(m_program);
|
||||
}
|
||||
Apply();
|
||||
@@ -213,6 +243,15 @@ void OpenGLPipelineState::DetachShader() {
|
||||
glUseProgram(0);
|
||||
}
|
||||
|
||||
void OpenGLPipelineState::SetOwnedGraphicsShader(std::unique_ptr<OpenGLShader> shader) {
|
||||
m_graphicsShader = std::move(shader);
|
||||
if (m_graphicsShader) {
|
||||
SetProgram(m_graphicsShader->GetID());
|
||||
} else {
|
||||
DetachShader();
|
||||
}
|
||||
}
|
||||
|
||||
const OpenGLDepthStencilState& OpenGLPipelineState::GetOpenGLDepthStencilState() const {
|
||||
return m_glDepthStencilState;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user