#include "fixtures/RHITestFixture.h" #include "XCEngine/RHI/RHIShader.h" #include "XCEngine/RHI/RHIPipelineState.h" #include using namespace XCEngine::RHI; namespace { ShaderCompileDesc MakeComputeShaderDesc(RHIType backendType) { ShaderCompileDesc desc = {}; if (backendType == RHIType::D3D12) { static const char* cs = R"( [numthreads(1, 1, 1)] void MainCS(uint3 dispatchThreadId : SV_DispatchThreadID) { } )"; desc.source.assign(cs, cs + std::strlen(cs)); desc.sourceLanguage = ShaderLanguage::HLSL; desc.entryPoint = L"MainCS"; desc.profile = L"cs_5_0"; } else { static const char* cs = R"( #version 460 layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; void main() { } )"; desc.source.assign(cs, cs + std::strlen(cs)); desc.sourceLanguage = ShaderLanguage::GLSL; } return desc; } } // namespace TEST_P(RHITestFixture, ComputeShader_Compile_ValidShader) { ShaderCompileDesc desc = MakeComputeShaderDesc(GetBackendType()); RHIShader* shader = GetDevice()->CreateShader(desc); if (shader != nullptr) { EXPECT_TRUE(shader->IsValid()); EXPECT_EQ(shader->GetType(), ShaderType::Compute); shader->Shutdown(); delete shader; } } TEST_P(RHITestFixture, ComputeShader_GetType_ReturnsCompute) { ShaderCompileDesc desc = MakeComputeShaderDesc(GetBackendType()); RHIShader* shader = GetDevice()->CreateShader(desc); if (shader != nullptr) { EXPECT_EQ(shader->GetType(), ShaderType::Compute); shader->Shutdown(); delete shader; } } TEST_P(RHITestFixture, ComputeShader_Shutdown_Invalidates) { ShaderCompileDesc desc = MakeComputeShaderDesc(GetBackendType()); RHIShader* shader = GetDevice()->CreateShader(desc); if (shader != nullptr) { EXPECT_TRUE(shader->IsValid()); shader->Shutdown(); EXPECT_FALSE(shader->IsValid()); delete shader; } } TEST_P(RHITestFixture, PipelineState_SetComputeShader) { GraphicsPipelineDesc desc = {}; RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc); if (pso == nullptr) { return; } ShaderCompileDesc shaderDesc = MakeComputeShaderDesc(GetBackendType()); RHIShader* computeShader = GetDevice()->CreateShader(shaderDesc); if (computeShader != nullptr) { pso->SetComputeShader(computeShader); EXPECT_TRUE(pso->HasComputeShader()); computeShader->Shutdown(); delete computeShader; } pso->Shutdown(); delete pso; } TEST_P(RHITestFixture, PipelineState_HasComputeShader_ReturnsTrue) { GraphicsPipelineDesc desc = {}; RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc); if (pso == nullptr) { return; } EXPECT_FALSE(pso->HasComputeShader()); ShaderCompileDesc shaderDesc = MakeComputeShaderDesc(GetBackendType()); RHIShader* computeShader = GetDevice()->CreateShader(shaderDesc); if (computeShader != nullptr) { pso->SetComputeShader(computeShader); EXPECT_TRUE(pso->HasComputeShader()); computeShader->Shutdown(); delete computeShader; } pso->Shutdown(); delete pso; } TEST_P(RHITestFixture, PipelineState_GetType_Compute) { GraphicsPipelineDesc desc = {}; RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc); if (pso == nullptr) { return; } EXPECT_EQ(pso->GetType(), PipelineType::Graphics); ShaderCompileDesc shaderDesc = MakeComputeShaderDesc(GetBackendType()); RHIShader* computeShader = GetDevice()->CreateShader(shaderDesc); if (computeShader != nullptr) { pso->SetComputeShader(computeShader); EXPECT_EQ(pso->GetType(), PipelineType::Compute); computeShader->Shutdown(); delete computeShader; } pso->Shutdown(); delete pso; } TEST_P(RHITestFixture, PipelineState_EnsureValid_Compute) { GraphicsPipelineDesc desc = {}; RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc); if (pso == nullptr) { return; } ShaderCompileDesc shaderDesc = MakeComputeShaderDesc(GetBackendType()); RHIShader* computeShader = GetDevice()->CreateShader(shaderDesc); if (computeShader != nullptr) { pso->SetComputeShader(computeShader); if (GetBackendType() == RHIType::D3D12) { EXPECT_FALSE(pso->IsValid()); pso->EnsureValid(); } else { EXPECT_TRUE(pso->IsValid()); } computeShader->Shutdown(); delete computeShader; } pso->Shutdown(); delete pso; } TEST_P(RHITestFixture, CommandList_Dispatch_Basic) { RHICommandList* cmdList = GetDevice()->CreateCommandList({}); ASSERT_NE(cmdList, nullptr); cmdList->Reset(); GraphicsPipelineDesc desc = {}; RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc); if (pso != nullptr) { ShaderCompileDesc shaderDesc = MakeComputeShaderDesc(GetBackendType()); RHIShader* computeShader = GetDevice()->CreateShader(shaderDesc); if (computeShader != nullptr) { pso->SetComputeShader(computeShader); cmdList->SetPipelineState(pso); cmdList->Dispatch(1, 1, 1); computeShader->Shutdown(); delete computeShader; } pso->Shutdown(); delete pso; } cmdList->Close(); cmdList->Shutdown(); delete cmdList; }