Files
XCEngine/engine/src/Rendering/Detail/ShaderVariantUtils.h

439 lines
14 KiB
C
Raw Normal View History

#pragma once
#include <XCEngine/Core/Containers/String.h>
#include <XCEngine/RHI/RHIEnums.h>
#include <XCEngine/RHI/RHIPipelineState.h>
#include <XCEngine/Rendering/Builtin/BuiltinPassLayoutUtils.h>
#include <XCEngine/Resources/BuiltinResources.h>
#include <XCEngine/Resources/Shader/Shader.h>
#include <regex>
#include <string>
namespace XCEngine {
namespace Rendering {
namespace Detail {
inline Resources::ShaderBackend ToShaderBackend(RHI::RHIType backendType) {
switch (backendType) {
case RHI::RHIType::D3D12:
return Resources::ShaderBackend::D3D12;
case RHI::RHIType::Vulkan:
return Resources::ShaderBackend::Vulkan;
case RHI::RHIType::OpenGL:
default:
return Resources::ShaderBackend::OpenGL;
}
}
inline RHI::ShaderLanguage ToRHIShaderLanguage(Resources::ShaderLanguage language) {
switch (language) {
case Resources::ShaderLanguage::HLSL:
return RHI::ShaderLanguage::HLSL;
case Resources::ShaderLanguage::SPIRV:
return RHI::ShaderLanguage::SPIRV;
case Resources::ShaderLanguage::GLSL:
default:
return RHI::ShaderLanguage::GLSL;
}
}
inline std::wstring ToWideAscii(const Containers::String& value) {
std::wstring wide;
wide.reserve(value.Length());
for (size_t index = 0; index < value.Length(); ++index) {
wide.push_back(static_cast<wchar_t>(value[index]));
}
return wide;
}
inline std::string ToStdString(const Containers::String& value) {
return std::string(value.CStr(), value.Length());
}
inline std::string EscapeRegexLiteral(const Containers::String& value) {
std::string escaped;
escaped.reserve(value.Length() * 2u);
for (size_t index = 0; index < value.Length(); ++index) {
const char ch = value[index];
switch (ch) {
case '\\':
case '^':
case '$':
case '.':
case '|':
case '?':
case '*':
case '+':
case '(':
case ')':
case '[':
case ']':
case '{':
case '}':
escaped.push_back('\\');
break;
default:
break;
}
escaped.push_back(ch);
}
return escaped;
}
inline bool TryCollectShaderPassResourceBindings(
const Resources::ShaderPass& pass,
Containers::Array<Resources::ShaderResourceBindingDesc>& outBindings) {
outBindings.Clear();
if (pass.resources.Empty()) {
return false;
}
outBindings.Reserve(pass.resources.Size());
for (const Resources::ShaderResourceBindingDesc& binding : pass.resources) {
outBindings.PushBack(binding);
}
return true;
}
inline bool TryRewriteHlslRegisterBindingWithName(
std::string& sourceText,
const Containers::String& declarationName,
const char* registerPrefix,
Core::uint32 bindingIndex,
Core::uint32 setIndex,
bool includeRegisterSpace,
Resources::ShaderResourceType resourceType) {
if (declarationName.Empty()) {
return false;
}
const std::string registerClause =
includeRegisterSpace
? std::string("register(") + registerPrefix +
std::to_string(bindingIndex) +
", space" +
std::to_string(setIndex) +
")"
: std::string("register(") + registerPrefix +
std::to_string(bindingIndex) +
")";
const std::string escapedName = EscapeRegexLiteral(declarationName);
if (resourceType == Resources::ShaderResourceType::ConstantBuffer) {
const std::regex pattern(
"(cbuffer\\s+" + escapedName + "\\s*)(:\\s*register\\s*\\([^\\)]*\\))?(\\s*\\{)",
std::regex::ECMAScript);
const std::string rewritten =
std::regex_replace(sourceText, pattern, "$1: " + registerClause + "$3");
if (rewritten != sourceText) {
sourceText = rewritten;
return true;
}
return false;
}
if (resourceType == Resources::ShaderResourceType::StructuredBuffer ||
resourceType == Resources::ShaderResourceType::RWStructuredBuffer) {
const std::regex pattern(
"((?:globallycoherent\\s+)?(?:StructuredBuffer|RWStructuredBuffer)\\s*<[^;\\r\\n>]+>\\s+" +
escapedName + "\\s*)(:\\s*register\\s*\\([^\\)]*\\))?(\\s*;)",
std::regex::ECMAScript);
const std::string rewritten =
std::regex_replace(sourceText, pattern, "$1: " + registerClause + "$3");
if (rewritten != sourceText) {
sourceText = rewritten;
return true;
}
return false;
}
if (resourceType == Resources::ShaderResourceType::RawBuffer ||
resourceType == Resources::ShaderResourceType::RWRawBuffer) {
const std::regex pattern(
"((?:globallycoherent\\s+)?(?:ByteAddressBuffer|RWByteAddressBuffer)\\s+" + escapedName +
"\\s*)(:\\s*register\\s*\\([^\\)]*\\))?(\\s*;)",
std::regex::ECMAScript);
const std::string rewritten =
std::regex_replace(sourceText, pattern, "$1: " + registerClause + "$3");
if (rewritten != sourceText) {
sourceText = rewritten;
return true;
}
return false;
}
const std::regex pattern(
"((?:Texture2D|TextureCube|SamplerState|SamplerComparisonState)\\s+" + escapedName +
"\\s*)(:\\s*register\\s*\\([^\\)]*\\))?(\\s*;)",
std::regex::ECMAScript);
const std::string rewritten =
std::regex_replace(sourceText, pattern, "$1: " + registerClause + "$3");
if (rewritten != sourceText) {
sourceText = rewritten;
return true;
}
return false;
}
inline const char* TryGetHlslRegisterPrefix(Resources::ShaderResourceType type) {
switch (type) {
case Resources::ShaderResourceType::ConstantBuffer:
return "b";
case Resources::ShaderResourceType::Texture2D:
case Resources::ShaderResourceType::TextureCube:
case Resources::ShaderResourceType::StructuredBuffer:
case Resources::ShaderResourceType::RawBuffer:
return "t";
case Resources::ShaderResourceType::Sampler:
return "s";
case Resources::ShaderResourceType::RWStructuredBuffer:
case Resources::ShaderResourceType::RWRawBuffer:
return "u";
default:
return nullptr;
}
}
inline bool TryRewriteHlslRegisterBinding(
std::string& sourceText,
const Resources::ShaderResourceBindingDesc& binding,
bool includeRegisterSpace) {
const char* registerPrefix = TryGetHlslRegisterPrefix(binding.type);
if (registerPrefix == nullptr) {
return false;
}
if (TryRewriteHlslRegisterBindingWithName(
sourceText,
binding.name,
registerPrefix,
binding.binding,
binding.set,
includeRegisterSpace,
binding.type)) {
return true;
}
return false;
}
inline bool TryBuildRuntimeShaderBindings(
const Resources::ShaderPass& pass,
Resources::ShaderBackend backend,
Containers::Array<Resources::ShaderResourceBindingDesc>& outBindings,
bool& outIncludeRegisterSpace) {
outBindings.Clear();
outIncludeRegisterSpace = false;
if (!TryCollectShaderPassResourceBindings(pass, outBindings)) {
return false;
}
if (backend == Resources::ShaderBackend::Vulkan) {
outIncludeRegisterSpace = true;
return true;
}
if (backend != Resources::ShaderBackend::D3D12 &&
backend != Resources::ShaderBackend::OpenGL) {
outBindings.Clear();
return false;
}
Core::uint32 nextConstantBufferRegister = 0;
Core::uint32 nextTextureRegister = 0;
Core::uint32 nextSamplerRegister = 0;
Core::uint32 nextUnorderedAccessRegister = 0;
for (Resources::ShaderResourceBindingDesc& binding : outBindings) {
binding.set = 0;
switch (binding.type) {
case Resources::ShaderResourceType::ConstantBuffer:
binding.binding = nextConstantBufferRegister++;
break;
case Resources::ShaderResourceType::Texture2D:
case Resources::ShaderResourceType::TextureCube:
case Resources::ShaderResourceType::StructuredBuffer:
case Resources::ShaderResourceType::RawBuffer:
binding.binding = nextTextureRegister++;
break;
case Resources::ShaderResourceType::Sampler:
binding.binding = nextSamplerRegister++;
break;
case Resources::ShaderResourceType::RWStructuredBuffer:
case Resources::ShaderResourceType::RWRawBuffer:
default:
binding.binding = nextUnorderedAccessRegister++;
break;
}
}
return true;
}
inline std::string BuildRuntimeShaderSource(
const Resources::ShaderPass& pass,
Resources::ShaderBackend backend,
const Resources::ShaderStageVariant& variant) {
std::string sourceText = ToStdString(variant.sourceCode);
if (variant.language != Resources::ShaderLanguage::HLSL ||
backend == Resources::ShaderBackend::Generic) {
return sourceText;
}
Containers::Array<Resources::ShaderResourceBindingDesc> bindings;
bool includeRegisterSpace = false;
if (!TryBuildRuntimeShaderBindings(pass, backend, bindings, includeRegisterSpace)) {
return sourceText;
}
for (const Resources::ShaderResourceBindingDesc& binding : bindings) {
TryRewriteHlslRegisterBinding(sourceText, binding, includeRegisterSpace);
}
return sourceText;
}
inline void AddShaderCompileMacro(
RHI::ShaderCompileDesc& compileDesc,
const wchar_t* name,
const wchar_t* definition = L"1") {
if (name == nullptr || *name == L'\0') {
return;
}
for (const RHI::ShaderCompileMacro& existingMacro : compileDesc.macros) {
if (existingMacro.name == name) {
return;
}
}
RHI::ShaderCompileMacro macro = {};
macro.name = name;
macro.definition = definition != nullptr ? definition : L"";
compileDesc.macros.push_back(std::move(macro));
}
inline void InjectShaderBackendMacros(
Resources::ShaderBackend backend,
RHI::ShaderCompileDesc& compileDesc) {
switch (backend) {
case Resources::ShaderBackend::OpenGL:
AddShaderCompileMacro(compileDesc, L"SHADER_API_GLCORE");
AddShaderCompileMacro(compileDesc, L"UNITY_UV_STARTS_AT_TOP", L"0");
AddShaderCompileMacro(compileDesc, L"UNITY_NEAR_CLIP_VALUE", L"-1");
break;
case Resources::ShaderBackend::Vulkan:
AddShaderCompileMacro(compileDesc, L"SHADER_API_VULKAN");
AddShaderCompileMacro(compileDesc, L"UNITY_UV_STARTS_AT_TOP", L"1");
AddShaderCompileMacro(compileDesc, L"UNITY_NEAR_CLIP_VALUE", L"0");
break;
case Resources::ShaderBackend::D3D12:
AddShaderCompileMacro(compileDesc, L"SHADER_API_D3D12");
AddShaderCompileMacro(compileDesc, L"UNITY_UV_STARTS_AT_TOP", L"1");
AddShaderCompileMacro(compileDesc, L"UNITY_NEAR_CLIP_VALUE", L"0");
break;
case Resources::ShaderBackend::Generic:
default:
break;
}
}
inline void ApplyShaderStageVariant(
const Resources::ShaderStageVariant& variant,
RHI::ShaderCompileDesc& compileDesc) {
compileDesc.source.assign(
variant.sourceCode.CStr(),
variant.sourceCode.CStr() + variant.sourceCode.Length());
compileDesc.sourceLanguage = ToRHIShaderLanguage(variant.language);
compileDesc.entryPoint = ToWideAscii(variant.entryPoint);
compileDesc.profile = ToWideAscii(variant.profile);
}
inline std::wstring ResolveRuntimeShaderSourcePath(const Containers::String& shaderPath) {
Containers::String resolvedPath = shaderPath;
if (resolvedPath.Empty()) {
return std::wstring();
}
if (Resources::IsBuiltinShaderPath(resolvedPath)) {
Containers::String builtinAssetPath;
if (!Resources::TryResolveBuiltinShaderAssetPath(resolvedPath, builtinAssetPath)) {
return std::wstring();
}
resolvedPath = builtinAssetPath;
}
return ToWideAscii(resolvedPath);
}
inline void ApplyShaderStageVariant(
const Containers::String& shaderPath,
const Resources::ShaderPass& pass,
Resources::ShaderBackend backend,
const Resources::ShaderStageVariant& variant,
RHI::ShaderCompileDesc& compileDesc) {
const std::string sourceText = BuildRuntimeShaderSource(pass, backend, variant);
compileDesc.source.assign(sourceText.begin(), sourceText.end());
compileDesc.fileName = ResolveRuntimeShaderSourcePath(shaderPath);
compileDesc.sourceLanguage = ToRHIShaderLanguage(variant.language);
compileDesc.entryPoint = ToWideAscii(variant.entryPoint);
compileDesc.profile = ToWideAscii(variant.profile);
InjectShaderBackendMacros(backend, compileDesc);
}
inline void ApplyShaderStageVariant(
const Resources::ShaderPass& pass,
Resources::ShaderBackend backend,
const Resources::ShaderStageVariant& variant,
RHI::ShaderCompileDesc& compileDesc) {
ApplyShaderStageVariant(Containers::String(), pass, backend, variant, compileDesc);
}
inline Containers::String BuildShaderKeywordSignature(
const Resources::ShaderKeywordSet& keywordSet) {
Resources::ShaderKeywordSet normalizedKeywords = keywordSet;
Resources::NormalizeShaderKeywordSetInPlace(normalizedKeywords);
Containers::String signature;
for (size_t keywordIndex = 0; keywordIndex < normalizedKeywords.enabledKeywords.Size(); ++keywordIndex) {
if (keywordIndex > 0) {
signature += ";";
}
signature += normalizedKeywords.enabledKeywords[keywordIndex];
}
return signature;
}
inline bool ShaderPassHasGraphicsVariants(
const Resources::Shader& shader,
const Containers::String& passName,
Resources::ShaderBackend backend,
const Resources::ShaderKeywordSet& enabledKeywords = Resources::ShaderKeywordSet()) {
return shader.FindVariant(
passName,
Resources::ShaderType::Vertex,
backend,
enabledKeywords) != nullptr &&
shader.FindVariant(
passName,
Resources::ShaderType::Fragment,
backend,
enabledKeywords) != nullptr;
}
} // namespace Detail
} // namespace Rendering
} // namespace XCEngine