Files
XCEngine/MVS/3DGS-D3D12/src/App.cpp

2066 lines
83 KiB
C++

#include "XC3DGSD3D12/App.h"
#include <d3d12.h>
#include <dxgi1_4.h>
#include <DirectXMath.h>
#include <algorithm>
#include <bit>
#include <cstring>
#include <filesystem>
#include <fstream>
#include <string_view>
#include "XCEngine/RHI/D3D12/D3D12Screenshot.h"
#include "XCEngine/RHI/RHIDescriptorPool.h"
#include "XCEngine/RHI/RHIDescriptorSet.h"
#include "XCEngine/RHI/RHIPipelineLayout.h"
#include "XCEngine/RHI/RHIPipelineState.h"
namespace XC3DGSD3D12 {
using namespace XCEngine::RHI;
namespace {
constexpr wchar_t kWindowClassName[] = L"XC3DGSD3D12WindowClass";
constexpr wchar_t kWindowTitle[] = L"XC 3DGS D3D12 MVS - Phase 3";
constexpr float kClearColor[4] = { 0.04f, 0.05f, 0.07f, 1.0f };
constexpr uint32_t kPrepareThreadGroupSize = 64u;
constexpr uint32_t kSortThreadGroupSize = 64u;
constexpr uint32_t kDeviceRadixSortPartitionSize = 3840u;
constexpr uint32_t kDeviceRadixSortRadix = 256u;
constexpr uint32_t kDeviceRadixSortPassCount = 4u;
constexpr bool kUseCpuSortBaseline = true;
struct FrameConstants {
float viewProjection[16] = {};
float view[16] = {};
float projection[16] = {};
float cameraWorldPos[4] = {};
float screenParams[4] = {};
float settings[4] = {};
};
struct RadixSortConstants {
uint32_t numKeys = 0;
uint32_t radixShift = 0;
uint32_t threadBlocks = 0;
uint32_t padding = 0;
};
static_assert(sizeof(FrameConstants) % 16 == 0, "Frame constants must stay 16-byte aligned.");
static_assert(sizeof(RadixSortConstants) % 16 == 0, "Radix sort constants must stay 16-byte aligned.");
static_assert(sizeof(PreparedSplatView) == 40, "Prepared view buffer layout must match shader.");
std::filesystem::path GetExecutableDirectory() {
std::wstring pathBuffer;
pathBuffer.resize(MAX_PATH);
const DWORD pathLength = GetModuleFileNameW(nullptr, pathBuffer.data(), static_cast<DWORD>(pathBuffer.size()));
pathBuffer.resize(pathLength);
return std::filesystem::path(pathBuffer).parent_path();
}
std::filesystem::path ResolveNearExecutable(const std::wstring& path) {
const std::filesystem::path inputPath(path);
if (inputPath.is_absolute()) {
return inputPath;
}
return GetExecutableDirectory() / inputPath;
}
std::filesystem::path ResolveShaderPath(std::wstring_view fileName) {
return GetExecutableDirectory() / L"shaders" / std::filesystem::path(fileName);
}
std::vector<uint8_t> LoadBinaryFile(const std::filesystem::path& filePath) {
std::ifstream input(filePath, std::ios::binary);
if (!input.is_open()) {
return {};
}
input.seekg(0, std::ios::end);
const std::streamoff size = input.tellg();
if (size <= 0) {
return {};
}
input.seekg(0, std::ios::beg);
std::vector<uint8_t> bytes(static_cast<size_t>(size));
input.read(reinterpret_cast<char*>(bytes.data()), size);
if (!input) {
return {};
}
return bytes;
}
std::string NarrowAscii(std::wstring_view text) {
std::string result;
result.reserve(text.size());
for (wchar_t ch : text) {
result.push_back(ch >= 0 && ch <= 0x7F ? static_cast<char>(ch) : '?');
}
return result;
}
void AppendTrace(std::string_view message) {
const std::filesystem::path tracePath = GetExecutableDirectory() / "phase3_trace.log";
std::ofstream file(tracePath, std::ios::app);
if (!file.is_open()) {
return;
}
file << GetTickCount64() << " | " << message << '\n';
}
void AppendTrace(const std::wstring& message) {
AppendTrace(NarrowAscii(message));
}
ShaderCompileDesc BuildDxilShaderDesc(const std::filesystem::path& compiledShaderPath, const std::wstring& profile) {
ShaderCompileDesc shaderDesc = {};
shaderDesc.profile = profile;
shaderDesc.compiledBinaryBackend = ShaderBinaryBackend::D3D12;
shaderDesc.compiledBinary = LoadBinaryFile(compiledShaderPath);
return shaderDesc;
}
void StoreMatrixTransposed(const DirectX::XMMATRIX& matrix, float* destination) {
DirectX::XMFLOAT4X4 output = {};
DirectX::XMStoreFloat4x4(&output, DirectX::XMMatrixTranspose(matrix));
std::memcpy(destination, &output, sizeof(output));
}
DirectX::XMMATRIX BuildSortViewMatrix() {
using namespace DirectX;
const XMVECTOR eye = XMVectorSet(0.0f, 0.5f, 1.0f, 1.0f);
const XMVECTOR target = XMVectorSet(0.0f, 0.5f, -5.0f, 1.0f);
const XMVECTOR up = XMVectorSet(0.0f, 1.0f, 0.0f, 0.0f);
return XMMatrixLookAtRH(eye, target, up);
}
uint32_t FloatToSortableUint(float value) {
const uint32_t bits = std::bit_cast<uint32_t>(value);
const uint32_t mask = (0u - (bits >> 31)) | 0x80000000u;
return bits ^ mask;
}
void BuildCpuSortedOrder(
const GaussianSplatRuntimeData& sceneData,
std::vector<uint32_t>& outOrder,
std::vector<uint32_t>* outSortedKeys = nullptr) {
using namespace DirectX;
const XMMATRIX view = BuildSortViewMatrix();
const float* positionBytes = reinterpret_cast<const float*>(sceneData.positionData.data());
std::vector<std::pair<uint32_t, uint32_t>> sortablePairs(sceneData.splatCount);
for (uint32_t index = 0; index < sceneData.splatCount; ++index) {
const float* position = positionBytes + index * 3u;
const XMVECTOR worldPosition = XMVectorSet(position[0], position[1], position[2], 1.0f);
const XMVECTOR viewPosition = XMVector4Transform(worldPosition, view);
sortablePairs[index] = {
FloatToSortableUint(XMVectorGetZ(viewPosition)),
index,
};
}
std::stable_sort(
sortablePairs.begin(),
sortablePairs.end(),
[](const auto& left, const auto& right) {
return left.first < right.first;
});
outOrder.resize(sortablePairs.size());
if (outSortedKeys != nullptr) {
outSortedKeys->resize(sortablePairs.size());
}
for (size_t index = 0; index < sortablePairs.size(); ++index) {
outOrder[index] = sortablePairs[index].second;
if (outSortedKeys != nullptr) {
(*outSortedKeys)[index] = sortablePairs[index].first;
}
}
}
FrameConstants BuildFrameConstants(uint32_t width, uint32_t height, uint32_t splatCount) {
using namespace DirectX;
const float aspect = height > 0 ? static_cast<float>(width) / static_cast<float>(height) : 1.0f;
const XMMATRIX view = BuildSortViewMatrix();
const XMMATRIX projection = XMMatrixPerspectiveFovRH(XMConvertToRadians(60.0f), aspect, 0.1f, 200.0f);
const XMMATRIX viewProjection = XMMatrixMultiply(view, projection);
FrameConstants constants = {};
StoreMatrixTransposed(viewProjection, constants.viewProjection);
StoreMatrixTransposed(view, constants.view);
StoreMatrixTransposed(projection, constants.projection);
constants.cameraWorldPos[0] = 0.0f;
constants.cameraWorldPos[1] = 0.5f;
constants.cameraWorldPos[2] = 1.0f;
constants.cameraWorldPos[3] = 1.0f;
constants.screenParams[0] = static_cast<float>(width);
constants.screenParams[1] = static_cast<float>(height);
constants.screenParams[2] = width > 0 ? 1.0f / static_cast<float>(width) : 0.0f;
constants.screenParams[3] = height > 0 ? 1.0f / static_cast<float>(height) : 0.0f;
constants.settings[0] = static_cast<float>(splatCount);
constants.settings[1] = 1.0f; // opacity scale
constants.settings[2] = 3.0f; // SH order
constants.settings[3] = 0.9f;
return constants;
}
template <typename T>
void ShutdownAndDelete(T*& object) {
if (object != nullptr) {
object->Shutdown();
delete object;
object = nullptr;
}
}
} // namespace
App::App() = default;
App::~App() {
Shutdown();
}
bool App::Initialize(HINSTANCE instance, int showCommand) {
AppendTrace("Initialize: begin");
m_instance = instance;
m_lastErrorMessage.clear();
AppendTrace("Initialize: LoadGaussianScene");
if (!LoadGaussianScene()) {
AppendTrace(std::string("Initialize: LoadGaussianScene failed: ") + NarrowAscii(m_lastErrorMessage));
return false;
}
AppendTrace("Initialize: RegisterWindowClass");
if (!RegisterWindowClass(instance)) {
m_lastErrorMessage = L"Failed to register the Win32 window class.";
AppendTrace(std::string("Initialize: RegisterWindowClass failed: ") + NarrowAscii(m_lastErrorMessage));
return false;
}
AppendTrace("Initialize: CreateMainWindow");
if (!CreateMainWindow(instance, showCommand)) {
m_lastErrorMessage = L"Failed to create the main window.";
AppendTrace(std::string("Initialize: CreateMainWindow failed: ") + NarrowAscii(m_lastErrorMessage));
return false;
}
AppendTrace("Initialize: InitializeRhi");
if (!InitializeRhi()) {
if (m_lastErrorMessage.empty()) {
m_lastErrorMessage = L"Failed to initialize the D3D12 RHI objects.";
}
AppendTrace(std::string("Initialize: InitializeRhi failed: ") + NarrowAscii(m_lastErrorMessage));
return false;
}
AppendTrace("Initialize: InitializeGaussianGpuResources");
if (!InitializeGaussianGpuResources()) {
AppendTrace(std::string("Initialize: InitializeGaussianGpuResources failed: ") + NarrowAscii(m_lastErrorMessage));
return false;
}
AppendTrace("Initialize: InitializePreparePassResources");
if (!InitializePreparePassResources()) {
AppendTrace(std::string("Initialize: InitializePreparePassResources failed: ") + NarrowAscii(m_lastErrorMessage));
return false;
}
AppendTrace("Initialize: InitializeSortResources");
if (!InitializeSortResources()) {
AppendTrace(std::string("Initialize: InitializeSortResources failed: ") + NarrowAscii(m_lastErrorMessage));
return false;
}
AppendTrace("Initialize: InitializeDebugDrawResources");
if (!InitializeDebugDrawResources()) {
AppendTrace(std::string("Initialize: InitializeDebugDrawResources failed: ") + NarrowAscii(m_lastErrorMessage));
return false;
}
m_isInitialized = true;
m_running = true;
AppendTrace("Initialize: success");
return true;
}
void App::SetFrameLimit(unsigned int frameLimit) {
m_frameLimit = frameLimit;
}
void App::SetGaussianScenePath(std::wstring scenePath) {
m_gaussianScenePath = std::move(scenePath);
}
void App::SetSummaryPath(std::wstring summaryPath) {
m_summaryPath = std::move(summaryPath);
}
void App::SetScreenshotPath(std::wstring screenshotPath) {
m_screenshotPath = std::move(screenshotPath);
}
const std::wstring& App::GetLastErrorMessage() const {
return m_lastErrorMessage;
}
int App::Run() {
AppendTrace("Run: begin");
MSG message = {};
int exitCode = 0;
while (m_running) {
while (PeekMessage(&message, nullptr, 0, 0, PM_REMOVE)) {
if (message.message == WM_QUIT) {
exitCode = static_cast<int>(message.wParam);
m_running = false;
break;
}
TranslateMessage(&message);
DispatchMessage(&message);
}
if (!m_running) {
break;
}
const bool captureScreenshot = m_frameLimit > 0 && (m_renderedFrameCount + 1) >= m_frameLimit;
AppendTrace(captureScreenshot ? "Run: RenderFrame capture" : "Run: RenderFrame");
RenderFrame(captureScreenshot);
AppendTrace("Run: RenderFrame complete");
++m_renderedFrameCount;
if (m_frameLimit > 0 && m_renderedFrameCount >= m_frameLimit) {
m_running = false;
AppendTrace("Run: frame limit reached, posting WM_CLOSE");
PostMessageW(m_hwnd, WM_CLOSE, 0, 0);
}
}
AppendTrace("Run: end");
return exitCode;
}
LRESULT CALLBACK App::StaticWindowProc(HWND hwnd, UINT message, WPARAM wParam, LPARAM lParam) {
App* app = nullptr;
if (message == WM_NCCREATE) {
CREATESTRUCTW* createStruct = reinterpret_cast<CREATESTRUCTW*>(lParam);
app = reinterpret_cast<App*>(createStruct->lpCreateParams);
SetWindowLongPtrW(hwnd, GWLP_USERDATA, reinterpret_cast<LONG_PTR>(app));
} else {
app = reinterpret_cast<App*>(GetWindowLongPtrW(hwnd, GWLP_USERDATA));
}
if (app != nullptr) {
return app->WindowProc(hwnd, message, wParam, lParam);
}
return DefWindowProcW(hwnd, message, wParam, lParam);
}
LRESULT App::WindowProc(HWND hwnd, UINT message, WPARAM wParam, LPARAM lParam) {
switch (message) {
case WM_CLOSE:
DestroyWindow(hwnd);
return 0;
case WM_DESTROY:
PostQuitMessage(0);
return 0;
default:
return DefWindowProcW(hwnd, message, wParam, lParam);
}
}
bool App::RegisterWindowClass(HINSTANCE instance) {
WNDCLASSEXW windowClass = {};
windowClass.cbSize = sizeof(WNDCLASSEXW);
windowClass.style = CS_HREDRAW | CS_VREDRAW;
windowClass.lpfnWndProc = &App::StaticWindowProc;
windowClass.hInstance = instance;
windowClass.hCursor = LoadCursorW(nullptr, IDC_ARROW);
windowClass.lpszClassName = kWindowClassName;
return RegisterClassExW(&windowClass) != 0;
}
bool App::CreateMainWindow(HINSTANCE instance, int showCommand) {
RECT windowRect = { 0, 0, m_width, m_height };
AdjustWindowRect(&windowRect, WS_OVERLAPPEDWINDOW, FALSE);
m_hwnd = CreateWindowExW(
0,
kWindowClassName,
kWindowTitle,
WS_OVERLAPPEDWINDOW,
CW_USEDEFAULT,
CW_USEDEFAULT,
windowRect.right - windowRect.left,
windowRect.bottom - windowRect.top,
nullptr,
nullptr,
instance,
this);
if (m_hwnd == nullptr) {
return false;
}
ShowWindow(m_hwnd, showCommand);
UpdateWindow(m_hwnd);
return true;
}
bool App::LoadGaussianScene() {
std::string errorMessage;
const std::filesystem::path scenePath = ResolveNearExecutable(m_gaussianScenePath);
if (!LoadGaussianSceneFromPly(scenePath, m_gaussianSceneData, errorMessage)) {
m_lastErrorMessage.assign(errorMessage.begin(), errorMessage.end());
return false;
}
if (!m_summaryPath.empty()) {
const std::filesystem::path summaryPath = ResolveNearExecutable(m_summaryPath);
if (!WriteGaussianSceneSummary(summaryPath, m_gaussianSceneData, errorMessage)) {
m_lastErrorMessage.assign(errorMessage.begin(), errorMessage.end());
return false;
}
}
return true;
}
bool App::InitializeRhi() {
RHIDeviceDesc deviceDesc = {};
deviceDesc.adapterIndex = 0;
deviceDesc.enableDebugLayer = false;
deviceDesc.enableGPUValidation = false;
if (!m_device.Initialize(deviceDesc)) {
m_lastErrorMessage = L"Failed to initialize the XCEngine D3D12 device.";
return false;
}
ID3D12Device* device = m_device.GetDevice();
IDXGIFactory4* factory = m_device.GetFactory();
D3D12_FEATURE_DATA_D3D12_OPTIONS1 options1 = {};
if (SUCCEEDED(device->CheckFeatureSupport(D3D12_FEATURE_D3D12_OPTIONS1, &options1, sizeof(options1)))) {
AppendTrace(
"InitializeRhi: wave ops min=" + std::to_string(options1.WaveLaneCountMin) +
" max=" + std::to_string(options1.WaveLaneCountMax) +
" total_lane_count=" + std::to_string(options1.TotalLaneCount));
}
if (!m_commandQueue.Initialize(device, CommandQueueType::Direct)) {
m_lastErrorMessage = L"Failed to initialize the direct command queue.";
return false;
}
if (!m_swapChain.Initialize(factory, m_commandQueue.GetCommandQueue(), m_hwnd, m_width, m_height, kBackBufferCount)) {
m_lastErrorMessage = L"Failed to initialize the swap chain.";
return false;
}
if (!m_depthStencil.InitializeDepthStencil(device, m_width, m_height)) {
m_lastErrorMessage = L"Failed to initialize the depth stencil texture.";
return false;
}
if (!m_rtvHeap.Initialize(device, DescriptorHeapType::RTV, kBackBufferCount)) {
m_lastErrorMessage = L"Failed to initialize the RTV descriptor heap.";
return false;
}
if (!m_dsvHeap.Initialize(device, DescriptorHeapType::DSV, 1)) {
m_lastErrorMessage = L"Failed to initialize the DSV descriptor heap.";
return false;
}
for (int index = 0; index < kBackBufferCount; ++index) {
D3D12Texture& backBuffer = m_swapChain.GetBackBuffer(index);
D3D12_RENDER_TARGET_VIEW_DESC renderTargetDesc =
D3D12ResourceView::CreateRenderTargetDesc(Format::R8G8B8A8_UNorm, D3D12_RTV_DIMENSION_TEXTURE2D);
m_rtvs[index].InitializeAsRenderTarget(device, backBuffer.GetResource(), &renderTargetDesc, &m_rtvHeap, index);
}
D3D12_DEPTH_STENCIL_VIEW_DESC depthStencilDesc =
D3D12ResourceView::CreateDepthStencilDesc(Format::D24_UNorm_S8_UInt, D3D12_DSV_DIMENSION_TEXTURE2D);
m_dsv.InitializeAsDepthStencil(device, m_depthStencil.GetResource(), &depthStencilDesc, &m_dsvHeap, 0);
if (!m_commandAllocator.Initialize(device, CommandQueueType::Direct)) {
m_lastErrorMessage = L"Failed to initialize the command allocator.";
return false;
}
if (!m_commandList.Initialize(device, CommandQueueType::Direct, m_commandAllocator.GetCommandAllocator())) {
m_lastErrorMessage = L"Failed to initialize the command list.";
return false;
}
return true;
}
bool App::InitializeGaussianGpuResources() {
ID3D12Device* device = m_device.GetDevice();
ID3D12GraphicsCommandList* commandList = m_commandList.GetCommandList();
m_commandAllocator.Reset();
m_commandList.Reset();
const D3D12_RESOURCE_STATES shaderResourceState =
D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE | D3D12_RESOURCE_STATE_PIXEL_SHADER_RESOURCE;
m_gaussianPositionBuffer.SetStride(4);
if (!m_gaussianPositionBuffer.InitializeWithData(
device,
commandList,
m_gaussianSceneData.positionData.data(),
static_cast<uint64_t>(m_gaussianSceneData.positionData.size()),
shaderResourceState)) {
m_lastErrorMessage = L"Failed to upload the Gaussian position buffer.";
return false;
}
m_gaussianOtherBuffer.SetStride(4);
if (!m_gaussianOtherBuffer.InitializeWithData(
device,
commandList,
m_gaussianSceneData.otherData.data(),
static_cast<uint64_t>(m_gaussianSceneData.otherData.size()),
shaderResourceState)) {
m_lastErrorMessage = L"Failed to upload the Gaussian other buffer.";
return false;
}
m_gaussianShBuffer.SetStride(4);
if (!m_gaussianShBuffer.InitializeWithData(
device,
commandList,
m_gaussianSceneData.shData.data(),
static_cast<uint64_t>(m_gaussianSceneData.shData.size()),
shaderResourceState)) {
m_lastErrorMessage = L"Failed to upload the Gaussian SH buffer.";
return false;
}
D3D12_RESOURCE_DESC colorTextureDesc = {};
colorTextureDesc.Dimension = D3D12_RESOURCE_DIMENSION_TEXTURE2D;
colorTextureDesc.Alignment = 0;
colorTextureDesc.Width = m_gaussianSceneData.colorTextureWidth;
colorTextureDesc.Height = m_gaussianSceneData.colorTextureHeight;
colorTextureDesc.DepthOrArraySize = 1;
colorTextureDesc.MipLevels = 1;
colorTextureDesc.Format = DXGI_FORMAT_R32G32B32A32_FLOAT;
colorTextureDesc.SampleDesc.Count = 1;
colorTextureDesc.SampleDesc.Quality = 0;
colorTextureDesc.Layout = D3D12_TEXTURE_LAYOUT_UNKNOWN;
colorTextureDesc.Flags = D3D12_RESOURCE_FLAG_NONE;
Microsoft::WRL::ComPtr<ID3D12Resource> colorUploadBuffer;
if (!m_gaussianColorTexture.InitializeFromData(
device,
commandList,
colorTextureDesc,
TextureType::Texture2D,
m_gaussianSceneData.colorData.data(),
m_gaussianSceneData.colorData.size(),
m_gaussianSceneData.colorTextureWidth * GaussianSplatRuntimeData::kColorStride,
&colorUploadBuffer)) {
m_lastErrorMessage = L"Failed to upload the Gaussian color texture.";
return false;
}
m_gaussianUploadBuffers.push_back(colorUploadBuffer);
ResourceViewDesc rawBufferViewDesc = {};
rawBufferViewDesc.dimension = ResourceViewDimension::RawBuffer;
m_gaussianPositionView.reset(static_cast<D3D12ResourceView*>(m_device.CreateShaderResourceView(&m_gaussianPositionBuffer, rawBufferViewDesc)));
if (!m_gaussianPositionView) {
m_lastErrorMessage = L"Failed to create the Gaussian position SRV.";
return false;
}
m_gaussianOtherView.reset(static_cast<D3D12ResourceView*>(m_device.CreateShaderResourceView(&m_gaussianOtherBuffer, rawBufferViewDesc)));
if (!m_gaussianOtherView) {
m_lastErrorMessage = L"Failed to create the Gaussian other SRV.";
return false;
}
m_gaussianShView.reset(static_cast<D3D12ResourceView*>(m_device.CreateShaderResourceView(&m_gaussianShBuffer, rawBufferViewDesc)));
if (!m_gaussianShView) {
m_lastErrorMessage = L"Failed to create the Gaussian SH SRV.";
return false;
}
ResourceViewDesc textureViewDesc = {};
textureViewDesc.dimension = ResourceViewDimension::Texture2D;
textureViewDesc.format = static_cast<uint32_t>(Format::R32G32B32A32_Float);
m_gaussianColorView.reset(static_cast<D3D12ResourceView*>(m_device.CreateShaderResourceView(&m_gaussianColorTexture, textureViewDesc)));
if (!m_gaussianColorView) {
m_lastErrorMessage = L"Failed to create the Gaussian color texture SRV.";
return false;
}
m_commandList.Close();
void* commandLists[] = { &m_commandList };
m_commandQueue.ExecuteCommandLists(1, commandLists);
m_commandQueue.WaitForIdle();
m_gaussianUploadBuffers.clear();
return true;
}
bool App::InitializePreparePassResources() {
DescriptorSetLayoutBinding bindings[6] = {};
bindings[0].binding = 0;
bindings[0].type = static_cast<uint32_t>(DescriptorType::CBV);
bindings[0].count = 1;
bindings[0].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[1].binding = 1;
bindings[1].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[1].count = 1;
bindings[1].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[1].resourceDimension = ResourceViewDimension::RawBuffer;
bindings[2].binding = 2;
bindings[2].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[2].count = 1;
bindings[2].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[2].resourceDimension = ResourceViewDimension::RawBuffer;
bindings[3].binding = 3;
bindings[3].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[3].count = 1;
bindings[3].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[3].resourceDimension = ResourceViewDimension::Texture2D;
bindings[4].binding = 4;
bindings[4].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[4].count = 1;
bindings[4].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[4].resourceDimension = ResourceViewDimension::RawBuffer;
bindings[5].binding = 5;
bindings[5].type = static_cast<uint32_t>(DescriptorType::UAV);
bindings[5].count = 1;
bindings[5].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[5].resourceDimension = ResourceViewDimension::StructuredBuffer;
DescriptorSetLayoutDesc setLayout = {};
setLayout.bindings = bindings;
setLayout.bindingCount = 6;
RHIPipelineLayoutDesc pipelineLayoutDesc = {};
pipelineLayoutDesc.setLayouts = &setLayout;
pipelineLayoutDesc.setLayoutCount = 1;
m_preparePipelineLayout = m_device.CreatePipelineLayout(pipelineLayoutDesc);
if (m_preparePipelineLayout == nullptr) {
m_lastErrorMessage = L"Failed to create the prepare pass pipeline layout.";
return false;
}
BufferDesc preparedViewBufferDesc = {};
preparedViewBufferDesc.size = static_cast<uint64_t>(sizeof(PreparedSplatView)) * m_gaussianSceneData.splatCount;
preparedViewBufferDesc.stride = sizeof(PreparedSplatView);
preparedViewBufferDesc.bufferType = static_cast<uint32_t>(BufferType::Storage);
preparedViewBufferDesc.flags = static_cast<uint64_t>(BufferFlags::AllowUnorderedAccess);
m_preparedViewBuffer = static_cast<D3D12Buffer*>(m_device.CreateBuffer(preparedViewBufferDesc));
if (m_preparedViewBuffer == nullptr) {
m_lastErrorMessage = L"Failed to create the prepared view buffer.";
return false;
}
m_preparedViewBuffer->SetStride(sizeof(PreparedSplatView));
m_preparedViewBuffer->SetBufferType(BufferType::Storage);
ResourceViewDesc structuredViewDesc = {};
structuredViewDesc.dimension = ResourceViewDimension::StructuredBuffer;
structuredViewDesc.structureByteStride = sizeof(PreparedSplatView);
structuredViewDesc.elementCount = m_gaussianSceneData.splatCount;
m_preparedViewSrv.reset(static_cast<D3D12ResourceView*>(m_device.CreateShaderResourceView(m_preparedViewBuffer, structuredViewDesc)));
if (!m_preparedViewSrv) {
m_lastErrorMessage = L"Failed to create the prepared view SRV.";
return false;
}
m_preparedViewUav.reset(static_cast<D3D12ResourceView*>(m_device.CreateUnorderedAccessView(m_preparedViewBuffer, structuredViewDesc)));
if (!m_preparedViewUav) {
m_lastErrorMessage = L"Failed to create the prepared view UAV.";
return false;
}
DescriptorPoolDesc poolDesc = {};
poolDesc.type = DescriptorHeapType::CBV_SRV_UAV;
poolDesc.descriptorCount = 5;
poolDesc.shaderVisible = true;
m_prepareDescriptorPool = m_device.CreateDescriptorPool(poolDesc);
if (m_prepareDescriptorPool == nullptr) {
m_lastErrorMessage = L"Failed to create the prepare pass descriptor pool.";
return false;
}
m_prepareDescriptorSet = m_prepareDescriptorPool->AllocateSet(setLayout);
if (m_prepareDescriptorSet == nullptr) {
m_lastErrorMessage = L"Failed to allocate the prepare pass descriptor set.";
return false;
}
m_prepareDescriptorSet->Update(1, m_gaussianPositionView.get());
m_prepareDescriptorSet->Update(2, m_gaussianOtherView.get());
m_prepareDescriptorSet->Update(3, m_gaussianColorView.get());
m_prepareDescriptorSet->Update(4, m_gaussianShView.get());
m_prepareDescriptorSet->Update(5, m_preparedViewUav.get());
ShaderCompileDesc computeShaderDesc = {};
computeShaderDesc.fileName = ResolveShaderPath(L"PrepareGaussiansCS.hlsl").wstring();
computeShaderDesc.entryPoint = L"MainCS";
computeShaderDesc.profile = L"cs_5_0";
ComputePipelineDesc pipelineDesc = {};
pipelineDesc.pipelineLayout = m_preparePipelineLayout;
pipelineDesc.computeShader = computeShaderDesc;
m_preparePipelineState = m_device.CreateComputePipelineState(pipelineDesc);
if (m_preparePipelineState == nullptr) {
m_lastErrorMessage = L"Failed to create the prepare pass pipeline state.";
return false;
}
return true;
}
bool App::InitializeSortResources() {
std::vector<uint32_t> initialOrder;
if (kUseCpuSortBaseline) {
BuildCpuSortedOrder(m_gaussianSceneData, initialOrder, nullptr);
} else {
initialOrder.resize(static_cast<size_t>(m_gaussianSceneData.splatCount));
for (uint32_t index = 0; index < m_gaussianSceneData.splatCount; ++index) {
initialOrder[index] = index;
}
}
const uint64_t sortBufferBytes = static_cast<uint64_t>(m_gaussianSceneData.splatCount) * sizeof(uint32_t);
const uint32_t threadBlocks =
static_cast<uint32_t>((m_gaussianSceneData.splatCount + (kDeviceRadixSortPartitionSize - 1u)) / kDeviceRadixSortPartitionSize);
const uint64_t passHistogramElements = static_cast<uint64_t>(threadBlocks) * kDeviceRadixSortRadix;
const uint64_t globalHistogramElements = static_cast<uint64_t>(kDeviceRadixSortRadix) * kDeviceRadixSortPassCount;
auto initializeStorageBuffer = [this](D3D12Buffer*& buffer, uint64_t sizeInBytes) -> bool {
BufferDesc bufferDesc = {};
bufferDesc.size = sizeInBytes;
bufferDesc.stride = sizeof(uint32_t);
bufferDesc.bufferType = static_cast<uint32_t>(BufferType::Storage);
bufferDesc.flags = static_cast<uint64_t>(BufferFlags::AllowUnorderedAccess);
buffer = static_cast<D3D12Buffer*>(m_device.CreateBuffer(bufferDesc));
if (buffer == nullptr) {
return false;
}
buffer->SetStride(sizeof(uint32_t));
buffer->SetBufferType(BufferType::Storage);
return true;
};
BufferDesc orderBufferDesc = {};
orderBufferDesc.size = sortBufferBytes;
orderBufferDesc.stride = sizeof(uint32_t);
orderBufferDesc.bufferType = static_cast<uint32_t>(BufferType::Storage);
orderBufferDesc.flags = static_cast<uint64_t>(BufferFlags::AllowUnorderedAccess);
m_orderBuffer = static_cast<D3D12Buffer*>(m_device.CreateBuffer(
orderBufferDesc,
initialOrder.data(),
static_cast<size_t>(sortBufferBytes),
ResourceStates::NonPixelShaderResource));
if (m_orderBuffer == nullptr) {
m_lastErrorMessage = L"Failed to create the primary order buffer.";
return false;
}
m_orderBuffer->SetStride(sizeof(uint32_t));
m_orderBuffer->SetBufferType(BufferType::Storage);
if (!initializeStorageBuffer(m_orderScratchBuffer, sortBufferBytes)) {
m_lastErrorMessage = L"Failed to create the scratch order buffer.";
return false;
}
if (!initializeStorageBuffer(m_sortKeyBuffer, sortBufferBytes)) {
m_lastErrorMessage = L"Failed to create the primary sort key buffer.";
return false;
}
if (!initializeStorageBuffer(m_sortKeyScratchBuffer, sortBufferBytes)) {
m_lastErrorMessage = L"Failed to create the scratch sort key buffer.";
return false;
}
if (!initializeStorageBuffer(m_passHistogramBuffer, passHistogramElements * sizeof(uint32_t))) {
m_lastErrorMessage = L"Failed to create the pass histogram buffer.";
return false;
}
if (!initializeStorageBuffer(m_globalHistogramBuffer, globalHistogramElements * sizeof(uint32_t))) {
m_lastErrorMessage = L"Failed to create the global histogram buffer.";
return false;
}
ResourceViewDesc structuredViewDesc = {};
structuredViewDesc.dimension = ResourceViewDimension::StructuredBuffer;
structuredViewDesc.structureByteStride = sizeof(uint32_t);
structuredViewDesc.elementCount = m_gaussianSceneData.splatCount;
ResourceViewDesc passHistogramViewDesc = structuredViewDesc;
passHistogramViewDesc.elementCount = static_cast<uint32_t>(passHistogramElements);
ResourceViewDesc globalHistogramViewDesc = structuredViewDesc;
globalHistogramViewDesc.elementCount = static_cast<uint32_t>(globalHistogramElements);
m_orderBufferSrv.reset(static_cast<D3D12ResourceView*>(m_device.CreateShaderResourceView(m_orderBuffer, structuredViewDesc)));
if (!m_orderBufferSrv) {
m_lastErrorMessage = L"Failed to create the primary order SRV.";
return false;
}
m_orderBufferUav.reset(static_cast<D3D12ResourceView*>(m_device.CreateUnorderedAccessView(m_orderBuffer, structuredViewDesc)));
if (!m_orderBufferUav) {
m_lastErrorMessage = L"Failed to create the primary order UAV.";
return false;
}
m_orderScratchUav.reset(static_cast<D3D12ResourceView*>(m_device.CreateUnorderedAccessView(m_orderScratchBuffer, structuredViewDesc)));
if (!m_orderScratchUav) {
m_lastErrorMessage = L"Failed to create the scratch order UAV.";
return false;
}
m_sortKeyUav.reset(static_cast<D3D12ResourceView*>(m_device.CreateUnorderedAccessView(m_sortKeyBuffer, structuredViewDesc)));
if (!m_sortKeyUav) {
m_lastErrorMessage = L"Failed to create the primary sort key UAV.";
return false;
}
m_sortKeyScratchUav.reset(static_cast<D3D12ResourceView*>(m_device.CreateUnorderedAccessView(m_sortKeyScratchBuffer, structuredViewDesc)));
if (!m_sortKeyScratchUav) {
m_lastErrorMessage = L"Failed to create the scratch sort key UAV.";
return false;
}
m_passHistogramUav.reset(static_cast<D3D12ResourceView*>(m_device.CreateUnorderedAccessView(m_passHistogramBuffer, passHistogramViewDesc)));
if (!m_passHistogramUav) {
m_lastErrorMessage = L"Failed to create the pass histogram UAV.";
return false;
}
m_globalHistogramUav.reset(static_cast<D3D12ResourceView*>(m_device.CreateUnorderedAccessView(m_globalHistogramBuffer, globalHistogramViewDesc)));
if (!m_globalHistogramUav) {
m_lastErrorMessage = L"Failed to create the global histogram UAV.";
return false;
}
DescriptorSetLayoutBinding buildSortKeyBindings[4] = {};
buildSortKeyBindings[0].binding = 0;
buildSortKeyBindings[0].type = static_cast<uint32_t>(DescriptorType::CBV);
buildSortKeyBindings[0].count = 1;
buildSortKeyBindings[0].visibility = static_cast<uint32_t>(ShaderVisibility::All);
buildSortKeyBindings[1].binding = 1;
buildSortKeyBindings[1].type = static_cast<uint32_t>(DescriptorType::SRV);
buildSortKeyBindings[1].count = 1;
buildSortKeyBindings[1].visibility = static_cast<uint32_t>(ShaderVisibility::All);
buildSortKeyBindings[1].resourceDimension = ResourceViewDimension::RawBuffer;
buildSortKeyBindings[2].binding = 2;
buildSortKeyBindings[2].type = static_cast<uint32_t>(DescriptorType::SRV);
buildSortKeyBindings[2].count = 1;
buildSortKeyBindings[2].visibility = static_cast<uint32_t>(ShaderVisibility::All);
buildSortKeyBindings[2].resourceDimension = ResourceViewDimension::StructuredBuffer;
buildSortKeyBindings[3].binding = 3;
buildSortKeyBindings[3].type = static_cast<uint32_t>(DescriptorType::UAV);
buildSortKeyBindings[3].count = 1;
buildSortKeyBindings[3].visibility = static_cast<uint32_t>(ShaderVisibility::All);
buildSortKeyBindings[3].resourceDimension = ResourceViewDimension::StructuredBuffer;
DescriptorSetLayoutDesc buildSortKeySetLayout = {};
buildSortKeySetLayout.bindings = buildSortKeyBindings;
buildSortKeySetLayout.bindingCount = 4;
RHIPipelineLayoutDesc buildSortKeyLayoutDesc = {};
buildSortKeyLayoutDesc.setLayouts = &buildSortKeySetLayout;
buildSortKeyLayoutDesc.setLayoutCount = 1;
m_buildSortKeyPipelineLayout = m_device.CreatePipelineLayout(buildSortKeyLayoutDesc);
if (m_buildSortKeyPipelineLayout == nullptr) {
m_lastErrorMessage = L"Failed to create the build-sort-key pipeline layout.";
return false;
}
DescriptorPoolDesc buildSortKeyPoolDesc = {};
buildSortKeyPoolDesc.type = DescriptorHeapType::CBV_SRV_UAV;
buildSortKeyPoolDesc.descriptorCount = 3;
buildSortKeyPoolDesc.shaderVisible = true;
m_buildSortKeyDescriptorPool = m_device.CreateDescriptorPool(buildSortKeyPoolDesc);
if (m_buildSortKeyDescriptorPool == nullptr) {
m_lastErrorMessage = L"Failed to create the build-sort-key descriptor pool.";
return false;
}
m_buildSortKeyDescriptorSet = m_buildSortKeyDescriptorPool->AllocateSet(buildSortKeySetLayout);
if (m_buildSortKeyDescriptorSet == nullptr) {
m_lastErrorMessage = L"Failed to allocate the build-sort-key descriptor set.";
return false;
}
m_buildSortKeyDescriptorSet->Update(1, m_gaussianPositionView.get());
m_buildSortKeyDescriptorSet->Update(2, m_orderBufferSrv.get());
m_buildSortKeyDescriptorSet->Update(3, m_sortKeyUav.get());
ShaderCompileDesc buildSortKeyShaderDesc = BuildDxilShaderDesc(ResolveShaderPath(L"BuildSortKeysCS.dxil"), L"cs_6_6");
if (buildSortKeyShaderDesc.compiledBinary.empty()) {
m_lastErrorMessage = L"Failed to load BuildSortKeysCS.dxil.";
return false;
}
ComputePipelineDesc buildSortKeyPipelineDesc = {};
buildSortKeyPipelineDesc.pipelineLayout = m_buildSortKeyPipelineLayout;
buildSortKeyPipelineDesc.computeShader = buildSortKeyShaderDesc;
m_buildSortKeyPipelineState = m_device.CreateComputePipelineState(buildSortKeyPipelineDesc);
if (m_buildSortKeyPipelineState == nullptr) {
m_lastErrorMessage = L"Failed to create the build-sort-key pipeline state.";
return false;
}
DescriptorSetLayoutBinding radixSortBindings[7] = {};
radixSortBindings[0].binding = 0;
radixSortBindings[0].type = static_cast<uint32_t>(DescriptorType::CBV);
radixSortBindings[0].count = 1;
radixSortBindings[0].visibility = static_cast<uint32_t>(ShaderVisibility::All);
for (uint32_t bindingIndex = 1; bindingIndex <= 6; ++bindingIndex) {
radixSortBindings[bindingIndex].binding = bindingIndex;
radixSortBindings[bindingIndex].type = static_cast<uint32_t>(DescriptorType::UAV);
radixSortBindings[bindingIndex].count = 1;
radixSortBindings[bindingIndex].visibility = static_cast<uint32_t>(ShaderVisibility::All);
radixSortBindings[bindingIndex].resourceDimension = ResourceViewDimension::StructuredBuffer;
}
DescriptorSetLayoutDesc radixSortSetLayout = {};
radixSortSetLayout.bindings = radixSortBindings;
radixSortSetLayout.bindingCount = 7;
RHIPipelineLayoutDesc radixSortLayoutDesc = {};
radixSortLayoutDesc.setLayouts = &radixSortSetLayout;
radixSortLayoutDesc.setLayoutCount = 1;
m_radixSortPipelineLayout = m_device.CreatePipelineLayout(radixSortLayoutDesc);
if (m_radixSortPipelineLayout == nullptr) {
m_lastErrorMessage = L"Failed to create the radix-sort pipeline layout.";
return false;
}
DescriptorPoolDesc radixSortPoolDesc = {};
radixSortPoolDesc.type = DescriptorHeapType::CBV_SRV_UAV;
radixSortPoolDesc.descriptorCount = 12;
radixSortPoolDesc.shaderVisible = true;
m_radixSortDescriptorPool = m_device.CreateDescriptorPool(radixSortPoolDesc);
if (m_radixSortDescriptorPool == nullptr) {
m_lastErrorMessage = L"Failed to create the radix-sort descriptor pool.";
return false;
}
m_radixSortDescriptorSetPrimaryToScratch = m_radixSortDescriptorPool->AllocateSet(radixSortSetLayout);
m_radixSortDescriptorSetScratchToPrimary = m_radixSortDescriptorPool->AllocateSet(radixSortSetLayout);
if (m_radixSortDescriptorSetPrimaryToScratch == nullptr || m_radixSortDescriptorSetScratchToPrimary == nullptr) {
m_lastErrorMessage = L"Failed to allocate the radix-sort descriptor sets.";
return false;
}
m_radixSortDescriptorSetPrimaryToScratch->Update(1, m_sortKeyUav.get());
m_radixSortDescriptorSetPrimaryToScratch->Update(2, m_sortKeyScratchUav.get());
m_radixSortDescriptorSetPrimaryToScratch->Update(3, m_orderBufferUav.get());
m_radixSortDescriptorSetPrimaryToScratch->Update(4, m_orderScratchUav.get());
m_radixSortDescriptorSetPrimaryToScratch->Update(5, m_passHistogramUav.get());
m_radixSortDescriptorSetPrimaryToScratch->Update(6, m_globalHistogramUav.get());
m_radixSortDescriptorSetScratchToPrimary->Update(1, m_sortKeyScratchUav.get());
m_radixSortDescriptorSetScratchToPrimary->Update(2, m_sortKeyUav.get());
m_radixSortDescriptorSetScratchToPrimary->Update(3, m_orderScratchUav.get());
m_radixSortDescriptorSetScratchToPrimary->Update(4, m_orderBufferUav.get());
m_radixSortDescriptorSetScratchToPrimary->Update(5, m_passHistogramUav.get());
m_radixSortDescriptorSetScratchToPrimary->Update(6, m_globalHistogramUav.get());
const ShaderCompileDesc radixInitShaderDesc = BuildDxilShaderDesc(ResolveShaderPath(L"RadixInit.dxil"), L"cs_6_6");
const ShaderCompileDesc radixUpsweepShaderDesc = BuildDxilShaderDesc(ResolveShaderPath(L"RadixUpsweep.dxil"), L"cs_6_6");
const ShaderCompileDesc radixGlobalHistogramShaderDesc = BuildDxilShaderDesc(ResolveShaderPath(L"RadixGlobalHistogram.dxil"), L"cs_6_6");
const ShaderCompileDesc radixScanShaderDesc = BuildDxilShaderDesc(ResolveShaderPath(L"RadixScan.dxil"), L"cs_6_6");
const ShaderCompileDesc radixDownsweepShaderDesc = BuildDxilShaderDesc(ResolveShaderPath(L"RadixDownsweep.dxil"), L"cs_6_6");
if (radixInitShaderDesc.compiledBinary.empty() ||
radixUpsweepShaderDesc.compiledBinary.empty() ||
radixGlobalHistogramShaderDesc.compiledBinary.empty() ||
radixScanShaderDesc.compiledBinary.empty() ||
radixDownsweepShaderDesc.compiledBinary.empty()) {
m_lastErrorMessage = L"Failed to load one or more radix-sort DXIL shaders.";
return false;
}
ComputePipelineDesc radixPipelineDesc = {};
radixPipelineDesc.pipelineLayout = m_radixSortPipelineLayout;
radixPipelineDesc.computeShader = radixInitShaderDesc;
m_radixSortInitPipelineState = m_device.CreateComputePipelineState(radixPipelineDesc);
if (m_radixSortInitPipelineState == nullptr) {
m_lastErrorMessage = L"Failed to create the radix-sort init pipeline state.";
return false;
}
radixPipelineDesc.computeShader = radixUpsweepShaderDesc;
m_radixSortUpsweepPipelineState = m_device.CreateComputePipelineState(radixPipelineDesc);
if (m_radixSortUpsweepPipelineState == nullptr) {
m_lastErrorMessage = L"Failed to create the radix-sort upsweep pipeline state.";
return false;
}
radixPipelineDesc.computeShader = radixGlobalHistogramShaderDesc;
m_radixSortGlobalHistogramPipelineState = m_device.CreateComputePipelineState(radixPipelineDesc);
if (m_radixSortGlobalHistogramPipelineState == nullptr) {
m_lastErrorMessage = L"Failed to create the radix-sort global-histogram pipeline state.";
return false;
}
radixPipelineDesc.computeShader = radixScanShaderDesc;
m_radixSortScanPipelineState = m_device.CreateComputePipelineState(radixPipelineDesc);
if (m_radixSortScanPipelineState == nullptr) {
m_lastErrorMessage = L"Failed to create the radix-sort scan pipeline state.";
return false;
}
radixPipelineDesc.computeShader = radixDownsweepShaderDesc;
m_radixSortDownsweepPipelineState = m_device.CreateComputePipelineState(radixPipelineDesc);
if (m_radixSortDownsweepPipelineState == nullptr) {
m_lastErrorMessage = L"Failed to create the radix-sort downsweep pipeline state.";
return false;
}
return true;
}
bool App::InitializeDebugDrawResources() {
DescriptorSetLayoutBinding bindings[3] = {};
bindings[0].binding = 0;
bindings[0].type = static_cast<uint32_t>(DescriptorType::CBV);
bindings[0].count = 1;
bindings[0].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[1].binding = 1;
bindings[1].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[1].count = 1;
bindings[1].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[1].resourceDimension = ResourceViewDimension::StructuredBuffer;
bindings[2].binding = 2;
bindings[2].type = static_cast<uint32_t>(DescriptorType::SRV);
bindings[2].count = 1;
bindings[2].visibility = static_cast<uint32_t>(ShaderVisibility::All);
bindings[2].resourceDimension = ResourceViewDimension::StructuredBuffer;
DescriptorSetLayoutDesc setLayout = {};
setLayout.bindings = bindings;
setLayout.bindingCount = 3;
RHIPipelineLayoutDesc pipelineLayoutDesc = {};
pipelineLayoutDesc.setLayouts = &setLayout;
pipelineLayoutDesc.setLayoutCount = 1;
m_debugPipelineLayout = m_device.CreatePipelineLayout(pipelineLayoutDesc);
if (m_debugPipelineLayout == nullptr) {
m_lastErrorMessage = L"Failed to create the debug draw pipeline layout.";
return false;
}
DescriptorPoolDesc poolDesc = {};
poolDesc.type = DescriptorHeapType::CBV_SRV_UAV;
poolDesc.descriptorCount = 2;
poolDesc.shaderVisible = true;
m_debugDescriptorPool = m_device.CreateDescriptorPool(poolDesc);
if (m_debugDescriptorPool == nullptr) {
m_lastErrorMessage = L"Failed to create the debug draw descriptor pool.";
return false;
}
m_debugDescriptorSet = m_debugDescriptorPool->AllocateSet(setLayout);
if (m_debugDescriptorSet == nullptr) {
m_lastErrorMessage = L"Failed to allocate the debug draw descriptor set.";
return false;
}
m_debugDescriptorSet->Update(1, m_preparedViewSrv.get());
m_debugDescriptorSet->Update(2, m_orderBufferSrv.get());
GraphicsPipelineDesc pipelineDesc = {};
pipelineDesc.pipelineLayout = m_debugPipelineLayout;
pipelineDesc.topologyType = static_cast<uint32_t>(PrimitiveTopologyType::Triangle);
pipelineDesc.renderTargetCount = 1;
pipelineDesc.renderTargetFormats[0] = static_cast<uint32_t>(Format::R8G8B8A8_UNorm);
pipelineDesc.depthStencilFormat = static_cast<uint32_t>(Format::D24_UNorm_S8_UInt);
pipelineDesc.sampleCount = 1;
pipelineDesc.sampleQuality = 0;
pipelineDesc.rasterizerState.cullMode = static_cast<uint32_t>(CullMode::None);
pipelineDesc.depthStencilState.depthTestEnable = false;
pipelineDesc.depthStencilState.depthWriteEnable = false;
pipelineDesc.blendState.blendEnable = true;
pipelineDesc.blendState.srcBlend = static_cast<uint32_t>(BlendFactor::One);
pipelineDesc.blendState.dstBlend = static_cast<uint32_t>(BlendFactor::InvSrcAlpha);
pipelineDesc.blendState.srcBlendAlpha = static_cast<uint32_t>(BlendFactor::One);
pipelineDesc.blendState.dstBlendAlpha = static_cast<uint32_t>(BlendFactor::InvSrcAlpha);
pipelineDesc.vertexShader.fileName = ResolveShaderPath(L"DebugPointsVS.hlsl").wstring();
pipelineDesc.vertexShader.entryPoint = L"MainVS";
pipelineDesc.vertexShader.profile = L"vs_5_0";
pipelineDesc.fragmentShader.fileName = ResolveShaderPath(L"DebugPointsPS.hlsl").wstring();
pipelineDesc.fragmentShader.entryPoint = L"MainPS";
pipelineDesc.fragmentShader.profile = L"ps_5_0";
m_debugPipelineState = m_device.CreatePipelineState(pipelineDesc);
if (m_debugPipelineState == nullptr) {
m_lastErrorMessage = L"Failed to create the debug draw pipeline state.";
return false;
}
return true;
}
void App::ShutdownGaussianGpuResources() {
m_gaussianColorView.reset();
m_gaussianShView.reset();
m_gaussianOtherView.reset();
m_gaussianPositionView.reset();
m_gaussianUploadBuffers.clear();
m_gaussianColorTexture.Shutdown();
m_gaussianShBuffer.Shutdown();
m_gaussianOtherBuffer.Shutdown();
m_gaussianPositionBuffer.Shutdown();
}
void App::ShutdownPreparePassResources() {
if (m_prepareDescriptorSet != nullptr) {
m_prepareDescriptorSet->Shutdown();
delete m_prepareDescriptorSet;
m_prepareDescriptorSet = nullptr;
}
ShutdownAndDelete(m_prepareDescriptorPool);
ShutdownAndDelete(m_preparePipelineState);
ShutdownAndDelete(m_preparePipelineLayout);
m_preparedViewUav.reset();
m_preparedViewSrv.reset();
if (m_preparedViewBuffer != nullptr) {
m_preparedViewBuffer->Shutdown();
delete m_preparedViewBuffer;
m_preparedViewBuffer = nullptr;
}
}
void App::ShutdownSortResources() {
if (m_buildSortKeyDescriptorSet != nullptr) {
m_buildSortKeyDescriptorSet->Shutdown();
delete m_buildSortKeyDescriptorSet;
m_buildSortKeyDescriptorSet = nullptr;
}
ShutdownAndDelete(m_buildSortKeyDescriptorPool);
ShutdownAndDelete(m_buildSortKeyPipelineState);
ShutdownAndDelete(m_buildSortKeyPipelineLayout);
if (m_radixSortDescriptorSetPrimaryToScratch != nullptr) {
m_radixSortDescriptorSetPrimaryToScratch->Shutdown();
delete m_radixSortDescriptorSetPrimaryToScratch;
m_radixSortDescriptorSetPrimaryToScratch = nullptr;
}
if (m_radixSortDescriptorSetScratchToPrimary != nullptr) {
m_radixSortDescriptorSetScratchToPrimary->Shutdown();
delete m_radixSortDescriptorSetScratchToPrimary;
m_radixSortDescriptorSetScratchToPrimary = nullptr;
}
ShutdownAndDelete(m_radixSortDescriptorPool);
ShutdownAndDelete(m_radixSortDownsweepPipelineState);
ShutdownAndDelete(m_radixSortScanPipelineState);
ShutdownAndDelete(m_radixSortGlobalHistogramPipelineState);
ShutdownAndDelete(m_radixSortUpsweepPipelineState);
ShutdownAndDelete(m_radixSortInitPipelineState);
ShutdownAndDelete(m_radixSortPipelineLayout);
m_globalHistogramUav.reset();
m_passHistogramUav.reset();
m_orderScratchUav.reset();
m_orderBufferUav.reset();
m_orderBufferSrv.reset();
m_sortKeyScratchUav.reset();
m_sortKeyUav.reset();
if (m_orderBuffer != nullptr) {
m_orderBuffer->Shutdown();
delete m_orderBuffer;
m_orderBuffer = nullptr;
}
if (m_orderScratchBuffer != nullptr) {
m_orderScratchBuffer->Shutdown();
delete m_orderScratchBuffer;
m_orderScratchBuffer = nullptr;
}
if (m_sortKeyBuffer != nullptr) {
m_sortKeyBuffer->Shutdown();
delete m_sortKeyBuffer;
m_sortKeyBuffer = nullptr;
}
if (m_sortKeyScratchBuffer != nullptr) {
m_sortKeyScratchBuffer->Shutdown();
delete m_sortKeyScratchBuffer;
m_sortKeyScratchBuffer = nullptr;
}
if (m_passHistogramBuffer != nullptr) {
m_passHistogramBuffer->Shutdown();
delete m_passHistogramBuffer;
m_passHistogramBuffer = nullptr;
}
if (m_globalHistogramBuffer != nullptr) {
m_globalHistogramBuffer->Shutdown();
delete m_globalHistogramBuffer;
m_globalHistogramBuffer = nullptr;
}
}
void App::ShutdownDebugDrawResources() {
if (m_debugDescriptorSet != nullptr) {
m_debugDescriptorSet->Shutdown();
delete m_debugDescriptorSet;
m_debugDescriptorSet = nullptr;
}
ShutdownAndDelete(m_debugDescriptorPool);
ShutdownAndDelete(m_debugPipelineState);
ShutdownAndDelete(m_debugPipelineLayout);
}
void App::Shutdown() {
AppendTrace("Shutdown: begin");
if (!m_isInitialized && m_hwnd == nullptr) {
AppendTrace("Shutdown: skipped");
return;
}
m_running = false;
if (m_commandQueue.GetCommandQueue() != nullptr) {
AppendTrace("Shutdown: WaitForIdle");
m_commandQueue.WaitForIdle();
}
ShutdownDebugDrawResources();
ShutdownSortResources();
ShutdownPreparePassResources();
ShutdownGaussianGpuResources();
m_commandList.Shutdown();
m_commandAllocator.Shutdown();
m_dsv.Shutdown();
for (D3D12ResourceView& rtv : m_rtvs) {
rtv.Shutdown();
}
m_dsvHeap.Shutdown();
m_rtvHeap.Shutdown();
m_depthStencil.Shutdown();
m_swapChain.Shutdown();
m_commandQueue.Shutdown();
m_device.Shutdown();
if (m_hwnd != nullptr) {
AppendTrace("Shutdown: DestroyWindow");
DestroyWindow(m_hwnd);
m_hwnd = nullptr;
}
if (m_instance != nullptr) {
UnregisterClassW(kWindowClassName, m_instance);
m_instance = nullptr;
}
m_isInitialized = false;
AppendTrace("Shutdown: end");
}
bool App::CaptureSortSnapshot() {
if (m_sortKeyBuffer == nullptr || m_orderBuffer == nullptr || m_gaussianSceneData.splatCount == 0 || m_sortKeySnapshotPath.empty()) {
return true;
}
const uint32_t sampleCount = std::min<uint32_t>(16u, m_gaussianSceneData.splatCount);
const uint64_t keyBufferBytes = static_cast<uint64_t>(m_gaussianSceneData.splatCount) * sizeof(uint32_t);
const uint64_t sampleBytes = static_cast<uint64_t>(sampleCount) * sizeof(uint32_t);
D3D12Buffer sortKeyReadbackBuffer;
if (!sortKeyReadbackBuffer.Initialize(
m_device.GetDevice(),
keyBufferBytes,
D3D12_RESOURCE_STATE_COPY_DEST,
D3D12_HEAP_TYPE_READBACK,
D3D12_RESOURCE_FLAG_NONE)) {
m_lastErrorMessage = L"Failed to create the sort key readback buffer.";
return false;
}
D3D12Buffer orderReadbackBuffer;
if (!orderReadbackBuffer.Initialize(
m_device.GetDevice(),
sampleBytes,
D3D12_RESOURCE_STATE_COPY_DEST,
D3D12_HEAP_TYPE_READBACK,
D3D12_RESOURCE_FLAG_NONE)) {
m_lastErrorMessage = L"Failed to create the order readback buffer.";
sortKeyReadbackBuffer.Shutdown();
return false;
}
m_commandAllocator.Reset();
m_commandList.Reset();
if (m_sortKeyBuffer->GetState() != ResourceStates::CopySrc) {
m_commandList.TransitionBarrier(
m_sortKeyBuffer->GetResource(),
m_sortKeyBuffer->GetState(),
ResourceStates::CopySrc);
m_sortKeyBuffer->SetState(ResourceStates::CopySrc);
}
if (m_orderBuffer->GetState() != ResourceStates::CopySrc) {
m_commandList.TransitionBarrier(
m_orderBuffer->GetResource(),
m_orderBuffer->GetState(),
ResourceStates::CopySrc);
m_orderBuffer->SetState(ResourceStates::CopySrc);
}
m_commandList.GetCommandList()->CopyBufferRegion(
sortKeyReadbackBuffer.GetResource(),
0,
m_sortKeyBuffer->GetResource(),
0,
keyBufferBytes);
m_commandList.GetCommandList()->CopyBufferRegion(
orderReadbackBuffer.GetResource(),
0,
m_orderBuffer->GetResource(),
0,
sampleBytes);
m_commandList.Close();
void* commandLists[] = { &m_commandList };
m_commandQueue.ExecuteCommandLists(1, commandLists);
m_commandQueue.WaitForIdle();
const uint32_t* keys = static_cast<const uint32_t*>(sortKeyReadbackBuffer.Map());
if (keys == nullptr) {
m_lastErrorMessage = L"Failed to map the sort key readback buffer.";
sortKeyReadbackBuffer.Shutdown();
orderReadbackBuffer.Shutdown();
return false;
}
const uint32_t* order = static_cast<const uint32_t*>(orderReadbackBuffer.Map());
if (order == nullptr) {
sortKeyReadbackBuffer.Unmap();
sortKeyReadbackBuffer.Shutdown();
orderReadbackBuffer.Shutdown();
m_lastErrorMessage = L"Failed to map the order readback buffer.";
return false;
}
const std::filesystem::path snapshotPath = ResolveNearExecutable(m_sortKeySnapshotPath);
if (!snapshotPath.parent_path().empty()) {
std::filesystem::create_directories(snapshotPath.parent_path());
}
std::ofstream output(snapshotPath, std::ios::binary | std::ios::trunc);
if (!output.is_open()) {
orderReadbackBuffer.Unmap();
sortKeyReadbackBuffer.Unmap();
orderReadbackBuffer.Shutdown();
sortKeyReadbackBuffer.Shutdown();
m_lastErrorMessage = L"Failed to open the sort key snapshot output file.";
return false;
}
output << "sample_count=" << sampleCount << '\n';
bool isSorted = true;
uint32_t firstInversionIndex = 0u;
uint32_t firstInversionPrevious = 0u;
uint32_t firstInversionCurrent = 0u;
for (uint32_t index = 1; index < m_gaussianSceneData.splatCount; ++index) {
if (keys[index - 1u] > keys[index]) {
isSorted = false;
firstInversionIndex = index;
firstInversionPrevious = keys[index - 1u];
firstInversionCurrent = keys[index];
break;
}
}
std::vector<uint32_t> cpuReferenceOrder;
std::vector<uint32_t> cpuReferenceKeys;
BuildCpuSortedOrder(m_gaussianSceneData, cpuReferenceOrder, &cpuReferenceKeys);
uint32_t firstCpuMismatchIndex = 0u;
uint32_t firstCpuMismatchGpu = 0u;
uint32_t firstCpuMismatchCpu = 0u;
bool matchesCpuReference = true;
for (uint32_t index = 0; index < m_gaussianSceneData.splatCount; ++index) {
if (keys[index] != cpuReferenceKeys[index]) {
matchesCpuReference = false;
firstCpuMismatchIndex = index;
firstCpuMismatchGpu = keys[index];
firstCpuMismatchCpu = cpuReferenceKeys[index];
break;
}
}
output << "sorted=" << (isSorted ? 1 : 0) << '\n';
output << "first_inversion_index=" << firstInversionIndex << '\n';
output << "first_inversion_prev=" << firstInversionPrevious << '\n';
output << "first_inversion_curr=" << firstInversionCurrent << '\n';
output << "matches_cpu_reference=" << (matchesCpuReference ? 1 : 0) << '\n';
output << "first_cpu_mismatch_index=" << firstCpuMismatchIndex << '\n';
output << "first_cpu_mismatch_gpu=" << firstCpuMismatchGpu << '\n';
output << "first_cpu_mismatch_cpu=" << firstCpuMismatchCpu << '\n';
for (uint32_t index = 0; index < sampleCount; ++index) {
output << "key[" << index << "]=" << keys[index] << '\n';
output << "order[" << index << "]=" << order[index] << '\n';
output << "cpu_order[" << index << "]=" << cpuReferenceOrder[index] << '\n';
output << "cpu_key[" << index << "]=" << cpuReferenceKeys[index] << '\n';
}
orderReadbackBuffer.Unmap();
sortKeyReadbackBuffer.Unmap();
orderReadbackBuffer.Shutdown();
sortKeyReadbackBuffer.Shutdown();
return output.good();
}
bool App::CapturePass3HistogramDebug() {
if (m_sortKeyScratchBuffer == nullptr ||
m_sortKeyBuffer == nullptr ||
m_globalHistogramBuffer == nullptr ||
m_gaussianSceneData.splatCount == 0) {
return true;
}
const uint64_t keyBufferBytes = static_cast<uint64_t>(m_gaussianSceneData.splatCount) * sizeof(uint32_t);
const uint32_t histogramElementCount = kDeviceRadixSortRadix * kDeviceRadixSortPassCount;
const uint64_t histogramBytes = static_cast<uint64_t>(histogramElementCount) * sizeof(uint32_t);
D3D12Buffer keyReadbackBuffer;
if (!keyReadbackBuffer.Initialize(
m_device.GetDevice(),
keyBufferBytes,
D3D12_RESOURCE_STATE_COPY_DEST,
D3D12_HEAP_TYPE_READBACK,
D3D12_RESOURCE_FLAG_NONE)) {
m_lastErrorMessage = L"Failed to create the pass3 histogram key readback buffer.";
return false;
}
D3D12Buffer histogramReadbackBuffer;
if (!histogramReadbackBuffer.Initialize(
m_device.GetDevice(),
histogramBytes,
D3D12_RESOURCE_STATE_COPY_DEST,
D3D12_HEAP_TYPE_READBACK,
D3D12_RESOURCE_FLAG_NONE)) {
keyReadbackBuffer.Shutdown();
m_lastErrorMessage = L"Failed to create the pass3 histogram readback buffer.";
return false;
}
D3D12Buffer primaryKeyReadbackBuffer;
if (!primaryKeyReadbackBuffer.Initialize(
m_device.GetDevice(),
keyBufferBytes,
D3D12_RESOURCE_STATE_COPY_DEST,
D3D12_HEAP_TYPE_READBACK,
D3D12_RESOURCE_FLAG_NONE)) {
histogramReadbackBuffer.Shutdown();
keyReadbackBuffer.Shutdown();
m_lastErrorMessage = L"Failed to create the pass3 primary-key readback buffer.";
return false;
}
m_commandAllocator.Reset();
m_commandList.Reset();
if (m_sortKeyScratchBuffer->GetState() != ResourceStates::CopySrc) {
m_commandList.TransitionBarrier(
m_sortKeyScratchBuffer->GetResource(),
m_sortKeyScratchBuffer->GetState(),
ResourceStates::CopySrc);
m_sortKeyScratchBuffer->SetState(ResourceStates::CopySrc);
}
if (m_sortKeyBuffer->GetState() != ResourceStates::CopySrc) {
m_commandList.TransitionBarrier(
m_sortKeyBuffer->GetResource(),
m_sortKeyBuffer->GetState(),
ResourceStates::CopySrc);
m_sortKeyBuffer->SetState(ResourceStates::CopySrc);
}
if (m_globalHistogramBuffer->GetState() != ResourceStates::CopySrc) {
m_commandList.TransitionBarrier(
m_globalHistogramBuffer->GetResource(),
m_globalHistogramBuffer->GetState(),
ResourceStates::CopySrc);
m_globalHistogramBuffer->SetState(ResourceStates::CopySrc);
}
m_commandList.GetCommandList()->CopyBufferRegion(
keyReadbackBuffer.GetResource(),
0,
m_sortKeyScratchBuffer->GetResource(),
0,
keyBufferBytes);
m_commandList.GetCommandList()->CopyBufferRegion(
primaryKeyReadbackBuffer.GetResource(),
0,
m_sortKeyBuffer->GetResource(),
0,
keyBufferBytes);
m_commandList.GetCommandList()->CopyBufferRegion(
histogramReadbackBuffer.GetResource(),
0,
m_globalHistogramBuffer->GetResource(),
0,
histogramBytes);
m_commandList.Close();
void* commandLists[] = { &m_commandList };
m_commandQueue.ExecuteCommandLists(1, commandLists);
m_commandQueue.WaitForIdle();
const uint32_t* keys = static_cast<const uint32_t*>(keyReadbackBuffer.Map());
if (keys == nullptr) {
primaryKeyReadbackBuffer.Shutdown();
histogramReadbackBuffer.Shutdown();
keyReadbackBuffer.Shutdown();
m_lastErrorMessage = L"Failed to map the pass3 key readback buffer.";
return false;
}
const uint32_t* primaryKeys = static_cast<const uint32_t*>(primaryKeyReadbackBuffer.Map());
if (primaryKeys == nullptr) {
keyReadbackBuffer.Unmap();
primaryKeyReadbackBuffer.Shutdown();
histogramReadbackBuffer.Shutdown();
keyReadbackBuffer.Shutdown();
m_lastErrorMessage = L"Failed to map the pass3 primary-key readback buffer.";
return false;
}
const uint32_t* histogram = static_cast<const uint32_t*>(histogramReadbackBuffer.Map());
if (histogram == nullptr) {
primaryKeyReadbackBuffer.Unmap();
keyReadbackBuffer.Unmap();
primaryKeyReadbackBuffer.Shutdown();
histogramReadbackBuffer.Shutdown();
keyReadbackBuffer.Shutdown();
m_lastErrorMessage = L"Failed to map the pass3 histogram readback buffer.";
return false;
}
std::vector<uint32_t> cpuCounts(kDeviceRadixSortRadix, 0u);
std::vector<uint32_t> primaryCpuCounts(kDeviceRadixSortRadix, 0u);
for (uint32_t index = 0; index < m_gaussianSceneData.splatCount; ++index) {
const uint32_t bin = (keys[index] >> 24u) & 0xffu;
++cpuCounts[bin];
const uint32_t primaryBin = (primaryKeys[index] >> 24u) & 0xffu;
++primaryCpuCounts[primaryBin];
}
std::vector<uint32_t> cpuExclusiveOffsets(kDeviceRadixSortRadix, 0u);
std::vector<uint32_t> primaryCpuExclusiveOffsets(kDeviceRadixSortRadix, 0u);
uint32_t runningOffset = 0u;
uint32_t primaryRunningOffset = 0u;
for (uint32_t bin = 0; bin < kDeviceRadixSortRadix; ++bin) {
cpuExclusiveOffsets[bin] = runningOffset;
runningOffset += cpuCounts[bin];
primaryCpuExclusiveOffsets[bin] = primaryRunningOffset;
primaryRunningOffset += primaryCpuCounts[bin];
}
const std::filesystem::path debugPath = ResolveNearExecutable(L"phase3_hist_debug.txt");
if (!debugPath.parent_path().empty()) {
std::filesystem::create_directories(debugPath.parent_path());
}
std::ofstream output(debugPath, std::ios::binary | std::ios::trunc);
if (!output.is_open()) {
histogramReadbackBuffer.Unmap();
keyReadbackBuffer.Unmap();
histogramReadbackBuffer.Shutdown();
keyReadbackBuffer.Shutdown();
m_lastErrorMessage = L"Failed to open the pass3 histogram debug output file.";
return false;
}
uint32_t mismatchCount = 0u;
uint32_t primaryMismatchCount = 0u;
constexpr uint32_t kPass3GlobalHistogramBase = kDeviceRadixSortRadix * 3u;
output << "splat_count=" << m_gaussianSceneData.splatCount << '\n';
output << "final_offset=" << runningOffset << '\n';
output << "primary_final_offset=" << primaryRunningOffset << '\n';
for (uint32_t bin = 0; bin < kDeviceRadixSortRadix; ++bin) {
const uint32_t gpuOffset = histogram[kPass3GlobalHistogramBase + bin];
const uint32_t cpuOffset = cpuExclusiveOffsets[bin];
const uint32_t primaryCpuOffset = primaryCpuExclusiveOffsets[bin];
if (gpuOffset != cpuOffset) {
++mismatchCount;
}
if (gpuOffset != primaryCpuOffset) {
++primaryMismatchCount;
}
output << "bin[" << bin << "].count=" << cpuCounts[bin]
<< " cpu_offset=" << cpuOffset
<< " primary_count=" << primaryCpuCounts[bin]
<< " primary_cpu_offset=" << primaryCpuOffset
<< " gpu_offset=" << gpuOffset << '\n';
}
output << "mismatch_count=" << mismatchCount << '\n';
output << "primary_mismatch_count=" << primaryMismatchCount << '\n';
histogramReadbackBuffer.Unmap();
primaryKeyReadbackBuffer.Unmap();
keyReadbackBuffer.Unmap();
primaryKeyReadbackBuffer.Shutdown();
histogramReadbackBuffer.Shutdown();
keyReadbackBuffer.Shutdown();
return output.good();
}
void App::RenderFrame(bool captureScreenshot) {
AppendTrace(captureScreenshot ? "RenderFrame: begin capture" : "RenderFrame: begin");
if (m_hasRenderedAtLeastOneFrame) {
AppendTrace("RenderFrame: WaitForPreviousFrame");
m_commandQueue.WaitForPreviousFrame();
}
m_commandAllocator.Reset();
m_commandList.Reset();
const int currentBackBufferIndex = m_swapChain.GetCurrentBackBufferIndex();
D3D12Texture& backBuffer = m_swapChain.GetBackBuffer(currentBackBufferIndex);
m_commandList.TransitionBarrier(backBuffer.GetResource(), ResourceStates::Present, ResourceStates::RenderTarget);
const CPUDescriptorHandle rtvCpuHandle = m_rtvHeap.GetCPUDescriptorHandle(currentBackBufferIndex);
const CPUDescriptorHandle dsvCpuHandle = m_dsvHeap.GetCPUDescriptorHandle(0);
const D3D12_CPU_DESCRIPTOR_HANDLE rtvHandle = { rtvCpuHandle.ptr };
const D3D12_CPU_DESCRIPTOR_HANDLE dsvHandle = { dsvCpuHandle.ptr };
m_commandList.SetRenderTargetsHandle(1, &rtvHandle, &dsvHandle);
const Viewport viewport = { 0.0f, 0.0f, static_cast<float>(m_width), static_cast<float>(m_height), 0.0f, 1.0f };
const Rect scissorRect = { 0, 0, m_width, m_height };
m_commandList.SetViewport(viewport);
m_commandList.SetScissorRect(scissorRect);
m_commandList.ClearRenderTargetView(rtvHandle, kClearColor, 0, nullptr);
m_commandList.ClearDepthStencilView(
dsvHandle,
D3D12_CLEAR_FLAG_DEPTH | D3D12_CLEAR_FLAG_STENCIL,
1.0f,
0,
0,
nullptr);
const FrameConstants frameConstants = BuildFrameConstants(
static_cast<uint32_t>(m_width),
static_cast<uint32_t>(m_height),
m_gaussianSceneData.splatCount);
const uint32_t threadBlocks =
static_cast<uint32_t>((m_gaussianSceneData.splatCount + (kDeviceRadixSortPartitionSize - 1u)) / kDeviceRadixSortPartitionSize);
const uint32_t passHistogramElementCount = threadBlocks * kDeviceRadixSortRadix;
constexpr uint32_t kSortDebugStageCount = 5u;
const char* const sortDebugStageNames[kSortDebugStageCount] = {
"build_sort_keys",
"radix_pass_0",
"radix_pass_1",
"radix_pass_2",
"radix_pass_3",
};
const uint32_t sortDebugSampleCount = 0u;
const uint64_t sortDebugSampleBytes = static_cast<uint64_t>(sortDebugSampleCount) * sizeof(uint32_t);
D3D12Buffer sortDebugKeyReadbackBuffer;
D3D12Buffer sortDebugOrderReadbackBuffer;
const bool sortDebugEnabled =
sortDebugSampleCount > 0 &&
sortDebugKeyReadbackBuffer.Initialize(
m_device.GetDevice(),
sortDebugSampleBytes * kSortDebugStageCount,
D3D12_RESOURCE_STATE_COPY_DEST,
D3D12_HEAP_TYPE_READBACK,
D3D12_RESOURCE_FLAG_NONE) &&
sortDebugOrderReadbackBuffer.Initialize(
m_device.GetDevice(),
sortDebugSampleBytes * kSortDebugStageCount,
D3D12_RESOURCE_STATE_COPY_DEST,
D3D12_HEAP_TYPE_READBACK,
D3D12_RESOURCE_FLAG_NONE);
D3D12Buffer pass3PassHistogramReadbackBuffer;
D3D12Buffer pass3PreScanGlobalHistogramReadbackBuffer;
const uint64_t pass3PassHistogramReadbackBytes =
static_cast<uint64_t>(passHistogramElementCount) * sizeof(uint32_t);
const uint64_t pass3PreScanGlobalHistogramReadbackBytes =
static_cast<uint64_t>(kDeviceRadixSortRadix * kDeviceRadixSortPassCount) * sizeof(uint32_t);
const bool pass3PassHistogramDebugEnabled =
captureScreenshot &&
pass3PassHistogramReadbackBuffer.Initialize(
m_device.GetDevice(),
pass3PassHistogramReadbackBytes,
D3D12_RESOURCE_STATE_COPY_DEST,
D3D12_HEAP_TYPE_READBACK,
D3D12_RESOURCE_FLAG_NONE);
const bool pass3PreScanGlobalHistogramDebugEnabled =
captureScreenshot &&
pass3PreScanGlobalHistogramReadbackBuffer.Initialize(
m_device.GetDevice(),
pass3PreScanGlobalHistogramReadbackBytes,
D3D12_RESOURCE_STATE_COPY_DEST,
D3D12_HEAP_TYPE_READBACK,
D3D12_RESOURCE_FLAG_NONE);
bool pass3PassHistogramCaptured = false;
bool pass3PreScanGlobalHistogramCaptured = false;
auto transitionBuffer = [this](D3D12Buffer* buffer, ResourceStates newState) {
if (buffer != nullptr && buffer->GetState() != newState) {
m_commandList.TransitionBarrier(
buffer->GetResource(),
buffer->GetState(),
newState);
buffer->SetState(newState);
}
};
auto recordSortDebugStage = [&](uint32_t stageIndex,
D3D12Buffer* keyBuffer,
D3D12Buffer* orderBuffer,
ResourceStates keyRestoreState,
ResourceStates orderRestoreState) {
if (!sortDebugEnabled ||
keyBuffer == nullptr ||
orderBuffer == nullptr ||
sortDebugSampleBytes == 0 ||
stageIndex >= kSortDebugStageCount) {
return;
}
transitionBuffer(keyBuffer, ResourceStates::CopySrc);
transitionBuffer(orderBuffer, ResourceStates::CopySrc);
m_commandList.GetCommandList()->CopyBufferRegion(
sortDebugKeyReadbackBuffer.GetResource(),
static_cast<uint64_t>(stageIndex) * sortDebugSampleBytes,
keyBuffer->GetResource(),
0,
sortDebugSampleBytes);
m_commandList.GetCommandList()->CopyBufferRegion(
sortDebugOrderReadbackBuffer.GetResource(),
static_cast<uint64_t>(stageIndex) * sortDebugSampleBytes,
orderBuffer->GetResource(),
0,
sortDebugSampleBytes);
transitionBuffer(keyBuffer, keyRestoreState);
transitionBuffer(orderBuffer, orderRestoreState);
};
m_prepareDescriptorSet->WriteConstant(0, &frameConstants, sizeof(frameConstants));
m_debugDescriptorSet->WriteConstant(0, &frameConstants, sizeof(frameConstants));
transitionBuffer(m_preparedViewBuffer, ResourceStates::UnorderedAccess);
m_commandList.SetPipelineState(m_preparePipelineState);
RHIDescriptorSet* prepareSets[] = { m_prepareDescriptorSet };
m_commandList.SetComputeDescriptorSets(0, 1, prepareSets, m_preparePipelineLayout);
m_commandList.Dispatch((m_gaussianSceneData.splatCount + (kPrepareThreadGroupSize - 1)) / kPrepareThreadGroupSize, 1, 1);
m_commandList.UAVBarrier(m_preparedViewBuffer->GetResource());
transitionBuffer(m_orderBuffer, ResourceStates::NonPixelShaderResource);
transitionBuffer(m_sortKeyBuffer, ResourceStates::UnorderedAccess);
m_buildSortKeyDescriptorSet->WriteConstant(0, &frameConstants, sizeof(frameConstants));
m_commandList.SetPipelineState(m_buildSortKeyPipelineState);
RHIDescriptorSet* buildSortKeySets[] = { m_buildSortKeyDescriptorSet };
m_commandList.SetComputeDescriptorSets(0, 1, buildSortKeySets, m_buildSortKeyPipelineLayout);
m_commandList.Dispatch((m_gaussianSceneData.splatCount + (kSortThreadGroupSize - 1)) / kSortThreadGroupSize, 1, 1);
m_commandList.UAVBarrier(m_sortKeyBuffer->GetResource());
recordSortDebugStage(
0u,
m_sortKeyBuffer,
m_orderBuffer,
ResourceStates::UnorderedAccess,
ResourceStates::NonPixelShaderResource);
if (!kUseCpuSortBaseline) {
transitionBuffer(m_orderBuffer, ResourceStates::UnorderedAccess);
transitionBuffer(m_orderScratchBuffer, ResourceStates::UnorderedAccess);
transitionBuffer(m_sortKeyScratchBuffer, ResourceStates::UnorderedAccess);
transitionBuffer(m_passHistogramBuffer, ResourceStates::UnorderedAccess);
transitionBuffer(m_globalHistogramBuffer, ResourceStates::UnorderedAccess);
RadixSortConstants radixSortConstants = {};
radixSortConstants.numKeys = m_gaussianSceneData.splatCount;
radixSortConstants.threadBlocks = threadBlocks;
m_radixSortDescriptorSetPrimaryToScratch->WriteConstant(0, &radixSortConstants, sizeof(radixSortConstants));
m_commandList.SetPipelineState(m_radixSortInitPipelineState);
RHIDescriptorSet* radixInitSets[] = { m_radixSortDescriptorSetPrimaryToScratch };
m_commandList.SetComputeDescriptorSets(0, 1, radixInitSets, m_radixSortPipelineLayout);
m_commandList.Dispatch(1, 1, 1);
m_commandList.UAVBarrier(m_globalHistogramBuffer->GetResource());
for (uint32_t passIndex = 0; passIndex < kDeviceRadixSortPassCount; ++passIndex) {
radixSortConstants.radixShift = passIndex * 8u;
RHIDescriptorSet* activeRadixSet =
(passIndex & 1u) == 0u
? m_radixSortDescriptorSetPrimaryToScratch
: m_radixSortDescriptorSetScratchToPrimary;
D3D12Buffer* destinationKeyBuffer =
(passIndex & 1u) == 0u
? m_sortKeyScratchBuffer
: m_sortKeyBuffer;
D3D12Buffer* destinationOrderBuffer =
(passIndex & 1u) == 0u
? m_orderScratchBuffer
: m_orderBuffer;
activeRadixSet->WriteConstant(0, &radixSortConstants, sizeof(radixSortConstants));
m_commandList.SetPipelineState(m_radixSortUpsweepPipelineState);
m_commandList.SetComputeDescriptorSets(0, 1, &activeRadixSet, m_radixSortPipelineLayout);
m_commandList.Dispatch(threadBlocks, 1, 1);
m_commandList.UAVBarrier(m_passHistogramBuffer->GetResource());
if (pass3PassHistogramDebugEnabled && passIndex == 3u && !pass3PassHistogramCaptured) {
transitionBuffer(m_passHistogramBuffer, ResourceStates::CopySrc);
m_commandList.GetCommandList()->CopyBufferRegion(
pass3PassHistogramReadbackBuffer.GetResource(),
0,
m_passHistogramBuffer->GetResource(),
0,
pass3PassHistogramReadbackBytes);
transitionBuffer(m_passHistogramBuffer, ResourceStates::UnorderedAccess);
pass3PassHistogramCaptured = true;
}
m_commandList.SetPipelineState(m_radixSortGlobalHistogramPipelineState);
m_commandList.SetComputeDescriptorSets(0, 1, &activeRadixSet, m_radixSortPipelineLayout);
m_commandList.Dispatch(1, 1, 1);
m_commandList.UAVBarrier(m_globalHistogramBuffer->GetResource());
if (pass3PreScanGlobalHistogramDebugEnabled && passIndex == 3u && !pass3PreScanGlobalHistogramCaptured) {
transitionBuffer(m_globalHistogramBuffer, ResourceStates::CopySrc);
m_commandList.GetCommandList()->CopyBufferRegion(
pass3PreScanGlobalHistogramReadbackBuffer.GetResource(),
0,
m_globalHistogramBuffer->GetResource(),
0,
pass3PreScanGlobalHistogramReadbackBytes);
transitionBuffer(m_globalHistogramBuffer, ResourceStates::UnorderedAccess);
pass3PreScanGlobalHistogramCaptured = true;
}
m_commandList.SetPipelineState(m_radixSortScanPipelineState);
m_commandList.SetComputeDescriptorSets(0, 1, &activeRadixSet, m_radixSortPipelineLayout);
m_commandList.Dispatch(kDeviceRadixSortRadix, 1, 1);
m_commandList.UAVBarrier(m_passHistogramBuffer->GetResource());
m_commandList.SetPipelineState(m_radixSortDownsweepPipelineState);
m_commandList.SetComputeDescriptorSets(0, 1, &activeRadixSet, m_radixSortPipelineLayout);
m_commandList.Dispatch(threadBlocks, 1, 1);
m_commandList.UAVBarrier(destinationKeyBuffer->GetResource());
m_commandList.UAVBarrier(destinationOrderBuffer->GetResource());
recordSortDebugStage(
passIndex + 1u,
destinationKeyBuffer,
destinationOrderBuffer,
ResourceStates::UnorderedAccess,
ResourceStates::UnorderedAccess);
}
}
transitionBuffer(m_preparedViewBuffer, ResourceStates::NonPixelShaderResource);
transitionBuffer(m_orderBuffer, ResourceStates::NonPixelShaderResource);
m_commandList.SetPipelineState(m_debugPipelineState);
RHIDescriptorSet* debugSets[] = { m_debugDescriptorSet };
m_commandList.SetGraphicsDescriptorSets(0, 1, debugSets, m_debugPipelineLayout);
m_commandList.SetPrimitiveTopology(PrimitiveTopology::TriangleStrip);
m_commandList.Draw(4u, m_gaussianSceneData.splatCount, 0u, 0u);
if (captureScreenshot) {
AppendTrace("RenderFrame: close+execute capture pre-screenshot");
m_commandList.Close();
void* commandLists[] = { &m_commandList };
m_commandQueue.ExecuteCommandLists(1, commandLists);
AppendTrace("RenderFrame: WaitForIdle before screenshot");
m_commandQueue.WaitForIdle();
if (sortDebugEnabled) {
const std::filesystem::path debugPath = ResolveNearExecutable(L"phase3_sort_debug.txt");
std::ofstream debugOutput(debugPath, std::ios::binary | std::ios::trunc);
const uint32_t* stageKeys = static_cast<const uint32_t*>(sortDebugKeyReadbackBuffer.Map());
const uint32_t* stageOrder = static_cast<const uint32_t*>(sortDebugOrderReadbackBuffer.Map());
if (debugOutput.is_open() && stageKeys != nullptr && stageOrder != nullptr) {
debugOutput << "sample_count=" << sortDebugSampleCount << '\n';
for (uint32_t stageIndex = 0; stageIndex < kSortDebugStageCount; ++stageIndex) {
debugOutput << "stage=" << sortDebugStageNames[stageIndex] << '\n';
const uint32_t* stageKeyBase = stageKeys + stageIndex * sortDebugSampleCount;
const uint32_t* stageOrderBase = stageOrder + stageIndex * sortDebugSampleCount;
for (uint32_t sampleIndex = 0; sampleIndex < sortDebugSampleCount; ++sampleIndex) {
debugOutput << "key[" << sampleIndex << "]=" << stageKeyBase[sampleIndex] << '\n';
debugOutput << "order[" << sampleIndex << "]=" << stageOrderBase[sampleIndex] << '\n';
}
}
}
if (stageOrder != nullptr) {
sortDebugOrderReadbackBuffer.Unmap();
}
if (stageKeys != nullptr) {
sortDebugKeyReadbackBuffer.Unmap();
}
}
if (pass3PassHistogramDebugEnabled && pass3PassHistogramCaptured) {
const uint32_t* passHistogram =
static_cast<const uint32_t*>(pass3PassHistogramReadbackBuffer.Map());
if (passHistogram != nullptr) {
const std::filesystem::path passHistogramDebugPath =
ResolveNearExecutable(L"phase3_passhist_debug.txt");
if (!passHistogramDebugPath.parent_path().empty()) {
std::filesystem::create_directories(passHistogramDebugPath.parent_path());
}
std::ofstream passHistogramOutput(
passHistogramDebugPath,
std::ios::binary | std::ios::trunc);
if (passHistogramOutput.is_open()) {
uint64_t totalCount = 0u;
passHistogramOutput << "thread_blocks=" << threadBlocks << '\n';
for (uint32_t bin = 0; bin < kDeviceRadixSortRadix; ++bin) {
uint64_t binTotal = 0u;
for (uint32_t blockIndex = 0; blockIndex < threadBlocks; ++blockIndex) {
binTotal += passHistogram[bin * threadBlocks + blockIndex];
}
totalCount += binTotal;
passHistogramOutput << "bin[" << bin << "]=" << binTotal << '\n';
}
passHistogramOutput << "total_count=" << totalCount << '\n';
}
pass3PassHistogramReadbackBuffer.Unmap();
}
}
if (pass3PreScanGlobalHistogramDebugEnabled && pass3PreScanGlobalHistogramCaptured) {
const uint32_t* preScanGlobalHistogram =
static_cast<const uint32_t*>(pass3PreScanGlobalHistogramReadbackBuffer.Map());
if (preScanGlobalHistogram != nullptr) {
const std::filesystem::path preScanGlobalHistogramPath =
ResolveNearExecutable(L"phase3_pre_scan_globalhist.txt");
if (!preScanGlobalHistogramPath.parent_path().empty()) {
std::filesystem::create_directories(preScanGlobalHistogramPath.parent_path());
}
std::ofstream preScanGlobalHistogramOutput(
preScanGlobalHistogramPath,
std::ios::binary | std::ios::trunc);
if (preScanGlobalHistogramOutput.is_open()) {
constexpr uint32_t kPass3GlobalHistogramBase = kDeviceRadixSortRadix * 3u;
for (uint32_t bin = 0; bin < kDeviceRadixSortRadix; ++bin) {
preScanGlobalHistogramOutput
<< "bin[" << bin << "]="
<< preScanGlobalHistogram[kPass3GlobalHistogramBase + bin]
<< '\n';
}
}
pass3PreScanGlobalHistogramReadbackBuffer.Unmap();
}
}
AppendTrace("RenderFrame: Capture sort snapshot");
if (!CaptureSortSnapshot()) {
AppendTrace(std::string("RenderFrame: Capture sort snapshot failed: ") + NarrowAscii(m_lastErrorMessage));
}
if (!kUseCpuSortBaseline) {
AppendTrace("RenderFrame: Capture pass3 histogram debug");
if (!CapturePass3HistogramDebug()) {
AppendTrace(std::string("RenderFrame: Capture pass3 histogram debug failed: ") + NarrowAscii(m_lastErrorMessage));
}
}
if (!m_screenshotPath.empty()) {
const std::filesystem::path screenshotPath = ResolveNearExecutable(m_screenshotPath);
if (!screenshotPath.parent_path().empty()) {
std::filesystem::create_directories(screenshotPath.parent_path());
}
AppendTrace("RenderFrame: Capture screenshot");
D3D12Screenshot::Capture(
m_device,
m_commandQueue,
backBuffer,
screenshotPath.string().c_str());
AppendTrace("RenderFrame: Capture screenshot done");
}
m_commandAllocator.Reset();
m_commandList.Reset();
m_commandList.TransitionBarrier(backBuffer.GetResource(), ResourceStates::RenderTarget, ResourceStates::Present);
m_commandList.Close();
void* presentCommandLists[] = { &m_commandList };
AppendTrace("RenderFrame: execute final present-transition list");
m_commandQueue.ExecuteCommandLists(1, presentCommandLists);
} else {
m_commandList.TransitionBarrier(backBuffer.GetResource(), ResourceStates::RenderTarget, ResourceStates::Present);
m_commandList.Close();
void* commandLists[] = { &m_commandList };
AppendTrace("RenderFrame: execute+present");
m_commandQueue.ExecuteCommandLists(1, commandLists);
m_swapChain.Present(1, 0);
}
if (sortDebugEnabled) {
sortDebugOrderReadbackBuffer.Shutdown();
sortDebugKeyReadbackBuffer.Shutdown();
}
if (pass3PassHistogramDebugEnabled) {
pass3PassHistogramReadbackBuffer.Shutdown();
}
if (pass3PreScanGlobalHistogramDebugEnabled) {
pass3PreScanGlobalHistogramReadbackBuffer.Shutdown();
}
m_hasRenderedAtLeastOneFrame = true;
AppendTrace("RenderFrame: end");
}
} // namespace XC3DGSD3D12