Fix D3D12 compute pipeline unit coverage
This commit is contained in:
@@ -79,6 +79,7 @@ public:
|
||||
private:
|
||||
bool CreateD3D12PSO();
|
||||
bool CreateD3D12ComputePSO();
|
||||
bool EnsureDefaultRootSignature();
|
||||
|
||||
ID3D12Device* m_device;
|
||||
bool m_finalized = false;
|
||||
|
||||
@@ -139,16 +139,32 @@ void D3D12CommandList::AliasBarrier(ID3D12Resource* beforeResource, ID3D12Resour
|
||||
|
||||
void D3D12CommandList::SetPipelineState(RHIPipelineState* pso) {
|
||||
if (!pso) return;
|
||||
|
||||
|
||||
pso->EnsureValid();
|
||||
|
||||
D3D12PipelineState* d3d12Pso = static_cast<D3D12PipelineState*>(pso);
|
||||
if (pso->GetType() == PipelineType::Compute) {
|
||||
D3D12PipelineState* d3d12Pso = static_cast<D3D12PipelineState*>(pso);
|
||||
SetPipelineStateInternal(d3d12Pso->GetComputePipelineState());
|
||||
ID3D12RootSignature* computeRootSignature = d3d12Pso->GetRootSignature();
|
||||
ID3D12PipelineState* computePipelineState = d3d12Pso->GetComputePipelineState();
|
||||
if (computeRootSignature == nullptr || computePipelineState == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
m_commandList->SetComputeRootSignature(computeRootSignature);
|
||||
m_commandList->SetPipelineState(computePipelineState);
|
||||
m_currentPipelineState = computePipelineState;
|
||||
m_currentRootSignature = computeRootSignature;
|
||||
} else {
|
||||
D3D12PipelineState* d3d12Pso = static_cast<D3D12PipelineState*>(pso);
|
||||
if (d3d12Pso->GetRootSignature() != nullptr) {
|
||||
SetRootSignature(d3d12Pso->GetRootSignature());
|
||||
}
|
||||
SetPipelineStateInternal(static_cast<ID3D12PipelineState*>(pso->GetNativeHandle()));
|
||||
|
||||
ID3D12PipelineState* graphicsPipelineState = static_cast<ID3D12PipelineState*>(pso->GetNativeHandle());
|
||||
if (graphicsPipelineState == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
SetPipelineStateInternal(graphicsPipelineState);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -484,6 +500,10 @@ void D3D12CommandList::EndRenderPass() {
|
||||
}
|
||||
|
||||
void D3D12CommandList::SetPipelineStateInternal(ID3D12PipelineState* pso) {
|
||||
if (pso == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
m_commandList->SetPipelineState(pso);
|
||||
m_currentPipelineState = pso;
|
||||
if (m_currentRootSignature) {
|
||||
|
||||
@@ -126,6 +126,17 @@ void D3D12PipelineState::SetShaderBytecodes(const D3D12_SHADER_BYTECODE& vs, con
|
||||
|
||||
void D3D12PipelineState::SetComputeShader(RHIShader* shader) {
|
||||
m_computeShader = shader;
|
||||
m_csBytecode = {};
|
||||
m_computePipelineState.Reset();
|
||||
m_pipelineState.Reset();
|
||||
m_finalized = false;
|
||||
|
||||
if (shader != nullptr) {
|
||||
auto* d3d12Shader = static_cast<D3D12Shader*>(shader);
|
||||
if (d3d12Shader->IsValid() && d3d12Shader->GetType() == ShaderType::Compute) {
|
||||
m_csBytecode = d3d12Shader->GetD3D12Bytecode();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void D3D12PipelineState::SetRootSignature(ID3D12RootSignature* rootSignature) {
|
||||
@@ -217,6 +228,10 @@ bool D3D12PipelineState::CreateD3D12ComputePSO() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!EnsureDefaultRootSignature()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
D3D12_COMPUTE_PIPELINE_STATE_DESC desc = {};
|
||||
desc.pRootSignature = m_rootSignature.Get();
|
||||
desc.CS = m_csBytecode;
|
||||
@@ -230,10 +245,46 @@ bool D3D12PipelineState::CreateD3D12ComputePSO() {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool D3D12PipelineState::EnsureDefaultRootSignature() {
|
||||
if (m_rootSignature != nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (m_device == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
D3D12_ROOT_SIGNATURE_DESC rootSignatureDesc = {};
|
||||
rootSignatureDesc.Flags = D3D12_ROOT_SIGNATURE_FLAG_NONE;
|
||||
|
||||
ComPtr<ID3DBlob> serializedSignature;
|
||||
ComPtr<ID3DBlob> error;
|
||||
HRESULT hr = D3D12SerializeRootSignature(
|
||||
&rootSignatureDesc,
|
||||
D3D_ROOT_SIGNATURE_VERSION_1,
|
||||
&serializedSignature,
|
||||
&error);
|
||||
if (FAILED(hr)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
hr = m_device->CreateRootSignature(
|
||||
0,
|
||||
serializedSignature->GetBufferPointer(),
|
||||
serializedSignature->GetBufferSize(),
|
||||
IID_PPV_ARGS(&m_rootSignature));
|
||||
return SUCCEEDED(hr);
|
||||
}
|
||||
|
||||
void D3D12PipelineState::Shutdown() {
|
||||
m_pipelineState.Reset();
|
||||
m_computePipelineState.Reset();
|
||||
m_rootSignature.Reset();
|
||||
m_vsBytecode = {};
|
||||
m_psBytecode = {};
|
||||
m_gsBytecode = {};
|
||||
m_csBytecode = {};
|
||||
m_computeShader = nullptr;
|
||||
m_finalized = false;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
#include "XCEngine/RHI/OpenGL/OpenGLShader.h"
|
||||
#include "XCEngine/RHI/OpenGL/OpenGLCommandList.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
using namespace XCEngine::RHI;
|
||||
|
||||
TEST_F(OpenGLTestFixture, Shader_Compile_VertexFragment) {
|
||||
@@ -74,12 +76,15 @@ TEST_F(OpenGLTestFixture, Shader_Compile_InvalidSource) {
|
||||
undefined_symbol;
|
||||
}
|
||||
)";
|
||||
|
||||
|
||||
OpenGLShader shader;
|
||||
testing::internal::CaptureStdout();
|
||||
bool result = shader.Compile(invalidSource, "void main() { }");
|
||||
|
||||
const std::string output = testing::internal::GetCapturedStdout();
|
||||
|
||||
EXPECT_FALSE(result);
|
||||
EXPECT_FALSE(shader.IsValid());
|
||||
|
||||
EXPECT_NE(output.find("ERROR::SHADER_COMPILATION_ERROR"), std::string::npos);
|
||||
|
||||
shader.Shutdown();
|
||||
}
|
||||
|
||||
@@ -5,23 +5,39 @@
|
||||
|
||||
using namespace XCEngine::RHI;
|
||||
|
||||
TEST_P(RHITestFixture, ComputeShader_Compile_ValidShader) {
|
||||
namespace {
|
||||
|
||||
ShaderCompileDesc MakeComputeShaderDesc(RHIType backendType) {
|
||||
ShaderCompileDesc desc = {};
|
||||
if (GetBackendType() == RHIType::D3D12) {
|
||||
desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
|
||||
if (backendType == RHIType::D3D12) {
|
||||
static const char* cs = R"(
|
||||
[numthreads(1, 1, 1)]
|
||||
void MainCS(uint3 dispatchThreadId : SV_DispatchThreadID) {
|
||||
}
|
||||
)";
|
||||
desc.source.assign(cs, cs + std::strlen(cs));
|
||||
desc.sourceLanguage = ShaderLanguage::HLSL;
|
||||
desc.entryPoint = L"MainCS";
|
||||
desc.profile = L"cs_5_0";
|
||||
} else {
|
||||
desc.sourceLanguage = ShaderLanguage::GLSL;
|
||||
static const char* cs = R"(
|
||||
#version 460
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
void main() {
|
||||
}
|
||||
)";
|
||||
desc.source.assign(cs, cs + strlen(cs));
|
||||
desc.source.assign(cs, cs + std::strlen(cs));
|
||||
desc.sourceLanguage = ShaderLanguage::GLSL;
|
||||
}
|
||||
|
||||
return desc;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST_P(RHITestFixture, ComputeShader_Compile_ValidShader) {
|
||||
ShaderCompileDesc desc = MakeComputeShaderDesc(GetBackendType());
|
||||
|
||||
RHIShader* shader = GetDevice()->CreateShader(desc);
|
||||
if (shader != nullptr) {
|
||||
EXPECT_TRUE(shader->IsValid());
|
||||
@@ -32,21 +48,7 @@ TEST_P(RHITestFixture, ComputeShader_Compile_ValidShader) {
|
||||
}
|
||||
|
||||
TEST_P(RHITestFixture, ComputeShader_GetType_ReturnsCompute) {
|
||||
ShaderCompileDesc desc = {};
|
||||
if (GetBackendType() == RHIType::D3D12) {
|
||||
desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
|
||||
desc.entryPoint = L"MainCS";
|
||||
desc.profile = L"cs_5_0";
|
||||
} else {
|
||||
desc.sourceLanguage = ShaderLanguage::GLSL;
|
||||
static const char* cs = R"(
|
||||
#version 460
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
void main() {
|
||||
}
|
||||
)";
|
||||
desc.source.assign(cs, cs + strlen(cs));
|
||||
}
|
||||
ShaderCompileDesc desc = MakeComputeShaderDesc(GetBackendType());
|
||||
|
||||
RHIShader* shader = GetDevice()->CreateShader(desc);
|
||||
if (shader != nullptr) {
|
||||
@@ -57,21 +59,7 @@ TEST_P(RHITestFixture, ComputeShader_GetType_ReturnsCompute) {
|
||||
}
|
||||
|
||||
TEST_P(RHITestFixture, ComputeShader_Shutdown_Invalidates) {
|
||||
ShaderCompileDesc desc = {};
|
||||
if (GetBackendType() == RHIType::D3D12) {
|
||||
desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
|
||||
desc.entryPoint = L"MainCS";
|
||||
desc.profile = L"cs_5_0";
|
||||
} else {
|
||||
desc.sourceLanguage = ShaderLanguage::GLSL;
|
||||
static const char* cs = R"(
|
||||
#version 460
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
void main() {
|
||||
}
|
||||
)";
|
||||
desc.source.assign(cs, cs + strlen(cs));
|
||||
}
|
||||
ShaderCompileDesc desc = MakeComputeShaderDesc(GetBackendType());
|
||||
|
||||
RHIShader* shader = GetDevice()->CreateShader(desc);
|
||||
if (shader != nullptr) {
|
||||
@@ -89,21 +77,7 @@ TEST_P(RHITestFixture, PipelineState_SetComputeShader) {
|
||||
return;
|
||||
}
|
||||
|
||||
ShaderCompileDesc shaderDesc = {};
|
||||
if (GetBackendType() == RHIType::D3D12) {
|
||||
shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
|
||||
shaderDesc.entryPoint = L"MainCS";
|
||||
shaderDesc.profile = L"cs_5_0";
|
||||
} else {
|
||||
shaderDesc.sourceLanguage = ShaderLanguage::GLSL;
|
||||
static const char* cs = R"(
|
||||
#version 460
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
void main() {
|
||||
}
|
||||
)";
|
||||
shaderDesc.source.assign(cs, cs + strlen(cs));
|
||||
}
|
||||
ShaderCompileDesc shaderDesc = MakeComputeShaderDesc(GetBackendType());
|
||||
|
||||
RHIShader* computeShader = GetDevice()->CreateShader(shaderDesc);
|
||||
if (computeShader != nullptr) {
|
||||
@@ -126,21 +100,7 @@ TEST_P(RHITestFixture, PipelineState_HasComputeShader_ReturnsTrue) {
|
||||
|
||||
EXPECT_FALSE(pso->HasComputeShader());
|
||||
|
||||
ShaderCompileDesc shaderDesc = {};
|
||||
if (GetBackendType() == RHIType::D3D12) {
|
||||
shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
|
||||
shaderDesc.entryPoint = L"MainCS";
|
||||
shaderDesc.profile = L"cs_5_0";
|
||||
} else {
|
||||
shaderDesc.sourceLanguage = ShaderLanguage::GLSL;
|
||||
static const char* cs = R"(
|
||||
#version 460
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
void main() {
|
||||
}
|
||||
)";
|
||||
shaderDesc.source.assign(cs, cs + strlen(cs));
|
||||
}
|
||||
ShaderCompileDesc shaderDesc = MakeComputeShaderDesc(GetBackendType());
|
||||
|
||||
RHIShader* computeShader = GetDevice()->CreateShader(shaderDesc);
|
||||
if (computeShader != nullptr) {
|
||||
@@ -163,21 +123,7 @@ TEST_P(RHITestFixture, PipelineState_GetType_Compute) {
|
||||
|
||||
EXPECT_EQ(pso->GetType(), PipelineType::Graphics);
|
||||
|
||||
ShaderCompileDesc shaderDesc = {};
|
||||
if (GetBackendType() == RHIType::D3D12) {
|
||||
shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
|
||||
shaderDesc.entryPoint = L"MainCS";
|
||||
shaderDesc.profile = L"cs_5_0";
|
||||
} else {
|
||||
shaderDesc.sourceLanguage = ShaderLanguage::GLSL;
|
||||
static const char* cs = R"(
|
||||
#version 460
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
void main() {
|
||||
}
|
||||
)";
|
||||
shaderDesc.source.assign(cs, cs + strlen(cs));
|
||||
}
|
||||
ShaderCompileDesc shaderDesc = MakeComputeShaderDesc(GetBackendType());
|
||||
|
||||
RHIShader* computeShader = GetDevice()->CreateShader(shaderDesc);
|
||||
if (computeShader != nullptr) {
|
||||
@@ -198,21 +144,7 @@ TEST_P(RHITestFixture, PipelineState_EnsureValid_Compute) {
|
||||
return;
|
||||
}
|
||||
|
||||
ShaderCompileDesc shaderDesc = {};
|
||||
if (GetBackendType() == RHIType::D3D12) {
|
||||
shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
|
||||
shaderDesc.entryPoint = L"MainCS";
|
||||
shaderDesc.profile = L"cs_5_0";
|
||||
} else {
|
||||
shaderDesc.sourceLanguage = ShaderLanguage::GLSL;
|
||||
static const char* cs = R"(
|
||||
#version 460
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
void main() {
|
||||
}
|
||||
)";
|
||||
shaderDesc.source.assign(cs, cs + strlen(cs));
|
||||
}
|
||||
ShaderCompileDesc shaderDesc = MakeComputeShaderDesc(GetBackendType());
|
||||
|
||||
RHIShader* computeShader = GetDevice()->CreateShader(shaderDesc);
|
||||
if (computeShader != nullptr) {
|
||||
@@ -243,21 +175,7 @@ TEST_P(RHITestFixture, CommandList_Dispatch_Basic) {
|
||||
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
|
||||
|
||||
if (pso != nullptr) {
|
||||
ShaderCompileDesc shaderDesc = {};
|
||||
if (GetBackendType() == RHIType::D3D12) {
|
||||
shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
|
||||
shaderDesc.entryPoint = L"MainCS";
|
||||
shaderDesc.profile = L"cs_5_0";
|
||||
} else {
|
||||
shaderDesc.sourceLanguage = ShaderLanguage::GLSL;
|
||||
static const char* cs = R"(
|
||||
#version 460
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
void main() {
|
||||
}
|
||||
)";
|
||||
shaderDesc.source.assign(cs, cs + strlen(cs));
|
||||
}
|
||||
ShaderCompileDesc shaderDesc = MakeComputeShaderDesc(GetBackendType());
|
||||
|
||||
RHIShader* computeShader = GetDevice()->CreateShader(shaderDesc);
|
||||
if (computeShader != nullptr) {
|
||||
|
||||
Reference in New Issue
Block a user