refactor: add shader pass and backend variants

This commit is contained in:
2026-04-02 16:10:50 +08:00
parent 35d9b1d465
commit 70ced2d91f
4 changed files with 297 additions and 2 deletions

View File

@@ -22,6 +22,13 @@ enum class ShaderLanguage : Core::uint8 {
SPIRV
};
enum class ShaderBackend : Core::uint8 {
Generic = 0,
D3D12,
OpenGL,
Vulkan
};
struct ShaderUniform {
Containers::String name;
Core::uint32 location;
@@ -36,6 +43,27 @@ struct ShaderAttribute {
Core::uint32 type;
};
struct ShaderPassTagEntry {
Containers::String name;
Containers::String value;
};
struct ShaderStageVariant {
ShaderType stage = ShaderType::Fragment;
ShaderLanguage language = ShaderLanguage::GLSL;
ShaderBackend backend = ShaderBackend::Generic;
Containers::String entryPoint;
Containers::String profile;
Containers::String sourceCode;
Containers::Array<Core::uint8> compiledBinary;
};
struct ShaderPass {
Containers::String name;
Containers::Array<ShaderPassTagEntry> tags;
Containers::Array<ShaderStageVariant> variants;
};
class Shader : public IResource {
public:
Shader();
@@ -49,10 +77,10 @@ public:
size_t GetMemorySize() const override { return m_memorySize; }
void Release() override;
void SetShaderType(ShaderType type) { m_shaderType = type; }
void SetShaderType(ShaderType type);
ShaderType GetShaderType() const { return m_shaderType; }
void SetShaderLanguage(ShaderLanguage lang) { m_language = lang; }
void SetShaderLanguage(ShaderLanguage lang);
ShaderLanguage GetShaderLanguage() const { return m_language; }
void SetSourceCode(const Containers::String& source);
@@ -67,10 +95,32 @@ public:
void AddAttribute(const ShaderAttribute& attribute);
const Containers::Array<ShaderAttribute>& GetAttributes() const { return m_attributes; }
void AddPass(const ShaderPass& pass);
void ClearPasses();
Core::uint32 GetPassCount() const { return static_cast<Core::uint32>(m_passes.Size()); }
const Containers::Array<ShaderPass>& GetPasses() const { return m_passes; }
void AddPassVariant(const Containers::String& passName, const ShaderStageVariant& variant);
void SetPassTag(
const Containers::String& passName,
const Containers::String& tagName,
const Containers::String& tagValue);
bool HasPass(const Containers::String& passName) const;
const ShaderPass* FindPass(const Containers::String& passName) const;
ShaderPass* FindPass(const Containers::String& passName);
const ShaderStageVariant* FindVariant(
const Containers::String& passName,
ShaderType stage,
ShaderBackend backend = ShaderBackend::Generic) const;
class IRHIShader* GetRHIResource() const { return m_rhiResource; }
void SetRHIResource(class IRHIShader* resource);
private:
ShaderPass& GetOrCreatePass(const Containers::String& passName);
ShaderStageVariant& GetOrCreateLegacyVariant();
void SyncLegacyVariant();
ShaderType m_shaderType = ShaderType::Fragment;
ShaderLanguage m_language = ShaderLanguage::GLSL;
@@ -79,6 +129,7 @@ private:
Containers::Array<ShaderUniform> m_uniforms;
Containers::Array<ShaderAttribute> m_attributes;
Containers::Array<ShaderPass> m_passes;
class IRHIShader* m_rhiResource = nullptr;
};

View File

@@ -3,25 +3,46 @@
namespace XCEngine {
namespace Resources {
namespace {
const char* kLegacyShaderPassName = "Default";
} // namespace
Shader::Shader() = default;
Shader::~Shader() = default;
void Shader::Release() {
m_shaderType = ShaderType::Fragment;
m_language = ShaderLanguage::GLSL;
m_sourceCode.Clear();
m_compiledBinary.Clear();
m_uniforms.Clear();
m_attributes.Clear();
m_passes.Clear();
m_rhiResource = nullptr;
m_isValid = false;
}
void Shader::SetShaderType(ShaderType type) {
m_shaderType = type;
SyncLegacyVariant();
}
void Shader::SetShaderLanguage(ShaderLanguage lang) {
m_language = lang;
SyncLegacyVariant();
}
void Shader::SetSourceCode(const Containers::String& source) {
m_sourceCode = source;
SyncLegacyVariant();
}
void Shader::SetCompiledBinary(const Containers::Array<Core::uint8>& binary) {
m_compiledBinary = binary;
SyncLegacyVariant();
}
void Shader::AddUniform(const ShaderUniform& uniform) {
@@ -32,9 +53,123 @@ void Shader::AddAttribute(const ShaderAttribute& attribute) {
m_attributes.PushBack(attribute);
}
void Shader::AddPass(const ShaderPass& pass) {
m_passes.PushBack(pass);
}
void Shader::ClearPasses() {
m_passes.Clear();
}
void Shader::AddPassVariant(
const Containers::String& passName,
const ShaderStageVariant& variant) {
ShaderPass& pass = GetOrCreatePass(passName);
pass.variants.PushBack(variant);
}
void Shader::SetPassTag(
const Containers::String& passName,
const Containers::String& tagName,
const Containers::String& tagValue) {
ShaderPass& pass = GetOrCreatePass(passName);
for (ShaderPassTagEntry& tag : pass.tags) {
if (tag.name == tagName) {
tag.value = tagValue;
return;
}
}
ShaderPassTagEntry& tag = pass.tags.EmplaceBack();
tag.name = tagName;
tag.value = tagValue;
}
bool Shader::HasPass(const Containers::String& passName) const {
return FindPass(passName) != nullptr;
}
const ShaderPass* Shader::FindPass(const Containers::String& passName) const {
for (const ShaderPass& pass : m_passes) {
if (pass.name == passName) {
return &pass;
}
}
return nullptr;
}
ShaderPass* Shader::FindPass(const Containers::String& passName) {
for (ShaderPass& pass : m_passes) {
if (pass.name == passName) {
return &pass;
}
}
return nullptr;
}
const ShaderStageVariant* Shader::FindVariant(
const Containers::String& passName,
ShaderType stage,
ShaderBackend backend) const {
const ShaderPass* pass = FindPass(passName);
if (pass == nullptr) {
return nullptr;
}
const ShaderStageVariant* genericVariant = nullptr;
for (const ShaderStageVariant& variant : pass->variants) {
if (variant.stage != stage) {
continue;
}
if (variant.backend == backend) {
return &variant;
}
if (variant.backend == ShaderBackend::Generic && genericVariant == nullptr) {
genericVariant = &variant;
}
}
return genericVariant;
}
void Shader::SetRHIResource(class IRHIShader* resource) {
m_rhiResource = resource;
}
ShaderPass& Shader::GetOrCreatePass(const Containers::String& passName) {
if (ShaderPass* pass = FindPass(passName)) {
return *pass;
}
ShaderPass& pass = m_passes.EmplaceBack();
pass.name = passName;
return pass;
}
ShaderStageVariant& Shader::GetOrCreateLegacyVariant() {
ShaderPass& pass = GetOrCreatePass(kLegacyShaderPassName);
for (ShaderStageVariant& variant : pass.variants) {
if (variant.backend == ShaderBackend::Generic) {
return variant;
}
}
ShaderStageVariant& variant = pass.variants.EmplaceBack();
variant.backend = ShaderBackend::Generic;
return variant;
}
void Shader::SyncLegacyVariant() {
ShaderStageVariant& variant = GetOrCreateLegacyVariant();
variant.stage = m_shaderType;
variant.language = m_language;
variant.sourceCode = m_sourceCode;
variant.compiledBinary = m_compiledBinary;
}
} // namespace Resources
} // namespace XCEngine

View File

@@ -100,4 +100,78 @@ TEST(Shader, AddGetAttributes) {
EXPECT_EQ(attributes[0].name, "aPosition");
}
TEST(Shader, LegacySingleStageStateSyncsIntoDefaultPassVariant) {
Shader shader;
shader.SetShaderType(ShaderType::Vertex);
shader.SetShaderLanguage(ShaderLanguage::HLSL);
shader.SetSourceCode("float4 MainVS() : SV_POSITION { return 0; }");
ASSERT_EQ(shader.GetPassCount(), 1u);
const ShaderPass* pass = shader.FindPass("Default");
ASSERT_NE(pass, nullptr);
ASSERT_EQ(pass->variants.Size(), 1u);
EXPECT_EQ(pass->variants[0].stage, ShaderType::Vertex);
EXPECT_EQ(pass->variants[0].language, ShaderLanguage::HLSL);
EXPECT_EQ(pass->variants[0].backend, ShaderBackend::Generic);
EXPECT_EQ(pass->variants[0].sourceCode, "float4 MainVS() : SV_POSITION { return 0; }");
}
TEST(Shader, FindsBackendSpecificVariantAndFallsBackToGeneric) {
Shader shader;
ShaderStageVariant genericFragment = {};
genericFragment.stage = ShaderType::Fragment;
genericFragment.language = ShaderLanguage::GLSL;
genericFragment.backend = ShaderBackend::Generic;
genericFragment.sourceCode = "generic fragment";
shader.AddPassVariant("ForwardLit", genericFragment);
ShaderStageVariant d3d12Fragment = {};
d3d12Fragment.stage = ShaderType::Fragment;
d3d12Fragment.language = ShaderLanguage::HLSL;
d3d12Fragment.backend = ShaderBackend::D3D12;
d3d12Fragment.sourceCode = "d3d12 fragment";
shader.AddPassVariant("ForwardLit", d3d12Fragment);
const ShaderStageVariant* d3d12Variant =
shader.FindVariant("ForwardLit", ShaderType::Fragment, ShaderBackend::D3D12);
ASSERT_NE(d3d12Variant, nullptr);
EXPECT_EQ(d3d12Variant->sourceCode, "d3d12 fragment");
const ShaderStageVariant* openglVariant =
shader.FindVariant("ForwardLit", ShaderType::Fragment, ShaderBackend::OpenGL);
ASSERT_NE(openglVariant, nullptr);
EXPECT_EQ(openglVariant->sourceCode, "generic fragment");
}
TEST(Shader, StoresPerPassTags) {
Shader shader;
shader.SetPassTag("ForwardLit", "LightMode", "ForwardBase");
shader.SetPassTag("ForwardLit", "Queue", "Geometry");
const ShaderPass* pass = shader.FindPass("ForwardLit");
ASSERT_NE(pass, nullptr);
ASSERT_EQ(pass->tags.Size(), 2u);
EXPECT_EQ(pass->tags[0].name, "LightMode");
EXPECT_EQ(pass->tags[0].value, "ForwardBase");
EXPECT_EQ(pass->tags[1].name, "Queue");
EXPECT_EQ(pass->tags[1].value, "Geometry");
}
TEST(Shader, ReleaseClearsPassRuntimeData) {
Shader shader;
shader.SetSourceCode("void main() {}");
ShaderStageVariant variant = {};
variant.stage = ShaderType::Fragment;
variant.sourceCode = "fragment";
shader.AddPassVariant("ForwardLit", variant);
shader.Release();
EXPECT_EQ(shader.GetPassCount(), 0u);
EXPECT_EQ(shader.GetSourceCode(), "");
EXPECT_EQ(shader.GetCompiledBinary().Size(), 0u);
EXPECT_FALSE(shader.IsValid());
}
} // namespace

View File

@@ -3,6 +3,10 @@
#include <XCEngine/Core/Asset/ResourceTypes.h>
#include <XCEngine/Core/Containers/Array.h>
#include <cstdio>
#include <filesystem>
#include <fstream>
using namespace XCEngine::Resources;
using namespace XCEngine::Containers;
@@ -35,4 +39,35 @@ TEST(ShaderLoader, LoadInvalidPath) {
EXPECT_FALSE(result);
}
TEST(ShaderLoader, LoadLegacySingleStageShaderBuildsDefaultPassVariant) {
namespace fs = std::filesystem;
const fs::path shaderPath = fs::temp_directory_path() / "xc_shader_loader_stage.vert";
{
std::ofstream shaderFile(shaderPath);
ASSERT_TRUE(shaderFile.is_open());
shaderFile << "#version 430\nvoid main() {}";
}
ShaderLoader loader;
LoadResult result = loader.Load(shaderPath.string().c_str());
ASSERT_TRUE(result);
ASSERT_NE(result.resource, nullptr);
Shader* shader = static_cast<Shader*>(result.resource);
ASSERT_NE(shader, nullptr);
EXPECT_EQ(shader->GetShaderType(), ShaderType::Vertex);
ASSERT_EQ(shader->GetPassCount(), 1u);
const ShaderPass* pass = shader->FindPass("Default");
ASSERT_NE(pass, nullptr);
ASSERT_EQ(pass->variants.Size(), 1u);
EXPECT_EQ(pass->variants[0].stage, ShaderType::Vertex);
EXPECT_EQ(pass->variants[0].backend, ShaderBackend::Generic);
EXPECT_EQ(pass->variants[0].sourceCode, shader->GetSourceCode());
delete shader;
std::remove(shaderPath.string().c_str());
}
} // namespace