Add graphics shader support to RHI pipeline states

This commit is contained in:
2026-03-25 23:19:18 +08:00
parent aaf9cce418
commit 1597181458
10 changed files with 311 additions and 40 deletions

View File

@@ -39,6 +39,27 @@ std::string NarrowAscii(const std::wstring& value) {
return result;
}
bool HasShaderPayload(const ShaderCompileDesc& desc) {
return !desc.source.empty() || !desc.fileName.empty();
}
bool CompileD3D12Shader(const ShaderCompileDesc& desc, D3D12Shader& shader) {
const std::string entryPoint = NarrowAscii(desc.entryPoint);
const std::string profile = NarrowAscii(desc.profile);
const char* entryPointPtr = entryPoint.empty() ? nullptr : entryPoint.c_str();
const char* profilePtr = profile.empty() ? nullptr : profile.c_str();
if (!desc.source.empty()) {
return shader.Compile(desc.source.data(), desc.source.size(), entryPointPtr, profilePtr);
}
if (!desc.fileName.empty()) {
return shader.CompileFromFile(desc.fileName.c_str(), entryPointPtr, profilePtr);
}
return false;
}
} // namespace
D3D12Device::D3D12Device()
@@ -489,6 +510,55 @@ RHIPipelineState* D3D12Device::CreatePipelineState(const GraphicsPipelineDesc& d
pso->SetTopology(desc.topologyType);
pso->SetRenderTargetFormats(desc.renderTargetCount, desc.renderTargetFormats, desc.depthStencilFormat);
pso->SetSampleCount(desc.sampleCount);
const bool hasVertexShader = HasShaderPayload(desc.vertexShader);
const bool hasFragmentShader = HasShaderPayload(desc.fragmentShader);
const bool hasGeometryShader = HasShaderPayload(desc.geometryShader);
if (!hasVertexShader && !hasFragmentShader && !hasGeometryShader) {
return pso;
}
if (!hasVertexShader || !hasFragmentShader) {
delete pso;
return nullptr;
}
auto* rootSignature = CreateRootSignature({});
if (rootSignature == nullptr) {
delete pso;
return nullptr;
}
pso->SetRootSignature(rootSignature->GetRootSignature());
D3D12Shader vertexShader;
D3D12Shader fragmentShader;
D3D12Shader geometryShader;
const bool vertexCompiled = CompileD3D12Shader(desc.vertexShader, vertexShader);
const bool fragmentCompiled = CompileD3D12Shader(desc.fragmentShader, fragmentShader);
const bool geometryCompiled = !hasGeometryShader || CompileD3D12Shader(desc.geometryShader, geometryShader);
if (!vertexCompiled || !fragmentCompiled || !geometryCompiled) {
rootSignature->Shutdown();
delete rootSignature;
delete pso;
return nullptr;
}
pso->SetShaderBytecodes(
vertexShader.GetD3D12Bytecode(),
fragmentShader.GetD3D12Bytecode(),
hasGeometryShader ? geometryShader.GetD3D12Bytecode() : D3D12_SHADER_BYTECODE{});
pso->EnsureValid();
rootSignature->Shutdown();
delete rootSignature;
if (!pso->IsValid()) {
delete pso;
return nullptr;
}
return pso;
}