RHI: Add Compute/Dispatch unit tests (P1-7) and fix shader type bugs
Bug fixes: - D3D12Shader::Compile: Set m_type based on target string (cs_/vs_/ps_/gs_) - OpenGLShader::Compile: Parse target parameter to determine shader type - OpenGLShader::CompileCompute: Set m_type = ShaderType::Compute - D3D12CommandList::SetPipelineState: Use correct PSO handle for Compute New tests (test_compute.cpp, 8 tests): - ComputeShader_Compile_ValidShader - ComputeShader_GetType_ReturnsCompute - ComputeShader_Shutdown_Invalidates - PipelineState_SetComputeShader - PipelineState_HasComputeShader_ReturnsTrue - PipelineState_GetType_Compute - PipelineState_EnsureValid_Compute - CommandList_Dispatch_Basic Test results: 232/232 passed (D3D12: 116, OpenGL: 116)
This commit is contained in:
@@ -178,7 +178,13 @@ void D3D12CommandList::AliasBarrier(ID3D12Resource* beforeResource, ID3D12Resour
|
|||||||
|
|
||||||
void D3D12CommandList::SetPipelineState(RHIPipelineState* pso) {
|
void D3D12CommandList::SetPipelineState(RHIPipelineState* pso) {
|
||||||
if (!pso) return;
|
if (!pso) return;
|
||||||
SetPipelineStateInternal(static_cast<ID3D12PipelineState*>(pso->GetNativeHandle()));
|
|
||||||
|
if (pso->GetType() == PipelineType::Compute) {
|
||||||
|
D3D12PipelineState* d3d12Pso = static_cast<D3D12PipelineState*>(pso);
|
||||||
|
SetPipelineStateInternal(d3d12Pso->GetComputePipelineState());
|
||||||
|
} else {
|
||||||
|
SetPipelineStateInternal(static_cast<ID3D12PipelineState*>(pso->GetNativeHandle()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void D3D12CommandList::SetPipelineState(ID3D12PipelineState* pso) {
|
void D3D12CommandList::SetPipelineState(ID3D12PipelineState* pso) {
|
||||||
|
|||||||
@@ -55,6 +55,16 @@ bool D3D12Shader::Compile(const void* sourceData, size_t sourceSize, const char*
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (strstr(target, "vs_")) {
|
||||||
|
m_type = ShaderType::Vertex;
|
||||||
|
} else if (strstr(target, "ps_")) {
|
||||||
|
m_type = ShaderType::Fragment;
|
||||||
|
} else if (strstr(target, "gs_")) {
|
||||||
|
m_type = ShaderType::Geometry;
|
||||||
|
} else if (strstr(target, "cs_")) {
|
||||||
|
m_type = ShaderType::Compute;
|
||||||
|
}
|
||||||
|
|
||||||
m_uniformsCached = false;
|
m_uniformsCached = false;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -153,6 +153,7 @@ bool OpenGLShader::CompileCompute(const char* computeSource) {
|
|||||||
|
|
||||||
glDeleteShader(compute);
|
glDeleteShader(compute);
|
||||||
|
|
||||||
|
m_type = ShaderType::Compute;
|
||||||
m_uniformsCached = false;
|
m_uniformsCached = false;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@@ -219,7 +220,20 @@ bool OpenGLShader::Compile(const void* sourceData, size_t sourceSize, const char
|
|||||||
if (!sourceData || sourceSize == 0) {
|
if (!sourceData || sourceSize == 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return Compile(static_cast<const char*>(sourceData), ShaderType::Fragment);
|
|
||||||
|
ShaderType type = ShaderType::Fragment;
|
||||||
|
if (target) {
|
||||||
|
if (strstr(target, "vs_")) {
|
||||||
|
type = ShaderType::Vertex;
|
||||||
|
} else if (strstr(target, "ps_")) {
|
||||||
|
type = ShaderType::Fragment;
|
||||||
|
} else if (strstr(target, "gs_")) {
|
||||||
|
type = ShaderType::Geometry;
|
||||||
|
} else if (strstr(target, "cs_")) {
|
||||||
|
type = ShaderType::Compute;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Compile(static_cast<const char*>(sourceData), type);
|
||||||
}
|
}
|
||||||
|
|
||||||
void OpenGLShader::Shutdown() {
|
void OpenGLShader::Shutdown() {
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ set(TEST_SOURCES
|
|||||||
test_fence.cpp
|
test_fence.cpp
|
||||||
test_sampler.cpp
|
test_sampler.cpp
|
||||||
test_descriptor.cpp
|
test_descriptor.cpp
|
||||||
|
test_compute.cpp
|
||||||
${CMAKE_SOURCE_DIR}/tests/opengl/package/src/glad.c
|
${CMAKE_SOURCE_DIR}/tests/opengl/package/src/glad.c
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
279
tests/RHI/unit/test_compute.cpp
Normal file
279
tests/RHI/unit/test_compute.cpp
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
#include "fixtures/RHITestFixture.h"
|
||||||
|
#include "XCEngine/RHI/RHIShader.h"
|
||||||
|
#include "XCEngine/RHI/RHIPipelineState.h"
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
|
using namespace XCEngine::RHI;
|
||||||
|
|
||||||
|
TEST_P(RHITestFixture, ComputeShader_Compile_ValidShader) {
|
||||||
|
ShaderCompileDesc desc = {};
|
||||||
|
if (GetBackendType() == RHIType::D3D12) {
|
||||||
|
desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
|
||||||
|
desc.entryPoint = L"MainCS";
|
||||||
|
desc.profile = L"cs_5_0";
|
||||||
|
} else {
|
||||||
|
desc.sourceLanguage = ShaderLanguage::GLSL;
|
||||||
|
static const char* cs = R"(
|
||||||
|
#version 460
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
void main() {
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
desc.source.assign(cs, cs + strlen(cs));
|
||||||
|
}
|
||||||
|
|
||||||
|
RHIShader* shader = GetDevice()->CompileShader(desc);
|
||||||
|
if (shader != nullptr) {
|
||||||
|
EXPECT_TRUE(shader->IsValid());
|
||||||
|
EXPECT_EQ(shader->GetType(), ShaderType::Compute);
|
||||||
|
shader->Shutdown();
|
||||||
|
delete shader;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(RHITestFixture, ComputeShader_GetType_ReturnsCompute) {
|
||||||
|
ShaderCompileDesc desc = {};
|
||||||
|
if (GetBackendType() == RHIType::D3D12) {
|
||||||
|
desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
|
||||||
|
desc.entryPoint = L"MainCS";
|
||||||
|
desc.profile = L"cs_5_0";
|
||||||
|
} else {
|
||||||
|
desc.sourceLanguage = ShaderLanguage::GLSL;
|
||||||
|
static const char* cs = R"(
|
||||||
|
#version 460
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
void main() {
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
desc.source.assign(cs, cs + strlen(cs));
|
||||||
|
}
|
||||||
|
|
||||||
|
RHIShader* shader = GetDevice()->CompileShader(desc);
|
||||||
|
if (shader != nullptr) {
|
||||||
|
EXPECT_EQ(shader->GetType(), ShaderType::Compute);
|
||||||
|
shader->Shutdown();
|
||||||
|
delete shader;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(RHITestFixture, ComputeShader_Shutdown_Invalidates) {
|
||||||
|
ShaderCompileDesc desc = {};
|
||||||
|
if (GetBackendType() == RHIType::D3D12) {
|
||||||
|
desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
|
||||||
|
desc.entryPoint = L"MainCS";
|
||||||
|
desc.profile = L"cs_5_0";
|
||||||
|
} else {
|
||||||
|
desc.sourceLanguage = ShaderLanguage::GLSL;
|
||||||
|
static const char* cs = R"(
|
||||||
|
#version 460
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
void main() {
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
desc.source.assign(cs, cs + strlen(cs));
|
||||||
|
}
|
||||||
|
|
||||||
|
RHIShader* shader = GetDevice()->CompileShader(desc);
|
||||||
|
if (shader != nullptr) {
|
||||||
|
EXPECT_TRUE(shader->IsValid());
|
||||||
|
shader->Shutdown();
|
||||||
|
EXPECT_FALSE(shader->IsValid());
|
||||||
|
delete shader;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(RHITestFixture, PipelineState_SetComputeShader) {
|
||||||
|
GraphicsPipelineDesc desc = {};
|
||||||
|
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
|
||||||
|
if (pso == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ShaderCompileDesc shaderDesc = {};
|
||||||
|
if (GetBackendType() == RHIType::D3D12) {
|
||||||
|
shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
|
||||||
|
shaderDesc.entryPoint = L"MainCS";
|
||||||
|
shaderDesc.profile = L"cs_5_0";
|
||||||
|
} else {
|
||||||
|
shaderDesc.sourceLanguage = ShaderLanguage::GLSL;
|
||||||
|
static const char* cs = R"(
|
||||||
|
#version 460
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
void main() {
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
shaderDesc.source.assign(cs, cs + strlen(cs));
|
||||||
|
}
|
||||||
|
|
||||||
|
RHIShader* computeShader = GetDevice()->CompileShader(shaderDesc);
|
||||||
|
if (computeShader != nullptr) {
|
||||||
|
pso->SetComputeShader(computeShader);
|
||||||
|
EXPECT_TRUE(pso->HasComputeShader());
|
||||||
|
computeShader->Shutdown();
|
||||||
|
delete computeShader;
|
||||||
|
}
|
||||||
|
|
||||||
|
pso->Shutdown();
|
||||||
|
delete pso;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(RHITestFixture, PipelineState_HasComputeShader_ReturnsTrue) {
|
||||||
|
GraphicsPipelineDesc desc = {};
|
||||||
|
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
|
||||||
|
if (pso == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_FALSE(pso->HasComputeShader());
|
||||||
|
|
||||||
|
ShaderCompileDesc shaderDesc = {};
|
||||||
|
if (GetBackendType() == RHIType::D3D12) {
|
||||||
|
shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
|
||||||
|
shaderDesc.entryPoint = L"MainCS";
|
||||||
|
shaderDesc.profile = L"cs_5_0";
|
||||||
|
} else {
|
||||||
|
shaderDesc.sourceLanguage = ShaderLanguage::GLSL;
|
||||||
|
static const char* cs = R"(
|
||||||
|
#version 460
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
void main() {
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
shaderDesc.source.assign(cs, cs + strlen(cs));
|
||||||
|
}
|
||||||
|
|
||||||
|
RHIShader* computeShader = GetDevice()->CompileShader(shaderDesc);
|
||||||
|
if (computeShader != nullptr) {
|
||||||
|
pso->SetComputeShader(computeShader);
|
||||||
|
EXPECT_TRUE(pso->HasComputeShader());
|
||||||
|
computeShader->Shutdown();
|
||||||
|
delete computeShader;
|
||||||
|
}
|
||||||
|
|
||||||
|
pso->Shutdown();
|
||||||
|
delete pso;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(RHITestFixture, PipelineState_GetType_Compute) {
|
||||||
|
GraphicsPipelineDesc desc = {};
|
||||||
|
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
|
||||||
|
if (pso == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(pso->GetType(), PipelineType::Graphics);
|
||||||
|
|
||||||
|
ShaderCompileDesc shaderDesc = {};
|
||||||
|
if (GetBackendType() == RHIType::D3D12) {
|
||||||
|
shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
|
||||||
|
shaderDesc.entryPoint = L"MainCS";
|
||||||
|
shaderDesc.profile = L"cs_5_0";
|
||||||
|
} else {
|
||||||
|
shaderDesc.sourceLanguage = ShaderLanguage::GLSL;
|
||||||
|
static const char* cs = R"(
|
||||||
|
#version 460
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
void main() {
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
shaderDesc.source.assign(cs, cs + strlen(cs));
|
||||||
|
}
|
||||||
|
|
||||||
|
RHIShader* computeShader = GetDevice()->CompileShader(shaderDesc);
|
||||||
|
if (computeShader != nullptr) {
|
||||||
|
pso->SetComputeShader(computeShader);
|
||||||
|
EXPECT_EQ(pso->GetType(), PipelineType::Compute);
|
||||||
|
computeShader->Shutdown();
|
||||||
|
delete computeShader;
|
||||||
|
}
|
||||||
|
|
||||||
|
pso->Shutdown();
|
||||||
|
delete pso;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(RHITestFixture, PipelineState_EnsureValid_Compute) {
|
||||||
|
GraphicsPipelineDesc desc = {};
|
||||||
|
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
|
||||||
|
if (pso == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ShaderCompileDesc shaderDesc = {};
|
||||||
|
if (GetBackendType() == RHIType::D3D12) {
|
||||||
|
shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
|
||||||
|
shaderDesc.entryPoint = L"MainCS";
|
||||||
|
shaderDesc.profile = L"cs_5_0";
|
||||||
|
} else {
|
||||||
|
shaderDesc.sourceLanguage = ShaderLanguage::GLSL;
|
||||||
|
static const char* cs = R"(
|
||||||
|
#version 460
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
void main() {
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
shaderDesc.source.assign(cs, cs + strlen(cs));
|
||||||
|
}
|
||||||
|
|
||||||
|
RHIShader* computeShader = GetDevice()->CompileShader(shaderDesc);
|
||||||
|
if (computeShader != nullptr) {
|
||||||
|
pso->SetComputeShader(computeShader);
|
||||||
|
|
||||||
|
if (GetBackendType() == RHIType::D3D12) {
|
||||||
|
EXPECT_FALSE(pso->IsValid());
|
||||||
|
pso->EnsureValid();
|
||||||
|
} else {
|
||||||
|
EXPECT_TRUE(pso->IsValid());
|
||||||
|
}
|
||||||
|
|
||||||
|
computeShader->Shutdown();
|
||||||
|
delete computeShader;
|
||||||
|
}
|
||||||
|
|
||||||
|
pso->Shutdown();
|
||||||
|
delete pso;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(RHITestFixture, CommandList_Dispatch_Basic) {
|
||||||
|
RHICommandList* cmdList = GetDevice()->CreateCommandList({});
|
||||||
|
ASSERT_NE(cmdList, nullptr);
|
||||||
|
|
||||||
|
cmdList->Reset();
|
||||||
|
|
||||||
|
GraphicsPipelineDesc desc = {};
|
||||||
|
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
|
||||||
|
|
||||||
|
if (pso != nullptr) {
|
||||||
|
ShaderCompileDesc shaderDesc = {};
|
||||||
|
if (GetBackendType() == RHIType::D3D12) {
|
||||||
|
shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
|
||||||
|
shaderDesc.entryPoint = L"MainCS";
|
||||||
|
shaderDesc.profile = L"cs_5_0";
|
||||||
|
} else {
|
||||||
|
shaderDesc.sourceLanguage = ShaderLanguage::GLSL;
|
||||||
|
static const char* cs = R"(
|
||||||
|
#version 460
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
void main() {
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
shaderDesc.source.assign(cs, cs + strlen(cs));
|
||||||
|
}
|
||||||
|
|
||||||
|
RHIShader* computeShader = GetDevice()->CompileShader(shaderDesc);
|
||||||
|
if (computeShader != nullptr) {
|
||||||
|
pso->SetComputeShader(computeShader);
|
||||||
|
cmdList->SetPipelineState(pso);
|
||||||
|
cmdList->Dispatch(1, 1, 1);
|
||||||
|
|
||||||
|
computeShader->Shutdown();
|
||||||
|
delete computeShader;
|
||||||
|
}
|
||||||
|
|
||||||
|
pso->Shutdown();
|
||||||
|
delete pso;
|
||||||
|
}
|
||||||
|
|
||||||
|
cmdList->Close();
|
||||||
|
cmdList->Shutdown();
|
||||||
|
delete cmdList;
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user