Fix D3D12 compute pipeline unit coverage

This commit is contained in:
2026-03-27 21:48:23 +08:00
parent 126860e79d
commit 1ea00a1879
5 changed files with 113 additions and 118 deletions

View File

@@ -139,16 +139,32 @@ void D3D12CommandList::AliasBarrier(ID3D12Resource* beforeResource, ID3D12Resour
void D3D12CommandList::SetPipelineState(RHIPipelineState* pso) {
if (!pso) return;
pso->EnsureValid();
D3D12PipelineState* d3d12Pso = static_cast<D3D12PipelineState*>(pso);
if (pso->GetType() == PipelineType::Compute) {
D3D12PipelineState* d3d12Pso = static_cast<D3D12PipelineState*>(pso);
SetPipelineStateInternal(d3d12Pso->GetComputePipelineState());
ID3D12RootSignature* computeRootSignature = d3d12Pso->GetRootSignature();
ID3D12PipelineState* computePipelineState = d3d12Pso->GetComputePipelineState();
if (computeRootSignature == nullptr || computePipelineState == nullptr) {
return;
}
m_commandList->SetComputeRootSignature(computeRootSignature);
m_commandList->SetPipelineState(computePipelineState);
m_currentPipelineState = computePipelineState;
m_currentRootSignature = computeRootSignature;
} else {
D3D12PipelineState* d3d12Pso = static_cast<D3D12PipelineState*>(pso);
if (d3d12Pso->GetRootSignature() != nullptr) {
SetRootSignature(d3d12Pso->GetRootSignature());
}
SetPipelineStateInternal(static_cast<ID3D12PipelineState*>(pso->GetNativeHandle()));
ID3D12PipelineState* graphicsPipelineState = static_cast<ID3D12PipelineState*>(pso->GetNativeHandle());
if (graphicsPipelineState == nullptr) {
return;
}
SetPipelineStateInternal(graphicsPipelineState);
}
}
@@ -484,6 +500,10 @@ void D3D12CommandList::EndRenderPass() {
}
void D3D12CommandList::SetPipelineStateInternal(ID3D12PipelineState* pso) {
if (pso == nullptr) {
return;
}
m_commandList->SetPipelineState(pso);
m_currentPipelineState = pso;
if (m_currentRootSignature) {

View File

@@ -126,6 +126,17 @@ void D3D12PipelineState::SetShaderBytecodes(const D3D12_SHADER_BYTECODE& vs, con
void D3D12PipelineState::SetComputeShader(RHIShader* shader) {
m_computeShader = shader;
m_csBytecode = {};
m_computePipelineState.Reset();
m_pipelineState.Reset();
m_finalized = false;
if (shader != nullptr) {
auto* d3d12Shader = static_cast<D3D12Shader*>(shader);
if (d3d12Shader->IsValid() && d3d12Shader->GetType() == ShaderType::Compute) {
m_csBytecode = d3d12Shader->GetD3D12Bytecode();
}
}
}
void D3D12PipelineState::SetRootSignature(ID3D12RootSignature* rootSignature) {
@@ -217,6 +228,10 @@ bool D3D12PipelineState::CreateD3D12ComputePSO() {
return false;
}
if (!EnsureDefaultRootSignature()) {
return false;
}
D3D12_COMPUTE_PIPELINE_STATE_DESC desc = {};
desc.pRootSignature = m_rootSignature.Get();
desc.CS = m_csBytecode;
@@ -230,10 +245,46 @@ bool D3D12PipelineState::CreateD3D12ComputePSO() {
return true;
}
bool D3D12PipelineState::EnsureDefaultRootSignature() {
if (m_rootSignature != nullptr) {
return true;
}
if (m_device == nullptr) {
return false;
}
D3D12_ROOT_SIGNATURE_DESC rootSignatureDesc = {};
rootSignatureDesc.Flags = D3D12_ROOT_SIGNATURE_FLAG_NONE;
ComPtr<ID3DBlob> serializedSignature;
ComPtr<ID3DBlob> error;
HRESULT hr = D3D12SerializeRootSignature(
&rootSignatureDesc,
D3D_ROOT_SIGNATURE_VERSION_1,
&serializedSignature,
&error);
if (FAILED(hr)) {
return false;
}
hr = m_device->CreateRootSignature(
0,
serializedSignature->GetBufferPointer(),
serializedSignature->GetBufferSize(),
IID_PPV_ARGS(&m_rootSignature));
return SUCCEEDED(hr);
}
void D3D12PipelineState::Shutdown() {
m_pipelineState.Reset();
m_computePipelineState.Reset();
m_rootSignature.Reset();
m_vsBytecode = {};
m_psBytecode = {};
m_gsBytecode = {};
m_csBytecode = {};
m_computeShader = nullptr;
m_finalized = false;
}