diff --git a/engine/src/Rendering/Graph/RenderGraphExecutor.cpp b/engine/src/Rendering/Graph/RenderGraphExecutor.cpp index 9e09159f..26729460 100644 --- a/engine/src/Rendering/Graph/RenderGraphExecutor.cpp +++ b/engine/src/Rendering/Graph/RenderGraphExecutor.cpp @@ -9,9 +9,11 @@ namespace Rendering { namespace { bool IsGraphManagedTransientState(RHI::ResourceStates state) { - return state == RHI::ResourceStates::RenderTarget || + return state == RHI::ResourceStates::Common || + state == RHI::ResourceStates::RenderTarget || state == RHI::ResourceStates::PixelShaderResource || - state == RHI::ResourceStates::GenericRead; + state == RHI::ResourceStates::GenericRead || + state == RHI::ResourceStates::Present; } RenderGraphTextureViewType ResolveBarrierViewType(RHI::ResourceStates state) { @@ -42,6 +44,10 @@ public: for (size_t textureIndex = 0u; textureIndex < m_graph.m_textures.size(); ++textureIndex) { const CompiledRenderGraph::CompiledTexture& texture = m_graph.m_textures[textureIndex]; const RenderGraphTextureLifetime& lifetime = m_graph.m_textureLifetimes[textureIndex]; + if (texture.kind == RenderGraphTextureKind::Imported) { + m_textureStates[textureIndex] = texture.importedOptions.initialState; + } + if (texture.kind != RenderGraphTextureKind::Transient || !lifetime.used) { continue; } @@ -110,6 +116,32 @@ public: m_graph.m_textures[handle.index].kind == RenderGraphTextureKind::Transient; } + bool TransitionGraphOwnedImportsToFinalStates( + const RenderContext& renderContext, + Containers::String* outErrorMessage) { + for (size_t textureIndex = 0u; textureIndex < m_graph.m_textures.size(); ++textureIndex) { + const CompiledRenderGraph::CompiledTexture& texture = m_graph.m_textures[textureIndex]; + if (texture.kind != RenderGraphTextureKind::Imported || + !texture.importedOptions.graphOwnsTransitions || + texture.importedView == nullptr || + textureIndex >= m_graph.m_textureLifetimes.size() || + !m_graph.m_textureLifetimes[textureIndex].used || + !IsGraphManagedTransientState(texture.importedOptions.finalState)) { + continue; + } + + if (!TransitionTexture( + { static_cast(textureIndex) }, + texture.importedOptions.finalState, + renderContext, + outErrorMessage)) { + return false; + } + } + + return true; + } + bool TransitionPassResources( const CompiledRenderGraph::CompiledPass& pass, const RenderContext& renderContext, @@ -117,41 +149,18 @@ public: for (const CompiledRenderGraph::CompiledTextureAccess& access : pass.accesses) { if (!access.texture.IsValid() || access.texture.index >= m_graph.m_textures.size() || - !IsTransientTexture(access.texture) || + !ShouldGraphManageTransitions(access.texture) || !IsGraphManagedTransientState(access.requiredState)) { continue; } - if (renderContext.commandList == nullptr) { - if (outErrorMessage != nullptr) { - *outErrorMessage = - Containers::String("RenderGraph cannot transition transient texture without a valid command list: ") + - m_graph.m_textures[access.texture.index].name; - } + if (!TransitionTexture( + access.texture, + access.requiredState, + renderContext, + outErrorMessage)) { return false; } - - RHI::ResourceStates& currentState = m_textureStates[access.texture.index]; - if (currentState == access.requiredState) { - continue; - } - - RHI::RHIResourceView* resourceView = - ResolveTextureView(access.texture, ResolveBarrierViewType(access.requiredState)); - if (resourceView == nullptr) { - if (outErrorMessage != nullptr) { - *outErrorMessage = - Containers::String("RenderGraph cannot resolve transient texture view for state transition: ") + - m_graph.m_textures[access.texture.index].name; - } - return false; - } - - renderContext.commandList->TransitionBarrier( - resourceView, - currentState, - access.requiredState); - currentState = access.requiredState; } return true; @@ -164,6 +173,55 @@ private: RHI::RHIResourceView* shaderResourceView = nullptr; }; + bool ShouldGraphManageTransitions(RenderGraphTextureHandle handle) const { + if (!handle.IsValid() || handle.index >= m_graph.m_textures.size()) { + return false; + } + + const CompiledRenderGraph::CompiledTexture& texture = m_graph.m_textures[handle.index]; + return texture.kind == RenderGraphTextureKind::Transient || + (texture.kind == RenderGraphTextureKind::Imported && + texture.importedOptions.graphOwnsTransitions); + } + + bool TransitionTexture( + RenderGraphTextureHandle handle, + RHI::ResourceStates targetState, + const RenderContext& renderContext, + Containers::String* outErrorMessage) { + if (renderContext.commandList == nullptr) { + if (outErrorMessage != nullptr) { + *outErrorMessage = + Containers::String("RenderGraph cannot transition texture without a valid command list: ") + + m_graph.m_textures[handle.index].name; + } + return false; + } + + RHI::ResourceStates& currentState = m_textureStates[handle.index]; + if (currentState == targetState) { + return true; + } + + RHI::RHIResourceView* resourceView = + ResolveTextureView(handle, ResolveBarrierViewType(targetState)); + if (resourceView == nullptr) { + if (outErrorMessage != nullptr) { + *outErrorMessage = + Containers::String("RenderGraph cannot resolve texture view for state transition: ") + + m_graph.m_textures[handle.index].name; + } + return false; + } + + renderContext.commandList->TransitionBarrier( + resourceView, + currentState, + targetState); + currentState = targetState; + return true; + } + static void DestroyTextureAllocation(TextureAllocation& allocation) { if (allocation.renderTargetView != nullptr) { allocation.renderTargetView->Shutdown(); @@ -286,6 +344,12 @@ bool RenderGraphExecutor::Execute( } } + if (!runtimeResources.TransitionGraphOwnedImportsToFinalStates( + renderContext, + outErrorMessage)) { + return false; + } + return true; } diff --git a/tests/Rendering/unit/test_render_graph.cpp b/tests/Rendering/unit/test_render_graph.cpp index 13b1fe92..4eb00a56 100644 --- a/tests/Rendering/unit/test_render_graph.cpp +++ b/tests/Rendering/unit/test_render_graph.cpp @@ -108,6 +108,16 @@ private: Format m_format = Format::Unknown; }; +class MockImportedView final : public RHIResourceView { +public: + void Shutdown() override {} + void* GetNativeHandle() override { return nullptr; } + bool IsValid() const override { return true; } + ResourceViewType GetViewType() const override { return ResourceViewType::RenderTarget; } + ResourceViewDimension GetDimension() const override { return ResourceViewDimension::Texture2D; } + Format GetFormat() const override { return Format::R8G8B8A8_UNorm; } +}; + class MockTransientDevice final : public RHIDevice { public: explicit MockTransientDevice(std::shared_ptr state) @@ -421,6 +431,51 @@ TEST(RenderGraph_Test, PreservesImportedTextureStateContractAcrossCompile) { EXPECT_TRUE(resolvedOptions.graphOwnsTransitions); } +TEST(RenderGraph_Test, ExecutesGraphOwnedImportedTextureTransitionsAtGraphBoundaries) { + RenderGraph graph; + RenderGraphBuilder builder(graph); + + const RenderGraphTextureDesc desc = BuildTestTextureDesc(); + RenderGraphImportedTextureOptions importedOptions = {}; + importedOptions.initialState = ResourceStates::Present; + importedOptions.finalState = ResourceStates::Present; + importedOptions.graphOwnsTransitions = true; + + MockImportedView importedView; + const RenderGraphTextureHandle backBuffer = builder.ImportTexture( + "BackBuffer", + desc, + &importedView, + importedOptions); + + builder.AddRasterPass( + "FinalBlit", + [&](RenderGraphPassBuilder& pass) { + pass.WriteTexture(backBuffer); + }); + + CompiledRenderGraph compiledGraph; + XCEngine::Containers::String errorMessage; + ASSERT_TRUE(RenderGraphCompiler::Compile(graph, compiledGraph, &errorMessage)) + << errorMessage.CStr(); + + MockTransientCommandList commandList; + RenderContext renderContext = {}; + renderContext.device = reinterpret_cast(1); + renderContext.commandList = &commandList; + renderContext.commandQueue = reinterpret_cast(1); + ASSERT_TRUE(RenderGraphExecutor::Execute(compiledGraph, renderContext, &errorMessage)) + << errorMessage.CStr(); + + ASSERT_EQ(commandList.transitionCalls.size(), 2u); + EXPECT_EQ(commandList.transitionCalls[0].resource, &importedView); + EXPECT_EQ(commandList.transitionCalls[0].before, ResourceStates::Present); + EXPECT_EQ(commandList.transitionCalls[0].after, ResourceStates::RenderTarget); + EXPECT_EQ(commandList.transitionCalls[1].resource, &importedView); + EXPECT_EQ(commandList.transitionCalls[1].before, ResourceStates::RenderTarget); + EXPECT_EQ(commandList.transitionCalls[1].after, ResourceStates::Present); +} + TEST(RenderGraph_Test, RejectsTransientTextureReadBeforeWrite) { RenderGraph graph; RenderGraphBuilder builder(graph);