rendering: add keyword-aware shader variant selection

This commit is contained in:
2026-04-06 19:37:01 +08:00
parent a8b4da16a3
commit 261dd44fd5
26 changed files with 469 additions and 76 deletions

View File

@@ -569,12 +569,17 @@ bool WriteShaderArtifactFile(const fs::path& artifactPath, const Shader& shader)
variantHeader.stage = static_cast<Core::uint32>(variant.stage);
variantHeader.language = static_cast<Core::uint32>(variant.language);
variantHeader.backend = static_cast<Core::uint32>(variant.backend);
variantHeader.keywordCount =
static_cast<Core::uint32>(variant.requiredKeywords.enabledKeywords.Size());
variantHeader.compiledBinarySize = static_cast<Core::uint64>(variant.compiledBinary.Size());
output.write(reinterpret_cast<const char*>(&variantHeader), sizeof(variantHeader));
WriteString(output, variant.entryPoint);
WriteString(output, variant.profile);
WriteString(output, variant.sourceCode);
for (const Containers::String& keyword : variant.requiredKeywords.enabledKeywords) {
WriteString(output, keyword);
}
if (!variant.compiledBinary.Empty()) {
output.write(
reinterpret_cast<const char*>(variant.compiledBinary.Data()),

View File

@@ -55,12 +55,38 @@ inline void ApplyShaderStageVariant(
compileDesc.profile = ToWideAscii(variant.profile);
}
inline Containers::String BuildShaderKeywordSignature(
const Resources::ShaderKeywordSet& keywordSet) {
Resources::ShaderKeywordSet normalizedKeywords = keywordSet;
Resources::NormalizeShaderKeywordSetInPlace(normalizedKeywords);
Containers::String signature;
for (size_t keywordIndex = 0; keywordIndex < normalizedKeywords.enabledKeywords.Size(); ++keywordIndex) {
if (keywordIndex > 0) {
signature += ";";
}
signature += normalizedKeywords.enabledKeywords[keywordIndex];
}
return signature;
}
inline bool ShaderPassHasGraphicsVariants(
const Resources::Shader& shader,
const Containers::String& passName,
Resources::ShaderBackend backend) {
return shader.FindVariant(passName, Resources::ShaderType::Vertex, backend) != nullptr &&
shader.FindVariant(passName, Resources::ShaderType::Fragment, backend) != nullptr;
Resources::ShaderBackend backend,
const Resources::ShaderKeywordSet& enabledKeywords = Resources::ShaderKeywordSet()) {
return shader.FindVariant(
passName,
Resources::ShaderType::Vertex,
backend,
enabledKeywords) != nullptr &&
shader.FindVariant(
passName,
Resources::ShaderType::Fragment,
backend,
enabledKeywords) != nullptr;
}
} // namespace Detail

View File

@@ -81,12 +81,14 @@ RHI::GraphicsPipelineDesc CreatePipelineDesc(
pipelineDesc.depthStencilState.depthFunc = static_cast<uint32_t>(RHI::ComparisonFunc::LessEqual);
const Resources::ShaderBackend backend = ::XCEngine::Rendering::Detail::ToShaderBackend(backendType);
const Resources::ShaderKeywordSet keywordSet =
material != nullptr ? material->GetKeywordSet() : Resources::ShaderKeywordSet();
if (const Resources::ShaderStageVariant* vertexVariant =
shader.FindVariant(passName, Resources::ShaderType::Vertex, backend)) {
shader.FindVariant(passName, Resources::ShaderType::Vertex, backend, keywordSet)) {
::XCEngine::Rendering::Detail::ApplyShaderStageVariant(*vertexVariant, pipelineDesc.vertexShader);
}
if (const Resources::ShaderStageVariant* fragmentVariant =
shader.FindVariant(passName, Resources::ShaderType::Fragment, backend)) {
shader.FindVariant(passName, Resources::ShaderType::Fragment, backend, keywordSet)) {
::XCEngine::Rendering::Detail::ApplyShaderStageVariant(*fragmentVariant, pipelineDesc.fragmentShader);
}
@@ -165,6 +167,8 @@ BuiltinDepthStylePassBase::ResolvedShaderPass BuiltinDepthStylePassBase::Resolve
}
const bool shaderHasExplicitBuiltinMetadata = ShaderHasExplicitBuiltinMetadata(*shader);
const Resources::ShaderKeywordSet keywordSet =
ownerMaterial != nullptr ? ownerMaterial->GetKeywordSet() : Resources::ShaderKeywordSet();
auto tryAcceptPass =
[this, shader, &resolved](const Resources::ShaderPass& shaderPass) -> bool {
@@ -182,7 +186,11 @@ BuiltinDepthStylePassBase::ResolvedShaderPass BuiltinDepthStylePassBase::Resolve
for (const Resources::ShaderPass& shaderPass : shader->GetPasses()) {
if (!ShaderPassMatchesBuiltinPass(shaderPass, m_passType) ||
!::XCEngine::Rendering::Detail::ShaderPassHasGraphicsVariants(*shader, shaderPass.name, backend)) {
!::XCEngine::Rendering::Detail::ShaderPassHasGraphicsVariants(
*shader,
shaderPass.name,
backend,
keywordSet)) {
continue;
}
@@ -199,7 +207,8 @@ BuiltinDepthStylePassBase::ResolvedShaderPass BuiltinDepthStylePassBase::Resolve
::XCEngine::Rendering::Detail::ShaderPassHasGraphicsVariants(
*shader,
explicitPass->name,
backend) &&
backend,
keywordSet) &&
tryAcceptPass(*explicitPass)) {
return true;
}
@@ -340,6 +349,9 @@ RHI::RHIPipelineState* BuiltinDepthStylePassBase::GetOrCreatePipelineState(
material != nullptr ? material->GetRenderState() : Resources::MaterialRenderState();
pipelineKey.shader = resolvedShaderPass.shader;
pipelineKey.passName = resolvedShaderPass.passName;
pipelineKey.keywordSignature =
::XCEngine::Rendering::Detail::BuildShaderKeywordSignature(
material != nullptr ? material->GetKeywordSet() : Resources::ShaderKeywordSet());
pipelineKey.renderTargetCount = ResolveSurfaceColorAttachmentCount(surface);
pipelineKey.renderTargetFormat = static_cast<uint32_t>(ResolveSurfaceColorFormat(surface));
pipelineKey.depthStencilFormat = static_cast<uint32_t>(ResolveSurfaceDepthFormat(surface));

View File

@@ -29,10 +29,16 @@ const Resources::ShaderPass* FindCompatibleSurfacePass(
BuiltinMaterialPass pass,
Resources::ShaderBackend backend) {
const bool shaderHasExplicitBuiltinMetadata = ShaderHasExplicitBuiltinMetadata(shader);
const Resources::ShaderKeywordSet keywordSet =
material != nullptr ? material->GetKeywordSet() : Resources::ShaderKeywordSet();
for (const Resources::ShaderPass& shaderPass : shader.GetPasses()) {
if (ShaderPassMatchesBuiltinPass(shaderPass, pass) &&
::XCEngine::Rendering::Detail::ShaderPassHasGraphicsVariants(shader, shaderPass.name, backend)) {
::XCEngine::Rendering::Detail::ShaderPassHasGraphicsVariants(
shader,
shaderPass.name,
backend,
keywordSet)) {
return &shaderPass;
}
}
@@ -42,7 +48,11 @@ const Resources::ShaderPass* FindCompatibleSurfacePass(
!material->GetShaderPass().Empty()) {
const Resources::ShaderPass* explicitPass = shader.FindPass(material->GetShaderPass());
if (explicitPass != nullptr &&
::XCEngine::Rendering::Detail::ShaderPassHasGraphicsVariants(shader, explicitPass->name, backend)) {
::XCEngine::Rendering::Detail::ShaderPassHasGraphicsVariants(
shader,
explicitPass->name,
backend,
keywordSet)) {
return explicitPass;
}
}
@@ -57,18 +67,30 @@ const Resources::ShaderPass* FindCompatibleSurfacePass(
const Resources::ShaderPass* defaultPass = shader.FindPass("ForwardLit");
if (defaultPass != nullptr &&
::XCEngine::Rendering::Detail::ShaderPassHasGraphicsVariants(shader, defaultPass->name, backend)) {
::XCEngine::Rendering::Detail::ShaderPassHasGraphicsVariants(
shader,
defaultPass->name,
backend,
keywordSet)) {
return defaultPass;
}
defaultPass = shader.FindPass("Default");
if (defaultPass != nullptr &&
::XCEngine::Rendering::Detail::ShaderPassHasGraphicsVariants(shader, defaultPass->name, backend)) {
::XCEngine::Rendering::Detail::ShaderPassHasGraphicsVariants(
shader,
defaultPass->name,
backend,
keywordSet)) {
return defaultPass;
}
if (shader.GetPassCount() > 0 &&
::XCEngine::Rendering::Detail::ShaderPassHasGraphicsVariants(shader, shader.GetPasses()[0].name, backend)) {
::XCEngine::Rendering::Detail::ShaderPassHasGraphicsVariants(
shader,
shader.GetPasses()[0].name,
backend,
keywordSet)) {
return &shader.GetPasses()[0];
}
@@ -93,10 +115,12 @@ RHI::GraphicsPipelineDesc CreatePipelineDesc(
pipelineDesc.inputLayout = BuiltinForwardPipeline::BuildInputLayout();
const Resources::ShaderBackend backend = ::XCEngine::Rendering::Detail::ToShaderBackend(backendType);
const Resources::ShaderKeywordSet keywordSet =
material != nullptr ? material->GetKeywordSet() : Resources::ShaderKeywordSet();
const Resources::ShaderStageVariant* vertexVariant =
shader.FindVariant(passName, Resources::ShaderType::Vertex, backend);
shader.FindVariant(passName, Resources::ShaderType::Vertex, backend, keywordSet);
const Resources::ShaderStageVariant* fragmentVariant =
shader.FindVariant(passName, Resources::ShaderType::Fragment, backend);
shader.FindVariant(passName, Resources::ShaderType::Fragment, backend, keywordSet);
if (vertexVariant != nullptr) {
::XCEngine::Rendering::Detail::ApplyShaderStageVariant(*vertexVariant, pipelineDesc.vertexShader);
}
@@ -269,6 +293,9 @@ RHI::RHIPipelineState* BuiltinForwardPipeline::GetOrCreatePipelineState(
material != nullptr ? material->GetRenderState() : Resources::MaterialRenderState();
pipelineKey.shader = resolvedShaderPass.shader;
pipelineKey.passName = resolvedShaderPass.passName;
pipelineKey.keywordSignature =
::XCEngine::Rendering::Detail::BuildShaderKeywordSignature(
material != nullptr ? material->GetKeywordSet() : Resources::ShaderKeywordSet());
const auto existing = m_pipelineStates.find(pipelineKey);
if (existing != m_pipelineStates.end()) {

View File

@@ -95,21 +95,6 @@ bool IsTextureMaterialPropertyType(MaterialPropertyType type) {
return type == MaterialPropertyType::Texture || type == MaterialPropertyType::Cubemap;
}
Containers::String NormalizeShaderKeyword(const Containers::String& keyword) {
const Containers::String normalized = keyword.Trim();
if (normalized.Empty() ||
normalized == Containers::String("_") ||
normalized == Containers::String("__")) {
return Containers::String();
}
return normalized;
}
bool CompareShaderKeywords(const Containers::String& left, const Containers::String& right) {
return std::strcmp(left.CStr(), right.CStr()) < 0;
}
MaterialPropertyType GetMaterialPropertyTypeForShaderProperty(ShaderPropertyType type) {
switch (type) {
case ShaderPropertyType::Float:
@@ -439,7 +424,7 @@ Containers::String Material::GetTagValue(Core::uint32 index) const {
}
void Material::EnableKeyword(const Containers::String& keyword) {
const Containers::String normalizedKeyword = NormalizeShaderKeyword(keyword);
const Containers::String normalizedKeyword = NormalizeShaderKeywordToken(keyword);
if (normalizedKeyword.Empty()) {
return;
}
@@ -458,12 +443,12 @@ void Material::EnableKeyword(const Containers::String& keyword) {
std::sort(
m_keywordSet.enabledKeywords.begin(),
m_keywordSet.enabledKeywords.end(),
CompareShaderKeywords);
CompareShaderKeywordTokens);
MarkChanged(false);
}
void Material::DisableKeyword(const Containers::String& keyword) {
const Containers::String normalizedKeyword = NormalizeShaderKeyword(keyword);
const Containers::String normalizedKeyword = NormalizeShaderKeywordToken(keyword);
if (normalizedKeyword.Empty()) {
return;
}
@@ -477,7 +462,7 @@ void Material::DisableKeyword(const Containers::String& keyword) {
std::sort(
m_keywordSet.enabledKeywords.begin(),
m_keywordSet.enabledKeywords.end(),
CompareShaderKeywords);
CompareShaderKeywordTokens);
MarkChanged(false);
return;
}
@@ -493,7 +478,7 @@ void Material::SetKeywordEnabled(const Containers::String& keyword, bool enabled
}
bool Material::IsKeywordEnabled(const Containers::String& keyword) const {
const Containers::String normalizedKeyword = NormalizeShaderKeyword(keyword);
const Containers::String normalizedKeyword = NormalizeShaderKeywordToken(keyword);
if (normalizedKeyword.Empty()) {
return false;
}
@@ -1098,7 +1083,7 @@ void Material::SyncShaderSchemaKeywords(bool removeUnknownKeywords) {
std::sort(
m_keywordSet.enabledKeywords.begin(),
m_keywordSet.enabledKeywords.end(),
CompareShaderKeywords);
CompareShaderKeywordTokens);
}
void Material::MarkChanged(bool updateConstantBuffer) {

View File

@@ -376,14 +376,7 @@ std::string TrimCopy(const std::string& text) {
}
Containers::String NormalizeMaterialKeywordToken(const Containers::String& keyword) {
const Containers::String normalized = keyword.Trim();
if (normalized.Empty() ||
normalized == Containers::String("_") ||
normalized == Containers::String("__")) {
return Containers::String();
}
return normalized;
return NormalizeShaderKeywordToken(keyword);
}
bool IsJsonValueTerminator(char ch) {

View File

@@ -7,19 +7,15 @@ namespace {
const char* kLegacyShaderPassName = "Default";
bool IsShaderKeywordPlaceholder(const Containers::String& keyword) {
return keyword == Containers::String("_") ||
keyword == Containers::String("__");
}
bool PassDeclaresKeywordInternal(const ShaderPass& pass, const Containers::String& keyword) {
if (keyword.Empty() || IsShaderKeywordPlaceholder(keyword)) {
const Containers::String normalizedKeyword = NormalizeShaderKeywordToken(keyword);
if (normalizedKeyword.Empty()) {
return false;
}
for (const ShaderKeywordDeclaration& declaration : pass.keywordDeclarations) {
for (const Containers::String& option : declaration.options) {
if (!IsShaderKeywordPlaceholder(option) && option == keyword) {
if (NormalizeShaderKeywordToken(option) == normalizedKeyword) {
return true;
}
}
@@ -28,6 +24,18 @@ bool PassDeclaresKeywordInternal(const ShaderPass& pass, const Containers::Strin
return false;
}
const ShaderStageVariant* SelectMoreSpecificVariant(
const ShaderStageVariant* currentBest,
const ShaderStageVariant& candidate) {
if (currentBest == nullptr ||
candidate.requiredKeywords.enabledKeywords.Size() >
currentBest->requiredKeywords.enabledKeywords.Size()) {
return &candidate;
}
return currentBest;
}
} // namespace
Shader::Shader() = default;
@@ -112,7 +120,9 @@ void Shader::AddPassVariant(
const Containers::String& passName,
const ShaderStageVariant& variant) {
ShaderPass& pass = GetOrCreatePass(passName);
pass.variants.PushBack(variant);
ShaderStageVariant normalizedVariant = variant;
NormalizeShaderKeywordSetInPlace(normalizedVariant.requiredKeywords);
pass.variants.PushBack(normalizedVariant);
}
void Shader::SetPassTag(
@@ -214,28 +224,38 @@ const ShaderResourceBindingDesc* Shader::FindPassResourceBinding(
const ShaderStageVariant* Shader::FindVariant(
const Containers::String& passName,
ShaderType stage,
ShaderBackend backend) const {
ShaderBackend backend,
const ShaderKeywordSet& enabledKeywords) const {
const ShaderPass* pass = FindPass(passName);
if (pass == nullptr) {
return nullptr;
}
ShaderKeywordSet normalizedEnabledKeywords = enabledKeywords;
NormalizeShaderKeywordSetInPlace(normalizedEnabledKeywords);
const ShaderStageVariant* exactBackendVariant = nullptr;
const ShaderStageVariant* genericVariant = nullptr;
for (const ShaderStageVariant& variant : pass->variants) {
if (variant.stage != stage) {
continue;
}
if (variant.backend == backend) {
return &variant;
if (!IsShaderKeywordSubset(variant.requiredKeywords, normalizedEnabledKeywords)) {
continue;
}
if (variant.backend == ShaderBackend::Generic && genericVariant == nullptr) {
genericVariant = &variant;
if (variant.backend == backend) {
exactBackendVariant = SelectMoreSpecificVariant(exactBackendVariant, variant);
continue;
}
if (variant.backend == ShaderBackend::Generic) {
genericVariant = SelectMoreSpecificVariant(genericVariant, variant);
}
}
return genericVariant;
return exactBackendVariant != nullptr ? exactBackendVariant : genericVariant;
}
void Shader::SetRHIResource(class IRHIShader* resource) {

View File

@@ -369,6 +369,35 @@ bool SplitTopLevelArrayElements(const std::string& arrayText, std::vector<std::s
return true;
}
bool TryParseShaderKeywordsArray(
const std::string& arrayText,
ShaderKeywordSet& outKeywordSet) {
std::vector<std::string> keywordElements;
if (!SplitTopLevelArrayElements(arrayText, keywordElements)) {
return false;
}
outKeywordSet = ShaderKeywordSet();
outKeywordSet.enabledKeywords.Reserve(keywordElements.size());
for (const std::string& keywordElement : keywordElements) {
Containers::String keyword;
size_t nextPos = 0;
if (!ParseQuotedString(keywordElement, 0, keyword, &nextPos)) {
return false;
}
if (SkipWhitespace(keywordElement, nextPos) != keywordElement.size()) {
return false;
}
outKeywordSet.enabledKeywords.PushBack(keyword);
}
NormalizeShaderKeywordSetInPlace(outKeywordSet);
return true;
}
bool TryParseShaderType(const Containers::String& value, ShaderType& outType) {
const Containers::String normalized = value.Trim().ToLower();
if (normalized == "vertex" || normalized == "vs") {
@@ -2223,6 +2252,9 @@ size_t CalculateShaderMemorySize(const Shader& shader) {
}
}
for (const ShaderStageVariant& variant : pass.variants) {
for (const Containers::String& keyword : variant.requiredKeywords.enabledKeywords) {
memorySize += keyword.Length();
}
memorySize += variant.entryPoint.Length();
memorySize += variant.profile.Length();
memorySize += variant.sourceCode.Length();
@@ -2458,6 +2490,12 @@ LoadResult LoadShaderManifest(const Containers::String& path, const std::string&
variant.profile = GetDefaultProfile(variant.language, variant.backend, variant.stage);
}
std::string keywordsArray;
if (TryExtractArray(variantObject, "keywords", keywordsArray) &&
!TryParseShaderKeywordsArray(keywordsArray, variant.requiredKeywords)) {
return LoadResult("Shader manifest variant keywords could not be parsed: " + path);
}
shader->AddPassVariant(passName, variant);
}
}
@@ -2480,9 +2518,10 @@ LoadResult LoadShaderArtifact(const Containers::String& path) {
const std::string magic(fileHeader.magic, fileHeader.magic + 7);
const bool isLegacySchema = magic == "XCSHD01" && fileHeader.schemaVersion == 1u;
const bool isSchemaV2 = magic == "XCSHD02" && fileHeader.schemaVersion == 2u;
const bool isCurrentSchema =
magic == "XCSHD02" && fileHeader.schemaVersion == kShaderArtifactSchemaVersion;
if (!isLegacySchema && !isCurrentSchema) {
magic == "XCSHD03" && fileHeader.schemaVersion == kShaderArtifactSchemaVersion;
if (!isLegacySchema && !isSchemaV2 && !isCurrentSchema) {
return LoadResult("Invalid shader artifact header: " + path);
}
@@ -2604,29 +2643,58 @@ LoadResult LoadShaderArtifact(const Containers::String& path) {
for (Core::uint32 variantIndex = 0; variantIndex < variantCount; ++variantIndex) {
ShaderStageVariant variant = {};
ShaderVariantArtifactHeader variantHeader;
if (!ReadShaderArtifactValue(data, offset, variantHeader) ||
!ReadShaderArtifactString(data, offset, variant.entryPoint) ||
Core::uint64 compiledBinarySize = 0;
Core::uint32 keywordCount = 0;
if (isCurrentSchema) {
ShaderVariantArtifactHeader variantHeader = {};
if (!ReadShaderArtifactValue(data, offset, variantHeader)) {
return LoadResult("Failed to read shader artifact variants: " + path);
}
variant.stage = static_cast<ShaderType>(variantHeader.stage);
variant.language = static_cast<ShaderLanguage>(variantHeader.language);
variant.backend = static_cast<ShaderBackend>(variantHeader.backend);
keywordCount = variantHeader.keywordCount;
compiledBinarySize = variantHeader.compiledBinarySize;
} else {
ShaderVariantArtifactHeaderV2 variantHeader = {};
if (!ReadShaderArtifactValue(data, offset, variantHeader)) {
return LoadResult("Failed to read shader artifact variants: " + path);
}
variant.stage = static_cast<ShaderType>(variantHeader.stage);
variant.language = static_cast<ShaderLanguage>(variantHeader.language);
variant.backend = static_cast<ShaderBackend>(variantHeader.backend);
compiledBinarySize = variantHeader.compiledBinarySize;
}
if (!ReadShaderArtifactString(data, offset, variant.entryPoint) ||
!ReadShaderArtifactString(data, offset, variant.profile) ||
!ReadShaderArtifactString(data, offset, variant.sourceCode)) {
return LoadResult("Failed to read shader artifact variants: " + path);
}
variant.stage = static_cast<ShaderType>(variantHeader.stage);
variant.language = static_cast<ShaderLanguage>(variantHeader.language);
variant.backend = static_cast<ShaderBackend>(variantHeader.backend);
for (Core::uint32 keywordIndex = 0; keywordIndex < keywordCount; ++keywordIndex) {
Containers::String keyword;
if (!ReadShaderArtifactString(data, offset, keyword)) {
return LoadResult("Failed to read shader artifact variant keywords: " + path);
}
if (variantHeader.compiledBinarySize > 0) {
if (offset + variantHeader.compiledBinarySize > data.Size()) {
variant.requiredKeywords.enabledKeywords.PushBack(keyword);
}
NormalizeShaderKeywordSetInPlace(variant.requiredKeywords);
if (compiledBinarySize > 0) {
if (offset + compiledBinarySize > data.Size()) {
return LoadResult("Shader artifact variant binary payload is truncated: " + path);
}
variant.compiledBinary.Resize(static_cast<size_t>(variantHeader.compiledBinarySize));
variant.compiledBinary.Resize(static_cast<size_t>(compiledBinarySize));
std::memcpy(
variant.compiledBinary.Data(),
data.Data() + offset,
static_cast<size_t>(variantHeader.compiledBinarySize));
offset += static_cast<size_t>(variantHeader.compiledBinarySize);
static_cast<size_t>(compiledBinarySize));
offset += static_cast<size_t>(compiledBinarySize);
}
shader->AddPassVariant(passName, variant);