Stabilize 3DGS D3D12 phase 3 and sort key setup

This commit is contained in:
2026-04-13 02:23:39 +08:00
parent 1d6f2e290d
commit b7428b0ef1
10 changed files with 1281 additions and 33 deletions

View File

@@ -2,17 +2,43 @@
#include <d3d12.h>
#include <dxgi1_4.h>
#include <DirectXMath.h>
#include <algorithm>
#include <cstring>
#include <filesystem>
#include <fstream>
#include <string_view>
#include "XCEngine/RHI/D3D12/D3D12Screenshot.h"
#include "XCEngine/RHI/RHIDescriptorPool.h"
#include "XCEngine/RHI/RHIDescriptorSet.h"
#include "XCEngine/RHI/RHIPipelineLayout.h"
#include "XCEngine/RHI/RHIPipelineState.h"
namespace XC3DGSD3D12 {
using namespace XCEngine::RHI;
namespace {
constexpr wchar_t kWindowClassName[] = L"XC3DGSD3D12WindowClass";
constexpr wchar_t kWindowTitle[] = L"XC 3DGS D3D12 MVS - Phase 1";
constexpr float kClearColor[4] = { 0.08f, 0.12f, 0.18f, 1.0f };
constexpr wchar_t kWindowTitle[] = L"XC 3DGS D3D12 MVS - Phase 3";
constexpr float kClearColor[4] = { 0.04f, 0.05f, 0.07f, 1.0f };
constexpr uint32_t kPrepareThreadGroupSize = 64u;
constexpr uint32_t kSortThreadGroupSize = 64u;
struct FrameConstants {
float viewProjection[16] = {};
float view[16] = {};
float projection[16] = {};
float cameraWorldPos[4] = {};
float screenParams[4] = {};
float settings[4] = {};
};
static_assert(sizeof(FrameConstants) % 16 == 0, "Frame constants must stay 16-byte aligned.");
static_assert(sizeof(PreparedSplatView) == 40, "Prepared view buffer layout must match shader.");
std::filesystem::path GetExecutableDirectory() {
std::wstring pathBuffer;
@@ -29,8 +55,106 @@ std::filesystem::path ResolveNearExecutable(const std::wstring& path) {
}
return GetExecutableDirectory() / inputPath;
}
std::filesystem::path ResolveShaderPath(std::wstring_view fileName) {
return GetExecutableDirectory() / L"shaders" / std::filesystem::path(fileName);
}
std::vector<uint8_t> LoadBinaryFile(const std::filesystem::path& filePath) {
std::ifstream input(filePath, std::ios::binary);
if (!input.is_open()) {
return {};
}
input.seekg(0, std::ios::end);
const std::streamoff size = input.tellg();
if (size <= 0) {
return {};
}
input.seekg(0, std::ios::beg);
std::vector<uint8_t> bytes(static_cast<size_t>(size));
input.read(reinterpret_cast<char*>(bytes.data()), size);
if (!input) {
return {};
}
return bytes;
}
std::string NarrowAscii(std::wstring_view text) {
std::string result;
result.reserve(text.size());
for (wchar_t ch : text) {
result.push_back(ch >= 0 && ch <= 0x7F ? static_cast<char>(ch) : '?');
}
return result;
}
void AppendTrace(std::string_view message) {
const std::filesystem::path tracePath = GetExecutableDirectory() / "phase3_trace.log";
std::ofstream file(tracePath, std::ios::app);
if (!file.is_open()) {
return;
}
file << GetTickCount64() << " | " << message << '\n';
}
void AppendTrace(const std::wstring& message) {
AppendTrace(NarrowAscii(message));
}
void StoreMatrixTransposed(const DirectX::XMMATRIX& matrix, float* destination) {
DirectX::XMFLOAT4X4 output = {};
DirectX::XMStoreFloat4x4(&output, DirectX::XMMatrixTranspose(matrix));
std::memcpy(destination, &output, sizeof(output));
}
FrameConstants BuildFrameConstants(uint32_t width, uint32_t height, uint32_t splatCount) {
using namespace DirectX;
const XMVECTOR eye = XMVectorSet(0.0f, 0.5f, 1.0f, 1.0f);
const XMVECTOR target = XMVectorSet(0.0f, 0.5f, -5.0f, 1.0f);
const XMVECTOR up = XMVectorSet(0.0f, 1.0f, 0.0f, 0.0f);
const float aspect = height > 0 ? static_cast<float>(width) / static_cast<float>(height) : 1.0f;
const XMMATRIX view = XMMatrixLookAtRH(eye, target, up);
const XMMATRIX projection = XMMatrixPerspectiveFovRH(XMConvertToRadians(60.0f), aspect, 0.1f, 200.0f);
const XMMATRIX viewProjection = XMMatrixMultiply(view, projection);
FrameConstants constants = {};
StoreMatrixTransposed(viewProjection, constants.viewProjection);
StoreMatrixTransposed(view, constants.view);
StoreMatrixTransposed(projection, constants.projection);
constants.cameraWorldPos[0] = 0.0f;
constants.cameraWorldPos[1] = 0.5f;
constants.cameraWorldPos[2] = 1.0f;
constants.cameraWorldPos[3] = 1.0f;
constants.screenParams[0] = static_cast<float>(width);
constants.screenParams[1] = static_cast<float>(height);
constants.screenParams[2] = width > 0 ? 1.0f / static_cast<float>(width) : 0.0f;
constants.screenParams[3] = height > 0 ? 1.0f / static_cast<float>(height) : 0.0f;
constants.settings[0] = static_cast<float>(splatCount);
constants.settings[1] = 1.0f; // opacity scale
constants.settings[2] = 3.0f; // SH order
constants.settings[3] = 1.0f;
return constants;
}
template <typename T>
void ShutdownAndDelete(T*& object) {
if (object != nullptr) {
object->Shutdown();
delete object;
object = nullptr;
}
}
} // namespace
App::App() = default;
App::~App() {
@@ -38,36 +162,66 @@ App::~App() {
}
bool App::Initialize(HINSTANCE instance, int showCommand) {
AppendTrace("Initialize: begin");
m_instance = instance;
m_lastErrorMessage.clear();
AppendTrace("Initialize: LoadGaussianScene");
if (!LoadGaussianScene()) {
AppendTrace(std::string("Initialize: LoadGaussianScene failed: ") + NarrowAscii(m_lastErrorMessage));
return false;
}
AppendTrace("Initialize: RegisterWindowClass");
if (!RegisterWindowClass(instance)) {
m_lastErrorMessage = L"Failed to register the Win32 window class.";
AppendTrace(std::string("Initialize: RegisterWindowClass failed: ") + NarrowAscii(m_lastErrorMessage));
return false;
}
AppendTrace("Initialize: CreateMainWindow");
if (!CreateMainWindow(instance, showCommand)) {
m_lastErrorMessage = L"Failed to create the main window.";
AppendTrace(std::string("Initialize: CreateMainWindow failed: ") + NarrowAscii(m_lastErrorMessage));
return false;
}
AppendTrace("Initialize: InitializeRhi");
if (!InitializeRhi()) {
if (m_lastErrorMessage.empty()) {
m_lastErrorMessage = L"Failed to initialize the D3D12 RHI objects.";
}
AppendTrace(std::string("Initialize: InitializeRhi failed: ") + NarrowAscii(m_lastErrorMessage));
return false;
}
AppendTrace("Initialize: InitializeGaussianGpuResources");
if (!InitializeGaussianGpuResources()) {
AppendTrace(std::string("Initialize: InitializeGaussianGpuResources failed: ") + NarrowAscii(m_lastErrorMessage));
return false;
}
AppendTrace("Initialize: InitializePreparePassResources");
if (!InitializePreparePassResources()) {
AppendTrace(std::string("Initialize: InitializePreparePassResources failed: ") + NarrowAscii(m_lastErrorMessage));
return false;
}
AppendTrace("Initialize: InitializeSortResources");
if (!InitializeSortResources()) {
AppendTrace(std::string("Initialize: InitializeSortResources failed: ") + NarrowAscii(m_lastErrorMessage));
return false;
}
AppendTrace("Initialize: InitializeDebugDrawResources");
if (!InitializeDebugDrawResources()) {
AppendTrace(std::string("Initialize: InitializeDebugDrawResources failed: ") + NarrowAscii(m_lastErrorMessage));
return false;
}
m_isInitialized = true;
m_running = true;
AppendTrace("Initialize: success");
return true;
}
@@ -83,15 +237,22 @@ void App::SetSummaryPath(std::wstring summaryPath) {
m_summaryPath = std::move(summaryPath);
}
void App::SetScreenshotPath(std::wstring screenshotPath) {
m_screenshotPath = std::move(screenshotPath);
}
const std::wstring& App::GetLastErrorMessage() const {
return m_lastErrorMessage;
}
int App::Run() {
AppendTrace("Run: begin");
MSG message = {};
int exitCode = 0;
while (m_running) {
while (PeekMessage(&message, nullptr, 0, 0, PM_REMOVE)) {
if (message.message == WM_QUIT) {
exitCode = static_cast<int>(message.wParam);
m_running = false;
break;
}
@@ -104,16 +265,21 @@ int App::Run() {
break;
}
RenderFrame();
const bool captureScreenshot = m_frameLimit > 0 && (m_renderedFrameCount + 1) >= m_frameLimit;
AppendTrace(captureScreenshot ? "Run: RenderFrame capture" : "Run: RenderFrame");
RenderFrame(captureScreenshot);
AppendTrace("Run: RenderFrame complete");
++m_renderedFrameCount;
if (m_frameLimit > 0 && m_renderedFrameCount >= m_frameLimit) {
m_running = false;
AppendTrace("Run: frame limit reached, posting WM_CLOSE");
PostMessageW(m_hwnd, WM_CLOSE, 0, 0);
}
}
return static_cast<int>(message.wParam);
AppendTrace("Run: end");
return exitCode;
}
LRESULT CALLBACK App::StaticWindowProc(HWND hwnd, UINT message, WPARAM wParam, LPARAM lParam) {
@@ -376,6 +542,361 @@ bool App::InitializeGaussianGpuResources() {
return true;
}
bool App::InitializePreparePassResources() {
DescriptorSetLayoutBinding bindings[6] = {};
bindings[0].binding = 0;
bindings[0].type = static_cast<uint32_t>(DescriptorType::CBV);
bindings[0].count = 1;
bindings[0].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[1].binding = 1;
bindings[1].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[1].count = 1;
bindings[1].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[1].resourceDimension = ResourceViewDimension::RawBuffer;
bindings[2].binding = 2;
bindings[2].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[2].count = 1;
bindings[2].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[2].resourceDimension = ResourceViewDimension::RawBuffer;
bindings[3].binding = 3;
bindings[3].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[3].count = 1;
bindings[3].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[3].resourceDimension = ResourceViewDimension::Texture2D;
bindings[4].binding = 4;
bindings[4].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[4].count = 1;
bindings[4].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[4].resourceDimension = ResourceViewDimension::RawBuffer;
bindings[5].binding = 5;
bindings[5].type = static_cast<uint32_t>(DescriptorType::UAV);
bindings[5].count = 1;
bindings[5].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[5].resourceDimension = ResourceViewDimension::StructuredBuffer;
DescriptorSetLayoutDesc setLayout = {};
setLayout.bindings = bindings;
setLayout.bindingCount = 6;
RHIPipelineLayoutDesc pipelineLayoutDesc = {};
pipelineLayoutDesc.setLayouts = &setLayout;
pipelineLayoutDesc.setLayoutCount = 1;
m_preparePipelineLayout = m_device.CreatePipelineLayout(pipelineLayoutDesc);
if (m_preparePipelineLayout == nullptr) {
m_lastErrorMessage = L"Failed to create the prepare pass pipeline layout.";
return false;
}
BufferDesc preparedViewBufferDesc = {};
preparedViewBufferDesc.size = static_cast<uint64_t>(sizeof(PreparedSplatView)) * m_gaussianSceneData.splatCount;
preparedViewBufferDesc.stride = sizeof(PreparedSplatView);
preparedViewBufferDesc.bufferType = static_cast<uint32_t>(BufferType::Storage);
preparedViewBufferDesc.flags = static_cast<uint64_t>(BufferFlags::AllowUnorderedAccess);
m_preparedViewBuffer = static_cast<D3D12Buffer*>(m_device.CreateBuffer(preparedViewBufferDesc));
if (m_preparedViewBuffer == nullptr) {
m_lastErrorMessage = L"Failed to create the prepared view buffer.";
return false;
}
m_preparedViewBuffer->SetStride(sizeof(PreparedSplatView));
m_preparedViewBuffer->SetBufferType(BufferType::Storage);
ResourceViewDesc structuredViewDesc = {};
structuredViewDesc.dimension = ResourceViewDimension::StructuredBuffer;
structuredViewDesc.structureByteStride = sizeof(PreparedSplatView);
structuredViewDesc.elementCount = m_gaussianSceneData.splatCount;
m_preparedViewSrv.reset(static_cast<D3D12ResourceView*>(m_device.CreateShaderResourceView(m_preparedViewBuffer, structuredViewDesc)));
if (!m_preparedViewSrv) {
m_lastErrorMessage = L"Failed to create the prepared view SRV.";
return false;
}
m_preparedViewUav.reset(static_cast<D3D12ResourceView*>(m_device.CreateUnorderedAccessView(m_preparedViewBuffer, structuredViewDesc)));
if (!m_preparedViewUav) {
m_lastErrorMessage = L"Failed to create the prepared view UAV.";
return false;
}
DescriptorPoolDesc poolDesc = {};
poolDesc.type = DescriptorHeapType::CBV_SRV_UAV;
poolDesc.descriptorCount = 5;
poolDesc.shaderVisible = true;
m_prepareDescriptorPool = m_device.CreateDescriptorPool(poolDesc);
if (m_prepareDescriptorPool == nullptr) {
m_lastErrorMessage = L"Failed to create the prepare pass descriptor pool.";
return false;
}
m_prepareDescriptorSet = m_prepareDescriptorPool->AllocateSet(setLayout);
if (m_prepareDescriptorSet == nullptr) {
m_lastErrorMessage = L"Failed to allocate the prepare pass descriptor set.";
return false;
}
m_prepareDescriptorSet->Update(1, m_gaussianPositionView.get());
m_prepareDescriptorSet->Update(2, m_gaussianOtherView.get());
m_prepareDescriptorSet->Update(3, m_gaussianColorView.get());
m_prepareDescriptorSet->Update(4, m_gaussianShView.get());
m_prepareDescriptorSet->Update(5, m_preparedViewUav.get());
ShaderCompileDesc computeShaderDesc = {};
computeShaderDesc.fileName = ResolveShaderPath(L"PrepareGaussiansCS.hlsl").wstring();
computeShaderDesc.entryPoint = L"MainCS";
computeShaderDesc.profile = L"cs_5_0";
ComputePipelineDesc pipelineDesc = {};
pipelineDesc.pipelineLayout = m_preparePipelineLayout;
pipelineDesc.computeShader = computeShaderDesc;
m_preparePipelineState = m_device.CreateComputePipelineState(pipelineDesc);
if (m_preparePipelineState == nullptr) {
m_lastErrorMessage = L"Failed to create the prepare pass pipeline state.";
return false;
}
return true;
}
bool App::InitializeSortResources() {
std::vector<uint32_t> initialOrder(static_cast<size_t>(m_gaussianSceneData.splatCount));
for (uint32_t index = 0; index < m_gaussianSceneData.splatCount; ++index) {
initialOrder[index] = index;
}
ID3D12Device* device = m_device.GetDevice();
ID3D12GraphicsCommandList* commandList = m_commandList.GetCommandList();
m_commandAllocator.Reset();
m_commandList.Reset();
const D3D12_RESOURCE_STATES shaderResourceState =
D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE | D3D12_RESOURCE_STATE_PIXEL_SHADER_RESOURCE;
m_orderBuffer = new D3D12Buffer();
m_orderBuffer->SetStride(sizeof(uint32_t));
m_orderBuffer->SetBufferType(BufferType::Storage);
if (!m_orderBuffer->InitializeWithData(
device,
commandList,
initialOrder.data(),
static_cast<uint64_t>(initialOrder.size() * sizeof(uint32_t)),
shaderResourceState)) {
m_lastErrorMessage = L"Failed to initialize the order buffer.";
return false;
}
m_orderBuffer->SetState(ResourceStates::NonPixelShaderResource);
BufferDesc sortKeyBufferDesc = {};
sortKeyBufferDesc.size = static_cast<uint64_t>(m_gaussianSceneData.splatCount) * sizeof(uint32_t);
sortKeyBufferDesc.stride = sizeof(uint32_t);
sortKeyBufferDesc.bufferType = static_cast<uint32_t>(BufferType::Storage);
sortKeyBufferDesc.flags = static_cast<uint64_t>(BufferFlags::AllowUnorderedAccess);
m_sortKeyBuffer = static_cast<D3D12Buffer*>(m_device.CreateBuffer(sortKeyBufferDesc));
if (m_sortKeyBuffer == nullptr) {
m_lastErrorMessage = L"Failed to create the sort key buffer.";
return false;
}
m_sortKeyBuffer->SetStride(sizeof(uint32_t));
m_sortKeyBuffer->SetBufferType(BufferType::Storage);
ResourceViewDesc structuredViewDesc = {};
structuredViewDesc.dimension = ResourceViewDimension::StructuredBuffer;
structuredViewDesc.structureByteStride = sizeof(uint32_t);
structuredViewDesc.elementCount = m_gaussianSceneData.splatCount;
m_orderBufferSrv.reset(static_cast<D3D12ResourceView*>(m_device.CreateShaderResourceView(m_orderBuffer, structuredViewDesc)));
if (!m_orderBufferSrv) {
m_lastErrorMessage = L"Failed to create the order buffer SRV.";
return false;
}
m_sortKeySrv.reset(static_cast<D3D12ResourceView*>(m_device.CreateShaderResourceView(m_sortKeyBuffer, structuredViewDesc)));
if (!m_sortKeySrv) {
m_lastErrorMessage = L"Failed to create the sort key SRV.";
return false;
}
m_sortKeyUav.reset(static_cast<D3D12ResourceView*>(m_device.CreateUnorderedAccessView(m_sortKeyBuffer, structuredViewDesc)));
if (!m_sortKeyUav) {
m_lastErrorMessage = L"Failed to create the sort key UAV.";
return false;
}
DescriptorSetLayoutBinding bindings[4] = {};
bindings[0].binding = 0;
bindings[0].type = static_cast<uint32_t>(DescriptorType::CBV);
bindings[0].count = 1;
bindings[0].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[1].binding = 1;
bindings[1].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[1].count = 1;
bindings[1].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[1].resourceDimension = ResourceViewDimension::RawBuffer;
bindings[2].binding = 2;
bindings[2].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[2].count = 1;
bindings[2].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[2].resourceDimension = ResourceViewDimension::StructuredBuffer;
bindings[3].binding = 3;
bindings[3].type = static_cast<uint32_t>(DescriptorType::UAV);
bindings[3].count = 1;
bindings[3].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[3].resourceDimension = ResourceViewDimension::StructuredBuffer;
DescriptorSetLayoutDesc setLayout = {};
setLayout.bindings = bindings;
setLayout.bindingCount = 4;
RHIPipelineLayoutDesc pipelineLayoutDesc = {};
pipelineLayoutDesc.setLayouts = &setLayout;
pipelineLayoutDesc.setLayoutCount = 1;
m_sortPipelineLayout = m_device.CreatePipelineLayout(pipelineLayoutDesc);
if (m_sortPipelineLayout == nullptr) {
m_lastErrorMessage = L"Failed to create the sort pipeline layout.";
return false;
}
DescriptorPoolDesc poolDesc = {};
poolDesc.type = DescriptorHeapType::CBV_SRV_UAV;
poolDesc.descriptorCount = 3;
poolDesc.shaderVisible = true;
m_sortDescriptorPool = m_device.CreateDescriptorPool(poolDesc);
if (m_sortDescriptorPool == nullptr) {
m_lastErrorMessage = L"Failed to create the sort descriptor pool.";
return false;
}
m_sortDescriptorSet = m_sortDescriptorPool->AllocateSet(setLayout);
if (m_sortDescriptorSet == nullptr) {
m_lastErrorMessage = L"Failed to allocate the sort descriptor set.";
return false;
}
m_sortDescriptorSet->Update(1, m_gaussianPositionView.get());
m_sortDescriptorSet->Update(2, m_orderBufferSrv.get());
m_sortDescriptorSet->Update(3, m_sortKeyUav.get());
ShaderCompileDesc computeShaderDesc = {};
const std::filesystem::path compiledShaderPath = ResolveShaderPath(L"BuildSortKeysCS.dxil");
const std::vector<uint8_t> compiledShaderBinary = LoadBinaryFile(compiledShaderPath);
if (!compiledShaderBinary.empty()) {
computeShaderDesc.profile = L"cs_6_6";
computeShaderDesc.compiledBinaryBackend = ShaderBinaryBackend::D3D12;
computeShaderDesc.compiledBinary = compiledShaderBinary;
} else {
computeShaderDesc.fileName = ResolveShaderPath(L"BuildSortKeysCS.hlsl").wstring();
computeShaderDesc.entryPoint = L"MainCS";
computeShaderDesc.profile = L"cs_5_0";
}
ComputePipelineDesc pipelineDesc = {};
pipelineDesc.pipelineLayout = m_sortPipelineLayout;
pipelineDesc.computeShader = computeShaderDesc;
m_sortPipelineState = m_device.CreateComputePipelineState(pipelineDesc);
if (m_sortPipelineState == nullptr) {
m_lastErrorMessage = L"Failed to create the sort pipeline state.";
return false;
}
m_commandList.Close();
void* commandLists[] = { &m_commandList };
m_commandQueue.ExecuteCommandLists(1, commandLists);
m_commandQueue.WaitForIdle();
return true;
}
bool App::InitializeDebugDrawResources() {
DescriptorSetLayoutBinding bindings[3] = {};
bindings[0].binding = 0;
bindings[0].type = static_cast<uint32_t>(DescriptorType::CBV);
bindings[0].count = 1;
bindings[0].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[1].binding = 1;
bindings[1].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[1].count = 1;
bindings[1].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[1].resourceDimension = ResourceViewDimension::StructuredBuffer;
bindings[2].binding = 2;
bindings[2].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[2].count = 1;
bindings[2].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[2].resourceDimension = ResourceViewDimension::StructuredBuffer;
DescriptorSetLayoutDesc setLayout = {};
setLayout.bindings = bindings;
setLayout.bindingCount = 3;
RHIPipelineLayoutDesc pipelineLayoutDesc = {};
pipelineLayoutDesc.setLayouts = &setLayout;
pipelineLayoutDesc.setLayoutCount = 1;
m_debugPipelineLayout = m_device.CreatePipelineLayout(pipelineLayoutDesc);
if (m_debugPipelineLayout == nullptr) {
m_lastErrorMessage = L"Failed to create the debug draw pipeline layout.";
return false;
}
DescriptorPoolDesc poolDesc = {};
poolDesc.type = DescriptorHeapType::CBV_SRV_UAV;
poolDesc.descriptorCount = 2;
poolDesc.shaderVisible = true;
m_debugDescriptorPool = m_device.CreateDescriptorPool(poolDesc);
if (m_debugDescriptorPool == nullptr) {
m_lastErrorMessage = L"Failed to create the debug draw descriptor pool.";
return false;
}
m_debugDescriptorSet = m_debugDescriptorPool->AllocateSet(setLayout);
if (m_debugDescriptorSet == nullptr) {
m_lastErrorMessage = L"Failed to allocate the debug draw descriptor set.";
return false;
}
m_debugDescriptorSet->Update(1, m_preparedViewSrv.get());
m_debugDescriptorSet->Update(2, m_orderBufferSrv.get());
GraphicsPipelineDesc pipelineDesc = {};
pipelineDesc.pipelineLayout = m_debugPipelineLayout;
pipelineDesc.topologyType = static_cast<uint32_t>(PrimitiveTopologyType::Point);
pipelineDesc.renderTargetCount = 1;
pipelineDesc.renderTargetFormats[0] = static_cast<uint32_t>(Format::R8G8B8A8_UNorm);
pipelineDesc.depthStencilFormat = static_cast<uint32_t>(Format::D24_UNorm_S8_UInt);
pipelineDesc.sampleCount = 1;
pipelineDesc.sampleQuality = 0;
pipelineDesc.rasterizerState.cullMode = static_cast<uint32_t>(CullMode::None);
pipelineDesc.depthStencilState.depthTestEnable = false;
pipelineDesc.depthStencilState.depthWriteEnable = false;
pipelineDesc.vertexShader.fileName = ResolveShaderPath(L"DebugPointsVS.hlsl").wstring();
pipelineDesc.vertexShader.entryPoint = L"MainVS";
pipelineDesc.vertexShader.profile = L"vs_5_0";
pipelineDesc.fragmentShader.fileName = ResolveShaderPath(L"DebugPointsPS.hlsl").wstring();
pipelineDesc.fragmentShader.entryPoint = L"MainPS";
pipelineDesc.fragmentShader.profile = L"ps_5_0";
m_debugPipelineState = m_device.CreatePipelineState(pipelineDesc);
if (m_debugPipelineState == nullptr) {
m_lastErrorMessage = L"Failed to create the debug draw pipeline state.";
return false;
}
return true;
}
void App::ShutdownGaussianGpuResources() {
m_gaussianColorView.reset();
m_gaussianShView.reset();
@@ -388,16 +909,78 @@ void App::ShutdownGaussianGpuResources() {
m_gaussianPositionBuffer.Shutdown();
}
void App::ShutdownPreparePassResources() {
if (m_prepareDescriptorSet != nullptr) {
m_prepareDescriptorSet->Shutdown();
delete m_prepareDescriptorSet;
m_prepareDescriptorSet = nullptr;
}
ShutdownAndDelete(m_prepareDescriptorPool);
ShutdownAndDelete(m_preparePipelineState);
ShutdownAndDelete(m_preparePipelineLayout);
m_preparedViewUav.reset();
m_preparedViewSrv.reset();
if (m_preparedViewBuffer != nullptr) {
m_preparedViewBuffer->Shutdown();
delete m_preparedViewBuffer;
m_preparedViewBuffer = nullptr;
}
}
void App::ShutdownSortResources() {
if (m_sortDescriptorSet != nullptr) {
m_sortDescriptorSet->Shutdown();
delete m_sortDescriptorSet;
m_sortDescriptorSet = nullptr;
}
ShutdownAndDelete(m_sortDescriptorPool);
ShutdownAndDelete(m_sortPipelineState);
ShutdownAndDelete(m_sortPipelineLayout);
m_orderBufferSrv.reset();
m_sortKeyUav.reset();
m_sortKeySrv.reset();
if (m_orderBuffer != nullptr) {
m_orderBuffer->Shutdown();
delete m_orderBuffer;
m_orderBuffer = nullptr;
}
if (m_sortKeyBuffer != nullptr) {
m_sortKeyBuffer->Shutdown();
delete m_sortKeyBuffer;
m_sortKeyBuffer = nullptr;
}
}
void App::ShutdownDebugDrawResources() {
if (m_debugDescriptorSet != nullptr) {
m_debugDescriptorSet->Shutdown();
delete m_debugDescriptorSet;
m_debugDescriptorSet = nullptr;
}
ShutdownAndDelete(m_debugDescriptorPool);
ShutdownAndDelete(m_debugPipelineState);
ShutdownAndDelete(m_debugPipelineLayout);
}
void App::Shutdown() {
AppendTrace("Shutdown: begin");
if (!m_isInitialized && m_hwnd == nullptr) {
AppendTrace("Shutdown: skipped");
return;
}
m_running = false;
if (m_commandQueue.GetCommandQueue() != nullptr) {
AppendTrace("Shutdown: WaitForIdle");
m_commandQueue.WaitForIdle();
}
ShutdownDebugDrawResources();
ShutdownSortResources();
ShutdownPreparePassResources();
ShutdownGaussianGpuResources();
m_commandList.Shutdown();
m_commandAllocator.Shutdown();
@@ -413,6 +996,7 @@ void App::Shutdown() {
m_device.Shutdown();
if (m_hwnd != nullptr) {
AppendTrace("Shutdown: DestroyWindow");
DestroyWindow(m_hwnd);
m_hwnd = nullptr;
}
@@ -423,10 +1007,85 @@ void App::Shutdown() {
}
m_isInitialized = false;
AppendTrace("Shutdown: end");
}
void App::RenderFrame() {
bool App::CaptureSortKeySnapshot() {
if (m_sortKeyBuffer == nullptr || m_gaussianSceneData.splatCount == 0 || m_sortKeySnapshotPath.empty()) {
return true;
}
const uint32_t sampleCount = std::min<uint32_t>(16u, m_gaussianSceneData.splatCount);
const uint64_t sampleBytes = static_cast<uint64_t>(sampleCount) * sizeof(uint32_t);
D3D12Buffer readbackBuffer;
if (!readbackBuffer.Initialize(
m_device.GetDevice(),
sampleBytes,
D3D12_RESOURCE_STATE_COPY_DEST,
D3D12_HEAP_TYPE_READBACK,
D3D12_RESOURCE_FLAG_NONE)) {
m_lastErrorMessage = L"Failed to create the sort key readback buffer.";
return false;
}
m_commandAllocator.Reset();
m_commandList.Reset();
if (m_sortKeyBuffer->GetState() != ResourceStates::CopySrc) {
m_commandList.TransitionBarrier(
m_sortKeyBuffer->GetResource(),
m_sortKeyBuffer->GetState(),
ResourceStates::CopySrc);
m_sortKeyBuffer->SetState(ResourceStates::CopySrc);
}
m_commandList.GetCommandList()->CopyBufferRegion(
readbackBuffer.GetResource(),
0,
m_sortKeyBuffer->GetResource(),
0,
sampleBytes);
m_commandList.Close();
void* commandLists[] = { &m_commandList };
m_commandQueue.ExecuteCommandLists(1, commandLists);
m_commandQueue.WaitForIdle();
const uint32_t* keys = static_cast<const uint32_t*>(readbackBuffer.Map());
if (keys == nullptr) {
m_lastErrorMessage = L"Failed to map the sort key readback buffer.";
readbackBuffer.Shutdown();
return false;
}
const std::filesystem::path snapshotPath = ResolveNearExecutable(m_sortKeySnapshotPath);
if (!snapshotPath.parent_path().empty()) {
std::filesystem::create_directories(snapshotPath.parent_path());
}
std::ofstream output(snapshotPath, std::ios::binary | std::ios::trunc);
if (!output.is_open()) {
readbackBuffer.Unmap();
readbackBuffer.Shutdown();
m_lastErrorMessage = L"Failed to open the sort key snapshot output file.";
return false;
}
output << "sample_count=" << sampleCount << '\n';
for (uint32_t index = 0; index < sampleCount; ++index) {
output << "key[" << index << "]=" << keys[index] << '\n';
}
readbackBuffer.Unmap();
readbackBuffer.Shutdown();
return output.good();
}
void App::RenderFrame(bool captureScreenshot) {
AppendTrace(captureScreenshot ? "RenderFrame: begin capture" : "RenderFrame: begin");
if (m_hasRenderedAtLeastOneFrame) {
AppendTrace("RenderFrame: WaitForPreviousFrame");
m_commandQueue.WaitForPreviousFrame();
}
@@ -459,14 +1118,99 @@ void App::RenderFrame() {
0,
nullptr);
m_commandList.TransitionBarrier(backBuffer.GetResource(), ResourceStates::RenderTarget, ResourceStates::Present);
m_commandList.Close();
const FrameConstants frameConstants = BuildFrameConstants(
static_cast<uint32_t>(m_width),
static_cast<uint32_t>(m_height),
m_gaussianSceneData.splatCount);
m_prepareDescriptorSet->WriteConstant(0, &frameConstants, sizeof(frameConstants));
m_debugDescriptorSet->WriteConstant(0, &frameConstants, sizeof(frameConstants));
void* commandLists[] = { &m_commandList };
m_commandQueue.ExecuteCommandLists(1, commandLists);
m_swapChain.Present(1, 0);
if (m_preparedViewBuffer->GetState() != ResourceStates::UnorderedAccess) {
m_commandList.TransitionBarrier(
m_preparedViewBuffer->GetResource(),
m_preparedViewBuffer->GetState(),
ResourceStates::UnorderedAccess);
m_preparedViewBuffer->SetState(ResourceStates::UnorderedAccess);
}
m_commandList.SetPipelineState(m_preparePipelineState);
RHIDescriptorSet* prepareSets[] = { m_prepareDescriptorSet };
m_commandList.SetComputeDescriptorSets(0, 1, prepareSets, m_preparePipelineLayout);
m_commandList.Dispatch((m_gaussianSceneData.splatCount + (kPrepareThreadGroupSize - 1)) / kPrepareThreadGroupSize, 1, 1);
m_commandList.UAVBarrier(m_preparedViewBuffer->GetResource());
if (m_sortKeyBuffer->GetState() != ResourceStates::UnorderedAccess) {
m_commandList.TransitionBarrier(
m_sortKeyBuffer->GetResource(),
m_sortKeyBuffer->GetState(),
ResourceStates::UnorderedAccess);
m_sortKeyBuffer->SetState(ResourceStates::UnorderedAccess);
}
m_sortDescriptorSet->WriteConstant(0, &frameConstants, sizeof(frameConstants));
m_commandList.SetPipelineState(m_sortPipelineState);
RHIDescriptorSet* sortSets[] = { m_sortDescriptorSet };
m_commandList.SetComputeDescriptorSets(0, 1, sortSets, m_sortPipelineLayout);
m_commandList.Dispatch((m_gaussianSceneData.splatCount + (kSortThreadGroupSize - 1)) / kSortThreadGroupSize, 1, 1);
m_commandList.UAVBarrier(m_sortKeyBuffer->GetResource());
m_commandList.TransitionBarrier(
m_preparedViewBuffer->GetResource(),
ResourceStates::UnorderedAccess,
ResourceStates::NonPixelShaderResource);
m_preparedViewBuffer->SetState(ResourceStates::NonPixelShaderResource);
m_commandList.SetPipelineState(m_debugPipelineState);
RHIDescriptorSet* debugSets[] = { m_debugDescriptorSet };
m_commandList.SetGraphicsDescriptorSets(0, 1, debugSets, m_debugPipelineLayout);
m_commandList.SetPrimitiveTopology(PrimitiveTopology::PointList);
m_commandList.Draw(1u, m_gaussianSceneData.splatCount, 0u, 0u);
if (captureScreenshot) {
AppendTrace("RenderFrame: close+execute capture pre-screenshot");
m_commandList.Close();
void* commandLists[] = { &m_commandList };
m_commandQueue.ExecuteCommandLists(1, commandLists);
AppendTrace("RenderFrame: WaitForIdle before screenshot");
m_commandQueue.WaitForIdle();
AppendTrace("RenderFrame: Capture sort key snapshot");
if (!CaptureSortKeySnapshot()) {
AppendTrace(std::string("RenderFrame: Capture sort key snapshot failed: ") + NarrowAscii(m_lastErrorMessage));
}
if (!m_screenshotPath.empty()) {
const std::filesystem::path screenshotPath = ResolveNearExecutable(m_screenshotPath);
if (!screenshotPath.parent_path().empty()) {
std::filesystem::create_directories(screenshotPath.parent_path());
}
AppendTrace("RenderFrame: Capture screenshot");
D3D12Screenshot::Capture(
m_device,
m_commandQueue,
backBuffer,
screenshotPath.string().c_str());
AppendTrace("RenderFrame: Capture screenshot done");
}
m_commandAllocator.Reset();
m_commandList.Reset();
m_commandList.TransitionBarrier(backBuffer.GetResource(), ResourceStates::RenderTarget, ResourceStates::Present);
m_commandList.Close();
void* presentCommandLists[] = { &m_commandList };
AppendTrace("RenderFrame: execute final present-transition list");
m_commandQueue.ExecuteCommandLists(1, presentCommandLists);
} else {
m_commandList.TransitionBarrier(backBuffer.GetResource(), ResourceStates::RenderTarget, ResourceStates::Present);
m_commandList.Close();
void* commandLists[] = { &m_commandList };
AppendTrace("RenderFrame: execute+present");
m_commandQueue.ExecuteCommandLists(1, commandLists);
m_swapChain.Present(1, 0);
}
m_hasRenderedAtLeastOneFrame = true;
AppendTrace("RenderFrame: end");
}
} // namespace XC3DGSD3D12