#include "fixtures/RHITestFixture.h" #include "XCEngine/RHI/RHIShader.h" #include "XCEngine/RHI/RHIPipelineState.h" #include using namespace XCEngine::RHI; TEST_P(RHITestFixture, ComputeShader_Compile_ValidShader) { ShaderCompileDesc desc = {}; if (GetBackendType() == RHIType::D3D12) { desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; desc.entryPoint = L"MainCS"; desc.profile = L"cs_5_0"; } else { desc.sourceLanguage = ShaderLanguage::GLSL; static const char* cs = R"( #version 460 layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; void main() { } )"; desc.source.assign(cs, cs + strlen(cs)); } RHIShader* shader = GetDevice()->CompileShader(desc); if (shader != nullptr) { EXPECT_TRUE(shader->IsValid()); EXPECT_EQ(shader->GetType(), ShaderType::Compute); shader->Shutdown(); delete shader; } } TEST_P(RHITestFixture, ComputeShader_GetType_ReturnsCompute) { ShaderCompileDesc desc = {}; if (GetBackendType() == RHIType::D3D12) { desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; desc.entryPoint = L"MainCS"; desc.profile = L"cs_5_0"; } else { desc.sourceLanguage = ShaderLanguage::GLSL; static const char* cs = R"( #version 460 layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; void main() { } )"; desc.source.assign(cs, cs + strlen(cs)); } RHIShader* shader = GetDevice()->CompileShader(desc); if (shader != nullptr) { EXPECT_EQ(shader->GetType(), ShaderType::Compute); shader->Shutdown(); delete shader; } } TEST_P(RHITestFixture, ComputeShader_Shutdown_Invalidates) { ShaderCompileDesc desc = {}; if (GetBackendType() == RHIType::D3D12) { desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; desc.entryPoint = L"MainCS"; desc.profile = L"cs_5_0"; } else { desc.sourceLanguage = ShaderLanguage::GLSL; static const char* cs = R"( #version 460 layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; void main() { } )"; desc.source.assign(cs, cs + strlen(cs)); } RHIShader* shader = GetDevice()->CompileShader(desc); if (shader != nullptr) { EXPECT_TRUE(shader->IsValid()); shader->Shutdown(); EXPECT_FALSE(shader->IsValid()); delete shader; } } TEST_P(RHITestFixture, PipelineState_SetComputeShader) { GraphicsPipelineDesc desc = {}; RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc); if (pso == nullptr) { return; } ShaderCompileDesc shaderDesc = {}; if (GetBackendType() == RHIType::D3D12) { shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; shaderDesc.entryPoint = L"MainCS"; shaderDesc.profile = L"cs_5_0"; } else { shaderDesc.sourceLanguage = ShaderLanguage::GLSL; static const char* cs = R"( #version 460 layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; void main() { } )"; shaderDesc.source.assign(cs, cs + strlen(cs)); } RHIShader* computeShader = GetDevice()->CompileShader(shaderDesc); if (computeShader != nullptr) { pso->SetComputeShader(computeShader); EXPECT_TRUE(pso->HasComputeShader()); computeShader->Shutdown(); delete computeShader; } pso->Shutdown(); delete pso; } TEST_P(RHITestFixture, PipelineState_HasComputeShader_ReturnsTrue) { GraphicsPipelineDesc desc = {}; RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc); if (pso == nullptr) { return; } EXPECT_FALSE(pso->HasComputeShader()); ShaderCompileDesc shaderDesc = {}; if (GetBackendType() == RHIType::D3D12) { shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; shaderDesc.entryPoint = L"MainCS"; shaderDesc.profile = L"cs_5_0"; } else { shaderDesc.sourceLanguage = ShaderLanguage::GLSL; static const char* cs = R"( #version 460 layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; void main() { } )"; shaderDesc.source.assign(cs, cs + strlen(cs)); } RHIShader* computeShader = GetDevice()->CompileShader(shaderDesc); if (computeShader != nullptr) { pso->SetComputeShader(computeShader); EXPECT_TRUE(pso->HasComputeShader()); computeShader->Shutdown(); delete computeShader; } pso->Shutdown(); delete pso; } TEST_P(RHITestFixture, PipelineState_GetType_Compute) { GraphicsPipelineDesc desc = {}; RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc); if (pso == nullptr) { return; } EXPECT_EQ(pso->GetType(), PipelineType::Graphics); ShaderCompileDesc shaderDesc = {}; if (GetBackendType() == RHIType::D3D12) { shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; shaderDesc.entryPoint = L"MainCS"; shaderDesc.profile = L"cs_5_0"; } else { shaderDesc.sourceLanguage = ShaderLanguage::GLSL; static const char* cs = R"( #version 460 layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; void main() { } )"; shaderDesc.source.assign(cs, cs + strlen(cs)); } RHIShader* computeShader = GetDevice()->CompileShader(shaderDesc); if (computeShader != nullptr) { pso->SetComputeShader(computeShader); EXPECT_EQ(pso->GetType(), PipelineType::Compute); computeShader->Shutdown(); delete computeShader; } pso->Shutdown(); delete pso; } TEST_P(RHITestFixture, PipelineState_EnsureValid_Compute) { GraphicsPipelineDesc desc = {}; RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc); if (pso == nullptr) { return; } ShaderCompileDesc shaderDesc = {}; if (GetBackendType() == RHIType::D3D12) { shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; shaderDesc.entryPoint = L"MainCS"; shaderDesc.profile = L"cs_5_0"; } else { shaderDesc.sourceLanguage = ShaderLanguage::GLSL; static const char* cs = R"( #version 460 layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; void main() { } )"; shaderDesc.source.assign(cs, cs + strlen(cs)); } RHIShader* computeShader = GetDevice()->CompileShader(shaderDesc); if (computeShader != nullptr) { pso->SetComputeShader(computeShader); if (GetBackendType() == RHIType::D3D12) { EXPECT_FALSE(pso->IsValid()); pso->EnsureValid(); } else { EXPECT_TRUE(pso->IsValid()); } computeShader->Shutdown(); delete computeShader; } pso->Shutdown(); delete pso; } TEST_P(RHITestFixture, CommandList_Dispatch_Basic) { RHICommandList* cmdList = GetDevice()->CreateCommandList({}); ASSERT_NE(cmdList, nullptr); cmdList->Reset(); GraphicsPipelineDesc desc = {}; RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc); if (pso != nullptr) { ShaderCompileDesc shaderDesc = {}; if (GetBackendType() == RHIType::D3D12) { shaderDesc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; shaderDesc.entryPoint = L"MainCS"; shaderDesc.profile = L"cs_5_0"; } else { shaderDesc.sourceLanguage = ShaderLanguage::GLSL; static const char* cs = R"( #version 460 layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; void main() { } )"; shaderDesc.source.assign(cs, cs + strlen(cs)); } RHIShader* computeShader = GetDevice()->CompileShader(shaderDesc); if (computeShader != nullptr) { pso->SetComputeShader(computeShader); cmdList->SetPipelineState(pso); cmdList->Dispatch(1, 1, 1); computeShader->Shutdown(); delete computeShader; } pso->Shutdown(); delete pso; } cmdList->Close(); cmdList->Shutdown(); delete cmdList; }