diff --git a/engine/src/RHI/OpenGL/OpenGLDevice.cpp b/engine/src/RHI/OpenGL/OpenGLDevice.cpp index 7db3ad2c..a16fc9f4 100644 --- a/engine/src/RHI/OpenGL/OpenGLDevice.cpp +++ b/engine/src/RHI/OpenGL/OpenGLDevice.cpp @@ -643,21 +643,12 @@ RHIShader* OpenGLDevice::CreateShader(const ShaderCompileDesc& desc) { } if (desc.sourceLanguage == ShaderLanguage::GLSL && !desc.source.empty()) { - const char* sourceStr = reinterpret_cast(desc.source.data()); - ShaderType shaderType = ShaderType::Vertex; - + const std::string entryPoint = NarrowAscii(desc.entryPoint); std::string profile = NarrowAscii(desc.profile); - if (profile.find("vs") != std::string::npos) { - shaderType = ShaderType::Vertex; - } else if (profile.find("ps") != std::string::npos || profile.find("fs") != std::string::npos) { - shaderType = ShaderType::Fragment; - } else if (profile.find("gs") != std::string::npos) { - shaderType = ShaderType::Geometry; - } else if (profile.find("cs") != std::string::npos) { - shaderType = ShaderType::Compute; - } - - if (shader->Compile(sourceStr, shaderType)) { + const char* entryPointPtr = entryPoint.empty() ? nullptr : entryPoint.c_str(); + const char* profilePtr = profile.empty() ? nullptr : profile.c_str(); + + if (shader->Compile(desc.source.data(), desc.source.size(), entryPointPtr, profilePtr)) { return shader; } delete shader; diff --git a/engine/src/RHI/OpenGL/OpenGLShader.cpp b/engine/src/RHI/OpenGL/OpenGLShader.cpp index 48d93c33..585345b5 100644 --- a/engine/src/RHI/OpenGL/OpenGLShader.cpp +++ b/engine/src/RHI/OpenGL/OpenGLShader.cpp @@ -80,6 +80,49 @@ bool ResolveShaderType(const std::string& path, const char* target, ShaderType& return false; } +bool ResolveShaderTypeFromSource(const std::string& source, ShaderType& type) { + if (source.find("layout(local_size_x") != std::string::npos || + source.find("gl_GlobalInvocationID") != std::string::npos || + source.find("gl_LocalInvocationID") != std::string::npos || + source.find("gl_WorkGroupID") != std::string::npos) { + type = ShaderType::Compute; + return true; + } + + if (source.find("EmitVertex") != std::string::npos || + source.find("EndPrimitive") != std::string::npos || + source.find("gl_in[") != std::string::npos) { + type = ShaderType::Geometry; + return true; + } + + if (source.find("layout(vertices") != std::string::npos || + source.find("gl_InvocationID") != std::string::npos) { + type = ShaderType::TessControl; + return true; + } + + if (source.find("gl_TessCoord") != std::string::npos || + source.find("gl_TessLevelOuter") != std::string::npos) { + type = ShaderType::TessEvaluation; + return true; + } + + if (source.find("gl_Position") != std::string::npos) { + type = ShaderType::Vertex; + return true; + } + + if (source.find("gl_FragCoord") != std::string::npos || + source.find("gl_FragColor") != std::string::npos || + source.find("gl_FragData") != std::string::npos) { + type = ShaderType::Fragment; + return true; + } + + return false; +} + } // namespace OpenGLShader::OpenGLShader() @@ -280,6 +323,7 @@ bool OpenGLShader::Compile(const char* source, ShaderType type) { glDeleteShader(shader); + m_type = type; m_uniformsCached = false; return true; @@ -316,20 +360,18 @@ bool OpenGLShader::Compile(const void* sourceData, size_t sourceSize, const char if (!sourceData || sourceSize == 0) { return false; } - - ShaderType type = ShaderType::Fragment; - if (target) { - if (strstr(target, "vs_")) { - type = ShaderType::Vertex; - } else if (strstr(target, "ps_")) { - type = ShaderType::Fragment; - } else if (strstr(target, "gs_")) { - type = ShaderType::Geometry; - } else if (strstr(target, "cs_")) { - type = ShaderType::Compute; - } + + (void)entryPoint; + + const std::string source(static_cast(sourceData), sourceSize); + + ShaderType type = ShaderType::Vertex; + if (!ResolveShaderType(std::string(), target, type) && + !ResolveShaderTypeFromSource(source, type)) { + return false; } - return Compile(static_cast(sourceData), type); + + return Compile(source.c_str(), type); } void OpenGLShader::Shutdown() { diff --git a/tests/RHI/unit/test_command_list.cpp b/tests/RHI/unit/test_command_list.cpp index 123e948e..de035ad0 100644 --- a/tests/RHI/unit/test_command_list.cpp +++ b/tests/RHI/unit/test_command_list.cpp @@ -854,3 +854,147 @@ void main() { pipelineState->Shutdown(); delete pipelineState; } + +TEST_P(RHITestFixture, CommandList_SetComputeDescriptorSets_UsesSetAwareImageBindingsOnOpenGL) { + if (GetBackendType() != RHIType::OpenGL) { + GTEST_SKIP() << "OpenGL-specific compute descriptor binding"; + } + + auto* openGLDevice = static_cast(GetDevice()); + ASSERT_NE(openGLDevice, nullptr); + ASSERT_TRUE(openGLDevice->MakeContextCurrent()); + + const uint8_t initialPixel[4] = { 0, 0, 0, 0 }; + + TextureDesc textureDesc = {}; + textureDesc.width = 1; + textureDesc.height = 1; + textureDesc.depth = 1; + textureDesc.mipLevels = 1; + textureDesc.arraySize = 1; + textureDesc.format = static_cast(Format::R8G8B8A8_UNorm); + textureDesc.textureType = static_cast(TextureType::Texture2D); + textureDesc.sampleCount = 1; + + RHITexture* texture = GetDevice()->CreateTexture(textureDesc, initialPixel, sizeof(initialPixel), 4); + ASSERT_NE(texture, nullptr); + + ResourceViewDesc uavDesc = {}; + uavDesc.format = textureDesc.format; + uavDesc.dimension = ResourceViewDimension::Texture2D; + RHIResourceView* uav = GetDevice()->CreateUnorderedAccessView(texture, uavDesc); + ASSERT_NE(uav, nullptr); + + DescriptorPoolDesc poolDesc = {}; + poolDesc.type = DescriptorHeapType::CBV_SRV_UAV; + poolDesc.descriptorCount = 1; + poolDesc.shaderVisible = true; + RHIDescriptorPool* pool = GetDevice()->CreateDescriptorPool(poolDesc); + ASSERT_NE(pool, nullptr); + + DescriptorSetLayoutBinding reservedBinding = {}; + reservedBinding.binding = 0; + reservedBinding.type = static_cast(DescriptorType::UAV); + reservedBinding.count = 1; + + DescriptorSetLayoutBinding actualBinding = {}; + actualBinding.binding = 0; + actualBinding.type = static_cast(DescriptorType::UAV); + actualBinding.count = 1; + + DescriptorSetLayoutDesc reservedLayout = {}; + reservedLayout.bindings = &reservedBinding; + reservedLayout.bindingCount = 1; + + DescriptorSetLayoutDesc actualLayout = {}; + actualLayout.bindings = &actualBinding; + actualLayout.bindingCount = 1; + + DescriptorSetLayoutDesc setLayouts[2] = {}; + setLayouts[0] = reservedLayout; + setLayouts[1] = actualLayout; + + RHIPipelineLayoutDesc pipelineLayoutDesc = {}; + pipelineLayoutDesc.setLayouts = setLayouts; + pipelineLayoutDesc.setLayoutCount = 2; + RHIPipelineLayout* pipelineLayout = GetDevice()->CreatePipelineLayout(pipelineLayoutDesc); + ASSERT_NE(pipelineLayout, nullptr); + + RHIDescriptorSet* descriptorSet = pool->AllocateSet(actualLayout); + ASSERT_NE(descriptorSet, nullptr); + descriptorSet->Update(0, uav); + + GraphicsPipelineDesc pipelineDesc = {}; + RHIPipelineState* pipelineState = GetDevice()->CreatePipelineState(pipelineDesc); + ASSERT_NE(pipelineState, nullptr); + + ShaderCompileDesc shaderDesc = {}; + shaderDesc.sourceLanguage = ShaderLanguage::GLSL; + static const char* computeSource = R"( + #version 430 + layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + layout(binding = 1, rgba8) uniform writeonly image2D uImage; + void main() { + imageStore(uImage, ivec2(0, 0), vec4(1.0, 0.0, 0.0, 1.0)); + } + )"; + shaderDesc.source.assign(computeSource, computeSource + strlen(computeSource)); + shaderDesc.profile = L"cs_4_30"; + RHIShader* computeShader = GetDevice()->CreateShader(shaderDesc); + ASSERT_NE(computeShader, nullptr); + pipelineState->SetComputeShader(computeShader); + + CommandListDesc cmdDesc = {}; + cmdDesc.commandListType = static_cast(CommandQueueType::Direct); + RHICommandList* cmdList = GetDevice()->CreateCommandList(cmdDesc); + ASSERT_NE(cmdList, nullptr); + + glBindImageTexture(0, 0, 0, GL_FALSE, 0, GL_READ_WRITE, GL_RGBA8); + glBindImageTexture(1, 0, 0, GL_FALSE, 0, GL_READ_WRITE, GL_RGBA8); + + cmdList->Reset(); + cmdList->SetPipelineState(pipelineState); + RHIDescriptorSet* descriptorSets[] = { descriptorSet }; + cmdList->SetComputeDescriptorSets(1, 1, descriptorSets, pipelineLayout); + + GLint boundImageAtZero = -1; + GLint boundImageAtOne = -1; + glGetIntegeri_v(GL_IMAGE_BINDING_NAME, 0, &boundImageAtZero); + glGetIntegeri_v(GL_IMAGE_BINDING_NAME, 1, &boundImageAtOne); + EXPECT_EQ(boundImageAtZero, 0); + EXPECT_EQ( + static_cast(boundImageAtOne), + static_cast(reinterpret_cast(texture->GetNativeHandle()))); + + cmdList->Dispatch(1, 1, 1); + glMemoryBarrier(GL_ALL_BARRIER_BITS); + + uint8_t pixel[4] = {}; + glBindTexture(GL_TEXTURE_2D, static_cast(reinterpret_cast(texture->GetNativeHandle()))); + glGetTexImage(GL_TEXTURE_2D, 0, GL_RGBA, GL_UNSIGNED_BYTE, pixel); + glBindTexture(GL_TEXTURE_2D, 0); + + EXPECT_EQ(pixel[0], 255u); + EXPECT_EQ(pixel[1], 0u); + EXPECT_EQ(pixel[2], 0u); + EXPECT_EQ(pixel[3], 255u); + + cmdList->Close(); + + cmdList->Shutdown(); + delete cmdList; + computeShader->Shutdown(); + delete computeShader; + pipelineState->Shutdown(); + delete pipelineState; + descriptorSet->Shutdown(); + delete descriptorSet; + pipelineLayout->Shutdown(); + delete pipelineLayout; + pool->Shutdown(); + delete pool; + uav->Shutdown(); + delete uav; + texture->Shutdown(); + delete texture; +} diff --git a/tests/RHI/unit/test_pipeline_layout.cpp b/tests/RHI/unit/test_pipeline_layout.cpp index be7872d3..2f358689 100644 --- a/tests/RHI/unit/test_pipeline_layout.cpp +++ b/tests/RHI/unit/test_pipeline_layout.cpp @@ -368,7 +368,7 @@ TEST_P(RHITestFixture, PipelineLayout_OpenGLSeparatesOverlappingBindingsAcrossSe GTEST_SKIP() << "OpenGL-specific binding point verification"; } - DescriptorSetLayoutBinding set0Bindings[3] = {}; + DescriptorSetLayoutBinding set0Bindings[4] = {}; set0Bindings[0].binding = 0; set0Bindings[0].type = static_cast(DescriptorType::CBV); set0Bindings[0].count = 1; @@ -376,10 +376,13 @@ TEST_P(RHITestFixture, PipelineLayout_OpenGLSeparatesOverlappingBindingsAcrossSe set0Bindings[1].type = static_cast(DescriptorType::SRV); set0Bindings[1].count = 1; set0Bindings[2].binding = 0; - set0Bindings[2].type = static_cast(DescriptorType::Sampler); + set0Bindings[2].type = static_cast(DescriptorType::UAV); set0Bindings[2].count = 1; + set0Bindings[3].binding = 0; + set0Bindings[3].type = static_cast(DescriptorType::Sampler); + set0Bindings[3].count = 1; - DescriptorSetLayoutBinding set1Bindings[3] = {}; + DescriptorSetLayoutBinding set1Bindings[4] = {}; set1Bindings[0].binding = 0; set1Bindings[0].type = static_cast(DescriptorType::CBV); set1Bindings[0].count = 1; @@ -387,14 +390,17 @@ TEST_P(RHITestFixture, PipelineLayout_OpenGLSeparatesOverlappingBindingsAcrossSe set1Bindings[1].type = static_cast(DescriptorType::SRV); set1Bindings[1].count = 1; set1Bindings[2].binding = 0; - set1Bindings[2].type = static_cast(DescriptorType::Sampler); + set1Bindings[2].type = static_cast(DescriptorType::UAV); set1Bindings[2].count = 1; + set1Bindings[3].binding = 0; + set1Bindings[3].type = static_cast(DescriptorType::Sampler); + set1Bindings[3].count = 1; DescriptorSetLayoutDesc setLayouts[2] = {}; setLayouts[0].bindings = set0Bindings; - setLayouts[0].bindingCount = 3; + setLayouts[0].bindingCount = 4; setLayouts[1].bindings = set1Bindings; - setLayouts[1].bindingCount = 3; + setLayouts[1].bindingCount = 4; RHIPipelineLayoutDesc desc = {}; desc.setLayouts = setLayouts; @@ -412,6 +418,8 @@ TEST_P(RHITestFixture, PipelineLayout_OpenGLSeparatesOverlappingBindingsAcrossSe EXPECT_EQ(openGLLayout->GetConstantBufferBindingPoint(1, 0), 1u); EXPECT_EQ(openGLLayout->GetShaderResourceBindingPoint(0, 0), 0u); EXPECT_EQ(openGLLayout->GetShaderResourceBindingPoint(1, 0), 1u); + EXPECT_EQ(openGLLayout->GetUnorderedAccessBindingPoint(0, 0), 0u); + EXPECT_EQ(openGLLayout->GetUnorderedAccessBindingPoint(1, 0), 1u); EXPECT_EQ(openGLLayout->GetSamplerBindingPoint(0, 0), 0u); EXPECT_EQ(openGLLayout->GetSamplerBindingPoint(1, 0), 1u);