Stabilize 3DGS D3D12 phase 3 and sort key setup

This commit is contained in:
2026-04-13 02:23:39 +08:00
parent 1d6f2e290d
commit b7428b0ef1
10 changed files with 1281 additions and 33 deletions

View File

@@ -5,6 +5,8 @@ project(XC3DGSD3D12MVS LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
find_program(XC_DXC_EXECUTABLE NAMES dxc)
get_filename_component(XCENGINE_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/../.." ABSOLUTE) get_filename_component(XCENGINE_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/../.." ABSOLUTE)
set(XCENGINE_BUILD_DIR "${XCENGINE_ROOT}/build") set(XCENGINE_BUILD_DIR "${XCENGINE_ROOT}/build")
set(XCENGINE_INCLUDE_DIR "${XCENGINE_ROOT}/engine/include") set(XCENGINE_INCLUDE_DIR "${XCENGINE_ROOT}/engine/include")
@@ -30,6 +32,21 @@ add_executable(xc_3dgs_d3d12_mvs
src/GaussianPlyLoader.cpp src/GaussianPlyLoader.cpp
include/XC3DGSD3D12/App.h include/XC3DGSD3D12/App.h
include/XC3DGSD3D12/GaussianPlyLoader.h include/XC3DGSD3D12/GaussianPlyLoader.h
shaders/PreparedSplatView.hlsli
shaders/PrepareGaussiansCS.hlsl
shaders/BuildSortKeysCS.hlsl
shaders/DebugPointsVS.hlsl
shaders/DebugPointsPS.hlsl
)
set_source_files_properties(
shaders/PreparedSplatView.hlsli
shaders/PrepareGaussiansCS.hlsl
shaders/BuildSortKeysCS.hlsl
shaders/DebugPointsVS.hlsl
shaders/DebugPointsPS.hlsl
PROPERTIES
HEADER_FILE_ONLY TRUE
) )
target_include_directories(xc_3dgs_d3d12_mvs PRIVATE target_include_directories(xc_3dgs_d3d12_mvs PRIVATE
@@ -69,3 +86,19 @@ add_custom_command(TARGET xc_3dgs_d3d12_mvs POST_BUILD
"${CMAKE_CURRENT_SOURCE_DIR}/room.ply" "${CMAKE_CURRENT_SOURCE_DIR}/room.ply"
"$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/room.ply" "$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/room.ply"
) )
add_custom_command(TARGET xc_3dgs_d3d12_mvs POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_directory
"${CMAKE_CURRENT_SOURCE_DIR}/shaders"
"$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/shaders"
)
if(XC_DXC_EXECUTABLE)
add_custom_command(TARGET xc_3dgs_d3d12_mvs POST_BUILD
COMMAND "${XC_DXC_EXECUTABLE}"
-T cs_6_6
-E MainCS
-Fo "$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/shaders/BuildSortKeysCS.dxil"
"${CMAKE_CURRENT_SOURCE_DIR}/shaders/BuildSortKeysCS.hlsl"
)
endif()

View File

@@ -20,8 +20,24 @@
#include "XCEngine/RHI/D3D12/D3D12SwapChain.h" #include "XCEngine/RHI/D3D12/D3D12SwapChain.h"
#include "XCEngine/RHI/D3D12/D3D12Texture.h" #include "XCEngine/RHI/D3D12/D3D12Texture.h"
namespace XCEngine {
namespace RHI {
class RHIDescriptorPool;
class RHIDescriptorSet;
class RHIPipelineLayout;
class RHIPipelineState;
} // namespace RHI
} // namespace XCEngine
namespace XC3DGSD3D12 { namespace XC3DGSD3D12 {
struct PreparedSplatView {
float clipPosition[4] = {};
float axis1[2] = {};
float axis2[2] = {};
uint32_t packedColor[2] = {};
};
class App { class App {
public: public:
App(); App();
@@ -32,6 +48,7 @@ public:
void SetFrameLimit(unsigned int frameLimit); void SetFrameLimit(unsigned int frameLimit);
void SetGaussianScenePath(std::wstring scenePath); void SetGaussianScenePath(std::wstring scenePath);
void SetSummaryPath(std::wstring summaryPath); void SetSummaryPath(std::wstring summaryPath);
void SetScreenshotPath(std::wstring screenshotPath);
const std::wstring& GetLastErrorMessage() const; const std::wstring& GetLastErrorMessage() const;
private: private:
@@ -47,9 +64,16 @@ private:
bool LoadGaussianScene(); bool LoadGaussianScene();
bool InitializeRhi(); bool InitializeRhi();
bool InitializeGaussianGpuResources(); bool InitializeGaussianGpuResources();
bool InitializePreparePassResources();
bool InitializeSortResources();
bool InitializeDebugDrawResources();
void ShutdownGaussianGpuResources(); void ShutdownGaussianGpuResources();
void ShutdownPreparePassResources();
void ShutdownSortResources();
void ShutdownDebugDrawResources();
void Shutdown(); void Shutdown();
void RenderFrame(); bool CaptureSortKeySnapshot();
void RenderFrame(bool captureScreenshot);
HWND m_hwnd = nullptr; HWND m_hwnd = nullptr;
HINSTANCE m_instance = nullptr; HINSTANCE m_instance = nullptr;
@@ -62,6 +86,8 @@ private:
unsigned int m_renderedFrameCount = 0; unsigned int m_renderedFrameCount = 0;
std::wstring m_gaussianScenePath = L"room.ply"; std::wstring m_gaussianScenePath = L"room.ply";
std::wstring m_summaryPath; std::wstring m_summaryPath;
std::wstring m_screenshotPath = L"phase3_debug_points.ppm";
std::wstring m_sortKeySnapshotPath = L"phase3_sortkeys.txt";
std::wstring m_lastErrorMessage; std::wstring m_lastErrorMessage;
GaussianSplatRuntimeData m_gaussianSceneData; GaussianSplatRuntimeData m_gaussianSceneData;
XCEngine::RHI::D3D12Buffer m_gaussianPositionBuffer; XCEngine::RHI::D3D12Buffer m_gaussianPositionBuffer;
@@ -73,6 +99,26 @@ private:
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_gaussianShView; std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_gaussianShView;
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_gaussianColorView; std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_gaussianColorView;
std::vector<Microsoft::WRL::ComPtr<ID3D12Resource>> m_gaussianUploadBuffers; std::vector<Microsoft::WRL::ComPtr<ID3D12Resource>> m_gaussianUploadBuffers;
XCEngine::RHI::D3D12Buffer* m_preparedViewBuffer = nullptr;
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_preparedViewSrv;
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_preparedViewUav;
XCEngine::RHI::RHIPipelineLayout* m_preparePipelineLayout = nullptr;
XCEngine::RHI::RHIPipelineState* m_preparePipelineState = nullptr;
XCEngine::RHI::RHIDescriptorPool* m_prepareDescriptorPool = nullptr;
XCEngine::RHI::RHIDescriptorSet* m_prepareDescriptorSet = nullptr;
XCEngine::RHI::D3D12Buffer* m_sortKeyBuffer = nullptr;
XCEngine::RHI::D3D12Buffer* m_orderBuffer = nullptr;
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_sortKeySrv;
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_sortKeyUav;
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_orderBufferSrv;
XCEngine::RHI::RHIPipelineLayout* m_sortPipelineLayout = nullptr;
XCEngine::RHI::RHIPipelineState* m_sortPipelineState = nullptr;
XCEngine::RHI::RHIDescriptorPool* m_sortDescriptorPool = nullptr;
XCEngine::RHI::RHIDescriptorSet* m_sortDescriptorSet = nullptr;
XCEngine::RHI::RHIPipelineLayout* m_debugPipelineLayout = nullptr;
XCEngine::RHI::RHIPipelineState* m_debugPipelineState = nullptr;
XCEngine::RHI::RHIDescriptorPool* m_debugDescriptorPool = nullptr;
XCEngine::RHI::RHIDescriptorSet* m_debugDescriptorSet = nullptr;
XCEngine::RHI::D3D12Device m_device; XCEngine::RHI::D3D12Device m_device;
XCEngine::RHI::D3D12CommandQueue m_commandQueue; XCEngine::RHI::D3D12CommandQueue m_commandQueue;

View File

@@ -0,0 +1,43 @@
#define GROUP_SIZE 64
cbuffer FrameConstants : register(b0)
{
float4x4 gViewProjection;
float4x4 gView;
float4x4 gProjection;
float4 gCameraWorldPos;
float4 gScreenParams;
float4 gSettings;
};
ByteAddressBuffer gPositions : register(t0);
StructuredBuffer<uint> gOrderBuffer : register(t1);
RWStructuredBuffer<uint> gSortKeys : register(u0);
float3 LoadFloat3(ByteAddressBuffer buffer, uint byteOffset)
{
return asfloat(buffer.Load3(byteOffset));
}
uint FloatToSortableUint(float value)
{
uint bits = asuint(value);
uint mask = (0u - (bits >> 31)) | 0x80000000u;
return bits ^ mask;
}
[numthreads(GROUP_SIZE, 1, 1)]
void MainCS(uint3 dispatchThreadId : SV_DispatchThreadID)
{
uint index = dispatchThreadId.x;
uint splatCount = (uint)gSettings.x;
if (index >= splatCount)
{
return;
}
uint splatIndex = gOrderBuffer[index];
float3 position = LoadFloat3(gPositions, splatIndex * 12);
float3 viewPosition = mul(float4(position, 1.0), gView).xyz;
gSortKeys[index] = FloatToSortableUint(viewPosition.z);
}

View File

@@ -0,0 +1,4 @@
float4 MainPS(float4 position : SV_Position, float4 color : COLOR0) : SV_Target0
{
return color;
}

View File

@@ -0,0 +1,39 @@
#include "PreparedSplatView.hlsli"
cbuffer FrameConstants : register(b0)
{
float4x4 gViewProjection;
float4x4 gView;
float4x4 gProjection;
float4 gCameraWorldPos;
float4 gScreenParams;
float4 gSettings;
};
StructuredBuffer<PreparedSplatView> gPreparedViews : register(t0);
StructuredBuffer<uint> gOrderBuffer : register(t1);
struct VertexOutput
{
float4 position : SV_Position;
float4 color : COLOR0;
};
VertexOutput MainVS(uint vertexId : SV_VertexID, uint instanceId : SV_InstanceID)
{
VertexOutput output;
uint splatIndex = gOrderBuffer[instanceId];
PreparedSplatView view = gPreparedViews[splatIndex];
float4 color = UnpackPreparedColor(view);
if (view.clipPosition.w <= 0.0)
{
output.position = float4(2.0, 2.0, 2.0, 1.0);
output.color = 0.0;
return output;
}
output.position = view.clipPosition;
output.color = float4(color.rgb, 1.0);
return output;
}

View File

@@ -0,0 +1,269 @@
#include "PreparedSplatView.hlsli"
#define GROUP_SIZE 64
cbuffer FrameConstants : register(b0)
{
float4x4 gViewProjection;
float4x4 gView;
float4x4 gProjection;
float4 gCameraWorldPos;
float4 gScreenParams;
float4 gSettings;
};
ByteAddressBuffer gPositions : register(t0);
ByteAddressBuffer gOther : register(t1);
Texture2D<float4> gColor : register(t2);
ByteAddressBuffer gSh : register(t3);
RWStructuredBuffer<PreparedSplatView> gPreparedViews : register(u0);
static const float SH_C1 = 0.4886025;
static const float SH_C2[] = { 1.0925484, -1.0925484, 0.3153916, -1.0925484, 0.5462742 };
static const float SH_C3[] = { -0.5900436, 2.8906114, -0.4570458, 0.3731763, -0.4570458, 1.4453057, -0.5900436 };
static const uint kColorTextureWidth = 2048;
static const uint kOtherStride = 16;
static const uint kShStride = 192;
struct SplatSHData
{
float3 col;
float3 sh[15];
};
float3 LoadFloat3(ByteAddressBuffer buffer, uint byteOffset)
{
return asfloat(buffer.Load3(byteOffset));
}
uint EncodeMorton2D_16x16(uint2 c)
{
uint t = ((c.y & 0xF) << 8) | (c.x & 0xF);
t = (t ^ (t << 2)) & 0x3333;
t = (t ^ (t << 1)) & 0x5555;
return (t | (t >> 7)) & 0xFF;
}
uint2 DecodeMorton2D_16x16(uint t)
{
t = (t & 0xFF) | ((t & 0xFE) << 7);
t &= 0x5555;
t = (t ^ (t >> 1)) & 0x3333;
t = (t ^ (t >> 2)) & 0x0F0F;
return uint2(t & 0xF, t >> 8);
}
uint3 SplatIndexToPixelIndex(uint index)
{
uint2 xy = DecodeMorton2D_16x16(index);
uint tileWidth = kColorTextureWidth / 16;
index >>= 8;
uint3 result;
result.x = (index % tileWidth) * 16 + xy.x;
result.y = (index / tileWidth) * 16 + xy.y;
result.z = 0;
return result;
}
float4 DecodePacked_10_10_10_2(uint encoded)
{
return float4(
(encoded & 1023) / 1023.0,
((encoded >> 10) & 1023) / 1023.0,
((encoded >> 20) & 1023) / 1023.0,
((encoded >> 30) & 3) / 3.0);
}
float4 DecodeRotation(float4 packedRotation)
{
uint droppedIndex = (uint)round(packedRotation.w * 3.0);
float4 rotation;
rotation.xyz = packedRotation.xyz * sqrt(2.0) - (1.0 / sqrt(2.0));
rotation.w = sqrt(1.0 - saturate(dot(rotation.xyz, rotation.xyz)));
if (droppedIndex == 0)
{
rotation = rotation.wxyz;
}
if (droppedIndex == 1)
{
rotation = rotation.xwyz;
}
if (droppedIndex == 2)
{
rotation = rotation.xywz;
}
return rotation;
}
float3x3 CalcMatrixFromRotationScale(float4 rotation, float3 scale)
{
float3x3 scaleMatrix = float3x3(
scale.x, 0, 0,
0, scale.y, 0,
0, 0, scale.z);
float x = rotation.x;
float y = rotation.y;
float z = rotation.z;
float w = rotation.w;
float3x3 rotationMatrix = float3x3(
1 - 2 * (y * y + z * z), 2 * (x * y - w * z), 2 * (x * z + w * y),
2 * (x * y + w * z), 1 - 2 * (x * x + z * z), 2 * (y * z - w * x),
2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x * x + y * y));
return mul(rotationMatrix, scaleMatrix);
}
void CalcCovariance3D(float3x3 rotationScaleMatrix, out float3 sigma0, out float3 sigma1)
{
float3x3 sigma = mul(rotationScaleMatrix, transpose(rotationScaleMatrix));
sigma0 = float3(sigma._m00, sigma._m01, sigma._m02);
sigma1 = float3(sigma._m11, sigma._m12, sigma._m22);
}
float3 CalcCovariance2D(float3 worldPosition, float3 covariance0, float3 covariance1)
{
float3 viewPosition = mul(float4(worldPosition, 1.0), gView).xyz;
float aspect = gProjection._m00 / gProjection._m11;
float tanFovX = rcp(gProjection._m00);
float tanFovY = rcp(gProjection._m11 * aspect);
float clampX = 1.3 * tanFovX;
float clampY = 1.3 * tanFovY;
viewPosition.x = clamp(viewPosition.x / viewPosition.z, -clampX, clampX) * viewPosition.z;
viewPosition.y = clamp(viewPosition.y / viewPosition.z, -clampY, clampY) * viewPosition.z;
float focal = gScreenParams.x * gProjection._m00 * 0.5;
float3x3 jacobian = float3x3(
focal / viewPosition.z, 0, -(focal * viewPosition.x) / (viewPosition.z * viewPosition.z),
0, focal / viewPosition.z, -(focal * viewPosition.y) / (viewPosition.z * viewPosition.z),
0, 0, 0);
float3x3 worldToView = (float3x3)gView;
float3x3 transform = mul(jacobian, worldToView);
float3x3 covariance = float3x3(
covariance0.x, covariance0.y, covariance0.z,
covariance0.y, covariance1.x, covariance1.y,
covariance0.z, covariance1.y, covariance1.z);
float3x3 projected = mul(transform, mul(covariance, transpose(transform)));
projected._m00 += 0.3;
projected._m11 += 0.3;
return float3(projected._m00, projected._m01, projected._m11);
}
void DecomposeCovariance(float3 covariance2D, out float2 axis1, out float2 axis2)
{
float diagonal0 = covariance2D.x;
float diagonal1 = covariance2D.z;
float offDiagonal = covariance2D.y;
float mid = 0.5 * (diagonal0 + diagonal1);
float radius = length(float2((diagonal0 - diagonal1) * 0.5, offDiagonal));
float lambda0 = mid + radius;
float lambda1 = max(mid - radius, 0.1);
float2 diagonalVector = normalize(float2(offDiagonal, lambda0 - diagonal0));
diagonalVector.y = -diagonalVector.y;
const float maxSize = 4096.0;
axis1 = min(sqrt(2.0 * lambda0), maxSize) * diagonalVector;
axis2 = min(sqrt(2.0 * lambda1), maxSize) * float2(diagonalVector.y, -diagonalVector.x);
}
SplatSHData LoadSplatSH(uint index)
{
SplatSHData sh;
const uint shBaseOffset = index * kShStride;
sh.col = gColor.Load(int3(SplatIndexToPixelIndex(index).xy, 0)).rgb;
[unroll]
for (uint coefficientIndex = 0; coefficientIndex < 15; ++coefficientIndex)
{
sh.sh[coefficientIndex] = LoadFloat3(gSh, shBaseOffset + coefficientIndex * 12);
}
return sh;
}
float3 ShadeSH(SplatSHData sh, float3 direction, int shOrder)
{
direction *= -1.0;
float x = direction.x;
float y = direction.y;
float z = direction.z;
float3 result = sh.col;
if (shOrder >= 1)
{
result += SH_C1 * (-sh.sh[0] * y + sh.sh[1] * z - sh.sh[2] * x);
if (shOrder >= 2)
{
float xx = x * x;
float yy = y * y;
float zz = z * z;
float xy = x * y;
float yz = y * z;
float xz = x * z;
result +=
(SH_C2[0] * xy) * sh.sh[3] +
(SH_C2[1] * yz) * sh.sh[4] +
(SH_C2[2] * (2 * zz - xx - yy)) * sh.sh[5] +
(SH_C2[3] * xz) * sh.sh[6] +
(SH_C2[4] * (xx - yy)) * sh.sh[7];
if (shOrder >= 3)
{
result +=
(SH_C3[0] * y * (3 * xx - yy)) * sh.sh[8] +
(SH_C3[1] * xy * z) * sh.sh[9] +
(SH_C3[2] * y * (4 * zz - xx - yy)) * sh.sh[10] +
(SH_C3[3] * z * (2 * zz - 3 * xx - 3 * yy)) * sh.sh[11] +
(SH_C3[4] * x * (4 * zz - xx - yy)) * sh.sh[12] +
(SH_C3[5] * z * (xx - yy)) * sh.sh[13] +
(SH_C3[6] * x * (xx - 3 * yy)) * sh.sh[14];
}
}
}
return max(result, 0.0);
}
[numthreads(GROUP_SIZE, 1, 1)]
void MainCS(uint3 dispatchThreadId : SV_DispatchThreadID)
{
uint index = dispatchThreadId.x;
uint splatCount = (uint)gSettings.x;
if (index >= splatCount)
{
return;
}
PreparedSplatView view = (PreparedSplatView)0;
float3 position = LoadFloat3(gPositions, index * 12);
uint packedRotation = gOther.Load(index * kOtherStride);
float4 rotation = DecodeRotation(DecodePacked_10_10_10_2(packedRotation));
float3 scale = LoadFloat3(gOther, index * kOtherStride + 4);
float4 colorOpacity = gColor.Load(int3(SplatIndexToPixelIndex(index).xy, 0));
view.clipPosition = mul(float4(position, 1.0), gViewProjection);
if (view.clipPosition.w > 0.0)
{
float3x3 rotationScale = CalcMatrixFromRotationScale(rotation, scale);
float3 covariance0;
float3 covariance1;
CalcCovariance3D(rotationScale, covariance0, covariance1);
float3 covariance2D = CalcCovariance2D(position, covariance0, covariance1);
DecomposeCovariance(covariance2D, view.axis1, view.axis2);
SplatSHData sh = LoadSplatSH(index);
float3 viewDirection = normalize(gCameraWorldPos.xyz - position);
float3 shadedColor = ShadeSH(sh, viewDirection, (int)gSettings.z);
float opacity = saturate(colorOpacity.a * gSettings.y);
view.packedColor.x = (f32tof16(shadedColor.r) << 16) | f32tof16(shadedColor.g);
view.packedColor.y = (f32tof16(shadedColor.b) << 16) | f32tof16(opacity);
}
gPreparedViews[index] = view;
}

View File

@@ -0,0 +1,22 @@
#ifndef PREPARED_SPLAT_VIEW_HLSLI
#define PREPARED_SPLAT_VIEW_HLSLI
struct PreparedSplatView
{
float4 clipPosition;
float2 axis1;
float2 axis2;
uint2 packedColor;
};
float4 UnpackPreparedColor(PreparedSplatView view)
{
float4 color;
color.r = f16tof32((view.packedColor.x >> 16) & 0xFFFF);
color.g = f16tof32(view.packedColor.x & 0xFFFF);
color.b = f16tof32((view.packedColor.y >> 16) & 0xFFFF);
color.a = f16tof32(view.packedColor.y & 0xFFFF);
return color;
}
#endif

View File

@@ -2,17 +2,43 @@
#include <d3d12.h> #include <d3d12.h>
#include <dxgi1_4.h> #include <dxgi1_4.h>
#include <DirectXMath.h>
#include <algorithm>
#include <cstring>
#include <filesystem> #include <filesystem>
#include <fstream>
#include <string_view>
#include "XCEngine/RHI/D3D12/D3D12Screenshot.h"
#include "XCEngine/RHI/RHIDescriptorPool.h"
#include "XCEngine/RHI/RHIDescriptorSet.h"
#include "XCEngine/RHI/RHIPipelineLayout.h"
#include "XCEngine/RHI/RHIPipelineState.h"
namespace XC3DGSD3D12 { namespace XC3DGSD3D12 {
using namespace XCEngine::RHI; using namespace XCEngine::RHI;
namespace { namespace {
constexpr wchar_t kWindowClassName[] = L"XC3DGSD3D12WindowClass"; constexpr wchar_t kWindowClassName[] = L"XC3DGSD3D12WindowClass";
constexpr wchar_t kWindowTitle[] = L"XC 3DGS D3D12 MVS - Phase 1"; constexpr wchar_t kWindowTitle[] = L"XC 3DGS D3D12 MVS - Phase 3";
constexpr float kClearColor[4] = { 0.08f, 0.12f, 0.18f, 1.0f }; constexpr float kClearColor[4] = { 0.04f, 0.05f, 0.07f, 1.0f };
constexpr uint32_t kPrepareThreadGroupSize = 64u;
constexpr uint32_t kSortThreadGroupSize = 64u;
struct FrameConstants {
float viewProjection[16] = {};
float view[16] = {};
float projection[16] = {};
float cameraWorldPos[4] = {};
float screenParams[4] = {};
float settings[4] = {};
};
static_assert(sizeof(FrameConstants) % 16 == 0, "Frame constants must stay 16-byte aligned.");
static_assert(sizeof(PreparedSplatView) == 40, "Prepared view buffer layout must match shader.");
std::filesystem::path GetExecutableDirectory() { std::filesystem::path GetExecutableDirectory() {
std::wstring pathBuffer; std::wstring pathBuffer;
@@ -29,8 +55,106 @@ std::filesystem::path ResolveNearExecutable(const std::wstring& path) {
} }
return GetExecutableDirectory() / inputPath; return GetExecutableDirectory() / inputPath;
} }
std::filesystem::path ResolveShaderPath(std::wstring_view fileName) {
return GetExecutableDirectory() / L"shaders" / std::filesystem::path(fileName);
} }
std::vector<uint8_t> LoadBinaryFile(const std::filesystem::path& filePath) {
std::ifstream input(filePath, std::ios::binary);
if (!input.is_open()) {
return {};
}
input.seekg(0, std::ios::end);
const std::streamoff size = input.tellg();
if (size <= 0) {
return {};
}
input.seekg(0, std::ios::beg);
std::vector<uint8_t> bytes(static_cast<size_t>(size));
input.read(reinterpret_cast<char*>(bytes.data()), size);
if (!input) {
return {};
}
return bytes;
}
std::string NarrowAscii(std::wstring_view text) {
std::string result;
result.reserve(text.size());
for (wchar_t ch : text) {
result.push_back(ch >= 0 && ch <= 0x7F ? static_cast<char>(ch) : '?');
}
return result;
}
void AppendTrace(std::string_view message) {
const std::filesystem::path tracePath = GetExecutableDirectory() / "phase3_trace.log";
std::ofstream file(tracePath, std::ios::app);
if (!file.is_open()) {
return;
}
file << GetTickCount64() << " | " << message << '\n';
}
void AppendTrace(const std::wstring& message) {
AppendTrace(NarrowAscii(message));
}
void StoreMatrixTransposed(const DirectX::XMMATRIX& matrix, float* destination) {
DirectX::XMFLOAT4X4 output = {};
DirectX::XMStoreFloat4x4(&output, DirectX::XMMatrixTranspose(matrix));
std::memcpy(destination, &output, sizeof(output));
}
FrameConstants BuildFrameConstants(uint32_t width, uint32_t height, uint32_t splatCount) {
using namespace DirectX;
const XMVECTOR eye = XMVectorSet(0.0f, 0.5f, 1.0f, 1.0f);
const XMVECTOR target = XMVectorSet(0.0f, 0.5f, -5.0f, 1.0f);
const XMVECTOR up = XMVectorSet(0.0f, 1.0f, 0.0f, 0.0f);
const float aspect = height > 0 ? static_cast<float>(width) / static_cast<float>(height) : 1.0f;
const XMMATRIX view = XMMatrixLookAtRH(eye, target, up);
const XMMATRIX projection = XMMatrixPerspectiveFovRH(XMConvertToRadians(60.0f), aspect, 0.1f, 200.0f);
const XMMATRIX viewProjection = XMMatrixMultiply(view, projection);
FrameConstants constants = {};
StoreMatrixTransposed(viewProjection, constants.viewProjection);
StoreMatrixTransposed(view, constants.view);
StoreMatrixTransposed(projection, constants.projection);
constants.cameraWorldPos[0] = 0.0f;
constants.cameraWorldPos[1] = 0.5f;
constants.cameraWorldPos[2] = 1.0f;
constants.cameraWorldPos[3] = 1.0f;
constants.screenParams[0] = static_cast<float>(width);
constants.screenParams[1] = static_cast<float>(height);
constants.screenParams[2] = width > 0 ? 1.0f / static_cast<float>(width) : 0.0f;
constants.screenParams[3] = height > 0 ? 1.0f / static_cast<float>(height) : 0.0f;
constants.settings[0] = static_cast<float>(splatCount);
constants.settings[1] = 1.0f; // opacity scale
constants.settings[2] = 3.0f; // SH order
constants.settings[3] = 1.0f;
return constants;
}
template <typename T>
void ShutdownAndDelete(T*& object) {
if (object != nullptr) {
object->Shutdown();
delete object;
object = nullptr;
}
}
} // namespace
App::App() = default; App::App() = default;
App::~App() { App::~App() {
@@ -38,36 +162,66 @@ App::~App() {
} }
bool App::Initialize(HINSTANCE instance, int showCommand) { bool App::Initialize(HINSTANCE instance, int showCommand) {
AppendTrace("Initialize: begin");
m_instance = instance; m_instance = instance;
m_lastErrorMessage.clear(); m_lastErrorMessage.clear();
AppendTrace("Initialize: LoadGaussianScene");
if (!LoadGaussianScene()) { if (!LoadGaussianScene()) {
AppendTrace(std::string("Initialize: LoadGaussianScene failed: ") + NarrowAscii(m_lastErrorMessage));
return false; return false;
} }
AppendTrace("Initialize: RegisterWindowClass");
if (!RegisterWindowClass(instance)) { if (!RegisterWindowClass(instance)) {
m_lastErrorMessage = L"Failed to register the Win32 window class."; m_lastErrorMessage = L"Failed to register the Win32 window class.";
AppendTrace(std::string("Initialize: RegisterWindowClass failed: ") + NarrowAscii(m_lastErrorMessage));
return false; return false;
} }
AppendTrace("Initialize: CreateMainWindow");
if (!CreateMainWindow(instance, showCommand)) { if (!CreateMainWindow(instance, showCommand)) {
m_lastErrorMessage = L"Failed to create the main window."; m_lastErrorMessage = L"Failed to create the main window.";
AppendTrace(std::string("Initialize: CreateMainWindow failed: ") + NarrowAscii(m_lastErrorMessage));
return false; return false;
} }
AppendTrace("Initialize: InitializeRhi");
if (!InitializeRhi()) { if (!InitializeRhi()) {
if (m_lastErrorMessage.empty()) { if (m_lastErrorMessage.empty()) {
m_lastErrorMessage = L"Failed to initialize the D3D12 RHI objects."; m_lastErrorMessage = L"Failed to initialize the D3D12 RHI objects.";
} }
AppendTrace(std::string("Initialize: InitializeRhi failed: ") + NarrowAscii(m_lastErrorMessage));
return false; return false;
} }
AppendTrace("Initialize: InitializeGaussianGpuResources");
if (!InitializeGaussianGpuResources()) { if (!InitializeGaussianGpuResources()) {
AppendTrace(std::string("Initialize: InitializeGaussianGpuResources failed: ") + NarrowAscii(m_lastErrorMessage));
return false;
}
AppendTrace("Initialize: InitializePreparePassResources");
if (!InitializePreparePassResources()) {
AppendTrace(std::string("Initialize: InitializePreparePassResources failed: ") + NarrowAscii(m_lastErrorMessage));
return false;
}
AppendTrace("Initialize: InitializeSortResources");
if (!InitializeSortResources()) {
AppendTrace(std::string("Initialize: InitializeSortResources failed: ") + NarrowAscii(m_lastErrorMessage));
return false;
}
AppendTrace("Initialize: InitializeDebugDrawResources");
if (!InitializeDebugDrawResources()) {
AppendTrace(std::string("Initialize: InitializeDebugDrawResources failed: ") + NarrowAscii(m_lastErrorMessage));
return false; return false;
} }
m_isInitialized = true; m_isInitialized = true;
m_running = true; m_running = true;
AppendTrace("Initialize: success");
return true; return true;
} }
@@ -83,15 +237,22 @@ void App::SetSummaryPath(std::wstring summaryPath) {
m_summaryPath = std::move(summaryPath); m_summaryPath = std::move(summaryPath);
} }
void App::SetScreenshotPath(std::wstring screenshotPath) {
m_screenshotPath = std::move(screenshotPath);
}
const std::wstring& App::GetLastErrorMessage() const { const std::wstring& App::GetLastErrorMessage() const {
return m_lastErrorMessage; return m_lastErrorMessage;
} }
int App::Run() { int App::Run() {
AppendTrace("Run: begin");
MSG message = {}; MSG message = {};
int exitCode = 0;
while (m_running) { while (m_running) {
while (PeekMessage(&message, nullptr, 0, 0, PM_REMOVE)) { while (PeekMessage(&message, nullptr, 0, 0, PM_REMOVE)) {
if (message.message == WM_QUIT) { if (message.message == WM_QUIT) {
exitCode = static_cast<int>(message.wParam);
m_running = false; m_running = false;
break; break;
} }
@@ -104,16 +265,21 @@ int App::Run() {
break; break;
} }
RenderFrame(); const bool captureScreenshot = m_frameLimit > 0 && (m_renderedFrameCount + 1) >= m_frameLimit;
AppendTrace(captureScreenshot ? "Run: RenderFrame capture" : "Run: RenderFrame");
RenderFrame(captureScreenshot);
AppendTrace("Run: RenderFrame complete");
++m_renderedFrameCount; ++m_renderedFrameCount;
if (m_frameLimit > 0 && m_renderedFrameCount >= m_frameLimit) { if (m_frameLimit > 0 && m_renderedFrameCount >= m_frameLimit) {
m_running = false; m_running = false;
AppendTrace("Run: frame limit reached, posting WM_CLOSE");
PostMessageW(m_hwnd, WM_CLOSE, 0, 0); PostMessageW(m_hwnd, WM_CLOSE, 0, 0);
} }
} }
return static_cast<int>(message.wParam); AppendTrace("Run: end");
return exitCode;
} }
LRESULT CALLBACK App::StaticWindowProc(HWND hwnd, UINT message, WPARAM wParam, LPARAM lParam) { LRESULT CALLBACK App::StaticWindowProc(HWND hwnd, UINT message, WPARAM wParam, LPARAM lParam) {
@@ -376,6 +542,361 @@ bool App::InitializeGaussianGpuResources() {
return true; return true;
} }
bool App::InitializePreparePassResources() {
DescriptorSetLayoutBinding bindings[6] = {};
bindings[0].binding = 0;
bindings[0].type = static_cast<uint32_t>(DescriptorType::CBV);
bindings[0].count = 1;
bindings[0].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[1].binding = 1;
bindings[1].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[1].count = 1;
bindings[1].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[1].resourceDimension = ResourceViewDimension::RawBuffer;
bindings[2].binding = 2;
bindings[2].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[2].count = 1;
bindings[2].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[2].resourceDimension = ResourceViewDimension::RawBuffer;
bindings[3].binding = 3;
bindings[3].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[3].count = 1;
bindings[3].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[3].resourceDimension = ResourceViewDimension::Texture2D;
bindings[4].binding = 4;
bindings[4].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[4].count = 1;
bindings[4].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[4].resourceDimension = ResourceViewDimension::RawBuffer;
bindings[5].binding = 5;
bindings[5].type = static_cast<uint32_t>(DescriptorType::UAV);
bindings[5].count = 1;
bindings[5].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[5].resourceDimension = ResourceViewDimension::StructuredBuffer;
DescriptorSetLayoutDesc setLayout = {};
setLayout.bindings = bindings;
setLayout.bindingCount = 6;
RHIPipelineLayoutDesc pipelineLayoutDesc = {};
pipelineLayoutDesc.setLayouts = &setLayout;
pipelineLayoutDesc.setLayoutCount = 1;
m_preparePipelineLayout = m_device.CreatePipelineLayout(pipelineLayoutDesc);
if (m_preparePipelineLayout == nullptr) {
m_lastErrorMessage = L"Failed to create the prepare pass pipeline layout.";
return false;
}
BufferDesc preparedViewBufferDesc = {};
preparedViewBufferDesc.size = static_cast<uint64_t>(sizeof(PreparedSplatView)) * m_gaussianSceneData.splatCount;
preparedViewBufferDesc.stride = sizeof(PreparedSplatView);
preparedViewBufferDesc.bufferType = static_cast<uint32_t>(BufferType::Storage);
preparedViewBufferDesc.flags = static_cast<uint64_t>(BufferFlags::AllowUnorderedAccess);
m_preparedViewBuffer = static_cast<D3D12Buffer*>(m_device.CreateBuffer(preparedViewBufferDesc));
if (m_preparedViewBuffer == nullptr) {
m_lastErrorMessage = L"Failed to create the prepared view buffer.";
return false;
}
m_preparedViewBuffer->SetStride(sizeof(PreparedSplatView));
m_preparedViewBuffer->SetBufferType(BufferType::Storage);
ResourceViewDesc structuredViewDesc = {};
structuredViewDesc.dimension = ResourceViewDimension::StructuredBuffer;
structuredViewDesc.structureByteStride = sizeof(PreparedSplatView);
structuredViewDesc.elementCount = m_gaussianSceneData.splatCount;
m_preparedViewSrv.reset(static_cast<D3D12ResourceView*>(m_device.CreateShaderResourceView(m_preparedViewBuffer, structuredViewDesc)));
if (!m_preparedViewSrv) {
m_lastErrorMessage = L"Failed to create the prepared view SRV.";
return false;
}
m_preparedViewUav.reset(static_cast<D3D12ResourceView*>(m_device.CreateUnorderedAccessView(m_preparedViewBuffer, structuredViewDesc)));
if (!m_preparedViewUav) {
m_lastErrorMessage = L"Failed to create the prepared view UAV.";
return false;
}
DescriptorPoolDesc poolDesc = {};
poolDesc.type = DescriptorHeapType::CBV_SRV_UAV;
poolDesc.descriptorCount = 5;
poolDesc.shaderVisible = true;
m_prepareDescriptorPool = m_device.CreateDescriptorPool(poolDesc);
if (m_prepareDescriptorPool == nullptr) {
m_lastErrorMessage = L"Failed to create the prepare pass descriptor pool.";
return false;
}
m_prepareDescriptorSet = m_prepareDescriptorPool->AllocateSet(setLayout);
if (m_prepareDescriptorSet == nullptr) {
m_lastErrorMessage = L"Failed to allocate the prepare pass descriptor set.";
return false;
}
m_prepareDescriptorSet->Update(1, m_gaussianPositionView.get());
m_prepareDescriptorSet->Update(2, m_gaussianOtherView.get());
m_prepareDescriptorSet->Update(3, m_gaussianColorView.get());
m_prepareDescriptorSet->Update(4, m_gaussianShView.get());
m_prepareDescriptorSet->Update(5, m_preparedViewUav.get());
ShaderCompileDesc computeShaderDesc = {};
computeShaderDesc.fileName = ResolveShaderPath(L"PrepareGaussiansCS.hlsl").wstring();
computeShaderDesc.entryPoint = L"MainCS";
computeShaderDesc.profile = L"cs_5_0";
ComputePipelineDesc pipelineDesc = {};
pipelineDesc.pipelineLayout = m_preparePipelineLayout;
pipelineDesc.computeShader = computeShaderDesc;
m_preparePipelineState = m_device.CreateComputePipelineState(pipelineDesc);
if (m_preparePipelineState == nullptr) {
m_lastErrorMessage = L"Failed to create the prepare pass pipeline state.";
return false;
}
return true;
}
bool App::InitializeSortResources() {
std::vector<uint32_t> initialOrder(static_cast<size_t>(m_gaussianSceneData.splatCount));
for (uint32_t index = 0; index < m_gaussianSceneData.splatCount; ++index) {
initialOrder[index] = index;
}
ID3D12Device* device = m_device.GetDevice();
ID3D12GraphicsCommandList* commandList = m_commandList.GetCommandList();
m_commandAllocator.Reset();
m_commandList.Reset();
const D3D12_RESOURCE_STATES shaderResourceState =
D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE | D3D12_RESOURCE_STATE_PIXEL_SHADER_RESOURCE;
m_orderBuffer = new D3D12Buffer();
m_orderBuffer->SetStride(sizeof(uint32_t));
m_orderBuffer->SetBufferType(BufferType::Storage);
if (!m_orderBuffer->InitializeWithData(
device,
commandList,
initialOrder.data(),
static_cast<uint64_t>(initialOrder.size() * sizeof(uint32_t)),
shaderResourceState)) {
m_lastErrorMessage = L"Failed to initialize the order buffer.";
return false;
}
m_orderBuffer->SetState(ResourceStates::NonPixelShaderResource);
BufferDesc sortKeyBufferDesc = {};
sortKeyBufferDesc.size = static_cast<uint64_t>(m_gaussianSceneData.splatCount) * sizeof(uint32_t);
sortKeyBufferDesc.stride = sizeof(uint32_t);
sortKeyBufferDesc.bufferType = static_cast<uint32_t>(BufferType::Storage);
sortKeyBufferDesc.flags = static_cast<uint64_t>(BufferFlags::AllowUnorderedAccess);
m_sortKeyBuffer = static_cast<D3D12Buffer*>(m_device.CreateBuffer(sortKeyBufferDesc));
if (m_sortKeyBuffer == nullptr) {
m_lastErrorMessage = L"Failed to create the sort key buffer.";
return false;
}
m_sortKeyBuffer->SetStride(sizeof(uint32_t));
m_sortKeyBuffer->SetBufferType(BufferType::Storage);
ResourceViewDesc structuredViewDesc = {};
structuredViewDesc.dimension = ResourceViewDimension::StructuredBuffer;
structuredViewDesc.structureByteStride = sizeof(uint32_t);
structuredViewDesc.elementCount = m_gaussianSceneData.splatCount;
m_orderBufferSrv.reset(static_cast<D3D12ResourceView*>(m_device.CreateShaderResourceView(m_orderBuffer, structuredViewDesc)));
if (!m_orderBufferSrv) {
m_lastErrorMessage = L"Failed to create the order buffer SRV.";
return false;
}
m_sortKeySrv.reset(static_cast<D3D12ResourceView*>(m_device.CreateShaderResourceView(m_sortKeyBuffer, structuredViewDesc)));
if (!m_sortKeySrv) {
m_lastErrorMessage = L"Failed to create the sort key SRV.";
return false;
}
m_sortKeyUav.reset(static_cast<D3D12ResourceView*>(m_device.CreateUnorderedAccessView(m_sortKeyBuffer, structuredViewDesc)));
if (!m_sortKeyUav) {
m_lastErrorMessage = L"Failed to create the sort key UAV.";
return false;
}
DescriptorSetLayoutBinding bindings[4] = {};
bindings[0].binding = 0;
bindings[0].type = static_cast<uint32_t>(DescriptorType::CBV);
bindings[0].count = 1;
bindings[0].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[1].binding = 1;
bindings[1].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[1].count = 1;
bindings[1].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[1].resourceDimension = ResourceViewDimension::RawBuffer;
bindings[2].binding = 2;
bindings[2].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[2].count = 1;
bindings[2].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[2].resourceDimension = ResourceViewDimension::StructuredBuffer;
bindings[3].binding = 3;
bindings[3].type = static_cast<uint32_t>(DescriptorType::UAV);
bindings[3].count = 1;
bindings[3].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[3].resourceDimension = ResourceViewDimension::StructuredBuffer;
DescriptorSetLayoutDesc setLayout = {};
setLayout.bindings = bindings;
setLayout.bindingCount = 4;
RHIPipelineLayoutDesc pipelineLayoutDesc = {};
pipelineLayoutDesc.setLayouts = &setLayout;
pipelineLayoutDesc.setLayoutCount = 1;
m_sortPipelineLayout = m_device.CreatePipelineLayout(pipelineLayoutDesc);
if (m_sortPipelineLayout == nullptr) {
m_lastErrorMessage = L"Failed to create the sort pipeline layout.";
return false;
}
DescriptorPoolDesc poolDesc = {};
poolDesc.type = DescriptorHeapType::CBV_SRV_UAV;
poolDesc.descriptorCount = 3;
poolDesc.shaderVisible = true;
m_sortDescriptorPool = m_device.CreateDescriptorPool(poolDesc);
if (m_sortDescriptorPool == nullptr) {
m_lastErrorMessage = L"Failed to create the sort descriptor pool.";
return false;
}
m_sortDescriptorSet = m_sortDescriptorPool->AllocateSet(setLayout);
if (m_sortDescriptorSet == nullptr) {
m_lastErrorMessage = L"Failed to allocate the sort descriptor set.";
return false;
}
m_sortDescriptorSet->Update(1, m_gaussianPositionView.get());
m_sortDescriptorSet->Update(2, m_orderBufferSrv.get());
m_sortDescriptorSet->Update(3, m_sortKeyUav.get());
ShaderCompileDesc computeShaderDesc = {};
const std::filesystem::path compiledShaderPath = ResolveShaderPath(L"BuildSortKeysCS.dxil");
const std::vector<uint8_t> compiledShaderBinary = LoadBinaryFile(compiledShaderPath);
if (!compiledShaderBinary.empty()) {
computeShaderDesc.profile = L"cs_6_6";
computeShaderDesc.compiledBinaryBackend = ShaderBinaryBackend::D3D12;
computeShaderDesc.compiledBinary = compiledShaderBinary;
} else {
computeShaderDesc.fileName = ResolveShaderPath(L"BuildSortKeysCS.hlsl").wstring();
computeShaderDesc.entryPoint = L"MainCS";
computeShaderDesc.profile = L"cs_5_0";
}
ComputePipelineDesc pipelineDesc = {};
pipelineDesc.pipelineLayout = m_sortPipelineLayout;
pipelineDesc.computeShader = computeShaderDesc;
m_sortPipelineState = m_device.CreateComputePipelineState(pipelineDesc);
if (m_sortPipelineState == nullptr) {
m_lastErrorMessage = L"Failed to create the sort pipeline state.";
return false;
}
m_commandList.Close();
void* commandLists[] = { &m_commandList };
m_commandQueue.ExecuteCommandLists(1, commandLists);
m_commandQueue.WaitForIdle();
return true;
}
bool App::InitializeDebugDrawResources() {
DescriptorSetLayoutBinding bindings[3] = {};
bindings[0].binding = 0;
bindings[0].type = static_cast<uint32_t>(DescriptorType::CBV);
bindings[0].count = 1;
bindings[0].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[1].binding = 1;
bindings[1].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[1].count = 1;
bindings[1].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[1].resourceDimension = ResourceViewDimension::StructuredBuffer;
bindings[2].binding = 2;
bindings[2].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[2].count = 1;
bindings[2].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[2].resourceDimension = ResourceViewDimension::StructuredBuffer;
DescriptorSetLayoutDesc setLayout = {};
setLayout.bindings = bindings;
setLayout.bindingCount = 3;
RHIPipelineLayoutDesc pipelineLayoutDesc = {};
pipelineLayoutDesc.setLayouts = &setLayout;
pipelineLayoutDesc.setLayoutCount = 1;
m_debugPipelineLayout = m_device.CreatePipelineLayout(pipelineLayoutDesc);
if (m_debugPipelineLayout == nullptr) {
m_lastErrorMessage = L"Failed to create the debug draw pipeline layout.";
return false;
}
DescriptorPoolDesc poolDesc = {};
poolDesc.type = DescriptorHeapType::CBV_SRV_UAV;
poolDesc.descriptorCount = 2;
poolDesc.shaderVisible = true;
m_debugDescriptorPool = m_device.CreateDescriptorPool(poolDesc);
if (m_debugDescriptorPool == nullptr) {
m_lastErrorMessage = L"Failed to create the debug draw descriptor pool.";
return false;
}
m_debugDescriptorSet = m_debugDescriptorPool->AllocateSet(setLayout);
if (m_debugDescriptorSet == nullptr) {
m_lastErrorMessage = L"Failed to allocate the debug draw descriptor set.";
return false;
}
m_debugDescriptorSet->Update(1, m_preparedViewSrv.get());
m_debugDescriptorSet->Update(2, m_orderBufferSrv.get());
GraphicsPipelineDesc pipelineDesc = {};
pipelineDesc.pipelineLayout = m_debugPipelineLayout;
pipelineDesc.topologyType = static_cast<uint32_t>(PrimitiveTopologyType::Point);
pipelineDesc.renderTargetCount = 1;
pipelineDesc.renderTargetFormats[0] = static_cast<uint32_t>(Format::R8G8B8A8_UNorm);
pipelineDesc.depthStencilFormat = static_cast<uint32_t>(Format::D24_UNorm_S8_UInt);
pipelineDesc.sampleCount = 1;
pipelineDesc.sampleQuality = 0;
pipelineDesc.rasterizerState.cullMode = static_cast<uint32_t>(CullMode::None);
pipelineDesc.depthStencilState.depthTestEnable = false;
pipelineDesc.depthStencilState.depthWriteEnable = false;
pipelineDesc.vertexShader.fileName = ResolveShaderPath(L"DebugPointsVS.hlsl").wstring();
pipelineDesc.vertexShader.entryPoint = L"MainVS";
pipelineDesc.vertexShader.profile = L"vs_5_0";
pipelineDesc.fragmentShader.fileName = ResolveShaderPath(L"DebugPointsPS.hlsl").wstring();
pipelineDesc.fragmentShader.entryPoint = L"MainPS";
pipelineDesc.fragmentShader.profile = L"ps_5_0";
m_debugPipelineState = m_device.CreatePipelineState(pipelineDesc);
if (m_debugPipelineState == nullptr) {
m_lastErrorMessage = L"Failed to create the debug draw pipeline state.";
return false;
}
return true;
}
void App::ShutdownGaussianGpuResources() { void App::ShutdownGaussianGpuResources() {
m_gaussianColorView.reset(); m_gaussianColorView.reset();
m_gaussianShView.reset(); m_gaussianShView.reset();
@@ -388,16 +909,78 @@ void App::ShutdownGaussianGpuResources() {
m_gaussianPositionBuffer.Shutdown(); m_gaussianPositionBuffer.Shutdown();
} }
void App::ShutdownPreparePassResources() {
if (m_prepareDescriptorSet != nullptr) {
m_prepareDescriptorSet->Shutdown();
delete m_prepareDescriptorSet;
m_prepareDescriptorSet = nullptr;
}
ShutdownAndDelete(m_prepareDescriptorPool);
ShutdownAndDelete(m_preparePipelineState);
ShutdownAndDelete(m_preparePipelineLayout);
m_preparedViewUav.reset();
m_preparedViewSrv.reset();
if (m_preparedViewBuffer != nullptr) {
m_preparedViewBuffer->Shutdown();
delete m_preparedViewBuffer;
m_preparedViewBuffer = nullptr;
}
}
void App::ShutdownSortResources() {
if (m_sortDescriptorSet != nullptr) {
m_sortDescriptorSet->Shutdown();
delete m_sortDescriptorSet;
m_sortDescriptorSet = nullptr;
}
ShutdownAndDelete(m_sortDescriptorPool);
ShutdownAndDelete(m_sortPipelineState);
ShutdownAndDelete(m_sortPipelineLayout);
m_orderBufferSrv.reset();
m_sortKeyUav.reset();
m_sortKeySrv.reset();
if (m_orderBuffer != nullptr) {
m_orderBuffer->Shutdown();
delete m_orderBuffer;
m_orderBuffer = nullptr;
}
if (m_sortKeyBuffer != nullptr) {
m_sortKeyBuffer->Shutdown();
delete m_sortKeyBuffer;
m_sortKeyBuffer = nullptr;
}
}
void App::ShutdownDebugDrawResources() {
if (m_debugDescriptorSet != nullptr) {
m_debugDescriptorSet->Shutdown();
delete m_debugDescriptorSet;
m_debugDescriptorSet = nullptr;
}
ShutdownAndDelete(m_debugDescriptorPool);
ShutdownAndDelete(m_debugPipelineState);
ShutdownAndDelete(m_debugPipelineLayout);
}
void App::Shutdown() { void App::Shutdown() {
AppendTrace("Shutdown: begin");
if (!m_isInitialized && m_hwnd == nullptr) { if (!m_isInitialized && m_hwnd == nullptr) {
AppendTrace("Shutdown: skipped");
return; return;
} }
m_running = false; m_running = false;
if (m_commandQueue.GetCommandQueue() != nullptr) { if (m_commandQueue.GetCommandQueue() != nullptr) {
AppendTrace("Shutdown: WaitForIdle");
m_commandQueue.WaitForIdle(); m_commandQueue.WaitForIdle();
} }
ShutdownDebugDrawResources();
ShutdownSortResources();
ShutdownPreparePassResources();
ShutdownGaussianGpuResources(); ShutdownGaussianGpuResources();
m_commandList.Shutdown(); m_commandList.Shutdown();
m_commandAllocator.Shutdown(); m_commandAllocator.Shutdown();
@@ -413,6 +996,7 @@ void App::Shutdown() {
m_device.Shutdown(); m_device.Shutdown();
if (m_hwnd != nullptr) { if (m_hwnd != nullptr) {
AppendTrace("Shutdown: DestroyWindow");
DestroyWindow(m_hwnd); DestroyWindow(m_hwnd);
m_hwnd = nullptr; m_hwnd = nullptr;
} }
@@ -423,10 +1007,85 @@ void App::Shutdown() {
} }
m_isInitialized = false; m_isInitialized = false;
AppendTrace("Shutdown: end");
} }
void App::RenderFrame() { bool App::CaptureSortKeySnapshot() {
if (m_sortKeyBuffer == nullptr || m_gaussianSceneData.splatCount == 0 || m_sortKeySnapshotPath.empty()) {
return true;
}
const uint32_t sampleCount = std::min<uint32_t>(16u, m_gaussianSceneData.splatCount);
const uint64_t sampleBytes = static_cast<uint64_t>(sampleCount) * sizeof(uint32_t);
D3D12Buffer readbackBuffer;
if (!readbackBuffer.Initialize(
m_device.GetDevice(),
sampleBytes,
D3D12_RESOURCE_STATE_COPY_DEST,
D3D12_HEAP_TYPE_READBACK,
D3D12_RESOURCE_FLAG_NONE)) {
m_lastErrorMessage = L"Failed to create the sort key readback buffer.";
return false;
}
m_commandAllocator.Reset();
m_commandList.Reset();
if (m_sortKeyBuffer->GetState() != ResourceStates::CopySrc) {
m_commandList.TransitionBarrier(
m_sortKeyBuffer->GetResource(),
m_sortKeyBuffer->GetState(),
ResourceStates::CopySrc);
m_sortKeyBuffer->SetState(ResourceStates::CopySrc);
}
m_commandList.GetCommandList()->CopyBufferRegion(
readbackBuffer.GetResource(),
0,
m_sortKeyBuffer->GetResource(),
0,
sampleBytes);
m_commandList.Close();
void* commandLists[] = { &m_commandList };
m_commandQueue.ExecuteCommandLists(1, commandLists);
m_commandQueue.WaitForIdle();
const uint32_t* keys = static_cast<const uint32_t*>(readbackBuffer.Map());
if (keys == nullptr) {
m_lastErrorMessage = L"Failed to map the sort key readback buffer.";
readbackBuffer.Shutdown();
return false;
}
const std::filesystem::path snapshotPath = ResolveNearExecutable(m_sortKeySnapshotPath);
if (!snapshotPath.parent_path().empty()) {
std::filesystem::create_directories(snapshotPath.parent_path());
}
std::ofstream output(snapshotPath, std::ios::binary | std::ios::trunc);
if (!output.is_open()) {
readbackBuffer.Unmap();
readbackBuffer.Shutdown();
m_lastErrorMessage = L"Failed to open the sort key snapshot output file.";
return false;
}
output << "sample_count=" << sampleCount << '\n';
for (uint32_t index = 0; index < sampleCount; ++index) {
output << "key[" << index << "]=" << keys[index] << '\n';
}
readbackBuffer.Unmap();
readbackBuffer.Shutdown();
return output.good();
}
void App::RenderFrame(bool captureScreenshot) {
AppendTrace(captureScreenshot ? "RenderFrame: begin capture" : "RenderFrame: begin");
if (m_hasRenderedAtLeastOneFrame) { if (m_hasRenderedAtLeastOneFrame) {
AppendTrace("RenderFrame: WaitForPreviousFrame");
m_commandQueue.WaitForPreviousFrame(); m_commandQueue.WaitForPreviousFrame();
} }
@@ -459,14 +1118,99 @@ void App::RenderFrame() {
0, 0,
nullptr); nullptr);
m_commandList.TransitionBarrier(backBuffer.GetResource(), ResourceStates::RenderTarget, ResourceStates::Present); const FrameConstants frameConstants = BuildFrameConstants(
m_commandList.Close(); static_cast<uint32_t>(m_width),
static_cast<uint32_t>(m_height),
m_gaussianSceneData.splatCount);
m_prepareDescriptorSet->WriteConstant(0, &frameConstants, sizeof(frameConstants));
m_debugDescriptorSet->WriteConstant(0, &frameConstants, sizeof(frameConstants));
void* commandLists[] = { &m_commandList }; if (m_preparedViewBuffer->GetState() != ResourceStates::UnorderedAccess) {
m_commandQueue.ExecuteCommandLists(1, commandLists); m_commandList.TransitionBarrier(
m_swapChain.Present(1, 0); m_preparedViewBuffer->GetResource(),
m_preparedViewBuffer->GetState(),
ResourceStates::UnorderedAccess);
m_preparedViewBuffer->SetState(ResourceStates::UnorderedAccess);
}
m_commandList.SetPipelineState(m_preparePipelineState);
RHIDescriptorSet* prepareSets[] = { m_prepareDescriptorSet };
m_commandList.SetComputeDescriptorSets(0, 1, prepareSets, m_preparePipelineLayout);
m_commandList.Dispatch((m_gaussianSceneData.splatCount + (kPrepareThreadGroupSize - 1)) / kPrepareThreadGroupSize, 1, 1);
m_commandList.UAVBarrier(m_preparedViewBuffer->GetResource());
if (m_sortKeyBuffer->GetState() != ResourceStates::UnorderedAccess) {
m_commandList.TransitionBarrier(
m_sortKeyBuffer->GetResource(),
m_sortKeyBuffer->GetState(),
ResourceStates::UnorderedAccess);
m_sortKeyBuffer->SetState(ResourceStates::UnorderedAccess);
}
m_sortDescriptorSet->WriteConstant(0, &frameConstants, sizeof(frameConstants));
m_commandList.SetPipelineState(m_sortPipelineState);
RHIDescriptorSet* sortSets[] = { m_sortDescriptorSet };
m_commandList.SetComputeDescriptorSets(0, 1, sortSets, m_sortPipelineLayout);
m_commandList.Dispatch((m_gaussianSceneData.splatCount + (kSortThreadGroupSize - 1)) / kSortThreadGroupSize, 1, 1);
m_commandList.UAVBarrier(m_sortKeyBuffer->GetResource());
m_commandList.TransitionBarrier(
m_preparedViewBuffer->GetResource(),
ResourceStates::UnorderedAccess,
ResourceStates::NonPixelShaderResource);
m_preparedViewBuffer->SetState(ResourceStates::NonPixelShaderResource);
m_commandList.SetPipelineState(m_debugPipelineState);
RHIDescriptorSet* debugSets[] = { m_debugDescriptorSet };
m_commandList.SetGraphicsDescriptorSets(0, 1, debugSets, m_debugPipelineLayout);
m_commandList.SetPrimitiveTopology(PrimitiveTopology::PointList);
m_commandList.Draw(1u, m_gaussianSceneData.splatCount, 0u, 0u);
if (captureScreenshot) {
AppendTrace("RenderFrame: close+execute capture pre-screenshot");
m_commandList.Close();
void* commandLists[] = { &m_commandList };
m_commandQueue.ExecuteCommandLists(1, commandLists);
AppendTrace("RenderFrame: WaitForIdle before screenshot");
m_commandQueue.WaitForIdle();
AppendTrace("RenderFrame: Capture sort key snapshot");
if (!CaptureSortKeySnapshot()) {
AppendTrace(std::string("RenderFrame: Capture sort key snapshot failed: ") + NarrowAscii(m_lastErrorMessage));
}
if (!m_screenshotPath.empty()) {
const std::filesystem::path screenshotPath = ResolveNearExecutable(m_screenshotPath);
if (!screenshotPath.parent_path().empty()) {
std::filesystem::create_directories(screenshotPath.parent_path());
}
AppendTrace("RenderFrame: Capture screenshot");
D3D12Screenshot::Capture(
m_device,
m_commandQueue,
backBuffer,
screenshotPath.string().c_str());
AppendTrace("RenderFrame: Capture screenshot done");
}
m_commandAllocator.Reset();
m_commandList.Reset();
m_commandList.TransitionBarrier(backBuffer.GetResource(), ResourceStates::RenderTarget, ResourceStates::Present);
m_commandList.Close();
void* presentCommandLists[] = { &m_commandList };
AppendTrace("RenderFrame: execute final present-transition list");
m_commandQueue.ExecuteCommandLists(1, presentCommandLists);
} else {
m_commandList.TransitionBarrier(backBuffer.GetResource(), ResourceStates::RenderTarget, ResourceStates::Present);
m_commandList.Close();
void* commandLists[] = { &m_commandList };
AppendTrace("RenderFrame: execute+present");
m_commandQueue.ExecuteCommandLists(1, commandLists);
m_swapChain.Present(1, 0);
}
m_hasRenderedAtLeastOneFrame = true; m_hasRenderedAtLeastOneFrame = true;
AppendTrace("RenderFrame: end");
} }
} // namespace XC3DGSD3D12 } // namespace XC3DGSD3D12

View File

@@ -53,6 +53,15 @@ struct RawGaussianSplat {
Float4 rotation = {}; Float4 rotation = {};
}; };
struct GaussianPlyPropertyLayout {
const PlyProperty* position[3] = {};
const PlyProperty* dc0[3] = {};
const PlyProperty* opacity = nullptr;
const PlyProperty* scale[3] = {};
const PlyProperty* rotation[4] = {};
std::array<const PlyProperty*, GaussianSplatRuntimeData::kShCoefficientCount * 3> sh = {};
};
std::string TrimTrailingCarriageReturn(std::string line) { std::string TrimTrailingCarriageReturn(std::string line) {
if (!line.empty() && line.back() == '\r') { if (!line.empty() && line.back() == '\r') {
line.pop_back(); line.pop_back();
@@ -243,6 +252,39 @@ bool RequireProperty(
return true; return true;
} }
bool BuildGaussianPlyPropertyLayout(
const std::unordered_map<std::string_view, const PlyProperty*>& propertyMap,
GaussianPlyPropertyLayout& outLayout,
std::string& outErrorMessage) {
outLayout = {};
if (!RequireProperty(propertyMap, "x", outLayout.position[0], outErrorMessage) ||
!RequireProperty(propertyMap, "y", outLayout.position[1], outErrorMessage) ||
!RequireProperty(propertyMap, "z", outLayout.position[2], outErrorMessage) ||
!RequireProperty(propertyMap, "f_dc_0", outLayout.dc0[0], outErrorMessage) ||
!RequireProperty(propertyMap, "f_dc_1", outLayout.dc0[1], outErrorMessage) ||
!RequireProperty(propertyMap, "f_dc_2", outLayout.dc0[2], outErrorMessage) ||
!RequireProperty(propertyMap, "opacity", outLayout.opacity, outErrorMessage) ||
!RequireProperty(propertyMap, "scale_0", outLayout.scale[0], outErrorMessage) ||
!RequireProperty(propertyMap, "scale_1", outLayout.scale[1], outErrorMessage) ||
!RequireProperty(propertyMap, "scale_2", outLayout.scale[2], outErrorMessage) ||
!RequireProperty(propertyMap, "rot_0", outLayout.rotation[0], outErrorMessage) ||
!RequireProperty(propertyMap, "rot_1", outLayout.rotation[1], outErrorMessage) ||
!RequireProperty(propertyMap, "rot_2", outLayout.rotation[2], outErrorMessage) ||
!RequireProperty(propertyMap, "rot_3", outLayout.rotation[3], outErrorMessage)) {
return false;
}
for (uint32_t index = 0; index < outLayout.sh.size(); ++index) {
const std::string propertyName = "f_rest_" + std::to_string(index);
if (!RequireProperty(propertyMap, propertyName, outLayout.sh[index], outErrorMessage)) {
return false;
}
}
return true;
}
Float3 Min(const Float3& a, const Float3& b) { Float3 Min(const Float3& a, const Float3& b) {
return { return {
std::min(a.x, b.x), std::min(a.x, b.x),
@@ -390,32 +432,31 @@ void WriteFloat4(std::vector<std::byte>& bytes, size_t offset, float x, float y,
bool ReadGaussianSplat( bool ReadGaussianSplat(
const std::byte* vertexBytes, const std::byte* vertexBytes,
const std::unordered_map<std::string_view, const PlyProperty*>& propertyMap, const GaussianPlyPropertyLayout& propertyLayout,
RawGaussianSplat& outSplat, RawGaussianSplat& outSplat,
std::string& outErrorMessage) { std::string& outErrorMessage) {
const PlyProperty* property = nullptr; auto readFloat = [&](const PlyProperty* property, float& outValue) -> bool {
if (property == nullptr) {
auto readFloat = [&](std::string_view name, float& outValue) -> bool { outErrorMessage = "Gaussian PLY property layout is incomplete.";
if (!RequireProperty(propertyMap, name, property, outErrorMessage)) {
return false; return false;
} }
return ReadPropertyAsFloat(vertexBytes, *property, outValue); return ReadPropertyAsFloat(vertexBytes, *property, outValue);
}; };
if (!readFloat("x", outSplat.position.x) || if (!readFloat(propertyLayout.position[0], outSplat.position.x) ||
!readFloat("y", outSplat.position.y) || !readFloat(propertyLayout.position[1], outSplat.position.y) ||
!readFloat("z", outSplat.position.z) || !readFloat(propertyLayout.position[2], outSplat.position.z) ||
!readFloat("f_dc_0", outSplat.dc0.x) || !readFloat(propertyLayout.dc0[0], outSplat.dc0.x) ||
!readFloat("f_dc_1", outSplat.dc0.y) || !readFloat(propertyLayout.dc0[1], outSplat.dc0.y) ||
!readFloat("f_dc_2", outSplat.dc0.z) || !readFloat(propertyLayout.dc0[2], outSplat.dc0.z) ||
!readFloat("opacity", outSplat.opacity) || !readFloat(propertyLayout.opacity, outSplat.opacity) ||
!readFloat("scale_0", outSplat.scale.x) || !readFloat(propertyLayout.scale[0], outSplat.scale.x) ||
!readFloat("scale_1", outSplat.scale.y) || !readFloat(propertyLayout.scale[1], outSplat.scale.y) ||
!readFloat("scale_2", outSplat.scale.z) || !readFloat(propertyLayout.scale[2], outSplat.scale.z) ||
!readFloat("rot_0", outSplat.rotation.x) || !readFloat(propertyLayout.rotation[0], outSplat.rotation.x) ||
!readFloat("rot_1", outSplat.rotation.y) || !readFloat(propertyLayout.rotation[1], outSplat.rotation.y) ||
!readFloat("rot_2", outSplat.rotation.z) || !readFloat(propertyLayout.rotation[2], outSplat.rotation.z) ||
!readFloat("rot_3", outSplat.rotation.w)) { !readFloat(propertyLayout.rotation[3], outSplat.rotation.w)) {
if (outErrorMessage.empty()) { if (outErrorMessage.empty()) {
outErrorMessage = "Failed to read required Gaussian splat PLY properties."; outErrorMessage = "Failed to read required Gaussian splat PLY properties.";
} }
@@ -424,8 +465,7 @@ bool ReadGaussianSplat(
std::array<float, GaussianSplatRuntimeData::kShCoefficientCount * 3> shRaw = {}; std::array<float, GaussianSplatRuntimeData::kShCoefficientCount * 3> shRaw = {};
for (uint32_t index = 0; index < shRaw.size(); ++index) { for (uint32_t index = 0; index < shRaw.size(); ++index) {
const std::string propertyName = "f_rest_" + std::to_string(index); if (!readFloat(propertyLayout.sh[index], shRaw[index])) {
if (!readFloat(propertyName, shRaw[index])) {
if (outErrorMessage.empty()) { if (outErrorMessage.empty()) {
outErrorMessage = "Failed to read SH rest coefficients from PLY."; outErrorMessage = "Failed to read SH rest coefficients from PLY.";
} }
@@ -478,6 +518,11 @@ bool LoadGaussianSceneFromPly(
return false; return false;
} }
GaussianPlyPropertyLayout propertyLayout;
if (!BuildGaussianPlyPropertyLayout(propertyMap, propertyLayout, outErrorMessage)) {
return false;
}
outData.splatCount = header.vertexCount; outData.splatCount = header.vertexCount;
outData.colorTextureWidth = GaussianSplatRuntimeData::kColorTextureWidth; outData.colorTextureWidth = GaussianSplatRuntimeData::kColorTextureWidth;
outData.colorTextureHeight = outData.colorTextureHeight =
@@ -513,7 +558,7 @@ bool LoadGaussianSceneFromPly(
} }
RawGaussianSplat splat; RawGaussianSplat splat;
if (!ReadGaussianSplat(vertexBytes.data(), propertyMap, splat, outErrorMessage)) { if (!ReadGaussianSplat(vertexBytes.data(), propertyLayout, splat, outErrorMessage)) {
return false; return false;
} }

View File

@@ -20,6 +20,9 @@ int WINAPI wWinMain(HINSTANCE instance, HINSTANCE, PWSTR, int showCommand) {
} else if (std::wstring(arguments[index]) == L"--summary-file") { } else if (std::wstring(arguments[index]) == L"--summary-file") {
app.SetSummaryPath(arguments[index + 1]); app.SetSummaryPath(arguments[index + 1]);
++index; ++index;
} else if (std::wstring(arguments[index]) == L"--screenshot-file") {
app.SetScreenshotPath(arguments[index + 1]);
++index;
} }
} }
if (arguments != nullptr) { if (arguments != nullptr) {