Files
XCEngine/tests/RHI/unit/test_compute.cpp
ssdfasd 720dd422d5 RHI: Add Compute/Dispatch unit tests (P1-7) and fix shader type bugs
Bug fixes:
- D3D12Shader::Compile: Set m_type based on target string (cs_/vs_/ps_/gs_)
- OpenGLShader::Compile: Parse target parameter to determine shader type
- OpenGLShader::CompileCompute: Set m_type = ShaderType::Compute
- D3D12CommandList::SetPipelineState: Use correct PSO handle for Compute

New tests (test_compute.cpp, 8 tests):
- ComputeShader_Compile_ValidShader
- ComputeShader_GetType_ReturnsCompute
- ComputeShader_Shutdown_Invalidates
- PipelineState_SetComputeShader
- PipelineState_HasComputeShader_ReturnsTrue
- PipelineState_GetType_Compute
- PipelineState_EnsureValid_Compute
- CommandList_Dispatch_Basic

Test results: 232/232 passed (D3D12: 116, OpenGL: 116)
2026-03-25 13:52:11 +08:00

279 lines
8.6 KiB
C++

#include "fixtures/RHITestFixture.h"
#include "XCEngine/RHI/RHIShader.h"
#include "XCEngine/RHI/RHIPipelineState.h"
#include <cstring>
using namespace XCEngine::RHI;
TEST_P(RHITestFixture, ComputeShader_Compile_ValidShader) {
ShaderCompileDesc desc = {};
if (GetBackendType() == RHIType::D3D12) {
desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
desc.entryPoint = L"MainCS";
desc.profile = L"cs_5_0";
} else {
desc.sourceLanguage = ShaderLanguage::GLSL;
static const char* cs = R"(
#version 460
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void main() {
}
)";
desc.source.assign(cs, cs + strlen(cs));
}
RHIShader* shader = GetDevice()->CompileShader(desc);
if (shader != nullptr) {
EXPECT_TRUE(shader->IsValid());
EXPECT_EQ(shader->GetType(), ShaderType::Compute);
shader->Shutdown();
delete shader;
}
}
TEST_P(RHITestFixture, ComputeShader_GetType_ReturnsCompute) {
ShaderCompileDesc desc = {};
if (GetBackendType() == RHIType::D3D12) {
desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
desc.entryPoint = L"MainCS";
desc.profile = L"cs_5_0";
} else {
desc.sourceLanguage = ShaderLanguage::GLSL;
static const char* cs = R"(
#version 460
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void main() {
}
)";
desc.source.assign(cs, cs + strlen(cs));
}
RHIShader* shader = GetDevice()->CompileShader(desc);
if (shader != nullptr) {
EXPECT_EQ(shader->GetType(), ShaderType::Compute);
shader->Shutdown();
delete shader;
}
}
TEST_P(RHITestFixture, ComputeShader_Shutdown_Invalidates) {
ShaderCompileDesc desc = {};
if (GetBackendType() == RHIType::D3D12) {
desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
desc.entryPoint = L"MainCS";
desc.profile = L"cs_5_0";
} else {
desc.sourceLanguage = ShaderLanguage::GLSL;
static const char* cs = R"(
#version 460
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void main() {
}
)";
desc.source.assign(cs, cs + strlen(cs));
}
RHIShader* shader = GetDevice()->CompileShader(desc);
if (shader != nullptr) {
EXPECT_TRUE(shader->IsValid());
shader->Shutdown();
EXPECT_FALSE(shader->IsValid());
delete shader;
}
}
TEST_P(RHITestFixture, PipelineState_SetComputeShader) {
GraphicsPipelineDesc desc = {};
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
if (pso == nullptr) {
return;
}
ShaderCompileDesc shaderDesc = {};
if (GetBackendType() == RHIType::D3D12) {
shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
shaderDesc.entryPoint = L"MainCS";
shaderDesc.profile = L"cs_5_0";
} else {
shaderDesc.sourceLanguage = ShaderLanguage::GLSL;
static const char* cs = R"(
#version 460
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void main() {
}
)";
shaderDesc.source.assign(cs, cs + strlen(cs));
}
RHIShader* computeShader = GetDevice()->CompileShader(shaderDesc);
if (computeShader != nullptr) {
pso->SetComputeShader(computeShader);
EXPECT_TRUE(pso->HasComputeShader());
computeShader->Shutdown();
delete computeShader;
}
pso->Shutdown();
delete pso;
}
TEST_P(RHITestFixture, PipelineState_HasComputeShader_ReturnsTrue) {
GraphicsPipelineDesc desc = {};
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
if (pso == nullptr) {
return;
}
EXPECT_FALSE(pso->HasComputeShader());
ShaderCompileDesc shaderDesc = {};
if (GetBackendType() == RHIType::D3D12) {
shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
shaderDesc.entryPoint = L"MainCS";
shaderDesc.profile = L"cs_5_0";
} else {
shaderDesc.sourceLanguage = ShaderLanguage::GLSL;
static const char* cs = R"(
#version 460
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void main() {
}
)";
shaderDesc.source.assign(cs, cs + strlen(cs));
}
RHIShader* computeShader = GetDevice()->CompileShader(shaderDesc);
if (computeShader != nullptr) {
pso->SetComputeShader(computeShader);
EXPECT_TRUE(pso->HasComputeShader());
computeShader->Shutdown();
delete computeShader;
}
pso->Shutdown();
delete pso;
}
TEST_P(RHITestFixture, PipelineState_GetType_Compute) {
GraphicsPipelineDesc desc = {};
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
if (pso == nullptr) {
return;
}
EXPECT_EQ(pso->GetType(), PipelineType::Graphics);
ShaderCompileDesc shaderDesc = {};
if (GetBackendType() == RHIType::D3D12) {
shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
shaderDesc.entryPoint = L"MainCS";
shaderDesc.profile = L"cs_5_0";
} else {
shaderDesc.sourceLanguage = ShaderLanguage::GLSL;
static const char* cs = R"(
#version 460
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void main() {
}
)";
shaderDesc.source.assign(cs, cs + strlen(cs));
}
RHIShader* computeShader = GetDevice()->CompileShader(shaderDesc);
if (computeShader != nullptr) {
pso->SetComputeShader(computeShader);
EXPECT_EQ(pso->GetType(), PipelineType::Compute);
computeShader->Shutdown();
delete computeShader;
}
pso->Shutdown();
delete pso;
}
TEST_P(RHITestFixture, PipelineState_EnsureValid_Compute) {
GraphicsPipelineDesc desc = {};
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
if (pso == nullptr) {
return;
}
ShaderCompileDesc shaderDesc = {};
if (GetBackendType() == RHIType::D3D12) {
shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
shaderDesc.entryPoint = L"MainCS";
shaderDesc.profile = L"cs_5_0";
} else {
shaderDesc.sourceLanguage = ShaderLanguage::GLSL;
static const char* cs = R"(
#version 460
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void main() {
}
)";
shaderDesc.source.assign(cs, cs + strlen(cs));
}
RHIShader* computeShader = GetDevice()->CompileShader(shaderDesc);
if (computeShader != nullptr) {
pso->SetComputeShader(computeShader);
if (GetBackendType() == RHIType::D3D12) {
EXPECT_FALSE(pso->IsValid());
pso->EnsureValid();
} else {
EXPECT_TRUE(pso->IsValid());
}
computeShader->Shutdown();
delete computeShader;
}
pso->Shutdown();
delete pso;
}
TEST_P(RHITestFixture, CommandList_Dispatch_Basic) {
RHICommandList* cmdList = GetDevice()->CreateCommandList({});
ASSERT_NE(cmdList, nullptr);
cmdList->Reset();
GraphicsPipelineDesc desc = {};
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
if (pso != nullptr) {
ShaderCompileDesc shaderDesc = {};
if (GetBackendType() == RHIType::D3D12) {
shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl";
shaderDesc.entryPoint = L"MainCS";
shaderDesc.profile = L"cs_5_0";
} else {
shaderDesc.sourceLanguage = ShaderLanguage::GLSL;
static const char* cs = R"(
#version 460
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void main() {
}
)";
shaderDesc.source.assign(cs, cs + strlen(cs));
}
RHIShader* computeShader = GetDevice()->CompileShader(shaderDesc);
if (computeShader != nullptr) {
pso->SetComputeShader(computeShader);
cmdList->SetPipelineState(pso);
cmdList->Dispatch(1, 1, 1);
computeShader->Shutdown();
delete computeShader;
}
pso->Shutdown();
delete pso;
}
cmdList->Close();
cmdList->Shutdown();
delete cmdList;
}