fix(rhi): validate opengl compute uav set bindings
This commit is contained in:
@@ -643,21 +643,12 @@ RHIShader* OpenGLDevice::CreateShader(const ShaderCompileDesc& desc) {
|
||||
}
|
||||
|
||||
if (desc.sourceLanguage == ShaderLanguage::GLSL && !desc.source.empty()) {
|
||||
const char* sourceStr = reinterpret_cast<const char*>(desc.source.data());
|
||||
ShaderType shaderType = ShaderType::Vertex;
|
||||
|
||||
const std::string entryPoint = NarrowAscii(desc.entryPoint);
|
||||
std::string profile = NarrowAscii(desc.profile);
|
||||
if (profile.find("vs") != std::string::npos) {
|
||||
shaderType = ShaderType::Vertex;
|
||||
} else if (profile.find("ps") != std::string::npos || profile.find("fs") != std::string::npos) {
|
||||
shaderType = ShaderType::Fragment;
|
||||
} else if (profile.find("gs") != std::string::npos) {
|
||||
shaderType = ShaderType::Geometry;
|
||||
} else if (profile.find("cs") != std::string::npos) {
|
||||
shaderType = ShaderType::Compute;
|
||||
}
|
||||
|
||||
if (shader->Compile(sourceStr, shaderType)) {
|
||||
const char* entryPointPtr = entryPoint.empty() ? nullptr : entryPoint.c_str();
|
||||
const char* profilePtr = profile.empty() ? nullptr : profile.c_str();
|
||||
|
||||
if (shader->Compile(desc.source.data(), desc.source.size(), entryPointPtr, profilePtr)) {
|
||||
return shader;
|
||||
}
|
||||
delete shader;
|
||||
|
||||
@@ -80,6 +80,49 @@ bool ResolveShaderType(const std::string& path, const char* target, ShaderType&
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ResolveShaderTypeFromSource(const std::string& source, ShaderType& type) {
|
||||
if (source.find("layout(local_size_x") != std::string::npos ||
|
||||
source.find("gl_GlobalInvocationID") != std::string::npos ||
|
||||
source.find("gl_LocalInvocationID") != std::string::npos ||
|
||||
source.find("gl_WorkGroupID") != std::string::npos) {
|
||||
type = ShaderType::Compute;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (source.find("EmitVertex") != std::string::npos ||
|
||||
source.find("EndPrimitive") != std::string::npos ||
|
||||
source.find("gl_in[") != std::string::npos) {
|
||||
type = ShaderType::Geometry;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (source.find("layout(vertices") != std::string::npos ||
|
||||
source.find("gl_InvocationID") != std::string::npos) {
|
||||
type = ShaderType::TessControl;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (source.find("gl_TessCoord") != std::string::npos ||
|
||||
source.find("gl_TessLevelOuter") != std::string::npos) {
|
||||
type = ShaderType::TessEvaluation;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (source.find("gl_Position") != std::string::npos) {
|
||||
type = ShaderType::Vertex;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (source.find("gl_FragCoord") != std::string::npos ||
|
||||
source.find("gl_FragColor") != std::string::npos ||
|
||||
source.find("gl_FragData") != std::string::npos) {
|
||||
type = ShaderType::Fragment;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
OpenGLShader::OpenGLShader()
|
||||
@@ -280,6 +323,7 @@ bool OpenGLShader::Compile(const char* source, ShaderType type) {
|
||||
|
||||
glDeleteShader(shader);
|
||||
|
||||
m_type = type;
|
||||
m_uniformsCached = false;
|
||||
|
||||
return true;
|
||||
@@ -316,20 +360,18 @@ bool OpenGLShader::Compile(const void* sourceData, size_t sourceSize, const char
|
||||
if (!sourceData || sourceSize == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ShaderType type = ShaderType::Fragment;
|
||||
if (target) {
|
||||
if (strstr(target, "vs_")) {
|
||||
type = ShaderType::Vertex;
|
||||
} else if (strstr(target, "ps_")) {
|
||||
type = ShaderType::Fragment;
|
||||
} else if (strstr(target, "gs_")) {
|
||||
type = ShaderType::Geometry;
|
||||
} else if (strstr(target, "cs_")) {
|
||||
type = ShaderType::Compute;
|
||||
}
|
||||
|
||||
(void)entryPoint;
|
||||
|
||||
const std::string source(static_cast<const char*>(sourceData), sourceSize);
|
||||
|
||||
ShaderType type = ShaderType::Vertex;
|
||||
if (!ResolveShaderType(std::string(), target, type) &&
|
||||
!ResolveShaderTypeFromSource(source, type)) {
|
||||
return false;
|
||||
}
|
||||
return Compile(static_cast<const char*>(sourceData), type);
|
||||
|
||||
return Compile(source.c_str(), type);
|
||||
}
|
||||
|
||||
void OpenGLShader::Shutdown() {
|
||||
|
||||
Reference in New Issue
Block a user