fix(rhi): validate opengl compute uav set bindings

This commit is contained in:
2026-03-26 15:42:44 +08:00
parent 491fef940d
commit 18fa150843
4 changed files with 218 additions and 33 deletions

View File

@@ -643,21 +643,12 @@ RHIShader* OpenGLDevice::CreateShader(const ShaderCompileDesc& desc) {
}
if (desc.sourceLanguage == ShaderLanguage::GLSL && !desc.source.empty()) {
const char* sourceStr = reinterpret_cast<const char*>(desc.source.data());
ShaderType shaderType = ShaderType::Vertex;
const std::string entryPoint = NarrowAscii(desc.entryPoint);
std::string profile = NarrowAscii(desc.profile);
if (profile.find("vs") != std::string::npos) {
shaderType = ShaderType::Vertex;
} else if (profile.find("ps") != std::string::npos || profile.find("fs") != std::string::npos) {
shaderType = ShaderType::Fragment;
} else if (profile.find("gs") != std::string::npos) {
shaderType = ShaderType::Geometry;
} else if (profile.find("cs") != std::string::npos) {
shaderType = ShaderType::Compute;
}
if (shader->Compile(sourceStr, shaderType)) {
const char* entryPointPtr = entryPoint.empty() ? nullptr : entryPoint.c_str();
const char* profilePtr = profile.empty() ? nullptr : profile.c_str();
if (shader->Compile(desc.source.data(), desc.source.size(), entryPointPtr, profilePtr)) {
return shader;
}
delete shader;

View File

@@ -80,6 +80,49 @@ bool ResolveShaderType(const std::string& path, const char* target, ShaderType&
return false;
}
bool ResolveShaderTypeFromSource(const std::string& source, ShaderType& type) {
if (source.find("layout(local_size_x") != std::string::npos ||
source.find("gl_GlobalInvocationID") != std::string::npos ||
source.find("gl_LocalInvocationID") != std::string::npos ||
source.find("gl_WorkGroupID") != std::string::npos) {
type = ShaderType::Compute;
return true;
}
if (source.find("EmitVertex") != std::string::npos ||
source.find("EndPrimitive") != std::string::npos ||
source.find("gl_in[") != std::string::npos) {
type = ShaderType::Geometry;
return true;
}
if (source.find("layout(vertices") != std::string::npos ||
source.find("gl_InvocationID") != std::string::npos) {
type = ShaderType::TessControl;
return true;
}
if (source.find("gl_TessCoord") != std::string::npos ||
source.find("gl_TessLevelOuter") != std::string::npos) {
type = ShaderType::TessEvaluation;
return true;
}
if (source.find("gl_Position") != std::string::npos) {
type = ShaderType::Vertex;
return true;
}
if (source.find("gl_FragCoord") != std::string::npos ||
source.find("gl_FragColor") != std::string::npos ||
source.find("gl_FragData") != std::string::npos) {
type = ShaderType::Fragment;
return true;
}
return false;
}
} // namespace
OpenGLShader::OpenGLShader()
@@ -280,6 +323,7 @@ bool OpenGLShader::Compile(const char* source, ShaderType type) {
glDeleteShader(shader);
m_type = type;
m_uniformsCached = false;
return true;
@@ -316,20 +360,18 @@ bool OpenGLShader::Compile(const void* sourceData, size_t sourceSize, const char
if (!sourceData || sourceSize == 0) {
return false;
}
ShaderType type = ShaderType::Fragment;
if (target) {
if (strstr(target, "vs_")) {
type = ShaderType::Vertex;
} else if (strstr(target, "ps_")) {
type = ShaderType::Fragment;
} else if (strstr(target, "gs_")) {
type = ShaderType::Geometry;
} else if (strstr(target, "cs_")) {
type = ShaderType::Compute;
}
(void)entryPoint;
const std::string source(static_cast<const char*>(sourceData), sourceSize);
ShaderType type = ShaderType::Vertex;
if (!ResolveShaderType(std::string(), target, type) &&
!ResolveShaderTypeFromSource(source, type)) {
return false;
}
return Compile(static_cast<const char*>(sourceData), type);
return Compile(source.c_str(), type);
}
void OpenGLShader::Shutdown() {