From b7428b0ef10686c9402b8d874d199ff1cb8382ac Mon Sep 17 00:00:00 2001 From: ssdfasd <2156608475@qq.com> Date: Mon, 13 Apr 2026 02:23:39 +0800 Subject: [PATCH] Stabilize 3DGS D3D12 phase 3 and sort key setup --- MVS/3DGS-D3D12/CMakeLists.txt | 33 + MVS/3DGS-D3D12/include/XC3DGSD3D12/App.h | 48 +- MVS/3DGS-D3D12/shaders/BuildSortKeysCS.hlsl | 43 + MVS/3DGS-D3D12/shaders/DebugPointsPS.hlsl | 4 + MVS/3DGS-D3D12/shaders/DebugPointsVS.hlsl | 39 + .../shaders/PrepareGaussiansCS.hlsl | 269 ++++++ .../shaders/PreparedSplatView.hlsli | 22 + MVS/3DGS-D3D12/src/App.cpp | 764 +++++++++++++++++- MVS/3DGS-D3D12/src/GaussianPlyLoader.cpp | 89 +- MVS/3DGS-D3D12/src/main.cpp | 3 + 10 files changed, 1281 insertions(+), 33 deletions(-) create mode 100644 MVS/3DGS-D3D12/shaders/BuildSortKeysCS.hlsl create mode 100644 MVS/3DGS-D3D12/shaders/DebugPointsPS.hlsl create mode 100644 MVS/3DGS-D3D12/shaders/DebugPointsVS.hlsl create mode 100644 MVS/3DGS-D3D12/shaders/PrepareGaussiansCS.hlsl create mode 100644 MVS/3DGS-D3D12/shaders/PreparedSplatView.hlsli diff --git a/MVS/3DGS-D3D12/CMakeLists.txt b/MVS/3DGS-D3D12/CMakeLists.txt index 80f5e72d..762692c4 100644 --- a/MVS/3DGS-D3D12/CMakeLists.txt +++ b/MVS/3DGS-D3D12/CMakeLists.txt @@ -5,6 +5,8 @@ project(XC3DGSD3D12MVS LANGUAGES CXX) set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) +find_program(XC_DXC_EXECUTABLE NAMES dxc) + get_filename_component(XCENGINE_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/../.." ABSOLUTE) set(XCENGINE_BUILD_DIR "${XCENGINE_ROOT}/build") set(XCENGINE_INCLUDE_DIR "${XCENGINE_ROOT}/engine/include") @@ -30,6 +32,21 @@ add_executable(xc_3dgs_d3d12_mvs src/GaussianPlyLoader.cpp include/XC3DGSD3D12/App.h include/XC3DGSD3D12/GaussianPlyLoader.h + shaders/PreparedSplatView.hlsli + shaders/PrepareGaussiansCS.hlsl + shaders/BuildSortKeysCS.hlsl + shaders/DebugPointsVS.hlsl + shaders/DebugPointsPS.hlsl +) + +set_source_files_properties( + shaders/PreparedSplatView.hlsli + shaders/PrepareGaussiansCS.hlsl + shaders/BuildSortKeysCS.hlsl + shaders/DebugPointsVS.hlsl + shaders/DebugPointsPS.hlsl + PROPERTIES + HEADER_FILE_ONLY TRUE ) target_include_directories(xc_3dgs_d3d12_mvs PRIVATE @@ -69,3 +86,19 @@ add_custom_command(TARGET xc_3dgs_d3d12_mvs POST_BUILD "${CMAKE_CURRENT_SOURCE_DIR}/room.ply" "$/room.ply" ) + +add_custom_command(TARGET xc_3dgs_d3d12_mvs POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_directory + "${CMAKE_CURRENT_SOURCE_DIR}/shaders" + "$/shaders" +) + +if(XC_DXC_EXECUTABLE) + add_custom_command(TARGET xc_3dgs_d3d12_mvs POST_BUILD + COMMAND "${XC_DXC_EXECUTABLE}" + -T cs_6_6 + -E MainCS + -Fo "$/shaders/BuildSortKeysCS.dxil" + "${CMAKE_CURRENT_SOURCE_DIR}/shaders/BuildSortKeysCS.hlsl" + ) +endif() diff --git a/MVS/3DGS-D3D12/include/XC3DGSD3D12/App.h b/MVS/3DGS-D3D12/include/XC3DGSD3D12/App.h index 748e45f1..a7d61681 100644 --- a/MVS/3DGS-D3D12/include/XC3DGSD3D12/App.h +++ b/MVS/3DGS-D3D12/include/XC3DGSD3D12/App.h @@ -20,8 +20,24 @@ #include "XCEngine/RHI/D3D12/D3D12SwapChain.h" #include "XCEngine/RHI/D3D12/D3D12Texture.h" +namespace XCEngine { +namespace RHI { +class RHIDescriptorPool; +class RHIDescriptorSet; +class RHIPipelineLayout; +class RHIPipelineState; +} // namespace RHI +} // namespace XCEngine + namespace XC3DGSD3D12 { +struct PreparedSplatView { + float clipPosition[4] = {}; + float axis1[2] = {}; + float axis2[2] = {}; + uint32_t packedColor[2] = {}; +}; + class App { public: App(); @@ -32,6 +48,7 @@ public: void SetFrameLimit(unsigned int frameLimit); void SetGaussianScenePath(std::wstring scenePath); void SetSummaryPath(std::wstring summaryPath); + void SetScreenshotPath(std::wstring screenshotPath); const std::wstring& GetLastErrorMessage() const; private: @@ -47,9 +64,16 @@ private: bool LoadGaussianScene(); bool InitializeRhi(); bool InitializeGaussianGpuResources(); + bool InitializePreparePassResources(); + bool InitializeSortResources(); + bool InitializeDebugDrawResources(); void ShutdownGaussianGpuResources(); + void ShutdownPreparePassResources(); + void ShutdownSortResources(); + void ShutdownDebugDrawResources(); void Shutdown(); - void RenderFrame(); + bool CaptureSortKeySnapshot(); + void RenderFrame(bool captureScreenshot); HWND m_hwnd = nullptr; HINSTANCE m_instance = nullptr; @@ -62,6 +86,8 @@ private: unsigned int m_renderedFrameCount = 0; std::wstring m_gaussianScenePath = L"room.ply"; std::wstring m_summaryPath; + std::wstring m_screenshotPath = L"phase3_debug_points.ppm"; + std::wstring m_sortKeySnapshotPath = L"phase3_sortkeys.txt"; std::wstring m_lastErrorMessage; GaussianSplatRuntimeData m_gaussianSceneData; XCEngine::RHI::D3D12Buffer m_gaussianPositionBuffer; @@ -73,6 +99,26 @@ private: std::unique_ptr m_gaussianShView; std::unique_ptr m_gaussianColorView; std::vector> m_gaussianUploadBuffers; + XCEngine::RHI::D3D12Buffer* m_preparedViewBuffer = nullptr; + std::unique_ptr m_preparedViewSrv; + std::unique_ptr m_preparedViewUav; + XCEngine::RHI::RHIPipelineLayout* m_preparePipelineLayout = nullptr; + XCEngine::RHI::RHIPipelineState* m_preparePipelineState = nullptr; + XCEngine::RHI::RHIDescriptorPool* m_prepareDescriptorPool = nullptr; + XCEngine::RHI::RHIDescriptorSet* m_prepareDescriptorSet = nullptr; + XCEngine::RHI::D3D12Buffer* m_sortKeyBuffer = nullptr; + XCEngine::RHI::D3D12Buffer* m_orderBuffer = nullptr; + std::unique_ptr m_sortKeySrv; + std::unique_ptr m_sortKeyUav; + std::unique_ptr m_orderBufferSrv; + XCEngine::RHI::RHIPipelineLayout* m_sortPipelineLayout = nullptr; + XCEngine::RHI::RHIPipelineState* m_sortPipelineState = nullptr; + XCEngine::RHI::RHIDescriptorPool* m_sortDescriptorPool = nullptr; + XCEngine::RHI::RHIDescriptorSet* m_sortDescriptorSet = nullptr; + XCEngine::RHI::RHIPipelineLayout* m_debugPipelineLayout = nullptr; + XCEngine::RHI::RHIPipelineState* m_debugPipelineState = nullptr; + XCEngine::RHI::RHIDescriptorPool* m_debugDescriptorPool = nullptr; + XCEngine::RHI::RHIDescriptorSet* m_debugDescriptorSet = nullptr; XCEngine::RHI::D3D12Device m_device; XCEngine::RHI::D3D12CommandQueue m_commandQueue; diff --git a/MVS/3DGS-D3D12/shaders/BuildSortKeysCS.hlsl b/MVS/3DGS-D3D12/shaders/BuildSortKeysCS.hlsl new file mode 100644 index 00000000..da80b4ef --- /dev/null +++ b/MVS/3DGS-D3D12/shaders/BuildSortKeysCS.hlsl @@ -0,0 +1,43 @@ +#define GROUP_SIZE 64 + +cbuffer FrameConstants : register(b0) +{ + float4x4 gViewProjection; + float4x4 gView; + float4x4 gProjection; + float4 gCameraWorldPos; + float4 gScreenParams; + float4 gSettings; +}; + +ByteAddressBuffer gPositions : register(t0); +StructuredBuffer gOrderBuffer : register(t1); +RWStructuredBuffer gSortKeys : register(u0); + +float3 LoadFloat3(ByteAddressBuffer buffer, uint byteOffset) +{ + return asfloat(buffer.Load3(byteOffset)); +} + +uint FloatToSortableUint(float value) +{ + uint bits = asuint(value); + uint mask = (0u - (bits >> 31)) | 0x80000000u; + return bits ^ mask; +} + +[numthreads(GROUP_SIZE, 1, 1)] +void MainCS(uint3 dispatchThreadId : SV_DispatchThreadID) +{ + uint index = dispatchThreadId.x; + uint splatCount = (uint)gSettings.x; + if (index >= splatCount) + { + return; + } + + uint splatIndex = gOrderBuffer[index]; + float3 position = LoadFloat3(gPositions, splatIndex * 12); + float3 viewPosition = mul(float4(position, 1.0), gView).xyz; + gSortKeys[index] = FloatToSortableUint(viewPosition.z); +} diff --git a/MVS/3DGS-D3D12/shaders/DebugPointsPS.hlsl b/MVS/3DGS-D3D12/shaders/DebugPointsPS.hlsl new file mode 100644 index 00000000..170deba0 --- /dev/null +++ b/MVS/3DGS-D3D12/shaders/DebugPointsPS.hlsl @@ -0,0 +1,4 @@ +float4 MainPS(float4 position : SV_Position, float4 color : COLOR0) : SV_Target0 +{ + return color; +} diff --git a/MVS/3DGS-D3D12/shaders/DebugPointsVS.hlsl b/MVS/3DGS-D3D12/shaders/DebugPointsVS.hlsl new file mode 100644 index 00000000..2a33db89 --- /dev/null +++ b/MVS/3DGS-D3D12/shaders/DebugPointsVS.hlsl @@ -0,0 +1,39 @@ +#include "PreparedSplatView.hlsli" + +cbuffer FrameConstants : register(b0) +{ + float4x4 gViewProjection; + float4x4 gView; + float4x4 gProjection; + float4 gCameraWorldPos; + float4 gScreenParams; + float4 gSettings; +}; + +StructuredBuffer gPreparedViews : register(t0); +StructuredBuffer gOrderBuffer : register(t1); + +struct VertexOutput +{ + float4 position : SV_Position; + float4 color : COLOR0; +}; + +VertexOutput MainVS(uint vertexId : SV_VertexID, uint instanceId : SV_InstanceID) +{ + VertexOutput output; + uint splatIndex = gOrderBuffer[instanceId]; + PreparedSplatView view = gPreparedViews[splatIndex]; + float4 color = UnpackPreparedColor(view); + + if (view.clipPosition.w <= 0.0) + { + output.position = float4(2.0, 2.0, 2.0, 1.0); + output.color = 0.0; + return output; + } + + output.position = view.clipPosition; + output.color = float4(color.rgb, 1.0); + return output; +} diff --git a/MVS/3DGS-D3D12/shaders/PrepareGaussiansCS.hlsl b/MVS/3DGS-D3D12/shaders/PrepareGaussiansCS.hlsl new file mode 100644 index 00000000..74254483 --- /dev/null +++ b/MVS/3DGS-D3D12/shaders/PrepareGaussiansCS.hlsl @@ -0,0 +1,269 @@ +#include "PreparedSplatView.hlsli" + +#define GROUP_SIZE 64 + +cbuffer FrameConstants : register(b0) +{ + float4x4 gViewProjection; + float4x4 gView; + float4x4 gProjection; + float4 gCameraWorldPos; + float4 gScreenParams; + float4 gSettings; +}; + +ByteAddressBuffer gPositions : register(t0); +ByteAddressBuffer gOther : register(t1); +Texture2D gColor : register(t2); +ByteAddressBuffer gSh : register(t3); +RWStructuredBuffer gPreparedViews : register(u0); + +static const float SH_C1 = 0.4886025; +static const float SH_C2[] = { 1.0925484, -1.0925484, 0.3153916, -1.0925484, 0.5462742 }; +static const float SH_C3[] = { -0.5900436, 2.8906114, -0.4570458, 0.3731763, -0.4570458, 1.4453057, -0.5900436 }; +static const uint kColorTextureWidth = 2048; +static const uint kOtherStride = 16; +static const uint kShStride = 192; + +struct SplatSHData +{ + float3 col; + float3 sh[15]; +}; + +float3 LoadFloat3(ByteAddressBuffer buffer, uint byteOffset) +{ + return asfloat(buffer.Load3(byteOffset)); +} + +uint EncodeMorton2D_16x16(uint2 c) +{ + uint t = ((c.y & 0xF) << 8) | (c.x & 0xF); + t = (t ^ (t << 2)) & 0x3333; + t = (t ^ (t << 1)) & 0x5555; + return (t | (t >> 7)) & 0xFF; +} + +uint2 DecodeMorton2D_16x16(uint t) +{ + t = (t & 0xFF) | ((t & 0xFE) << 7); + t &= 0x5555; + t = (t ^ (t >> 1)) & 0x3333; + t = (t ^ (t >> 2)) & 0x0F0F; + return uint2(t & 0xF, t >> 8); +} + +uint3 SplatIndexToPixelIndex(uint index) +{ + uint2 xy = DecodeMorton2D_16x16(index); + uint tileWidth = kColorTextureWidth / 16; + index >>= 8; + + uint3 result; + result.x = (index % tileWidth) * 16 + xy.x; + result.y = (index / tileWidth) * 16 + xy.y; + result.z = 0; + return result; +} + +float4 DecodePacked_10_10_10_2(uint encoded) +{ + return float4( + (encoded & 1023) / 1023.0, + ((encoded >> 10) & 1023) / 1023.0, + ((encoded >> 20) & 1023) / 1023.0, + ((encoded >> 30) & 3) / 3.0); +} + +float4 DecodeRotation(float4 packedRotation) +{ + uint droppedIndex = (uint)round(packedRotation.w * 3.0); + float4 rotation; + rotation.xyz = packedRotation.xyz * sqrt(2.0) - (1.0 / sqrt(2.0)); + rotation.w = sqrt(1.0 - saturate(dot(rotation.xyz, rotation.xyz))); + + if (droppedIndex == 0) + { + rotation = rotation.wxyz; + } + if (droppedIndex == 1) + { + rotation = rotation.xwyz; + } + if (droppedIndex == 2) + { + rotation = rotation.xywz; + } + + return rotation; +} + +float3x3 CalcMatrixFromRotationScale(float4 rotation, float3 scale) +{ + float3x3 scaleMatrix = float3x3( + scale.x, 0, 0, + 0, scale.y, 0, + 0, 0, scale.z); + + float x = rotation.x; + float y = rotation.y; + float z = rotation.z; + float w = rotation.w; + + float3x3 rotationMatrix = float3x3( + 1 - 2 * (y * y + z * z), 2 * (x * y - w * z), 2 * (x * z + w * y), + 2 * (x * y + w * z), 1 - 2 * (x * x + z * z), 2 * (y * z - w * x), + 2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x * x + y * y)); + + return mul(rotationMatrix, scaleMatrix); +} + +void CalcCovariance3D(float3x3 rotationScaleMatrix, out float3 sigma0, out float3 sigma1) +{ + float3x3 sigma = mul(rotationScaleMatrix, transpose(rotationScaleMatrix)); + sigma0 = float3(sigma._m00, sigma._m01, sigma._m02); + sigma1 = float3(sigma._m11, sigma._m12, sigma._m22); +} + +float3 CalcCovariance2D(float3 worldPosition, float3 covariance0, float3 covariance1) +{ + float3 viewPosition = mul(float4(worldPosition, 1.0), gView).xyz; + + float aspect = gProjection._m00 / gProjection._m11; + float tanFovX = rcp(gProjection._m00); + float tanFovY = rcp(gProjection._m11 * aspect); + float clampX = 1.3 * tanFovX; + float clampY = 1.3 * tanFovY; + viewPosition.x = clamp(viewPosition.x / viewPosition.z, -clampX, clampX) * viewPosition.z; + viewPosition.y = clamp(viewPosition.y / viewPosition.z, -clampY, clampY) * viewPosition.z; + + float focal = gScreenParams.x * gProjection._m00 * 0.5; + + float3x3 jacobian = float3x3( + focal / viewPosition.z, 0, -(focal * viewPosition.x) / (viewPosition.z * viewPosition.z), + 0, focal / viewPosition.z, -(focal * viewPosition.y) / (viewPosition.z * viewPosition.z), + 0, 0, 0); + float3x3 worldToView = (float3x3)gView; + float3x3 transform = mul(jacobian, worldToView); + float3x3 covariance = float3x3( + covariance0.x, covariance0.y, covariance0.z, + covariance0.y, covariance1.x, covariance1.y, + covariance0.z, covariance1.y, covariance1.z); + float3x3 projected = mul(transform, mul(covariance, transpose(transform))); + projected._m00 += 0.3; + projected._m11 += 0.3; + return float3(projected._m00, projected._m01, projected._m11); +} + +void DecomposeCovariance(float3 covariance2D, out float2 axis1, out float2 axis2) +{ + float diagonal0 = covariance2D.x; + float diagonal1 = covariance2D.z; + float offDiagonal = covariance2D.y; + float mid = 0.5 * (diagonal0 + diagonal1); + float radius = length(float2((diagonal0 - diagonal1) * 0.5, offDiagonal)); + float lambda0 = mid + radius; + float lambda1 = max(mid - radius, 0.1); + float2 diagonalVector = normalize(float2(offDiagonal, lambda0 - diagonal0)); + diagonalVector.y = -diagonalVector.y; + const float maxSize = 4096.0; + axis1 = min(sqrt(2.0 * lambda0), maxSize) * diagonalVector; + axis2 = min(sqrt(2.0 * lambda1), maxSize) * float2(diagonalVector.y, -diagonalVector.x); +} + +SplatSHData LoadSplatSH(uint index) +{ + SplatSHData sh; + const uint shBaseOffset = index * kShStride; + sh.col = gColor.Load(int3(SplatIndexToPixelIndex(index).xy, 0)).rgb; + + [unroll] + for (uint coefficientIndex = 0; coefficientIndex < 15; ++coefficientIndex) + { + sh.sh[coefficientIndex] = LoadFloat3(gSh, shBaseOffset + coefficientIndex * 12); + } + + return sh; +} + +float3 ShadeSH(SplatSHData sh, float3 direction, int shOrder) +{ + direction *= -1.0; + float x = direction.x; + float y = direction.y; + float z = direction.z; + + float3 result = sh.col; + if (shOrder >= 1) + { + result += SH_C1 * (-sh.sh[0] * y + sh.sh[1] * z - sh.sh[2] * x); + if (shOrder >= 2) + { + float xx = x * x; + float yy = y * y; + float zz = z * z; + float xy = x * y; + float yz = y * z; + float xz = x * z; + result += + (SH_C2[0] * xy) * sh.sh[3] + + (SH_C2[1] * yz) * sh.sh[4] + + (SH_C2[2] * (2 * zz - xx - yy)) * sh.sh[5] + + (SH_C2[3] * xz) * sh.sh[6] + + (SH_C2[4] * (xx - yy)) * sh.sh[7]; + if (shOrder >= 3) + { + result += + (SH_C3[0] * y * (3 * xx - yy)) * sh.sh[8] + + (SH_C3[1] * xy * z) * sh.sh[9] + + (SH_C3[2] * y * (4 * zz - xx - yy)) * sh.sh[10] + + (SH_C3[3] * z * (2 * zz - 3 * xx - 3 * yy)) * sh.sh[11] + + (SH_C3[4] * x * (4 * zz - xx - yy)) * sh.sh[12] + + (SH_C3[5] * z * (xx - yy)) * sh.sh[13] + + (SH_C3[6] * x * (xx - 3 * yy)) * sh.sh[14]; + } + } + } + + return max(result, 0.0); +} + +[numthreads(GROUP_SIZE, 1, 1)] +void MainCS(uint3 dispatchThreadId : SV_DispatchThreadID) +{ + uint index = dispatchThreadId.x; + uint splatCount = (uint)gSettings.x; + if (index >= splatCount) + { + return; + } + + PreparedSplatView view = (PreparedSplatView)0; + + float3 position = LoadFloat3(gPositions, index * 12); + uint packedRotation = gOther.Load(index * kOtherStride); + float4 rotation = DecodeRotation(DecodePacked_10_10_10_2(packedRotation)); + float3 scale = LoadFloat3(gOther, index * kOtherStride + 4); + float4 colorOpacity = gColor.Load(int3(SplatIndexToPixelIndex(index).xy, 0)); + + view.clipPosition = mul(float4(position, 1.0), gViewProjection); + if (view.clipPosition.w > 0.0) + { + float3x3 rotationScale = CalcMatrixFromRotationScale(rotation, scale); + float3 covariance0; + float3 covariance1; + CalcCovariance3D(rotationScale, covariance0, covariance1); + float3 covariance2D = CalcCovariance2D(position, covariance0, covariance1); + DecomposeCovariance(covariance2D, view.axis1, view.axis2); + + SplatSHData sh = LoadSplatSH(index); + float3 viewDirection = normalize(gCameraWorldPos.xyz - position); + float3 shadedColor = ShadeSH(sh, viewDirection, (int)gSettings.z); + float opacity = saturate(colorOpacity.a * gSettings.y); + + view.packedColor.x = (f32tof16(shadedColor.r) << 16) | f32tof16(shadedColor.g); + view.packedColor.y = (f32tof16(shadedColor.b) << 16) | f32tof16(opacity); + } + + gPreparedViews[index] = view; +} diff --git a/MVS/3DGS-D3D12/shaders/PreparedSplatView.hlsli b/MVS/3DGS-D3D12/shaders/PreparedSplatView.hlsli new file mode 100644 index 00000000..01f7659c --- /dev/null +++ b/MVS/3DGS-D3D12/shaders/PreparedSplatView.hlsli @@ -0,0 +1,22 @@ +#ifndef PREPARED_SPLAT_VIEW_HLSLI +#define PREPARED_SPLAT_VIEW_HLSLI + +struct PreparedSplatView +{ + float4 clipPosition; + float2 axis1; + float2 axis2; + uint2 packedColor; +}; + +float4 UnpackPreparedColor(PreparedSplatView view) +{ + float4 color; + color.r = f16tof32((view.packedColor.x >> 16) & 0xFFFF); + color.g = f16tof32(view.packedColor.x & 0xFFFF); + color.b = f16tof32((view.packedColor.y >> 16) & 0xFFFF); + color.a = f16tof32(view.packedColor.y & 0xFFFF); + return color; +} + +#endif diff --git a/MVS/3DGS-D3D12/src/App.cpp b/MVS/3DGS-D3D12/src/App.cpp index 8b54e45c..1e515c9e 100644 --- a/MVS/3DGS-D3D12/src/App.cpp +++ b/MVS/3DGS-D3D12/src/App.cpp @@ -2,17 +2,43 @@ #include #include +#include +#include +#include #include +#include +#include + +#include "XCEngine/RHI/D3D12/D3D12Screenshot.h" +#include "XCEngine/RHI/RHIDescriptorPool.h" +#include "XCEngine/RHI/RHIDescriptorSet.h" +#include "XCEngine/RHI/RHIPipelineLayout.h" +#include "XCEngine/RHI/RHIPipelineState.h" namespace XC3DGSD3D12 { using namespace XCEngine::RHI; namespace { + constexpr wchar_t kWindowClassName[] = L"XC3DGSD3D12WindowClass"; -constexpr wchar_t kWindowTitle[] = L"XC 3DGS D3D12 MVS - Phase 1"; -constexpr float kClearColor[4] = { 0.08f, 0.12f, 0.18f, 1.0f }; +constexpr wchar_t kWindowTitle[] = L"XC 3DGS D3D12 MVS - Phase 3"; +constexpr float kClearColor[4] = { 0.04f, 0.05f, 0.07f, 1.0f }; +constexpr uint32_t kPrepareThreadGroupSize = 64u; +constexpr uint32_t kSortThreadGroupSize = 64u; + +struct FrameConstants { + float viewProjection[16] = {}; + float view[16] = {}; + float projection[16] = {}; + float cameraWorldPos[4] = {}; + float screenParams[4] = {}; + float settings[4] = {}; +}; + +static_assert(sizeof(FrameConstants) % 16 == 0, "Frame constants must stay 16-byte aligned."); +static_assert(sizeof(PreparedSplatView) == 40, "Prepared view buffer layout must match shader."); std::filesystem::path GetExecutableDirectory() { std::wstring pathBuffer; @@ -29,8 +55,106 @@ std::filesystem::path ResolveNearExecutable(const std::wstring& path) { } return GetExecutableDirectory() / inputPath; } + +std::filesystem::path ResolveShaderPath(std::wstring_view fileName) { + return GetExecutableDirectory() / L"shaders" / std::filesystem::path(fileName); } +std::vector LoadBinaryFile(const std::filesystem::path& filePath) { + std::ifstream input(filePath, std::ios::binary); + if (!input.is_open()) { + return {}; + } + + input.seekg(0, std::ios::end); + const std::streamoff size = input.tellg(); + if (size <= 0) { + return {}; + } + input.seekg(0, std::ios::beg); + + std::vector bytes(static_cast(size)); + input.read(reinterpret_cast(bytes.data()), size); + if (!input) { + return {}; + } + return bytes; +} + +std::string NarrowAscii(std::wstring_view text) { + std::string result; + result.reserve(text.size()); + for (wchar_t ch : text) { + result.push_back(ch >= 0 && ch <= 0x7F ? static_cast(ch) : '?'); + } + return result; +} + +void AppendTrace(std::string_view message) { + const std::filesystem::path tracePath = GetExecutableDirectory() / "phase3_trace.log"; + std::ofstream file(tracePath, std::ios::app); + if (!file.is_open()) { + return; + } + + file << GetTickCount64() << " | " << message << '\n'; +} + +void AppendTrace(const std::wstring& message) { + AppendTrace(NarrowAscii(message)); +} + +void StoreMatrixTransposed(const DirectX::XMMATRIX& matrix, float* destination) { + DirectX::XMFLOAT4X4 output = {}; + DirectX::XMStoreFloat4x4(&output, DirectX::XMMatrixTranspose(matrix)); + std::memcpy(destination, &output, sizeof(output)); +} + +FrameConstants BuildFrameConstants(uint32_t width, uint32_t height, uint32_t splatCount) { + using namespace DirectX; + + const XMVECTOR eye = XMVectorSet(0.0f, 0.5f, 1.0f, 1.0f); + const XMVECTOR target = XMVectorSet(0.0f, 0.5f, -5.0f, 1.0f); + const XMVECTOR up = XMVectorSet(0.0f, 1.0f, 0.0f, 0.0f); + + const float aspect = height > 0 ? static_cast(width) / static_cast(height) : 1.0f; + const XMMATRIX view = XMMatrixLookAtRH(eye, target, up); + const XMMATRIX projection = XMMatrixPerspectiveFovRH(XMConvertToRadians(60.0f), aspect, 0.1f, 200.0f); + const XMMATRIX viewProjection = XMMatrixMultiply(view, projection); + + FrameConstants constants = {}; + StoreMatrixTransposed(viewProjection, constants.viewProjection); + StoreMatrixTransposed(view, constants.view); + StoreMatrixTransposed(projection, constants.projection); + + constants.cameraWorldPos[0] = 0.0f; + constants.cameraWorldPos[1] = 0.5f; + constants.cameraWorldPos[2] = 1.0f; + constants.cameraWorldPos[3] = 1.0f; + + constants.screenParams[0] = static_cast(width); + constants.screenParams[1] = static_cast(height); + constants.screenParams[2] = width > 0 ? 1.0f / static_cast(width) : 0.0f; + constants.screenParams[3] = height > 0 ? 1.0f / static_cast(height) : 0.0f; + + constants.settings[0] = static_cast(splatCount); + constants.settings[1] = 1.0f; // opacity scale + constants.settings[2] = 3.0f; // SH order + constants.settings[3] = 1.0f; + return constants; +} + +template +void ShutdownAndDelete(T*& object) { + if (object != nullptr) { + object->Shutdown(); + delete object; + object = nullptr; + } +} + +} // namespace + App::App() = default; App::~App() { @@ -38,36 +162,66 @@ App::~App() { } bool App::Initialize(HINSTANCE instance, int showCommand) { + AppendTrace("Initialize: begin"); m_instance = instance; m_lastErrorMessage.clear(); + AppendTrace("Initialize: LoadGaussianScene"); if (!LoadGaussianScene()) { + AppendTrace(std::string("Initialize: LoadGaussianScene failed: ") + NarrowAscii(m_lastErrorMessage)); return false; } + AppendTrace("Initialize: RegisterWindowClass"); if (!RegisterWindowClass(instance)) { m_lastErrorMessage = L"Failed to register the Win32 window class."; + AppendTrace(std::string("Initialize: RegisterWindowClass failed: ") + NarrowAscii(m_lastErrorMessage)); return false; } + AppendTrace("Initialize: CreateMainWindow"); if (!CreateMainWindow(instance, showCommand)) { m_lastErrorMessage = L"Failed to create the main window."; + AppendTrace(std::string("Initialize: CreateMainWindow failed: ") + NarrowAscii(m_lastErrorMessage)); return false; } + AppendTrace("Initialize: InitializeRhi"); if (!InitializeRhi()) { if (m_lastErrorMessage.empty()) { m_lastErrorMessage = L"Failed to initialize the D3D12 RHI objects."; } + AppendTrace(std::string("Initialize: InitializeRhi failed: ") + NarrowAscii(m_lastErrorMessage)); return false; } + AppendTrace("Initialize: InitializeGaussianGpuResources"); if (!InitializeGaussianGpuResources()) { + AppendTrace(std::string("Initialize: InitializeGaussianGpuResources failed: ") + NarrowAscii(m_lastErrorMessage)); + return false; + } + + AppendTrace("Initialize: InitializePreparePassResources"); + if (!InitializePreparePassResources()) { + AppendTrace(std::string("Initialize: InitializePreparePassResources failed: ") + NarrowAscii(m_lastErrorMessage)); + return false; + } + + AppendTrace("Initialize: InitializeSortResources"); + if (!InitializeSortResources()) { + AppendTrace(std::string("Initialize: InitializeSortResources failed: ") + NarrowAscii(m_lastErrorMessage)); + return false; + } + + AppendTrace("Initialize: InitializeDebugDrawResources"); + if (!InitializeDebugDrawResources()) { + AppendTrace(std::string("Initialize: InitializeDebugDrawResources failed: ") + NarrowAscii(m_lastErrorMessage)); return false; } m_isInitialized = true; m_running = true; + AppendTrace("Initialize: success"); return true; } @@ -83,15 +237,22 @@ void App::SetSummaryPath(std::wstring summaryPath) { m_summaryPath = std::move(summaryPath); } +void App::SetScreenshotPath(std::wstring screenshotPath) { + m_screenshotPath = std::move(screenshotPath); +} + const std::wstring& App::GetLastErrorMessage() const { return m_lastErrorMessage; } int App::Run() { + AppendTrace("Run: begin"); MSG message = {}; + int exitCode = 0; while (m_running) { while (PeekMessage(&message, nullptr, 0, 0, PM_REMOVE)) { if (message.message == WM_QUIT) { + exitCode = static_cast(message.wParam); m_running = false; break; } @@ -104,16 +265,21 @@ int App::Run() { break; } - RenderFrame(); + const bool captureScreenshot = m_frameLimit > 0 && (m_renderedFrameCount + 1) >= m_frameLimit; + AppendTrace(captureScreenshot ? "Run: RenderFrame capture" : "Run: RenderFrame"); + RenderFrame(captureScreenshot); + AppendTrace("Run: RenderFrame complete"); ++m_renderedFrameCount; if (m_frameLimit > 0 && m_renderedFrameCount >= m_frameLimit) { m_running = false; + AppendTrace("Run: frame limit reached, posting WM_CLOSE"); PostMessageW(m_hwnd, WM_CLOSE, 0, 0); } } - return static_cast(message.wParam); + AppendTrace("Run: end"); + return exitCode; } LRESULT CALLBACK App::StaticWindowProc(HWND hwnd, UINT message, WPARAM wParam, LPARAM lParam) { @@ -376,6 +542,361 @@ bool App::InitializeGaussianGpuResources() { return true; } +bool App::InitializePreparePassResources() { + DescriptorSetLayoutBinding bindings[6] = {}; + bindings[0].binding = 0; + bindings[0].type = static_cast(DescriptorType::CBV); + bindings[0].count = 1; + bindings[0].visibility = static_cast(ShaderVisibility::All); + + bindings[1].binding = 1; + bindings[1].type = static_cast(DescriptorType::SRV); + bindings[1].count = 1; + bindings[1].visibility = static_cast(ShaderVisibility::All); + bindings[1].resourceDimension = ResourceViewDimension::RawBuffer; + + bindings[2].binding = 2; + bindings[2].type = static_cast(DescriptorType::SRV); + bindings[2].count = 1; + bindings[2].visibility = static_cast(ShaderVisibility::All); + bindings[2].resourceDimension = ResourceViewDimension::RawBuffer; + + bindings[3].binding = 3; + bindings[3].type = static_cast(DescriptorType::SRV); + bindings[3].count = 1; + bindings[3].visibility = static_cast(ShaderVisibility::All); + bindings[3].resourceDimension = ResourceViewDimension::Texture2D; + + bindings[4].binding = 4; + bindings[4].type = static_cast(DescriptorType::SRV); + bindings[4].count = 1; + bindings[4].visibility = static_cast(ShaderVisibility::All); + bindings[4].resourceDimension = ResourceViewDimension::RawBuffer; + + bindings[5].binding = 5; + bindings[5].type = static_cast(DescriptorType::UAV); + bindings[5].count = 1; + bindings[5].visibility = static_cast(ShaderVisibility::All); + bindings[5].resourceDimension = ResourceViewDimension::StructuredBuffer; + + DescriptorSetLayoutDesc setLayout = {}; + setLayout.bindings = bindings; + setLayout.bindingCount = 6; + + RHIPipelineLayoutDesc pipelineLayoutDesc = {}; + pipelineLayoutDesc.setLayouts = &setLayout; + pipelineLayoutDesc.setLayoutCount = 1; + + m_preparePipelineLayout = m_device.CreatePipelineLayout(pipelineLayoutDesc); + if (m_preparePipelineLayout == nullptr) { + m_lastErrorMessage = L"Failed to create the prepare pass pipeline layout."; + return false; + } + + BufferDesc preparedViewBufferDesc = {}; + preparedViewBufferDesc.size = static_cast(sizeof(PreparedSplatView)) * m_gaussianSceneData.splatCount; + preparedViewBufferDesc.stride = sizeof(PreparedSplatView); + preparedViewBufferDesc.bufferType = static_cast(BufferType::Storage); + preparedViewBufferDesc.flags = static_cast(BufferFlags::AllowUnorderedAccess); + + m_preparedViewBuffer = static_cast(m_device.CreateBuffer(preparedViewBufferDesc)); + if (m_preparedViewBuffer == nullptr) { + m_lastErrorMessage = L"Failed to create the prepared view buffer."; + return false; + } + m_preparedViewBuffer->SetStride(sizeof(PreparedSplatView)); + m_preparedViewBuffer->SetBufferType(BufferType::Storage); + + ResourceViewDesc structuredViewDesc = {}; + structuredViewDesc.dimension = ResourceViewDimension::StructuredBuffer; + structuredViewDesc.structureByteStride = sizeof(PreparedSplatView); + structuredViewDesc.elementCount = m_gaussianSceneData.splatCount; + + m_preparedViewSrv.reset(static_cast(m_device.CreateShaderResourceView(m_preparedViewBuffer, structuredViewDesc))); + if (!m_preparedViewSrv) { + m_lastErrorMessage = L"Failed to create the prepared view SRV."; + return false; + } + + m_preparedViewUav.reset(static_cast(m_device.CreateUnorderedAccessView(m_preparedViewBuffer, structuredViewDesc))); + if (!m_preparedViewUav) { + m_lastErrorMessage = L"Failed to create the prepared view UAV."; + return false; + } + + DescriptorPoolDesc poolDesc = {}; + poolDesc.type = DescriptorHeapType::CBV_SRV_UAV; + poolDesc.descriptorCount = 5; + poolDesc.shaderVisible = true; + m_prepareDescriptorPool = m_device.CreateDescriptorPool(poolDesc); + if (m_prepareDescriptorPool == nullptr) { + m_lastErrorMessage = L"Failed to create the prepare pass descriptor pool."; + return false; + } + + m_prepareDescriptorSet = m_prepareDescriptorPool->AllocateSet(setLayout); + if (m_prepareDescriptorSet == nullptr) { + m_lastErrorMessage = L"Failed to allocate the prepare pass descriptor set."; + return false; + } + + m_prepareDescriptorSet->Update(1, m_gaussianPositionView.get()); + m_prepareDescriptorSet->Update(2, m_gaussianOtherView.get()); + m_prepareDescriptorSet->Update(3, m_gaussianColorView.get()); + m_prepareDescriptorSet->Update(4, m_gaussianShView.get()); + m_prepareDescriptorSet->Update(5, m_preparedViewUav.get()); + + ShaderCompileDesc computeShaderDesc = {}; + computeShaderDesc.fileName = ResolveShaderPath(L"PrepareGaussiansCS.hlsl").wstring(); + computeShaderDesc.entryPoint = L"MainCS"; + computeShaderDesc.profile = L"cs_5_0"; + + ComputePipelineDesc pipelineDesc = {}; + pipelineDesc.pipelineLayout = m_preparePipelineLayout; + pipelineDesc.computeShader = computeShaderDesc; + m_preparePipelineState = m_device.CreateComputePipelineState(pipelineDesc); + if (m_preparePipelineState == nullptr) { + m_lastErrorMessage = L"Failed to create the prepare pass pipeline state."; + return false; + } + + return true; +} + +bool App::InitializeSortResources() { + std::vector initialOrder(static_cast(m_gaussianSceneData.splatCount)); + for (uint32_t index = 0; index < m_gaussianSceneData.splatCount; ++index) { + initialOrder[index] = index; + } + + ID3D12Device* device = m_device.GetDevice(); + ID3D12GraphicsCommandList* commandList = m_commandList.GetCommandList(); + + m_commandAllocator.Reset(); + m_commandList.Reset(); + + const D3D12_RESOURCE_STATES shaderResourceState = + D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE | D3D12_RESOURCE_STATE_PIXEL_SHADER_RESOURCE; + + m_orderBuffer = new D3D12Buffer(); + m_orderBuffer->SetStride(sizeof(uint32_t)); + m_orderBuffer->SetBufferType(BufferType::Storage); + if (!m_orderBuffer->InitializeWithData( + device, + commandList, + initialOrder.data(), + static_cast(initialOrder.size() * sizeof(uint32_t)), + shaderResourceState)) { + m_lastErrorMessage = L"Failed to initialize the order buffer."; + return false; + } + m_orderBuffer->SetState(ResourceStates::NonPixelShaderResource); + + BufferDesc sortKeyBufferDesc = {}; + sortKeyBufferDesc.size = static_cast(m_gaussianSceneData.splatCount) * sizeof(uint32_t); + sortKeyBufferDesc.stride = sizeof(uint32_t); + sortKeyBufferDesc.bufferType = static_cast(BufferType::Storage); + sortKeyBufferDesc.flags = static_cast(BufferFlags::AllowUnorderedAccess); + + m_sortKeyBuffer = static_cast(m_device.CreateBuffer(sortKeyBufferDesc)); + if (m_sortKeyBuffer == nullptr) { + m_lastErrorMessage = L"Failed to create the sort key buffer."; + return false; + } + m_sortKeyBuffer->SetStride(sizeof(uint32_t)); + m_sortKeyBuffer->SetBufferType(BufferType::Storage); + + ResourceViewDesc structuredViewDesc = {}; + structuredViewDesc.dimension = ResourceViewDimension::StructuredBuffer; + structuredViewDesc.structureByteStride = sizeof(uint32_t); + structuredViewDesc.elementCount = m_gaussianSceneData.splatCount; + + m_orderBufferSrv.reset(static_cast(m_device.CreateShaderResourceView(m_orderBuffer, structuredViewDesc))); + if (!m_orderBufferSrv) { + m_lastErrorMessage = L"Failed to create the order buffer SRV."; + return false; + } + + m_sortKeySrv.reset(static_cast(m_device.CreateShaderResourceView(m_sortKeyBuffer, structuredViewDesc))); + if (!m_sortKeySrv) { + m_lastErrorMessage = L"Failed to create the sort key SRV."; + return false; + } + + m_sortKeyUav.reset(static_cast(m_device.CreateUnorderedAccessView(m_sortKeyBuffer, structuredViewDesc))); + if (!m_sortKeyUav) { + m_lastErrorMessage = L"Failed to create the sort key UAV."; + return false; + } + + DescriptorSetLayoutBinding bindings[4] = {}; + bindings[0].binding = 0; + bindings[0].type = static_cast(DescriptorType::CBV); + bindings[0].count = 1; + bindings[0].visibility = static_cast(ShaderVisibility::All); + + bindings[1].binding = 1; + bindings[1].type = static_cast(DescriptorType::SRV); + bindings[1].count = 1; + bindings[1].visibility = static_cast(ShaderVisibility::All); + bindings[1].resourceDimension = ResourceViewDimension::RawBuffer; + + bindings[2].binding = 2; + bindings[2].type = static_cast(DescriptorType::SRV); + bindings[2].count = 1; + bindings[2].visibility = static_cast(ShaderVisibility::All); + bindings[2].resourceDimension = ResourceViewDimension::StructuredBuffer; + + bindings[3].binding = 3; + bindings[3].type = static_cast(DescriptorType::UAV); + bindings[3].count = 1; + bindings[3].visibility = static_cast(ShaderVisibility::All); + bindings[3].resourceDimension = ResourceViewDimension::StructuredBuffer; + + DescriptorSetLayoutDesc setLayout = {}; + setLayout.bindings = bindings; + setLayout.bindingCount = 4; + + RHIPipelineLayoutDesc pipelineLayoutDesc = {}; + pipelineLayoutDesc.setLayouts = &setLayout; + pipelineLayoutDesc.setLayoutCount = 1; + + m_sortPipelineLayout = m_device.CreatePipelineLayout(pipelineLayoutDesc); + if (m_sortPipelineLayout == nullptr) { + m_lastErrorMessage = L"Failed to create the sort pipeline layout."; + return false; + } + + DescriptorPoolDesc poolDesc = {}; + poolDesc.type = DescriptorHeapType::CBV_SRV_UAV; + poolDesc.descriptorCount = 3; + poolDesc.shaderVisible = true; + m_sortDescriptorPool = m_device.CreateDescriptorPool(poolDesc); + if (m_sortDescriptorPool == nullptr) { + m_lastErrorMessage = L"Failed to create the sort descriptor pool."; + return false; + } + + m_sortDescriptorSet = m_sortDescriptorPool->AllocateSet(setLayout); + if (m_sortDescriptorSet == nullptr) { + m_lastErrorMessage = L"Failed to allocate the sort descriptor set."; + return false; + } + + m_sortDescriptorSet->Update(1, m_gaussianPositionView.get()); + m_sortDescriptorSet->Update(2, m_orderBufferSrv.get()); + m_sortDescriptorSet->Update(3, m_sortKeyUav.get()); + + ShaderCompileDesc computeShaderDesc = {}; + const std::filesystem::path compiledShaderPath = ResolveShaderPath(L"BuildSortKeysCS.dxil"); + const std::vector compiledShaderBinary = LoadBinaryFile(compiledShaderPath); + if (!compiledShaderBinary.empty()) { + computeShaderDesc.profile = L"cs_6_6"; + computeShaderDesc.compiledBinaryBackend = ShaderBinaryBackend::D3D12; + computeShaderDesc.compiledBinary = compiledShaderBinary; + } else { + computeShaderDesc.fileName = ResolveShaderPath(L"BuildSortKeysCS.hlsl").wstring(); + computeShaderDesc.entryPoint = L"MainCS"; + computeShaderDesc.profile = L"cs_5_0"; + } + + ComputePipelineDesc pipelineDesc = {}; + pipelineDesc.pipelineLayout = m_sortPipelineLayout; + pipelineDesc.computeShader = computeShaderDesc; + m_sortPipelineState = m_device.CreateComputePipelineState(pipelineDesc); + if (m_sortPipelineState == nullptr) { + m_lastErrorMessage = L"Failed to create the sort pipeline state."; + return false; + } + + m_commandList.Close(); + void* commandLists[] = { &m_commandList }; + m_commandQueue.ExecuteCommandLists(1, commandLists); + m_commandQueue.WaitForIdle(); + + return true; +} + +bool App::InitializeDebugDrawResources() { + DescriptorSetLayoutBinding bindings[3] = {}; + bindings[0].binding = 0; + bindings[0].type = static_cast(DescriptorType::CBV); + bindings[0].count = 1; + bindings[0].visibility = static_cast(ShaderVisibility::All); + + bindings[1].binding = 1; + bindings[1].type = static_cast(DescriptorType::SRV); + bindings[1].count = 1; + bindings[1].visibility = static_cast(ShaderVisibility::All); + bindings[1].resourceDimension = ResourceViewDimension::StructuredBuffer; + + bindings[2].binding = 2; + bindings[2].type = static_cast(DescriptorType::SRV); + bindings[2].count = 1; + bindings[2].visibility = static_cast(ShaderVisibility::All); + bindings[2].resourceDimension = ResourceViewDimension::StructuredBuffer; + + DescriptorSetLayoutDesc setLayout = {}; + setLayout.bindings = bindings; + setLayout.bindingCount = 3; + + RHIPipelineLayoutDesc pipelineLayoutDesc = {}; + pipelineLayoutDesc.setLayouts = &setLayout; + pipelineLayoutDesc.setLayoutCount = 1; + + m_debugPipelineLayout = m_device.CreatePipelineLayout(pipelineLayoutDesc); + if (m_debugPipelineLayout == nullptr) { + m_lastErrorMessage = L"Failed to create the debug draw pipeline layout."; + return false; + } + + DescriptorPoolDesc poolDesc = {}; + poolDesc.type = DescriptorHeapType::CBV_SRV_UAV; + poolDesc.descriptorCount = 2; + poolDesc.shaderVisible = true; + m_debugDescriptorPool = m_device.CreateDescriptorPool(poolDesc); + if (m_debugDescriptorPool == nullptr) { + m_lastErrorMessage = L"Failed to create the debug draw descriptor pool."; + return false; + } + + m_debugDescriptorSet = m_debugDescriptorPool->AllocateSet(setLayout); + if (m_debugDescriptorSet == nullptr) { + m_lastErrorMessage = L"Failed to allocate the debug draw descriptor set."; + return false; + } + m_debugDescriptorSet->Update(1, m_preparedViewSrv.get()); + m_debugDescriptorSet->Update(2, m_orderBufferSrv.get()); + + GraphicsPipelineDesc pipelineDesc = {}; + pipelineDesc.pipelineLayout = m_debugPipelineLayout; + pipelineDesc.topologyType = static_cast(PrimitiveTopologyType::Point); + pipelineDesc.renderTargetCount = 1; + pipelineDesc.renderTargetFormats[0] = static_cast(Format::R8G8B8A8_UNorm); + pipelineDesc.depthStencilFormat = static_cast(Format::D24_UNorm_S8_UInt); + pipelineDesc.sampleCount = 1; + pipelineDesc.sampleQuality = 0; + pipelineDesc.rasterizerState.cullMode = static_cast(CullMode::None); + pipelineDesc.depthStencilState.depthTestEnable = false; + pipelineDesc.depthStencilState.depthWriteEnable = false; + + pipelineDesc.vertexShader.fileName = ResolveShaderPath(L"DebugPointsVS.hlsl").wstring(); + pipelineDesc.vertexShader.entryPoint = L"MainVS"; + pipelineDesc.vertexShader.profile = L"vs_5_0"; + + pipelineDesc.fragmentShader.fileName = ResolveShaderPath(L"DebugPointsPS.hlsl").wstring(); + pipelineDesc.fragmentShader.entryPoint = L"MainPS"; + pipelineDesc.fragmentShader.profile = L"ps_5_0"; + + m_debugPipelineState = m_device.CreatePipelineState(pipelineDesc); + if (m_debugPipelineState == nullptr) { + m_lastErrorMessage = L"Failed to create the debug draw pipeline state."; + return false; + } + + return true; +} + void App::ShutdownGaussianGpuResources() { m_gaussianColorView.reset(); m_gaussianShView.reset(); @@ -388,16 +909,78 @@ void App::ShutdownGaussianGpuResources() { m_gaussianPositionBuffer.Shutdown(); } +void App::ShutdownPreparePassResources() { + if (m_prepareDescriptorSet != nullptr) { + m_prepareDescriptorSet->Shutdown(); + delete m_prepareDescriptorSet; + m_prepareDescriptorSet = nullptr; + } + ShutdownAndDelete(m_prepareDescriptorPool); + ShutdownAndDelete(m_preparePipelineState); + ShutdownAndDelete(m_preparePipelineLayout); + m_preparedViewUav.reset(); + m_preparedViewSrv.reset(); + if (m_preparedViewBuffer != nullptr) { + m_preparedViewBuffer->Shutdown(); + delete m_preparedViewBuffer; + m_preparedViewBuffer = nullptr; + } +} + +void App::ShutdownSortResources() { + if (m_sortDescriptorSet != nullptr) { + m_sortDescriptorSet->Shutdown(); + delete m_sortDescriptorSet; + m_sortDescriptorSet = nullptr; + } + ShutdownAndDelete(m_sortDescriptorPool); + ShutdownAndDelete(m_sortPipelineState); + ShutdownAndDelete(m_sortPipelineLayout); + m_orderBufferSrv.reset(); + m_sortKeyUav.reset(); + m_sortKeySrv.reset(); + + if (m_orderBuffer != nullptr) { + m_orderBuffer->Shutdown(); + delete m_orderBuffer; + m_orderBuffer = nullptr; + } + + if (m_sortKeyBuffer != nullptr) { + m_sortKeyBuffer->Shutdown(); + delete m_sortKeyBuffer; + m_sortKeyBuffer = nullptr; + } +} + +void App::ShutdownDebugDrawResources() { + if (m_debugDescriptorSet != nullptr) { + m_debugDescriptorSet->Shutdown(); + delete m_debugDescriptorSet; + m_debugDescriptorSet = nullptr; + } + ShutdownAndDelete(m_debugDescriptorPool); + ShutdownAndDelete(m_debugPipelineState); + ShutdownAndDelete(m_debugPipelineLayout); +} + void App::Shutdown() { + AppendTrace("Shutdown: begin"); if (!m_isInitialized && m_hwnd == nullptr) { + AppendTrace("Shutdown: skipped"); return; } m_running = false; if (m_commandQueue.GetCommandQueue() != nullptr) { + AppendTrace("Shutdown: WaitForIdle"); m_commandQueue.WaitForIdle(); } + + ShutdownDebugDrawResources(); + ShutdownSortResources(); + ShutdownPreparePassResources(); ShutdownGaussianGpuResources(); m_commandList.Shutdown(); m_commandAllocator.Shutdown(); @@ -413,6 +996,7 @@ void App::Shutdown() { m_device.Shutdown(); if (m_hwnd != nullptr) { + AppendTrace("Shutdown: DestroyWindow"); DestroyWindow(m_hwnd); m_hwnd = nullptr; } @@ -423,10 +1007,85 @@ void App::Shutdown() { } m_isInitialized = false; + AppendTrace("Shutdown: end"); } -void App::RenderFrame() { +bool App::CaptureSortKeySnapshot() { + if (m_sortKeyBuffer == nullptr || m_gaussianSceneData.splatCount == 0 || m_sortKeySnapshotPath.empty()) { + return true; + } + + const uint32_t sampleCount = std::min(16u, m_gaussianSceneData.splatCount); + const uint64_t sampleBytes = static_cast(sampleCount) * sizeof(uint32_t); + + D3D12Buffer readbackBuffer; + if (!readbackBuffer.Initialize( + m_device.GetDevice(), + sampleBytes, + D3D12_RESOURCE_STATE_COPY_DEST, + D3D12_HEAP_TYPE_READBACK, + D3D12_RESOURCE_FLAG_NONE)) { + m_lastErrorMessage = L"Failed to create the sort key readback buffer."; + return false; + } + + m_commandAllocator.Reset(); + m_commandList.Reset(); + + if (m_sortKeyBuffer->GetState() != ResourceStates::CopySrc) { + m_commandList.TransitionBarrier( + m_sortKeyBuffer->GetResource(), + m_sortKeyBuffer->GetState(), + ResourceStates::CopySrc); + m_sortKeyBuffer->SetState(ResourceStates::CopySrc); + } + + m_commandList.GetCommandList()->CopyBufferRegion( + readbackBuffer.GetResource(), + 0, + m_sortKeyBuffer->GetResource(), + 0, + sampleBytes); + + m_commandList.Close(); + void* commandLists[] = { &m_commandList }; + m_commandQueue.ExecuteCommandLists(1, commandLists); + m_commandQueue.WaitForIdle(); + + const uint32_t* keys = static_cast(readbackBuffer.Map()); + if (keys == nullptr) { + m_lastErrorMessage = L"Failed to map the sort key readback buffer."; + readbackBuffer.Shutdown(); + return false; + } + + const std::filesystem::path snapshotPath = ResolveNearExecutable(m_sortKeySnapshotPath); + if (!snapshotPath.parent_path().empty()) { + std::filesystem::create_directories(snapshotPath.parent_path()); + } + + std::ofstream output(snapshotPath, std::ios::binary | std::ios::trunc); + if (!output.is_open()) { + readbackBuffer.Unmap(); + readbackBuffer.Shutdown(); + m_lastErrorMessage = L"Failed to open the sort key snapshot output file."; + return false; + } + + output << "sample_count=" << sampleCount << '\n'; + for (uint32_t index = 0; index < sampleCount; ++index) { + output << "key[" << index << "]=" << keys[index] << '\n'; + } + + readbackBuffer.Unmap(); + readbackBuffer.Shutdown(); + return output.good(); +} + +void App::RenderFrame(bool captureScreenshot) { + AppendTrace(captureScreenshot ? "RenderFrame: begin capture" : "RenderFrame: begin"); if (m_hasRenderedAtLeastOneFrame) { + AppendTrace("RenderFrame: WaitForPreviousFrame"); m_commandQueue.WaitForPreviousFrame(); } @@ -459,14 +1118,99 @@ void App::RenderFrame() { 0, nullptr); - m_commandList.TransitionBarrier(backBuffer.GetResource(), ResourceStates::RenderTarget, ResourceStates::Present); - m_commandList.Close(); + const FrameConstants frameConstants = BuildFrameConstants( + static_cast(m_width), + static_cast(m_height), + m_gaussianSceneData.splatCount); + m_prepareDescriptorSet->WriteConstant(0, &frameConstants, sizeof(frameConstants)); + m_debugDescriptorSet->WriteConstant(0, &frameConstants, sizeof(frameConstants)); - void* commandLists[] = { &m_commandList }; - m_commandQueue.ExecuteCommandLists(1, commandLists); - m_swapChain.Present(1, 0); + if (m_preparedViewBuffer->GetState() != ResourceStates::UnorderedAccess) { + m_commandList.TransitionBarrier( + m_preparedViewBuffer->GetResource(), + m_preparedViewBuffer->GetState(), + ResourceStates::UnorderedAccess); + m_preparedViewBuffer->SetState(ResourceStates::UnorderedAccess); + } + + m_commandList.SetPipelineState(m_preparePipelineState); + RHIDescriptorSet* prepareSets[] = { m_prepareDescriptorSet }; + m_commandList.SetComputeDescriptorSets(0, 1, prepareSets, m_preparePipelineLayout); + m_commandList.Dispatch((m_gaussianSceneData.splatCount + (kPrepareThreadGroupSize - 1)) / kPrepareThreadGroupSize, 1, 1); + m_commandList.UAVBarrier(m_preparedViewBuffer->GetResource()); + + if (m_sortKeyBuffer->GetState() != ResourceStates::UnorderedAccess) { + m_commandList.TransitionBarrier( + m_sortKeyBuffer->GetResource(), + m_sortKeyBuffer->GetState(), + ResourceStates::UnorderedAccess); + m_sortKeyBuffer->SetState(ResourceStates::UnorderedAccess); + } + + m_sortDescriptorSet->WriteConstant(0, &frameConstants, sizeof(frameConstants)); + m_commandList.SetPipelineState(m_sortPipelineState); + RHIDescriptorSet* sortSets[] = { m_sortDescriptorSet }; + m_commandList.SetComputeDescriptorSets(0, 1, sortSets, m_sortPipelineLayout); + m_commandList.Dispatch((m_gaussianSceneData.splatCount + (kSortThreadGroupSize - 1)) / kSortThreadGroupSize, 1, 1); + m_commandList.UAVBarrier(m_sortKeyBuffer->GetResource()); + + m_commandList.TransitionBarrier( + m_preparedViewBuffer->GetResource(), + ResourceStates::UnorderedAccess, + ResourceStates::NonPixelShaderResource); + m_preparedViewBuffer->SetState(ResourceStates::NonPixelShaderResource); + + m_commandList.SetPipelineState(m_debugPipelineState); + RHIDescriptorSet* debugSets[] = { m_debugDescriptorSet }; + m_commandList.SetGraphicsDescriptorSets(0, 1, debugSets, m_debugPipelineLayout); + m_commandList.SetPrimitiveTopology(PrimitiveTopology::PointList); + m_commandList.Draw(1u, m_gaussianSceneData.splatCount, 0u, 0u); + + if (captureScreenshot) { + AppendTrace("RenderFrame: close+execute capture pre-screenshot"); + m_commandList.Close(); + void* commandLists[] = { &m_commandList }; + m_commandQueue.ExecuteCommandLists(1, commandLists); + AppendTrace("RenderFrame: WaitForIdle before screenshot"); + m_commandQueue.WaitForIdle(); + + AppendTrace("RenderFrame: Capture sort key snapshot"); + if (!CaptureSortKeySnapshot()) { + AppendTrace(std::string("RenderFrame: Capture sort key snapshot failed: ") + NarrowAscii(m_lastErrorMessage)); + } + + if (!m_screenshotPath.empty()) { + const std::filesystem::path screenshotPath = ResolveNearExecutable(m_screenshotPath); + if (!screenshotPath.parent_path().empty()) { + std::filesystem::create_directories(screenshotPath.parent_path()); + } + AppendTrace("RenderFrame: Capture screenshot"); + D3D12Screenshot::Capture( + m_device, + m_commandQueue, + backBuffer, + screenshotPath.string().c_str()); + AppendTrace("RenderFrame: Capture screenshot done"); + } + + m_commandAllocator.Reset(); + m_commandList.Reset(); + m_commandList.TransitionBarrier(backBuffer.GetResource(), ResourceStates::RenderTarget, ResourceStates::Present); + m_commandList.Close(); + void* presentCommandLists[] = { &m_commandList }; + AppendTrace("RenderFrame: execute final present-transition list"); + m_commandQueue.ExecuteCommandLists(1, presentCommandLists); + } else { + m_commandList.TransitionBarrier(backBuffer.GetResource(), ResourceStates::RenderTarget, ResourceStates::Present); + m_commandList.Close(); + void* commandLists[] = { &m_commandList }; + AppendTrace("RenderFrame: execute+present"); + m_commandQueue.ExecuteCommandLists(1, commandLists); + m_swapChain.Present(1, 0); + } m_hasRenderedAtLeastOneFrame = true; + AppendTrace("RenderFrame: end"); } } // namespace XC3DGSD3D12 diff --git a/MVS/3DGS-D3D12/src/GaussianPlyLoader.cpp b/MVS/3DGS-D3D12/src/GaussianPlyLoader.cpp index 0e563d83..a4fa1d02 100644 --- a/MVS/3DGS-D3D12/src/GaussianPlyLoader.cpp +++ b/MVS/3DGS-D3D12/src/GaussianPlyLoader.cpp @@ -53,6 +53,15 @@ struct RawGaussianSplat { Float4 rotation = {}; }; +struct GaussianPlyPropertyLayout { + const PlyProperty* position[3] = {}; + const PlyProperty* dc0[3] = {}; + const PlyProperty* opacity = nullptr; + const PlyProperty* scale[3] = {}; + const PlyProperty* rotation[4] = {}; + std::array sh = {}; +}; + std::string TrimTrailingCarriageReturn(std::string line) { if (!line.empty() && line.back() == '\r') { line.pop_back(); @@ -243,6 +252,39 @@ bool RequireProperty( return true; } +bool BuildGaussianPlyPropertyLayout( + const std::unordered_map& propertyMap, + GaussianPlyPropertyLayout& outLayout, + std::string& outErrorMessage) { + outLayout = {}; + + if (!RequireProperty(propertyMap, "x", outLayout.position[0], outErrorMessage) || + !RequireProperty(propertyMap, "y", outLayout.position[1], outErrorMessage) || + !RequireProperty(propertyMap, "z", outLayout.position[2], outErrorMessage) || + !RequireProperty(propertyMap, "f_dc_0", outLayout.dc0[0], outErrorMessage) || + !RequireProperty(propertyMap, "f_dc_1", outLayout.dc0[1], outErrorMessage) || + !RequireProperty(propertyMap, "f_dc_2", outLayout.dc0[2], outErrorMessage) || + !RequireProperty(propertyMap, "opacity", outLayout.opacity, outErrorMessage) || + !RequireProperty(propertyMap, "scale_0", outLayout.scale[0], outErrorMessage) || + !RequireProperty(propertyMap, "scale_1", outLayout.scale[1], outErrorMessage) || + !RequireProperty(propertyMap, "scale_2", outLayout.scale[2], outErrorMessage) || + !RequireProperty(propertyMap, "rot_0", outLayout.rotation[0], outErrorMessage) || + !RequireProperty(propertyMap, "rot_1", outLayout.rotation[1], outErrorMessage) || + !RequireProperty(propertyMap, "rot_2", outLayout.rotation[2], outErrorMessage) || + !RequireProperty(propertyMap, "rot_3", outLayout.rotation[3], outErrorMessage)) { + return false; + } + + for (uint32_t index = 0; index < outLayout.sh.size(); ++index) { + const std::string propertyName = "f_rest_" + std::to_string(index); + if (!RequireProperty(propertyMap, propertyName, outLayout.sh[index], outErrorMessage)) { + return false; + } + } + + return true; +} + Float3 Min(const Float3& a, const Float3& b) { return { std::min(a.x, b.x), @@ -390,32 +432,31 @@ void WriteFloat4(std::vector& bytes, size_t offset, float x, float y, bool ReadGaussianSplat( const std::byte* vertexBytes, - const std::unordered_map& propertyMap, + const GaussianPlyPropertyLayout& propertyLayout, RawGaussianSplat& outSplat, std::string& outErrorMessage) { - const PlyProperty* property = nullptr; - - auto readFloat = [&](std::string_view name, float& outValue) -> bool { - if (!RequireProperty(propertyMap, name, property, outErrorMessage)) { + auto readFloat = [&](const PlyProperty* property, float& outValue) -> bool { + if (property == nullptr) { + outErrorMessage = "Gaussian PLY property layout is incomplete."; return false; } return ReadPropertyAsFloat(vertexBytes, *property, outValue); }; - if (!readFloat("x", outSplat.position.x) || - !readFloat("y", outSplat.position.y) || - !readFloat("z", outSplat.position.z) || - !readFloat("f_dc_0", outSplat.dc0.x) || - !readFloat("f_dc_1", outSplat.dc0.y) || - !readFloat("f_dc_2", outSplat.dc0.z) || - !readFloat("opacity", outSplat.opacity) || - !readFloat("scale_0", outSplat.scale.x) || - !readFloat("scale_1", outSplat.scale.y) || - !readFloat("scale_2", outSplat.scale.z) || - !readFloat("rot_0", outSplat.rotation.x) || - !readFloat("rot_1", outSplat.rotation.y) || - !readFloat("rot_2", outSplat.rotation.z) || - !readFloat("rot_3", outSplat.rotation.w)) { + if (!readFloat(propertyLayout.position[0], outSplat.position.x) || + !readFloat(propertyLayout.position[1], outSplat.position.y) || + !readFloat(propertyLayout.position[2], outSplat.position.z) || + !readFloat(propertyLayout.dc0[0], outSplat.dc0.x) || + !readFloat(propertyLayout.dc0[1], outSplat.dc0.y) || + !readFloat(propertyLayout.dc0[2], outSplat.dc0.z) || + !readFloat(propertyLayout.opacity, outSplat.opacity) || + !readFloat(propertyLayout.scale[0], outSplat.scale.x) || + !readFloat(propertyLayout.scale[1], outSplat.scale.y) || + !readFloat(propertyLayout.scale[2], outSplat.scale.z) || + !readFloat(propertyLayout.rotation[0], outSplat.rotation.x) || + !readFloat(propertyLayout.rotation[1], outSplat.rotation.y) || + !readFloat(propertyLayout.rotation[2], outSplat.rotation.z) || + !readFloat(propertyLayout.rotation[3], outSplat.rotation.w)) { if (outErrorMessage.empty()) { outErrorMessage = "Failed to read required Gaussian splat PLY properties."; } @@ -424,8 +465,7 @@ bool ReadGaussianSplat( std::array shRaw = {}; for (uint32_t index = 0; index < shRaw.size(); ++index) { - const std::string propertyName = "f_rest_" + std::to_string(index); - if (!readFloat(propertyName, shRaw[index])) { + if (!readFloat(propertyLayout.sh[index], shRaw[index])) { if (outErrorMessage.empty()) { outErrorMessage = "Failed to read SH rest coefficients from PLY."; } @@ -478,6 +518,11 @@ bool LoadGaussianSceneFromPly( return false; } + GaussianPlyPropertyLayout propertyLayout; + if (!BuildGaussianPlyPropertyLayout(propertyMap, propertyLayout, outErrorMessage)) { + return false; + } + outData.splatCount = header.vertexCount; outData.colorTextureWidth = GaussianSplatRuntimeData::kColorTextureWidth; outData.colorTextureHeight = @@ -513,7 +558,7 @@ bool LoadGaussianSceneFromPly( } RawGaussianSplat splat; - if (!ReadGaussianSplat(vertexBytes.data(), propertyMap, splat, outErrorMessage)) { + if (!ReadGaussianSplat(vertexBytes.data(), propertyLayout, splat, outErrorMessage)) { return false; } diff --git a/MVS/3DGS-D3D12/src/main.cpp b/MVS/3DGS-D3D12/src/main.cpp index d2394380..0ccb3643 100644 --- a/MVS/3DGS-D3D12/src/main.cpp +++ b/MVS/3DGS-D3D12/src/main.cpp @@ -20,6 +20,9 @@ int WINAPI wWinMain(HINSTANCE instance, HINSTANCE, PWSTR, int showCommand) { } else if (std::wstring(arguments[index]) == L"--summary-file") { app.SetSummaryPath(arguments[index + 1]); ++index; + } else if (std::wstring(arguments[index]) == L"--screenshot-file") { + app.SetScreenshotPath(arguments[index + 1]); + ++index; } } if (arguments != nullptr) {