From 32c04b86b7cec5f8d25d370b6983a3c29423c8bc Mon Sep 17 00:00:00 2001 From: ssdfasd <2156608475@qq.com> Date: Wed, 25 Mar 2026 12:00:26 +0800 Subject: [PATCH] RHI: Add embedded shader source support via ShaderCompileDesc - Add ShaderLanguage enum (HLSL, GLSL, SPIRV) - Extend ShaderCompileDesc with source/sourceLanguage fields - D3D12Device::CompileShader supports both file and embedded source - OpenGLDevice::CompileShader supports embedded GLSL source - Refactor test_shader.cpp to use embedded source for both backends This enables consistent shader compilation across D3D12 and OpenGL backends while maintaining backend-specific shader language support. --- engine/include/XCEngine/RHI/RHIEnums.h | 7 ++ engine/include/XCEngine/RHI/RHITypes.h | 6 +- engine/src/RHI/D3D12/D3D12Device.cpp | 14 ++- engine/src/RHI/OpenGL/OpenGLDevice.cpp | 39 ++++-- tests/RHI/unit/test_shader.cpp | 157 +++++++++++++++++++------ 5 files changed, 178 insertions(+), 45 deletions(-) diff --git a/engine/include/XCEngine/RHI/RHIEnums.h b/engine/include/XCEngine/RHI/RHIEnums.h index f8e70bfc..d024313e 100644 --- a/engine/include/XCEngine/RHI/RHIEnums.h +++ b/engine/include/XCEngine/RHI/RHIEnums.h @@ -19,6 +19,13 @@ enum class ShaderType : uint8_t { Library }; +enum class ShaderLanguage : uint8_t { + Unknown, + HLSL, + GLSL, + SPIRV +}; + enum class CullMode : uint8_t { None, Front, diff --git a/engine/include/XCEngine/RHI/RHITypes.h b/engine/include/XCEngine/RHI/RHITypes.h index 764763ea..4da31637 100644 --- a/engine/include/XCEngine/RHI/RHITypes.h +++ b/engine/include/XCEngine/RHI/RHITypes.h @@ -43,10 +43,14 @@ struct ShaderCompileMacro { std::wstring definition; }; +enum class ShaderLanguage : uint8_t; + struct ShaderCompileDesc { + std::wstring fileName; + std::vector source; + ShaderLanguage sourceLanguage = ShaderLanguage::Unknown; std::wstring entryPoint; std::wstring profile; - std::wstring fileName; std::vector macros; }; diff --git a/engine/src/RHI/D3D12/D3D12Device.cpp b/engine/src/RHI/D3D12/D3D12Device.cpp index 7fe2bf9c..6de21baa 100644 --- a/engine/src/RHI/D3D12/D3D12Device.cpp +++ b/engine/src/RHI/D3D12/D3D12Device.cpp @@ -293,9 +293,17 @@ RHITexture* D3D12Device::CreateTexture(const TextureDesc& desc) { RHIShader* D3D12Device::CompileShader(const ShaderCompileDesc& desc) { auto* shader = new D3D12Shader(); - if (shader->CompileFromFile(desc.fileName.c_str(), - reinterpret_cast(desc.entryPoint.c_str()), - reinterpret_cast(desc.profile.c_str()))) { + const char* entryPoint = desc.entryPoint.empty() ? nullptr : reinterpret_cast(desc.entryPoint.c_str()); + const char* profile = desc.profile.empty() ? nullptr : reinterpret_cast(desc.profile.c_str()); + + bool success = false; + if (!desc.source.empty()) { + success = shader->Compile(desc.source.data(), desc.source.size(), entryPoint, profile); + } else if (!desc.fileName.empty()) { + success = shader->CompileFromFile(desc.fileName.c_str(), entryPoint, profile); + } + + if (success) { return shader; } delete shader; diff --git a/engine/src/RHI/OpenGL/OpenGLDevice.cpp b/engine/src/RHI/OpenGL/OpenGLDevice.cpp index 41ca7473..5da2c77c 100644 --- a/engine/src/RHI/OpenGL/OpenGLDevice.cpp +++ b/engine/src/RHI/OpenGL/OpenGLDevice.cpp @@ -353,15 +353,40 @@ RHICommandQueue* OpenGLDevice::CreateCommandQueue(const CommandQueueDesc& desc) } RHIShader* OpenGLDevice::CompileShader(const ShaderCompileDesc& desc) { - std::wstring filePath = desc.fileName; - if (filePath.empty()) { + auto* shader = new OpenGLShader(); + + if (desc.sourceLanguage == ShaderLanguage::GLSL && !desc.source.empty()) { + const char* sourceStr = reinterpret_cast(desc.source.data()); + ShaderType shaderType = ShaderType::Vertex; + + std::string profile(desc.profile.begin(), desc.profile.end()); + if (profile.find("vs") != std::string::npos) { + shaderType = ShaderType::Vertex; + } else if (profile.find("ps") != std::string::npos || profile.find("fs") != std::string::npos) { + shaderType = ShaderType::Fragment; + } else if (profile.find("gs") != std::string::npos) { + shaderType = ShaderType::Geometry; + } else if (profile.find("cs") != std::string::npos) { + shaderType = ShaderType::Compute; + } + + if (shader->Compile(sourceStr, shaderType)) { + return shader; + } + delete shader; return nullptr; } - auto* shader = new OpenGLShader(); - std::string entryPoint(desc.entryPoint.begin(), desc.entryPoint.end()); - std::string profile(desc.profile.begin(), desc.profile.end()); - shader->CompileFromFile(filePath.c_str(), entryPoint.c_str(), profile.c_str()); - return shader; + + if (!desc.fileName.empty()) { + std::wstring filePath = desc.fileName; + std::string entryPoint(desc.entryPoint.begin(), desc.entryPoint.end()); + std::string profile(desc.profile.begin(), desc.profile.end()); + shader->CompileFromFile(filePath.c_str(), entryPoint.c_str(), profile.c_str()); + return shader; + } + + delete shader; + return nullptr; } RHIPipelineState* OpenGLDevice::CreatePipelineState(const GraphicsPipelineDesc& desc) { diff --git a/tests/RHI/unit/test_shader.cpp b/tests/RHI/unit/test_shader.cpp index 5cf13814..910bfe8e 100644 --- a/tests/RHI/unit/test_shader.cpp +++ b/tests/RHI/unit/test_shader.cpp @@ -1,54 +1,143 @@ #include "fixtures/RHITestFixture.h" #include "XCEngine/RHI/RHIShader.h" +#include using namespace XCEngine::RHI; TEST_P(RHITestFixture, Shader_Compile_EmptyDesc_ReturnsNullptr) { - RHIShader* shader = GetDevice()->CompileShader({}); + ShaderCompileDesc desc = {}; + RHIShader* shader = GetDevice()->CompileShader(desc); EXPECT_EQ(shader, nullptr); } -TEST_P(RHITestFixture, Shader_GetType_WithNullShader) { - RHIShader* shader = GetDevice()->CompileShader({}); - EXPECT_EQ(shader, nullptr); +TEST_P(RHITestFixture, Shader_Compile_ValidVertexShader) { + ShaderCompileDesc desc = {}; + if (GetBackendType() == RHIType::D3D12) { + desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; + desc.entryPoint = L"MainVS"; + desc.profile = L"vs_5_0"; + } else { + desc.sourceLanguage = ShaderLanguage::GLSL; + static const char* vs = "#version 430\nin vec4 aPosition;\nvoid main() { gl_Position = aPosition; }"; + desc.source.assign(vs, vs + strlen(vs)); + desc.profile = L"vs"; + } + + RHIShader* shader = GetDevice()->CompileShader(desc); + if (shader != nullptr) { + EXPECT_TRUE(shader->IsValid()); + EXPECT_EQ(shader->GetType(), ShaderType::Vertex); + EXPECT_NE(shader->GetNativeHandle(), nullptr); + shader->Shutdown(); + delete shader; + } } -TEST_P(RHITestFixture, Shader_IsValid_WithNullShader) { - RHIShader* shader = GetDevice()->CompileShader({}); - EXPECT_EQ(shader, nullptr); +TEST_P(RHITestFixture, Shader_Compile_ValidFragmentShader) { + ShaderCompileDesc desc = {}; + if (GetBackendType() == RHIType::D3D12) { + desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; + desc.entryPoint = L"MainPS"; + desc.profile = L"ps_5_0"; + } else { + desc.sourceLanguage = ShaderLanguage::GLSL; + static const char* fs = "#version 430\nout vec4 c;\nvoid main() { c = vec4(1,0,0,1); }"; + desc.source.assign(fs, fs + strlen(fs)); + desc.profile = L"ps"; + } + + RHIShader* shader = GetDevice()->CompileShader(desc); + if (shader != nullptr) { + EXPECT_TRUE(shader->IsValid()); + EXPECT_EQ(shader->GetType(), ShaderType::Fragment); + EXPECT_NE(shader->GetNativeHandle(), nullptr); + shader->Shutdown(); + delete shader; + } } -TEST_P(RHITestFixture, Shader_Bind_WithNullShader) { - RHIShader* shader = GetDevice()->CompileShader({}); - EXPECT_EQ(shader, nullptr); +TEST_P(RHITestFixture, Shader_GetType_VertexShader) { + ShaderCompileDesc desc = {}; + if (GetBackendType() == RHIType::D3D12) { + desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; + desc.entryPoint = L"MainVS"; + desc.profile = L"vs_5_0"; + } else { + desc.sourceLanguage = ShaderLanguage::GLSL; + static const char* vs = "#version 430\nin vec4 aPosition;\nvoid main() { gl_Position = aPosition; }"; + desc.source.assign(vs, vs + strlen(vs)); + desc.profile = L"vs"; + } + + RHIShader* shader = GetDevice()->CompileShader(desc); + if (shader != nullptr) { + EXPECT_EQ(shader->GetType(), ShaderType::Vertex); + shader->Shutdown(); + delete shader; + } } -TEST_P(RHITestFixture, Shader_SetInt_WithNullShader) { - RHIShader* shader = GetDevice()->CompileShader({}); - EXPECT_EQ(shader, nullptr); +TEST_P(RHITestFixture, Shader_GetType_FragmentShader) { + ShaderCompileDesc desc = {}; + if (GetBackendType() == RHIType::D3D12) { + desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; + desc.entryPoint = L"MainPS"; + desc.profile = L"ps_5_0"; + } else { + desc.sourceLanguage = ShaderLanguage::GLSL; + static const char* fs = "#version 430\nout vec4 c;\nvoid main() { c = vec4(1,0,0,1); }"; + desc.source.assign(fs, fs + strlen(fs)); + desc.profile = L"ps"; + } + + RHIShader* shader = GetDevice()->CompileShader(desc); + if (shader != nullptr) { + EXPECT_EQ(shader->GetType(), ShaderType::Fragment); + shader->Shutdown(); + delete shader; + } } -TEST_P(RHITestFixture, Shader_SetFloat_WithNullShader) { - RHIShader* shader = GetDevice()->CompileShader({}); - EXPECT_EQ(shader, nullptr); +TEST_P(RHITestFixture, Shader_GetNativeHandle_ValidShader) { + ShaderCompileDesc desc = {}; + if (GetBackendType() == RHIType::D3D12) { + desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; + desc.entryPoint = L"MainVS"; + desc.profile = L"vs_5_0"; + } else { + desc.sourceLanguage = ShaderLanguage::GLSL; + static const char* vs = "#version 430\nin vec4 aPosition;\nvoid main() { gl_Position = aPosition; }"; + desc.source.assign(vs, vs + strlen(vs)); + desc.profile = L"vs"; + } + + RHIShader* shader = GetDevice()->CompileShader(desc); + if (shader != nullptr) { + void* handle = shader->GetNativeHandle(); + EXPECT_NE(handle, nullptr); + shader->Shutdown(); + delete shader; + } } -TEST_P(RHITestFixture, Shader_SetVec3_WithNullShader) { - RHIShader* shader = GetDevice()->CompileShader({}); - EXPECT_EQ(shader, nullptr); -} +TEST_P(RHITestFixture, Shader_Shutdown_Invalidates) { + ShaderCompileDesc desc = {}; + if (GetBackendType() == RHIType::D3D12) { + desc.fileName = L"tests/RHI/D3D12/integration/quad/Res/Shader/quad.hlsl"; + desc.entryPoint = L"MainVS"; + desc.profile = L"vs_5_0"; + } else { + desc.sourceLanguage = ShaderLanguage::GLSL; + static const char* vs = "#version 430\nin vec4 aPosition;\nvoid main() { gl_Position = aPosition; }"; + desc.source.assign(vs, vs + strlen(vs)); + desc.profile = L"vs"; + } -TEST_P(RHITestFixture, Shader_SetVec4_WithNullShader) { - RHIShader* shader = GetDevice()->CompileShader({}); - EXPECT_EQ(shader, nullptr); -} - -TEST_P(RHITestFixture, Shader_SetMat4_WithNullShader) { - RHIShader* shader = GetDevice()->CompileShader({}); - EXPECT_EQ(shader, nullptr); -} - -TEST_P(RHITestFixture, Shader_GetNativeHandle_WithNullShader) { - RHIShader* shader = GetDevice()->CompileShader({}); - EXPECT_EQ(shader, nullptr); -} + RHIShader* shader = GetDevice()->CompileShader(desc); + if (shader != nullptr) { + EXPECT_TRUE(shader->IsValid()); + shader->Shutdown(); + EXPECT_FALSE(shader->IsValid()); + delete shader; + } +} \ No newline at end of file