diff --git a/engine/include/XCEngine/Resources/Shader/Shader.h b/engine/include/XCEngine/Resources/Shader/Shader.h index 13f58db8..9a2cd01f 100644 --- a/engine/include/XCEngine/Resources/Shader/Shader.h +++ b/engine/include/XCEngine/Resources/Shader/Shader.h @@ -22,6 +22,13 @@ enum class ShaderLanguage : Core::uint8 { SPIRV }; +enum class ShaderBackend : Core::uint8 { + Generic = 0, + D3D12, + OpenGL, + Vulkan +}; + struct ShaderUniform { Containers::String name; Core::uint32 location; @@ -36,6 +43,27 @@ struct ShaderAttribute { Core::uint32 type; }; +struct ShaderPassTagEntry { + Containers::String name; + Containers::String value; +}; + +struct ShaderStageVariant { + ShaderType stage = ShaderType::Fragment; + ShaderLanguage language = ShaderLanguage::GLSL; + ShaderBackend backend = ShaderBackend::Generic; + Containers::String entryPoint; + Containers::String profile; + Containers::String sourceCode; + Containers::Array compiledBinary; +}; + +struct ShaderPass { + Containers::String name; + Containers::Array tags; + Containers::Array variants; +}; + class Shader : public IResource { public: Shader(); @@ -49,10 +77,10 @@ public: size_t GetMemorySize() const override { return m_memorySize; } void Release() override; - void SetShaderType(ShaderType type) { m_shaderType = type; } + void SetShaderType(ShaderType type); ShaderType GetShaderType() const { return m_shaderType; } - void SetShaderLanguage(ShaderLanguage lang) { m_language = lang; } + void SetShaderLanguage(ShaderLanguage lang); ShaderLanguage GetShaderLanguage() const { return m_language; } void SetSourceCode(const Containers::String& source); @@ -66,11 +94,33 @@ public: void AddAttribute(const ShaderAttribute& attribute); const Containers::Array& GetAttributes() const { return m_attributes; } + + void AddPass(const ShaderPass& pass); + void ClearPasses(); + Core::uint32 GetPassCount() const { return static_cast(m_passes.Size()); } + const Containers::Array& GetPasses() const { return m_passes; } + + void AddPassVariant(const Containers::String& passName, const ShaderStageVariant& variant); + void SetPassTag( + const Containers::String& passName, + const Containers::String& tagName, + const Containers::String& tagValue); + bool HasPass(const Containers::String& passName) const; + const ShaderPass* FindPass(const Containers::String& passName) const; + ShaderPass* FindPass(const Containers::String& passName); + const ShaderStageVariant* FindVariant( + const Containers::String& passName, + ShaderType stage, + ShaderBackend backend = ShaderBackend::Generic) const; class IRHIShader* GetRHIResource() const { return m_rhiResource; } void SetRHIResource(class IRHIShader* resource); private: + ShaderPass& GetOrCreatePass(const Containers::String& passName); + ShaderStageVariant& GetOrCreateLegacyVariant(); + void SyncLegacyVariant(); + ShaderType m_shaderType = ShaderType::Fragment; ShaderLanguage m_language = ShaderLanguage::GLSL; @@ -79,6 +129,7 @@ private: Containers::Array m_uniforms; Containers::Array m_attributes; + Containers::Array m_passes; class IRHIShader* m_rhiResource = nullptr; }; diff --git a/engine/src/Resources/Shader/Shader.cpp b/engine/src/Resources/Shader/Shader.cpp index c11032f2..d8249eaa 100644 --- a/engine/src/Resources/Shader/Shader.cpp +++ b/engine/src/Resources/Shader/Shader.cpp @@ -3,25 +3,46 @@ namespace XCEngine { namespace Resources { +namespace { + +const char* kLegacyShaderPassName = "Default"; + +} // namespace + Shader::Shader() = default; Shader::~Shader() = default; void Shader::Release() { + m_shaderType = ShaderType::Fragment; + m_language = ShaderLanguage::GLSL; m_sourceCode.Clear(); m_compiledBinary.Clear(); m_uniforms.Clear(); m_attributes.Clear(); + m_passes.Clear(); m_rhiResource = nullptr; m_isValid = false; } +void Shader::SetShaderType(ShaderType type) { + m_shaderType = type; + SyncLegacyVariant(); +} + +void Shader::SetShaderLanguage(ShaderLanguage lang) { + m_language = lang; + SyncLegacyVariant(); +} + void Shader::SetSourceCode(const Containers::String& source) { m_sourceCode = source; + SyncLegacyVariant(); } void Shader::SetCompiledBinary(const Containers::Array& binary) { m_compiledBinary = binary; + SyncLegacyVariant(); } void Shader::AddUniform(const ShaderUniform& uniform) { @@ -32,9 +53,123 @@ void Shader::AddAttribute(const ShaderAttribute& attribute) { m_attributes.PushBack(attribute); } +void Shader::AddPass(const ShaderPass& pass) { + m_passes.PushBack(pass); +} + +void Shader::ClearPasses() { + m_passes.Clear(); +} + +void Shader::AddPassVariant( + const Containers::String& passName, + const ShaderStageVariant& variant) { + ShaderPass& pass = GetOrCreatePass(passName); + pass.variants.PushBack(variant); +} + +void Shader::SetPassTag( + const Containers::String& passName, + const Containers::String& tagName, + const Containers::String& tagValue) { + ShaderPass& pass = GetOrCreatePass(passName); + for (ShaderPassTagEntry& tag : pass.tags) { + if (tag.name == tagName) { + tag.value = tagValue; + return; + } + } + + ShaderPassTagEntry& tag = pass.tags.EmplaceBack(); + tag.name = tagName; + tag.value = tagValue; +} + +bool Shader::HasPass(const Containers::String& passName) const { + return FindPass(passName) != nullptr; +} + +const ShaderPass* Shader::FindPass(const Containers::String& passName) const { + for (const ShaderPass& pass : m_passes) { + if (pass.name == passName) { + return &pass; + } + } + + return nullptr; +} + +ShaderPass* Shader::FindPass(const Containers::String& passName) { + for (ShaderPass& pass : m_passes) { + if (pass.name == passName) { + return &pass; + } + } + + return nullptr; +} + +const ShaderStageVariant* Shader::FindVariant( + const Containers::String& passName, + ShaderType stage, + ShaderBackend backend) const { + const ShaderPass* pass = FindPass(passName); + if (pass == nullptr) { + return nullptr; + } + + const ShaderStageVariant* genericVariant = nullptr; + for (const ShaderStageVariant& variant : pass->variants) { + if (variant.stage != stage) { + continue; + } + + if (variant.backend == backend) { + return &variant; + } + + if (variant.backend == ShaderBackend::Generic && genericVariant == nullptr) { + genericVariant = &variant; + } + } + + return genericVariant; +} + void Shader::SetRHIResource(class IRHIShader* resource) { m_rhiResource = resource; } +ShaderPass& Shader::GetOrCreatePass(const Containers::String& passName) { + if (ShaderPass* pass = FindPass(passName)) { + return *pass; + } + + ShaderPass& pass = m_passes.EmplaceBack(); + pass.name = passName; + return pass; +} + +ShaderStageVariant& Shader::GetOrCreateLegacyVariant() { + ShaderPass& pass = GetOrCreatePass(kLegacyShaderPassName); + for (ShaderStageVariant& variant : pass.variants) { + if (variant.backend == ShaderBackend::Generic) { + return variant; + } + } + + ShaderStageVariant& variant = pass.variants.EmplaceBack(); + variant.backend = ShaderBackend::Generic; + return variant; +} + +void Shader::SyncLegacyVariant() { + ShaderStageVariant& variant = GetOrCreateLegacyVariant(); + variant.stage = m_shaderType; + variant.language = m_language; + variant.sourceCode = m_sourceCode; + variant.compiledBinary = m_compiledBinary; +} + } // namespace Resources } // namespace XCEngine diff --git a/tests/Resources/Shader/test_shader.cpp b/tests/Resources/Shader/test_shader.cpp index 4caf44dd..7bd49fca 100644 --- a/tests/Resources/Shader/test_shader.cpp +++ b/tests/Resources/Shader/test_shader.cpp @@ -100,4 +100,78 @@ TEST(Shader, AddGetAttributes) { EXPECT_EQ(attributes[0].name, "aPosition"); } +TEST(Shader, LegacySingleStageStateSyncsIntoDefaultPassVariant) { + Shader shader; + shader.SetShaderType(ShaderType::Vertex); + shader.SetShaderLanguage(ShaderLanguage::HLSL); + shader.SetSourceCode("float4 MainVS() : SV_POSITION { return 0; }"); + + ASSERT_EQ(shader.GetPassCount(), 1u); + const ShaderPass* pass = shader.FindPass("Default"); + ASSERT_NE(pass, nullptr); + ASSERT_EQ(pass->variants.Size(), 1u); + EXPECT_EQ(pass->variants[0].stage, ShaderType::Vertex); + EXPECT_EQ(pass->variants[0].language, ShaderLanguage::HLSL); + EXPECT_EQ(pass->variants[0].backend, ShaderBackend::Generic); + EXPECT_EQ(pass->variants[0].sourceCode, "float4 MainVS() : SV_POSITION { return 0; }"); +} + +TEST(Shader, FindsBackendSpecificVariantAndFallsBackToGeneric) { + Shader shader; + + ShaderStageVariant genericFragment = {}; + genericFragment.stage = ShaderType::Fragment; + genericFragment.language = ShaderLanguage::GLSL; + genericFragment.backend = ShaderBackend::Generic; + genericFragment.sourceCode = "generic fragment"; + shader.AddPassVariant("ForwardLit", genericFragment); + + ShaderStageVariant d3d12Fragment = {}; + d3d12Fragment.stage = ShaderType::Fragment; + d3d12Fragment.language = ShaderLanguage::HLSL; + d3d12Fragment.backend = ShaderBackend::D3D12; + d3d12Fragment.sourceCode = "d3d12 fragment"; + shader.AddPassVariant("ForwardLit", d3d12Fragment); + + const ShaderStageVariant* d3d12Variant = + shader.FindVariant("ForwardLit", ShaderType::Fragment, ShaderBackend::D3D12); + ASSERT_NE(d3d12Variant, nullptr); + EXPECT_EQ(d3d12Variant->sourceCode, "d3d12 fragment"); + + const ShaderStageVariant* openglVariant = + shader.FindVariant("ForwardLit", ShaderType::Fragment, ShaderBackend::OpenGL); + ASSERT_NE(openglVariant, nullptr); + EXPECT_EQ(openglVariant->sourceCode, "generic fragment"); +} + +TEST(Shader, StoresPerPassTags) { + Shader shader; + shader.SetPassTag("ForwardLit", "LightMode", "ForwardBase"); + shader.SetPassTag("ForwardLit", "Queue", "Geometry"); + + const ShaderPass* pass = shader.FindPass("ForwardLit"); + ASSERT_NE(pass, nullptr); + ASSERT_EQ(pass->tags.Size(), 2u); + EXPECT_EQ(pass->tags[0].name, "LightMode"); + EXPECT_EQ(pass->tags[0].value, "ForwardBase"); + EXPECT_EQ(pass->tags[1].name, "Queue"); + EXPECT_EQ(pass->tags[1].value, "Geometry"); +} + +TEST(Shader, ReleaseClearsPassRuntimeData) { + Shader shader; + shader.SetSourceCode("void main() {}"); + ShaderStageVariant variant = {}; + variant.stage = ShaderType::Fragment; + variant.sourceCode = "fragment"; + shader.AddPassVariant("ForwardLit", variant); + + shader.Release(); + + EXPECT_EQ(shader.GetPassCount(), 0u); + EXPECT_EQ(shader.GetSourceCode(), ""); + EXPECT_EQ(shader.GetCompiledBinary().Size(), 0u); + EXPECT_FALSE(shader.IsValid()); +} + } // namespace diff --git a/tests/Resources/Shader/test_shader_loader.cpp b/tests/Resources/Shader/test_shader_loader.cpp index c43abe5e..02f4b32e 100644 --- a/tests/Resources/Shader/test_shader_loader.cpp +++ b/tests/Resources/Shader/test_shader_loader.cpp @@ -3,6 +3,10 @@ #include #include +#include +#include +#include + using namespace XCEngine::Resources; using namespace XCEngine::Containers; @@ -35,4 +39,35 @@ TEST(ShaderLoader, LoadInvalidPath) { EXPECT_FALSE(result); } +TEST(ShaderLoader, LoadLegacySingleStageShaderBuildsDefaultPassVariant) { + namespace fs = std::filesystem; + + const fs::path shaderPath = fs::temp_directory_path() / "xc_shader_loader_stage.vert"; + { + std::ofstream shaderFile(shaderPath); + ASSERT_TRUE(shaderFile.is_open()); + shaderFile << "#version 430\nvoid main() {}"; + } + + ShaderLoader loader; + LoadResult result = loader.Load(shaderPath.string().c_str()); + ASSERT_TRUE(result); + ASSERT_NE(result.resource, nullptr); + + Shader* shader = static_cast(result.resource); + ASSERT_NE(shader, nullptr); + EXPECT_EQ(shader->GetShaderType(), ShaderType::Vertex); + ASSERT_EQ(shader->GetPassCount(), 1u); + + const ShaderPass* pass = shader->FindPass("Default"); + ASSERT_NE(pass, nullptr); + ASSERT_EQ(pass->variants.Size(), 1u); + EXPECT_EQ(pass->variants[0].stage, ShaderType::Vertex); + EXPECT_EQ(pass->variants[0].backend, ShaderBackend::Generic); + EXPECT_EQ(pass->variants[0].sourceCode, shader->GetSourceCode()); + + delete shader; + std::remove(shaderPath.string().c_str()); +} + } // namespace