diff --git a/MVS/3DGS-D3D12/CMakeLists.txt b/MVS/3DGS-D3D12/CMakeLists.txt index 762692c4..83e4dbe4 100644 --- a/MVS/3DGS-D3D12/CMakeLists.txt +++ b/MVS/3DGS-D3D12/CMakeLists.txt @@ -6,6 +6,9 @@ set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) find_program(XC_DXC_EXECUTABLE NAMES dxc) +if(NOT XC_DXC_EXECUTABLE) + message(FATAL_ERROR "dxc is required to build the 3DGS D3D12 MVS sort shaders.") +endif() get_filename_component(XCENGINE_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/../.." ABSOLUTE) set(XCENGINE_BUILD_DIR "${XCENGINE_ROOT}/build") @@ -35,6 +38,8 @@ add_executable(xc_3dgs_d3d12_mvs shaders/PreparedSplatView.hlsli shaders/PrepareGaussiansCS.hlsl shaders/BuildSortKeysCS.hlsl + shaders/SortCommon.hlsl + shaders/DeviceRadixSort.hlsl shaders/DebugPointsVS.hlsl shaders/DebugPointsPS.hlsl ) @@ -43,6 +48,8 @@ set_source_files_properties( shaders/PreparedSplatView.hlsli shaders/PrepareGaussiansCS.hlsl shaders/BuildSortKeysCS.hlsl + shaders/SortCommon.hlsl + shaders/DeviceRadixSort.hlsl shaders/DebugPointsVS.hlsl shaders/DebugPointsPS.hlsl PROPERTIES @@ -93,12 +100,55 @@ add_custom_command(TARGET xc_3dgs_d3d12_mvs POST_BUILD "$/shaders" ) -if(XC_DXC_EXECUTABLE) - add_custom_command(TARGET xc_3dgs_d3d12_mvs POST_BUILD - COMMAND "${XC_DXC_EXECUTABLE}" - -T cs_6_6 - -E MainCS - -Fo "$/shaders/BuildSortKeysCS.dxil" - "${CMAKE_CURRENT_SOURCE_DIR}/shaders/BuildSortKeysCS.hlsl" - ) -endif() +add_custom_command(TARGET xc_3dgs_d3d12_mvs POST_BUILD + COMMAND "${XC_DXC_EXECUTABLE}" + -T cs_6_6 + -E MainCS + -Fo "$/shaders/BuildSortKeysCS.dxil" + "${CMAKE_CURRENT_SOURCE_DIR}/shaders/BuildSortKeysCS.hlsl" + COMMAND "${XC_DXC_EXECUTABLE}" + -T cs_6_6 + -E InitDeviceRadixSort + -D KEY_UINT=1 + -D PAYLOAD_UINT=1 + -D SORT_PAIRS=1 + -D SHOULD_ASCEND=1 + -Fo "$/shaders/RadixInit.dxil" + "${CMAKE_CURRENT_SOURCE_DIR}/shaders/DeviceRadixSort.hlsl" + COMMAND "${XC_DXC_EXECUTABLE}" + -T cs_6_6 + -E Upsweep + -D KEY_UINT=1 + -D PAYLOAD_UINT=1 + -D SORT_PAIRS=1 + -D SHOULD_ASCEND=1 + -Fo "$/shaders/RadixUpsweep.dxil" + "${CMAKE_CURRENT_SOURCE_DIR}/shaders/DeviceRadixSort.hlsl" + COMMAND "${XC_DXC_EXECUTABLE}" + -T cs_6_6 + -E BuildGlobalHistogram + -D KEY_UINT=1 + -D PAYLOAD_UINT=1 + -D SORT_PAIRS=1 + -D SHOULD_ASCEND=1 + -Fo "$/shaders/RadixGlobalHistogram.dxil" + "${CMAKE_CURRENT_SOURCE_DIR}/shaders/DeviceRadixSort.hlsl" + COMMAND "${XC_DXC_EXECUTABLE}" + -T cs_6_6 + -E Scan + -D KEY_UINT=1 + -D PAYLOAD_UINT=1 + -D SORT_PAIRS=1 + -D SHOULD_ASCEND=1 + -Fo "$/shaders/RadixScan.dxil" + "${CMAKE_CURRENT_SOURCE_DIR}/shaders/DeviceRadixSort.hlsl" + COMMAND "${XC_DXC_EXECUTABLE}" + -T cs_6_6 + -E Downsweep + -D KEY_UINT=1 + -D PAYLOAD_UINT=1 + -D SORT_PAIRS=1 + -D SHOULD_ASCEND=1 + -Fo "$/shaders/RadixDownsweep.dxil" + "${CMAKE_CURRENT_SOURCE_DIR}/shaders/DeviceRadixSort.hlsl" +) diff --git a/MVS/3DGS-D3D12/include/XC3DGSD3D12/App.h b/MVS/3DGS-D3D12/include/XC3DGSD3D12/App.h index a7d61681..4282c018 100644 --- a/MVS/3DGS-D3D12/include/XC3DGSD3D12/App.h +++ b/MVS/3DGS-D3D12/include/XC3DGSD3D12/App.h @@ -72,7 +72,8 @@ private: void ShutdownSortResources(); void ShutdownDebugDrawResources(); void Shutdown(); - bool CaptureSortKeySnapshot(); + bool CaptureSortSnapshot(); + bool CapturePass3HistogramDebug(); void RenderFrame(bool captureScreenshot); HWND m_hwnd = nullptr; @@ -107,14 +108,31 @@ private: XCEngine::RHI::RHIDescriptorPool* m_prepareDescriptorPool = nullptr; XCEngine::RHI::RHIDescriptorSet* m_prepareDescriptorSet = nullptr; XCEngine::RHI::D3D12Buffer* m_sortKeyBuffer = nullptr; + XCEngine::RHI::D3D12Buffer* m_sortKeyScratchBuffer = nullptr; XCEngine::RHI::D3D12Buffer* m_orderBuffer = nullptr; - std::unique_ptr m_sortKeySrv; + XCEngine::RHI::D3D12Buffer* m_orderScratchBuffer = nullptr; + XCEngine::RHI::D3D12Buffer* m_passHistogramBuffer = nullptr; + XCEngine::RHI::D3D12Buffer* m_globalHistogramBuffer = nullptr; std::unique_ptr m_sortKeyUav; + std::unique_ptr m_sortKeyScratchUav; std::unique_ptr m_orderBufferSrv; - XCEngine::RHI::RHIPipelineLayout* m_sortPipelineLayout = nullptr; - XCEngine::RHI::RHIPipelineState* m_sortPipelineState = nullptr; - XCEngine::RHI::RHIDescriptorPool* m_sortDescriptorPool = nullptr; - XCEngine::RHI::RHIDescriptorSet* m_sortDescriptorSet = nullptr; + std::unique_ptr m_orderBufferUav; + std::unique_ptr m_orderScratchUav; + std::unique_ptr m_passHistogramUav; + std::unique_ptr m_globalHistogramUav; + XCEngine::RHI::RHIPipelineLayout* m_buildSortKeyPipelineLayout = nullptr; + XCEngine::RHI::RHIPipelineState* m_buildSortKeyPipelineState = nullptr; + XCEngine::RHI::RHIDescriptorPool* m_buildSortKeyDescriptorPool = nullptr; + XCEngine::RHI::RHIDescriptorSet* m_buildSortKeyDescriptorSet = nullptr; + XCEngine::RHI::RHIPipelineLayout* m_radixSortPipelineLayout = nullptr; + XCEngine::RHI::RHIPipelineState* m_radixSortInitPipelineState = nullptr; + XCEngine::RHI::RHIPipelineState* m_radixSortUpsweepPipelineState = nullptr; + XCEngine::RHI::RHIPipelineState* m_radixSortGlobalHistogramPipelineState = nullptr; + XCEngine::RHI::RHIPipelineState* m_radixSortScanPipelineState = nullptr; + XCEngine::RHI::RHIPipelineState* m_radixSortDownsweepPipelineState = nullptr; + XCEngine::RHI::RHIDescriptorPool* m_radixSortDescriptorPool = nullptr; + XCEngine::RHI::RHIDescriptorSet* m_radixSortDescriptorSetPrimaryToScratch = nullptr; + XCEngine::RHI::RHIDescriptorSet* m_radixSortDescriptorSetScratchToPrimary = nullptr; XCEngine::RHI::RHIPipelineLayout* m_debugPipelineLayout = nullptr; XCEngine::RHI::RHIPipelineState* m_debugPipelineState = nullptr; XCEngine::RHI::RHIDescriptorPool* m_debugDescriptorPool = nullptr; diff --git a/MVS/3DGS-D3D12/shaders/DeviceRadixSort.hlsl b/MVS/3DGS-D3D12/shaders/DeviceRadixSort.hlsl new file mode 100644 index 00000000..c8cd20e3 --- /dev/null +++ b/MVS/3DGS-D3D12/shaders/DeviceRadixSort.hlsl @@ -0,0 +1,477 @@ +/****************************************************************************** + * DeviceRadixSort + * Device Level 8-bit LSD Radix Sort using reduce then scan + * + * SPDX-License-Identifier: MIT + * Copyright Thomas Smith 5/17/2024 + * https://github.com/b0nes164/GPUSorting + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + ******************************************************************************/ +#include "SortCommon.hlsl" + +#define US_DIM 128U //The number of threads in a Upsweep threadblock +#define SCAN_DIM 128U //The number of threads in a Scan threadblock + +RWStructuredBuffer b_globalHist : register(u5); //buffer holding device level offsets for each binning pass +RWStructuredBuffer b_passHist : register(u4); //buffer used to store reduced sums of partition tiles + +groupshared uint g_us[RADIX * 2]; //Shared memory for upsweep +groupshared uint g_scan[SCAN_DIM]; //Shared memory for the scan + +//***************************************************************************** +//INIT KERNEL +//***************************************************************************** +//Clear the global histogram, as we will be adding to it atomically +[numthreads(1024, 1, 1)] +void InitDeviceRadixSort(int3 id : SV_DispatchThreadID) +{ + b_globalHist[id.x] = 0; +} + +//***************************************************************************** +//UPSWEEP KERNEL +//***************************************************************************** +//histogram, 64 threads to a histogram +inline void HistogramDigitCounts(uint gtid, uint gid) +{ + const uint histOffset = gtid / 64 * RADIX; + const uint partitionEnd = gid == e_threadBlocks - 1 ? + e_numKeys : (gid + 1) * PART_SIZE; + for (uint i = gtid + gid * PART_SIZE; i < partitionEnd; i += US_DIM) + { +#if defined(KEY_UINT) + InterlockedAdd(g_us[ExtractDigit(b_sort[i]) + histOffset], 1); +#elif defined(KEY_INT) + InterlockedAdd(g_us[ExtractDigit(IntToUint(b_sort[i])) + histOffset], 1); +#elif defined(KEY_FLOAT) + InterlockedAdd(g_us[ExtractDigit(FloatToUint(b_sort[i])) + histOffset], 1); +#endif + } +} + +//reduce and pass to tile histogram +inline void ReduceWriteDigitCounts(uint gtid, uint gid) +{ + for (uint i = gtid; i < RADIX; i += US_DIM) + { + g_us[i] += g_us[i + RADIX]; + b_passHist[i * e_threadBlocks + gid] = g_us[i]; + } +} + +//Build the per-pass 256-bin exclusive prefix from the reduced pass histogram. +inline void BuildGlobalHistogramExclusive(uint gtid) +{ + uint digitIndices[2]; + uint digitTotals[2]; + uint digitCount = 0; + for (uint i = gtid; i < RADIX; i += US_DIM) + { + uint total = 0u; + const uint baseOffset = i * e_threadBlocks; + for (uint blockIndex = 0; blockIndex < e_threadBlocks; ++blockIndex) + { + total += b_passHist[baseOffset + blockIndex]; + } + + g_us[i] = total; + digitIndices[digitCount] = i; + digitTotals[digitCount] = total; + ++digitCount; + } + + GroupMemoryBarrierWithGroupSync(); + + for (uint offset = 1; offset < RADIX; offset <<= 1) + { + for (uint i = gtid; i < RADIX; i += US_DIM) + { + g_us[i + RADIX] = g_us[i] + (i >= offset ? g_us[i - offset] : 0u); + } + GroupMemoryBarrierWithGroupSync(); + + for (uint i = gtid; i < RADIX; i += US_DIM) + { + g_us[i] = g_us[i + RADIX]; + } + GroupMemoryBarrierWithGroupSync(); + } + + const uint globalHistOffset = GlobalHistOffset(); + for (uint localIndex = 0; localIndex < digitCount; ++localIndex) + { + const uint digitIndex = digitIndices[localIndex]; + b_globalHist[digitIndex + globalHistOffset] = g_us[digitIndex] - digitTotals[localIndex]; + } +} + +[numthreads(US_DIM, 1, 1)] +void Upsweep(uint3 gtid : SV_GroupThreadID, uint3 gid : SV_GroupID) +{ + //get the wave size + const uint waveSize = getWaveSize(); + + //clear shared memory + const uint histsEnd = RADIX * 2; + for (uint i = gtid.x; i < histsEnd; i += US_DIM) + g_us[i] = 0; + GroupMemoryBarrierWithGroupSync(); + + HistogramDigitCounts(gtid.x, gid.x); + GroupMemoryBarrierWithGroupSync(); + + ReduceWriteDigitCounts(gtid.x, gid.x); +} + +[numthreads(US_DIM, 1, 1)] +void BuildGlobalHistogram(uint3 gtid : SV_GroupThreadID) +{ + const uint histsEnd = RADIX * 2; + for (uint i = gtid.x; i < histsEnd; i += US_DIM) + g_us[i] = 0; + GroupMemoryBarrierWithGroupSync(); + + BuildGlobalHistogramExclusive(gtid.x); +} + +//***************************************************************************** +//SCAN KERNEL +//***************************************************************************** +inline void ExclusiveThreadBlockScanFullWGE16( + uint gtid, + uint laneMask, + uint circularLaneShift, + uint partEnd, + uint deviceOffset, + uint waveSize, + inout uint reduction) +{ + for (uint i = gtid; i < partEnd; i += SCAN_DIM) + { + g_scan[gtid] = b_passHist[i + deviceOffset]; + g_scan[gtid] += WavePrefixSum(g_scan[gtid]); + GroupMemoryBarrierWithGroupSync(); + + if (gtid < SCAN_DIM / waveSize) + { + g_scan[(gtid + 1) * waveSize - 1] += + WavePrefixSum(g_scan[(gtid + 1) * waveSize - 1]); + } + GroupMemoryBarrierWithGroupSync(); + + uint t = (WaveGetLaneIndex() != laneMask ? g_scan[gtid] : 0) + reduction; + if (gtid >= waveSize) + t += WaveReadLaneAt(g_scan[gtid - 1], 0); + b_passHist[circularLaneShift + (i & ~laneMask) + deviceOffset] = t; + + reduction += g_scan[SCAN_DIM - 1]; + GroupMemoryBarrierWithGroupSync(); + } +} + +inline void ExclusiveThreadBlockScanPartialWGE16( + uint gtid, + uint laneMask, + uint circularLaneShift, + uint partEnd, + uint deviceOffset, + uint waveSize, + uint reduction) +{ + uint i = gtid + partEnd; + if (i < e_threadBlocks) + g_scan[gtid] = b_passHist[deviceOffset + i]; + g_scan[gtid] += WavePrefixSum(g_scan[gtid]); + GroupMemoryBarrierWithGroupSync(); + + if (gtid < SCAN_DIM / waveSize) + { + g_scan[(gtid + 1) * waveSize - 1] += + WavePrefixSum(g_scan[(gtid + 1) * waveSize - 1]); + } + GroupMemoryBarrierWithGroupSync(); + + const uint index = circularLaneShift + (i & ~laneMask); + if (index < e_threadBlocks) + { + uint t = (WaveGetLaneIndex() != laneMask ? g_scan[gtid] : 0) + reduction; + if (gtid >= waveSize) + t += g_scan[(gtid & ~laneMask) - 1]; + b_passHist[index + deviceOffset] = t; + } +} + +inline void ExclusiveThreadBlockScanWGE16(uint gtid, uint gid, uint waveSize) +{ + uint reduction = 0; + const uint laneMask = waveSize - 1; + const uint circularLaneShift = WaveGetLaneIndex() + 1 & laneMask; + const uint partionsEnd = e_threadBlocks / SCAN_DIM * SCAN_DIM; + const uint deviceOffset = gid * e_threadBlocks; + + ExclusiveThreadBlockScanFullWGE16( + gtid, + laneMask, + circularLaneShift, + partionsEnd, + deviceOffset, + waveSize, + reduction); + + ExclusiveThreadBlockScanPartialWGE16( + gtid, + laneMask, + circularLaneShift, + partionsEnd, + deviceOffset, + waveSize, + reduction); +} + +inline void ExclusiveThreadBlockScanFullWLT16( + uint gtid, + uint partitions, + uint deviceOffset, + uint laneLog, + uint circularLaneShift, + uint waveSize, + inout uint reduction) +{ + for (uint k = 0; k < partitions; ++k) + { + g_scan[gtid] = b_passHist[gtid + k * SCAN_DIM + deviceOffset]; + g_scan[gtid] += WavePrefixSum(g_scan[gtid]); + GroupMemoryBarrierWithGroupSync(); + if (gtid < waveSize) + { + b_passHist[circularLaneShift + k * SCAN_DIM + deviceOffset] = + (circularLaneShift ? g_scan[gtid] : 0) + reduction; + } + + uint offset = laneLog; + uint j = waveSize; + for (; j < (SCAN_DIM >> 1); j <<= laneLog) + { + if (gtid < (SCAN_DIM >> offset)) + { + g_scan[((gtid + 1) << offset) - 1] += + WavePrefixSum(g_scan[((gtid + 1) << offset) - 1]); + } + GroupMemoryBarrierWithGroupSync(); + + if ((gtid & ((j << laneLog) - 1)) >= j) + { + if (gtid < (j << laneLog)) + { + b_passHist[gtid + k * SCAN_DIM + deviceOffset] = + WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0) + + ((gtid & (j - 1)) ? g_scan[gtid - 1] : 0) + reduction; + } + else + { + if ((gtid + 1) & (j - 1)) + { + g_scan[gtid] += + WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0); + } + } + } + offset += laneLog; + } + GroupMemoryBarrierWithGroupSync(); + + //If SCAN_DIM is not a power of lanecount + for (uint i = gtid + j; i < SCAN_DIM; i += SCAN_DIM) + { + b_passHist[i + k * SCAN_DIM + deviceOffset] = + WaveReadLaneAt(g_scan[((i >> offset) << offset) - 1], 0) + + ((i & (j - 1)) ? g_scan[i - 1] : 0) + reduction; + } + + reduction += WaveReadLaneAt(g_scan[SCAN_DIM - 1], 0) + + WaveReadLaneAt(g_scan[(((SCAN_DIM - 1) >> offset) << offset) - 1], 0); + GroupMemoryBarrierWithGroupSync(); + } +} + +inline void ExclusiveThreadBlockScanParitalWLT16( + uint gtid, + uint partitions, + uint deviceOffset, + uint laneLog, + uint circularLaneShift, + uint waveSize, + uint reduction) +{ + const uint finalPartSize = e_threadBlocks - partitions * SCAN_DIM; + if (gtid < finalPartSize) + { + g_scan[gtid] = b_passHist[gtid + partitions * SCAN_DIM + deviceOffset]; + g_scan[gtid] += WavePrefixSum(g_scan[gtid]); + } + GroupMemoryBarrierWithGroupSync(); + if (gtid < waveSize && circularLaneShift < finalPartSize) + { + b_passHist[circularLaneShift + partitions * SCAN_DIM + deviceOffset] = + (circularLaneShift ? g_scan[gtid] : 0) + reduction; + } + + uint offset = laneLog; + for (uint j = waveSize; j < finalPartSize; j <<= laneLog) + { + if (gtid < (finalPartSize >> offset)) + { + g_scan[((gtid + 1) << offset) - 1] += + WavePrefixSum(g_scan[((gtid + 1) << offset) - 1]); + } + GroupMemoryBarrierWithGroupSync(); + + if ((gtid & ((j << laneLog) - 1)) >= j && gtid < finalPartSize) + { + if (gtid < (j << laneLog)) + { + b_passHist[gtid + partitions * SCAN_DIM + deviceOffset] = + WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0) + + ((gtid & (j - 1)) ? g_scan[gtid - 1] : 0) + reduction; + } + else + { + if ((gtid + 1) & (j - 1)) + { + g_scan[gtid] += + WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0); + } + } + } + offset += laneLog; + } +} + +inline void ExclusiveThreadBlockScanWLT16(uint gtid, uint gid, uint waveSize) +{ + uint reduction = 0; + const uint partitions = e_threadBlocks / SCAN_DIM; + const uint deviceOffset = gid * e_threadBlocks; + const uint laneLog = countbits(waveSize - 1); + const uint circularLaneShift = WaveGetLaneIndex() + 1 & waveSize - 1; + + ExclusiveThreadBlockScanFullWLT16( + gtid, + partitions, + deviceOffset, + laneLog, + circularLaneShift, + waveSize, + reduction); + + ExclusiveThreadBlockScanParitalWLT16( + gtid, + partitions, + deviceOffset, + laneLog, + circularLaneShift, + waveSize, + reduction); +} + +//Scan does not need flattening of gids +[numthreads(SCAN_DIM, 1, 1)] +void Scan(uint3 gtid : SV_GroupThreadID, uint3 gid : SV_GroupID) +{ + if (gtid.x != 0u) + { + return; + } + + const uint deviceOffset = gid.x * e_threadBlocks; + uint runningOffset = 0u; + for (uint blockIndex = 0u; blockIndex < e_threadBlocks; ++blockIndex) + { + const uint index = deviceOffset + blockIndex; + const uint count = b_passHist[index]; + b_passHist[index] = runningOffset; + runningOffset += count; + } +} + +//***************************************************************************** +//DOWNSWEEP KERNEL +//***************************************************************************** +inline void LoadThreadBlockReductions(uint gtid, uint gid, uint exclusiveHistReduction) +{ + if (gtid < RADIX) + { + g_d[gtid + PART_SIZE] = b_globalHist[gtid + GlobalHistOffset()] + + b_passHist[gtid * e_threadBlocks + gid] - exclusiveHistReduction; + } +} + +[numthreads(D_DIM, 1, 1)] +void Downsweep(uint3 gtid : SV_GroupThreadID, uint3 gid : SV_GroupID) +{ + if (gtid.x != 0u) + { + return; + } + + const uint partitionStart = gid.x * PART_SIZE; + const uint partitionEnd = min(partitionStart + PART_SIZE, e_numKeys); + uint digitOffsets[RADIX]; + const uint globalHistOffset = GlobalHistOffset(); + for (uint digit = 0u; digit < RADIX; ++digit) + { + digitOffsets[digit] = + b_globalHist[globalHistOffset + digit] + + b_passHist[digit * e_threadBlocks + gid.x]; + } + + for (uint index = partitionStart; index < partitionEnd; ++index) + { + uint key; +#if defined(KEY_UINT) + key = b_sort[index]; +#elif defined(KEY_INT) + key = IntToUint(b_sort[index]); +#elif defined(KEY_FLOAT) + key = FloatToUint(b_sort[index]); +#endif + + const uint digit = ExtractDigit(key); + const uint destinationIndex = digitOffsets[digit]++; + +#if defined(KEY_UINT) + b_alt[destinationIndex] = key; +#elif defined(KEY_INT) + b_alt[destinationIndex] = UintToInt(key); +#elif defined(KEY_FLOAT) + b_alt[destinationIndex] = UintToFloat(key); +#endif + +#if defined(SORT_PAIRS) +#if defined(PAYLOAD_UINT) + b_altPayload[destinationIndex] = b_sortPayload[index]; +#elif defined(PAYLOAD_INT) + b_altPayload[destinationIndex] = b_sortPayload[index]; +#elif defined(PAYLOAD_FLOAT) + b_altPayload[destinationIndex] = b_sortPayload[index]; +#endif +#endif + } +} diff --git a/MVS/3DGS-D3D12/shaders/SortCommon.hlsl b/MVS/3DGS-D3D12/shaders/SortCommon.hlsl new file mode 100644 index 00000000..58bd629f --- /dev/null +++ b/MVS/3DGS-D3D12/shaders/SortCommon.hlsl @@ -0,0 +1,959 @@ +/****************************************************************************** + * SortCommon + * Common functions for GPUSorting + * + * SPDX-License-Identifier: MIT + * Copyright Thomas Smith 5/17/2024 + * https://github.com/b0nes164/GPUSorting + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + ******************************************************************************/ +#define KEYS_PER_THREAD 15U +#define D_DIM 256U +#define PART_SIZE 3840U +#define D_TOTAL_SMEM 4096U + +#define RADIX 256U //Number of digit bins +#define RADIX_MASK 255U //Mask of digit bins +#define HALF_RADIX 128U //For smaller waves where bit packing is necessary +#define HALF_MASK 127U // '' +#define RADIX_LOG 8U //log2(RADIX) +#define RADIX_PASSES 4U //(Key width) / RADIX_LOG + +cbuffer cbGpuSorting : register(b0) +{ + uint e_numKeys; + uint e_radixShift; + uint e_threadBlocks; + uint padding; +}; + +#if defined(KEY_UINT) +RWStructuredBuffer b_sort : register(u0); +RWStructuredBuffer b_alt : register(u1); +#elif defined(KEY_INT) +RWStructuredBuffer b_sort : register(u0); +RWStructuredBuffer b_alt : register(u1); +#elif defined(KEY_FLOAT) +RWStructuredBuffer b_sort : register(u0); +RWStructuredBuffer b_alt : register(u1); +#endif + +#if defined(PAYLOAD_UINT) +RWStructuredBuffer b_sortPayload : register(u2); +RWStructuredBuffer b_altPayload : register(u3); +#elif defined(PAYLOAD_INT) +RWStructuredBuffer b_sortPayload : register(u2); +RWStructuredBuffer b_altPayload : register(u3); +#elif defined(PAYLOAD_FLOAT) +RWStructuredBuffer b_sortPayload : register(u2); +RWStructuredBuffer b_altPayload : register(u3); +#endif + +groupshared uint g_d[D_TOTAL_SMEM]; //Shared memory for DigitBinningPass and DownSweep kernels + +struct KeyStruct +{ + uint k[KEYS_PER_THREAD]; +}; + +struct OffsetStruct +{ +#if defined(ENABLE_16_BIT) + uint16_t o[KEYS_PER_THREAD]; +#else + uint o[KEYS_PER_THREAD]; +#endif +}; + +struct DigitStruct +{ +#if defined(ENABLE_16_BIT) + uint16_t d[KEYS_PER_THREAD]; +#else + uint d[KEYS_PER_THREAD]; +#endif +}; + +//***************************************************************************** +//HELPER FUNCTIONS +//***************************************************************************** +//Due to a bug with SPIRV pre 1.6, we cannot use WaveGetLaneCount() to get the currently active wavesize +inline uint getWaveSize() +{ +#if defined(VULKAN) + GroupMemoryBarrierWithGroupSync(); //Make absolutely sure the wave is not diverged here + return dot(countbits(WaveActiveBallot(true)), uint4(1, 1, 1, 1)); +#else + return WaveGetLaneCount(); +#endif +} + +inline uint getWaveIndex(uint gtid, uint waveSize) +{ + return gtid / waveSize; +} + +//Radix Tricks by Michael Herf +//http://stereopsis.com/radix.html +inline uint FloatToUint(float f) +{ + uint mask = -((int) (asuint(f) >> 31)) | 0x80000000; + return asuint(f) ^ mask; +} + +inline float UintToFloat(uint u) +{ + uint mask = ((u >> 31) - 1) | 0x80000000; + return asfloat(u ^ mask); +} + +inline uint IntToUint(int i) +{ + return asuint(i ^ 0x80000000); +} + +inline int UintToInt(uint u) +{ + return asint(u ^ 0x80000000); +} + +inline uint getWaveCountPass(uint waveSize) +{ + return D_DIM / waveSize; +} + +inline uint ExtractDigit(uint key) +{ + return key >> e_radixShift & RADIX_MASK; +} + +inline uint ExtractDigit(uint key, uint shift) +{ + return key >> shift & RADIX_MASK; +} + +inline uint ExtractPackedIndex(uint key) +{ + return key >> (e_radixShift + 1) & HALF_MASK; +} + +inline uint ExtractPackedShift(uint key) +{ + return (key >> e_radixShift & 1) ? 16 : 0; +} + +inline uint ExtractPackedValue(uint packed, uint key) +{ + return packed >> ExtractPackedShift(key) & 0xffff; +} + +inline uint SubPartSizeWGE16(uint waveSize) +{ + return KEYS_PER_THREAD * waveSize; +} + +inline uint SharedOffsetWGE16(uint gtid, uint waveSize) +{ + return WaveGetLaneIndex() + getWaveIndex(gtid, waveSize) * SubPartSizeWGE16(waveSize); +} + +inline uint SubPartSizeWLT16(uint waveSize, uint _serialIterations) +{ + return KEYS_PER_THREAD * waveSize * _serialIterations; +} + +inline uint SharedOffsetWLT16(uint gtid, uint waveSize, uint _serialIterations) +{ + return WaveGetLaneIndex() + + (getWaveIndex(gtid, waveSize) / _serialIterations * SubPartSizeWLT16(waveSize, _serialIterations)) + + (getWaveIndex(gtid, waveSize) % _serialIterations * waveSize); +} + +inline uint DeviceOffsetWGE16(uint gtid, uint waveSize, uint partIndex) +{ + return SharedOffsetWGE16(gtid, waveSize) + partIndex * PART_SIZE; +} + +inline uint DeviceOffsetWLT16(uint gtid, uint waveSize, uint partIndex, uint serialIterations) +{ + return SharedOffsetWLT16(gtid, waveSize, serialIterations) + partIndex * PART_SIZE; +} + +inline uint GlobalHistOffset() +{ + return e_radixShift << 5; +} + +inline uint WaveHistsSizeWGE16(uint waveSize) +{ + return D_DIM / waveSize * RADIX; +} + +inline uint WaveHistsSizeWLT16() +{ + return D_TOTAL_SMEM; +} + +//***************************************************************************** +//FUNCTIONS COMMON TO THE DOWNSWEEP / DIGIT BINNING PASS +//***************************************************************************** +//If the size of a wave is too small, we do not have enough space in +//shared memory to assign a histogram to each wave, so instead, +//some operations are peformed serially. +inline uint SerialIterations(uint waveSize) +{ + return (D_DIM / waveSize + 31) >> 5; +} + +inline void ClearWaveHists(uint gtid, uint waveSize) +{ + const uint histsEnd = waveSize >= 16 ? + WaveHistsSizeWGE16(waveSize) : WaveHistsSizeWLT16(); + for (uint i = gtid; i < histsEnd; i += D_DIM) + g_d[i] = 0; +} + +inline void LoadKey(inout uint key, uint index) +{ +#if defined(KEY_UINT) + key = b_sort[index]; +#elif defined(KEY_INT) + key = UintToInt(b_sort[index]); +#elif defined(KEY_FLOAT) + key = FloatToUint(b_sort[index]); +#endif +} + +inline void LoadDummyKey(inout uint key) +{ + key = 0xffffffff; +} + +inline KeyStruct LoadKeysWGE16(uint gtid, uint waveSize, uint partIndex) +{ + KeyStruct keys; + [unroll] + for (uint i = 0, t = DeviceOffsetWGE16(gtid, waveSize, partIndex); + i < KEYS_PER_THREAD; + ++i, t += waveSize) + { + LoadKey(keys.k[i], t); + } + return keys; +} + +inline KeyStruct LoadKeysWLT16(uint gtid, uint waveSize, uint partIndex, uint serialIterations) +{ + KeyStruct keys; + [unroll] + for (uint i = 0, t = DeviceOffsetWLT16(gtid, waveSize, partIndex, serialIterations); + i < KEYS_PER_THREAD; + ++i, t += waveSize * serialIterations) + { + LoadKey(keys.k[i], t); + } + return keys; +} + +inline KeyStruct LoadKeysPartialWGE16(uint gtid, uint waveSize, uint partIndex) +{ + KeyStruct keys; + [unroll] + for (uint i = 0, t = DeviceOffsetWGE16(gtid, waveSize, partIndex); + i < KEYS_PER_THREAD; + ++i, t += waveSize) + { + if (t < e_numKeys) + LoadKey(keys.k[i], t); + else + LoadDummyKey(keys.k[i]); + } + return keys; +} + +inline KeyStruct LoadKeysPartialWLT16(uint gtid, uint waveSize, uint partIndex, uint serialIterations) +{ + KeyStruct keys; + [unroll] + for (uint i = 0, t = DeviceOffsetWLT16(gtid, waveSize, partIndex, serialIterations); + i < KEYS_PER_THREAD; + ++i, t += waveSize * serialIterations) + { + if (t < e_numKeys) + LoadKey(keys.k[i], t); + else + LoadDummyKey(keys.k[i]); + } + return keys; +} + +inline uint WaveFlagsWGE16(uint waveSize) +{ + return (waveSize & 31) ? (1U << waveSize) - 1 : 0xffffffff; +} + +inline uint WaveFlagsWLT16(uint waveSize) +{ + return (1U << waveSize) - 1;; +} + +inline void WarpLevelMultiSplitWGE16(uint key, inout uint4 waveFlags) +{ + [unroll] + for (uint k = 0; k < RADIX_LOG; ++k) + { + const uint currentBit = 1U << (k + e_radixShift); + const bool t = (key & currentBit) != 0; + GroupMemoryBarrierWithGroupSync(); //Play on the safe side, throw in a barrier for convergence + const uint4 ballot = WaveActiveBallot(t); + if(t) + waveFlags &= ballot; + else + waveFlags &= (~ballot); + } +} + +inline uint2 CountBitsWGE16(uint waveSize, uint ltMask, uint4 waveFlags) +{ + uint2 count = uint2(0, 0); + + for(uint wavePart = 0; wavePart < waveSize; wavePart += 32) + { + uint t = countbits(waveFlags[wavePart >> 5]); + if (WaveGetLaneIndex() >= wavePart) + { + if (WaveGetLaneIndex() >= wavePart + 32) + count.x += t; + else + count.x += countbits(waveFlags[wavePart >> 5] & ltMask); + } + count.y += t; + } + + return count; +} + +inline void WarpLevelMultiSplitWLT16(uint key, inout uint waveFlags) +{ + [unroll] + for (uint k = 0; k < RADIX_LOG; ++k) + { + const bool t = key >> (k + e_radixShift) & 1; + waveFlags &= (t ? 0 : 0xffffffff) ^ (uint) WaveActiveBallot(t); + } +} + +inline OffsetStruct RankKeysWGE16( + uint waveSize, + uint waveOffset, + KeyStruct keys) +{ + OffsetStruct offsets; + const uint initialFlags = WaveFlagsWGE16(waveSize); + const uint ltMask = (1U << (WaveGetLaneIndex() & 31)) - 1; + + [unroll] + for (uint i = 0; i < KEYS_PER_THREAD; ++i) + { + uint4 waveFlags = initialFlags; + WarpLevelMultiSplitWGE16(keys.k[i], waveFlags); + + const uint index = ExtractDigit(keys.k[i]) + waveOffset; + const uint2 bitCount = CountBitsWGE16(waveSize, ltMask, waveFlags); + + offsets.o[i] = g_d[index] + bitCount.x; + GroupMemoryBarrierWithGroupSync(); + if (bitCount.x == 0) + g_d[index] += bitCount.y; + GroupMemoryBarrierWithGroupSync(); + } + + return offsets; +} + +inline OffsetStruct RankKeysWLT16(uint waveSize, uint waveIndex, KeyStruct keys, uint serialIterations) +{ + OffsetStruct offsets; + const uint ltMask = (1U << WaveGetLaneIndex()) - 1; + const uint initialFlags = WaveFlagsWLT16(waveSize); + + [unroll] + for (uint i = 0; i < KEYS_PER_THREAD; ++i) + { + uint waveFlags = initialFlags; + WarpLevelMultiSplitWLT16(keys.k[i], waveFlags); + + const uint index = ExtractPackedIndex(keys.k[i]) + + (waveIndex / serialIterations * HALF_RADIX); + + const uint peerBits = countbits(waveFlags & ltMask); + for (uint k = 0; k < serialIterations; ++k) + { + if (waveIndex % serialIterations == k) + offsets.o[i] = ExtractPackedValue(g_d[index], keys.k[i]) + peerBits; + + GroupMemoryBarrierWithGroupSync(); + if (waveIndex % serialIterations == k && peerBits == 0) + { + InterlockedAdd(g_d[index], + countbits(waveFlags) << ExtractPackedShift(keys.k[i])); + } + GroupMemoryBarrierWithGroupSync(); + } + } + + return offsets; +} + +inline uint WaveHistInclusiveScanCircularShiftWGE16(uint gtid, uint waveSize) +{ + uint histReduction = g_d[gtid]; + for (uint i = gtid + RADIX; i < WaveHistsSizeWGE16(waveSize); i += RADIX) + { + histReduction += g_d[i]; + g_d[i] = histReduction - g_d[i]; + } + return histReduction; +} + +inline uint WaveHistInclusiveScanCircularShiftWLT16(uint gtid) +{ + uint histReduction = g_d[gtid]; + for (uint i = gtid + HALF_RADIX; i < WaveHistsSizeWLT16(); i += HALF_RADIX) + { + histReduction += g_d[i]; + g_d[i] = histReduction - g_d[i]; + } + return histReduction; +} + +inline void WaveHistReductionExclusiveScanWGE16(uint gtid, uint waveSize, uint histReduction) +{ + if (gtid < RADIX) + { + const uint laneMask = waveSize - 1; + g_d[((WaveGetLaneIndex() + 1) & laneMask) + (gtid & ~laneMask)] = histReduction; + } + GroupMemoryBarrierWithGroupSync(); + + if (gtid < RADIX / waveSize) + { + g_d[gtid * waveSize] = + WavePrefixSum(g_d[gtid * waveSize]); + } + GroupMemoryBarrierWithGroupSync(); + + uint t = WaveReadLaneAt(g_d[gtid], 0); + if (gtid < RADIX && WaveGetLaneIndex()) + g_d[gtid] += t; +} + +//inclusive/exclusive prefix sum up the histograms, +//use a blelloch scan for in place packed exclusive +inline void WaveHistReductionExclusiveScanWLT16(uint gtid) +{ + uint shift = 1; + for (uint j = RADIX >> 2; j > 0; j >>= 1) + { + GroupMemoryBarrierWithGroupSync(); + if (gtid < j) + { + g_d[((((gtid << 1) + 2) << shift) - 1) >> 1] += + g_d[((((gtid << 1) + 1) << shift) - 1) >> 1] & 0xffff0000; + } + shift++; + } + GroupMemoryBarrierWithGroupSync(); + + if (gtid == 0) + g_d[HALF_RADIX - 1] &= 0xffff; + + for (uint j = 1; j < RADIX >> 1; j <<= 1) + { + --shift; + GroupMemoryBarrierWithGroupSync(); + if (gtid < j) + { + const uint t = ((((gtid << 1) + 1) << shift) - 1) >> 1; + const uint t2 = ((((gtid << 1) + 2) << shift) - 1) >> 1; + const uint t3 = g_d[t]; + g_d[t] = (g_d[t] & 0xffff) | (g_d[t2] & 0xffff0000); + g_d[t2] += t3 & 0xffff0000; + } + } + + GroupMemoryBarrierWithGroupSync(); + if (gtid < HALF_RADIX) + { + const uint t = g_d[gtid]; + g_d[gtid] = (t >> 16) + (t << 16) + (t & 0xffff0000); + } +} + +inline void UpdateOffsetsWGE16( + uint gtid, + uint waveSize, + inout OffsetStruct offsets, + KeyStruct keys) +{ + if (gtid >= waveSize) + { + const uint t = getWaveIndex(gtid, waveSize) * RADIX; + [unroll] + for (uint i = 0; i < KEYS_PER_THREAD; ++i) + { + const uint t2 = ExtractDigit(keys.k[i]); + offsets.o[i] += g_d[t2 + t] + g_d[t2]; + } + } + else + { + [unroll] + for (uint i = 0; i < KEYS_PER_THREAD; ++i) + offsets.o[i] += g_d[ExtractDigit(keys.k[i])]; + } +} + +inline void UpdateOffsetsWLT16( + uint gtid, + uint waveSize, + uint serialIterations, + inout OffsetStruct offsets, + KeyStruct keys) +{ + if (gtid >= waveSize * serialIterations) + { + const uint t = getWaveIndex(gtid, waveSize) / serialIterations * HALF_RADIX; + [unroll] + for (uint i = 0; i < KEYS_PER_THREAD; ++i) + { + const uint t2 = ExtractPackedIndex(keys.k[i]); + offsets.o[i] += ExtractPackedValue(g_d[t2 + t] + g_d[t2], keys.k[i]); + } + } + else + { + [unroll] + for (uint i = 0; i < KEYS_PER_THREAD; ++i) + offsets.o[i] += ExtractPackedValue(g_d[ExtractPackedIndex(keys.k[i])], keys.k[i]); + } +} + +inline void ScatterKeysShared(OffsetStruct offsets, KeyStruct keys) +{ + [unroll] + for (uint i = 0; i < KEYS_PER_THREAD; ++i) + g_d[offsets.o[i]] = keys.k[i]; +} + +inline uint DescendingIndex(uint deviceIndex) +{ + return e_numKeys - deviceIndex - 1; +} + +inline void WriteKey(uint deviceIndex, uint groupSharedIndex) +{ +#if defined(KEY_UINT) + b_alt[deviceIndex] = g_d[groupSharedIndex]; +#elif defined(KEY_INT) + b_alt[deviceIndex] = UintToInt(g_d[groupSharedIndex]); +#elif defined(KEY_FLOAT) + b_alt[deviceIndex] = UintToFloat(g_d[groupSharedIndex]); +#endif +} + +inline void LoadPayload(inout uint payload, uint deviceIndex) +{ +#if defined(PAYLOAD_UINT) + payload = b_sortPayload[deviceIndex]; +#elif defined(PAYLOAD_INT) || defined(PAYLOAD_FLOAT) + payload = asuint(b_sortPayload[deviceIndex]); +#endif +} + +inline void ScatterPayloadsShared(OffsetStruct offsets, KeyStruct payloads) +{ + ScatterKeysShared(offsets, payloads); +} + +inline void WritePayload(uint deviceIndex, uint groupSharedIndex) +{ +#if defined(PAYLOAD_UINT) + b_altPayload[deviceIndex] = g_d[groupSharedIndex]; +#elif defined(PAYLOAD_INT) + b_altPayload[deviceIndex] = asint(g_d[groupSharedIndex]); +#elif defined(PAYLOAD_FLOAT) + b_altPayload[deviceIndex] = asfloat(g_d[groupSharedIndex]); +#endif +} + +//***************************************************************************** +//SCATTERING: FULL PARTITIONS +//***************************************************************************** +//KEYS ONLY +inline void ScatterKeysOnlyDeviceAscending(uint gtid) +{ + for (uint i = gtid; i < PART_SIZE; i += D_DIM) + WriteKey(g_d[ExtractDigit(g_d[i]) + PART_SIZE] + i, i); +} + +inline void ScatterKeysOnlyDeviceDescending(uint gtid) +{ + if (e_radixShift == 24) + { + for (uint i = gtid; i < PART_SIZE; i += D_DIM) + WriteKey(DescendingIndex(g_d[ExtractDigit(g_d[i]) + PART_SIZE] + i), i); + } + else + { + ScatterKeysOnlyDeviceAscending(gtid); + } +} + +inline void ScatterKeysOnlyDevice(uint gtid) +{ +#if defined(SHOULD_ASCEND) + ScatterKeysOnlyDeviceAscending(gtid); +#else + ScatterKeysOnlyDeviceDescending(gtid); +#endif +} + +//KEY VALUE PAIRS +inline void ScatterPairsKeyPhaseAscending( + uint gtid, + inout DigitStruct digits) +{ + [unroll] + for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM) + { + digits.d[i] = ExtractDigit(g_d[t]); + WriteKey(g_d[digits.d[i] + PART_SIZE] + t, t); + } +} + +inline void ScatterPairsKeyPhaseDescending( + uint gtid, + inout DigitStruct digits) +{ + if (e_radixShift == 24) + { + [unroll] + for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM) + { + digits.d[i] = ExtractDigit(g_d[t]); + WriteKey(DescendingIndex(g_d[digits.d[i] + PART_SIZE] + t), t); + } + } + else + { + ScatterPairsKeyPhaseAscending(gtid, digits); + } +} + +inline void LoadPayloadsWGE16( + uint gtid, + uint waveSize, + uint partIndex, + inout KeyStruct payloads) +{ + [unroll] + for (uint i = 0, t = DeviceOffsetWGE16(gtid, waveSize, partIndex); + i < KEYS_PER_THREAD; + ++i, t += waveSize) + { + LoadPayload(payloads.k[i], t); + } +} + +inline void LoadPayloadsWLT16( + uint gtid, + uint waveSize, + uint partIndex, + uint serialIterations, + inout KeyStruct payloads) +{ + [unroll] + for (uint i = 0, t = DeviceOffsetWLT16(gtid, waveSize, partIndex, serialIterations); + i < KEYS_PER_THREAD; + ++i, t += waveSize * serialIterations) + { + LoadPayload(payloads.k[i], t); + } +} + +inline void ScatterPayloadsAscending(uint gtid, DigitStruct digits) +{ + [unroll] + for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM) + WritePayload(g_d[digits.d[i] + PART_SIZE] + t, t); +} + +inline void ScatterPayloadsDescending(uint gtid, DigitStruct digits) +{ + if (e_radixShift == 24) + { + [unroll] + for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM) + WritePayload(DescendingIndex(g_d[digits.d[i] + PART_SIZE] + t), t); + } + else + { + ScatterPayloadsAscending(gtid, digits); + } +} + +inline void ScatterPairsDevice( + uint gtid, + uint waveSize, + uint partIndex, + OffsetStruct offsets) +{ + DigitStruct digits; +#if defined(SHOULD_ASCEND) + ScatterPairsKeyPhaseAscending(gtid, digits); +#else + ScatterPairsKeyPhaseDescending(gtid, digits); +#endif + GroupMemoryBarrierWithGroupSync(); + + KeyStruct payloads; + if (waveSize >= 16) + LoadPayloadsWGE16(gtid, waveSize, partIndex, payloads); + else + LoadPayloadsWLT16(gtid, waveSize, partIndex, SerialIterations(waveSize), payloads); + ScatterPayloadsShared(offsets, payloads); + GroupMemoryBarrierWithGroupSync(); + +#if defined(SHOULD_ASCEND) + ScatterPayloadsAscending(gtid, digits); +#else + ScatterPayloadsDescending(gtid, digits); +#endif +} + +inline void ScatterDevice( + uint gtid, + uint waveSize, + uint partIndex, + OffsetStruct offsets) +{ +#if defined(SORT_PAIRS) + ScatterPairsDevice( + gtid, + waveSize, + partIndex, + offsets); +#else + ScatterKeysOnlyDevice(gtid); +#endif +} + +//***************************************************************************** +//SCATTERING: PARTIAL PARTITIONS +//***************************************************************************** +//KEYS ONLY +inline void ScatterKeysOnlyDevicePartialAscending(uint gtid, uint finalPartSize) +{ + for (uint i = gtid; i < PART_SIZE; i += D_DIM) + { + if (i < finalPartSize) + WriteKey(g_d[ExtractDigit(g_d[i]) + PART_SIZE] + i, i); + } +} + +inline void ScatterKeysOnlyDevicePartialDescending(uint gtid, uint finalPartSize) +{ + if (e_radixShift == 24) + { + for (uint i = gtid; i < PART_SIZE; i += D_DIM) + { + if (i < finalPartSize) + WriteKey(DescendingIndex(g_d[ExtractDigit(g_d[i]) + PART_SIZE] + i), i); + } + } + else + { + ScatterKeysOnlyDevicePartialAscending(gtid, finalPartSize); + } +} + +inline void ScatterKeysOnlyDevicePartial(uint gtid, uint partIndex) +{ + const uint finalPartSize = e_numKeys - partIndex * PART_SIZE; +#if defined(SHOULD_ASCEND) + ScatterKeysOnlyDevicePartialAscending(gtid, finalPartSize); +#else + ScatterKeysOnlyDevicePartialDescending(gtid, finalPartSize); +#endif +} + +//KEY VALUE PAIRS +inline void ScatterPairsKeyPhaseAscendingPartial( + uint gtid, + uint finalPartSize, + inout DigitStruct digits) +{ + [unroll] + for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM) + { + if (t < finalPartSize) + { + digits.d[i] = ExtractDigit(g_d[t]); + WriteKey(g_d[digits.d[i] + PART_SIZE] + t, t); + } + } +} + +inline void ScatterPairsKeyPhaseDescendingPartial( + uint gtid, + uint finalPartSize, + inout DigitStruct digits) +{ + if (e_radixShift == 24) + { + [unroll] + for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM) + { + if (t < finalPartSize) + { + digits.d[i] = ExtractDigit(g_d[t]); + WriteKey(DescendingIndex(g_d[digits.d[i] + PART_SIZE] + t), t); + } + } + } + else + { + ScatterPairsKeyPhaseAscendingPartial(gtid, finalPartSize, digits); + } +} + +inline void LoadPayloadsPartialWGE16( + uint gtid, + uint waveSize, + uint partIndex, + inout KeyStruct payloads) +{ + [unroll] + for (uint i = 0, t = DeviceOffsetWGE16(gtid, waveSize, partIndex); + i < KEYS_PER_THREAD; + ++i, t += waveSize) + { + if (t < e_numKeys) + LoadPayload(payloads.k[i], t); + } +} + +inline void LoadPayloadsPartialWLT16( + uint gtid, + uint waveSize, + uint partIndex, + uint serialIterations, + inout KeyStruct payloads) +{ + [unroll] + for (uint i = 0, t = DeviceOffsetWLT16(gtid, waveSize, partIndex, serialIterations); + i < KEYS_PER_THREAD; + ++i, t += waveSize * serialIterations) + { + if (t < e_numKeys) + LoadPayload(payloads.k[i], t); + } +} + +inline void ScatterPayloadsAscendingPartial( + uint gtid, + uint finalPartSize, + DigitStruct digits) +{ + [unroll] + for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM) + { + if (t < finalPartSize) + WritePayload(g_d[digits.d[i] + PART_SIZE] + t, t); + } +} + +inline void ScatterPayloadsDescendingPartial( + uint gtid, + uint finalPartSize, + DigitStruct digits) +{ + if (e_radixShift == 24) + { + [unroll] + for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM) + { + if (t < finalPartSize) + WritePayload(DescendingIndex(g_d[digits.d[i] + PART_SIZE] + t), t); + } + } + else + { + ScatterPayloadsAscendingPartial(gtid, finalPartSize, digits); + } +} + +inline void ScatterPairsDevicePartial( + uint gtid, + uint waveSize, + uint partIndex, + OffsetStruct offsets) +{ + DigitStruct digits; + const uint finalPartSize = e_numKeys - partIndex * PART_SIZE; +#if defined(SHOULD_ASCEND) + ScatterPairsKeyPhaseAscendingPartial(gtid, finalPartSize, digits); +#else + ScatterPairsKeyPhaseDescendingPartial(gtid, finalPartSize, digits); +#endif + GroupMemoryBarrierWithGroupSync(); + + KeyStruct payloads; + if (waveSize >= 16) + LoadPayloadsPartialWGE16(gtid, waveSize, partIndex, payloads); + else + LoadPayloadsPartialWLT16(gtid, waveSize, partIndex, SerialIterations(waveSize), payloads); + ScatterPayloadsShared(offsets, payloads); + GroupMemoryBarrierWithGroupSync(); + +#if defined(SHOULD_ASCEND) + ScatterPayloadsAscendingPartial(gtid, finalPartSize, digits); +#else + ScatterPayloadsDescendingPartial(gtid, finalPartSize, digits); +#endif +} + +inline void ScatterDevicePartial( + uint gtid, + uint waveSize, + uint partIndex, + OffsetStruct offsets) +{ +#if defined(SORT_PAIRS) + ScatterPairsDevicePartial( + gtid, + waveSize, + partIndex, + offsets); +#else + ScatterKeysOnlyDevicePartial(gtid, partIndex); +#endif +} diff --git a/MVS/3DGS-D3D12/src/App.cpp b/MVS/3DGS-D3D12/src/App.cpp index 1e515c9e..a93f2ec7 100644 --- a/MVS/3DGS-D3D12/src/App.cpp +++ b/MVS/3DGS-D3D12/src/App.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -27,6 +28,10 @@ constexpr wchar_t kWindowTitle[] = L"XC 3DGS D3D12 MVS - Phase 3"; constexpr float kClearColor[4] = { 0.04f, 0.05f, 0.07f, 1.0f }; constexpr uint32_t kPrepareThreadGroupSize = 64u; constexpr uint32_t kSortThreadGroupSize = 64u; +constexpr uint32_t kDeviceRadixSortPartitionSize = 3840u; +constexpr uint32_t kDeviceRadixSortRadix = 256u; +constexpr uint32_t kDeviceRadixSortPassCount = 4u; +constexpr bool kUseCpuSortBaseline = true; struct FrameConstants { float viewProjection[16] = {}; @@ -37,7 +42,15 @@ struct FrameConstants { float settings[4] = {}; }; +struct RadixSortConstants { + uint32_t numKeys = 0; + uint32_t radixShift = 0; + uint32_t threadBlocks = 0; + uint32_t padding = 0; +}; + static_assert(sizeof(FrameConstants) % 16 == 0, "Frame constants must stay 16-byte aligned."); +static_assert(sizeof(RadixSortConstants) % 16 == 0, "Radix sort constants must stay 16-byte aligned."); static_assert(sizeof(PreparedSplatView) == 40, "Prepared view buffer layout must match shader."); std::filesystem::path GetExecutableDirectory() { @@ -104,21 +117,77 @@ void AppendTrace(const std::wstring& message) { AppendTrace(NarrowAscii(message)); } +ShaderCompileDesc BuildDxilShaderDesc(const std::filesystem::path& compiledShaderPath, const std::wstring& profile) { + ShaderCompileDesc shaderDesc = {}; + shaderDesc.profile = profile; + shaderDesc.compiledBinaryBackend = ShaderBinaryBackend::D3D12; + shaderDesc.compiledBinary = LoadBinaryFile(compiledShaderPath); + return shaderDesc; +} + void StoreMatrixTransposed(const DirectX::XMMATRIX& matrix, float* destination) { DirectX::XMFLOAT4X4 output = {}; DirectX::XMStoreFloat4x4(&output, DirectX::XMMatrixTranspose(matrix)); std::memcpy(destination, &output, sizeof(output)); } -FrameConstants BuildFrameConstants(uint32_t width, uint32_t height, uint32_t splatCount) { +DirectX::XMMATRIX BuildSortViewMatrix() { using namespace DirectX; - const XMVECTOR eye = XMVectorSet(0.0f, 0.5f, 1.0f, 1.0f); const XMVECTOR target = XMVectorSet(0.0f, 0.5f, -5.0f, 1.0f); const XMVECTOR up = XMVectorSet(0.0f, 1.0f, 0.0f, 0.0f); + return XMMatrixLookAtRH(eye, target, up); +} + +uint32_t FloatToSortableUint(float value) { + const uint32_t bits = std::bit_cast(value); + const uint32_t mask = (0u - (bits >> 31)) | 0x80000000u; + return bits ^ mask; +} + +void BuildCpuSortedOrder( + const GaussianSplatRuntimeData& sceneData, + std::vector& outOrder, + std::vector* outSortedKeys = nullptr) { + using namespace DirectX; + + const XMMATRIX view = BuildSortViewMatrix(); + const float* positionBytes = reinterpret_cast(sceneData.positionData.data()); + std::vector> sortablePairs(sceneData.splatCount); + for (uint32_t index = 0; index < sceneData.splatCount; ++index) { + const float* position = positionBytes + index * 3u; + const XMVECTOR worldPosition = XMVectorSet(position[0], position[1], position[2], 1.0f); + const XMVECTOR viewPosition = XMVector4Transform(worldPosition, view); + sortablePairs[index] = { + FloatToSortableUint(XMVectorGetZ(viewPosition)), + index, + }; + } + + std::stable_sort( + sortablePairs.begin(), + sortablePairs.end(), + [](const auto& left, const auto& right) { + return left.first < right.first; + }); + + outOrder.resize(sortablePairs.size()); + if (outSortedKeys != nullptr) { + outSortedKeys->resize(sortablePairs.size()); + } + for (size_t index = 0; index < sortablePairs.size(); ++index) { + outOrder[index] = sortablePairs[index].second; + if (outSortedKeys != nullptr) { + (*outSortedKeys)[index] = sortablePairs[index].first; + } + } +} + +FrameConstants BuildFrameConstants(uint32_t width, uint32_t height, uint32_t splatCount) { + using namespace DirectX; const float aspect = height > 0 ? static_cast(width) / static_cast(height) : 1.0f; - const XMMATRIX view = XMMatrixLookAtRH(eye, target, up); + const XMMATRIX view = BuildSortViewMatrix(); const XMMATRIX projection = XMMatrixPerspectiveFovRH(XMConvertToRadians(60.0f), aspect, 0.1f, 200.0f); const XMMATRIX viewProjection = XMMatrixMultiply(view, projection); @@ -383,6 +452,14 @@ bool App::InitializeRhi() { ID3D12Device* device = m_device.GetDevice(); IDXGIFactory4* factory = m_device.GetFactory(); + D3D12_FEATURE_DATA_D3D12_OPTIONS1 options1 = {}; + if (SUCCEEDED(device->CheckFeatureSupport(D3D12_FEATURE_D3D12_OPTIONS1, &options1, sizeof(options1)))) { + AppendTrace( + "InitializeRhi: wave ops min=" + std::to_string(options1.WaveLaneCountMin) + + " max=" + std::to_string(options1.WaveLaneCountMax) + + " total_lane_count=" + std::to_string(options1.TotalLaneCount)); + } + if (!m_commandQueue.Initialize(device, CommandQueueType::Direct)) { m_lastErrorMessage = L"Failed to initialize the direct command queue."; return false; @@ -664,155 +741,316 @@ bool App::InitializePreparePassResources() { } bool App::InitializeSortResources() { - std::vector initialOrder(static_cast(m_gaussianSceneData.splatCount)); - for (uint32_t index = 0; index < m_gaussianSceneData.splatCount; ++index) { - initialOrder[index] = index; + std::vector initialOrder; + if (kUseCpuSortBaseline) { + BuildCpuSortedOrder(m_gaussianSceneData, initialOrder, nullptr); + } else { + initialOrder.resize(static_cast(m_gaussianSceneData.splatCount)); + for (uint32_t index = 0; index < m_gaussianSceneData.splatCount; ++index) { + initialOrder[index] = index; + } } - ID3D12Device* device = m_device.GetDevice(); - ID3D12GraphicsCommandList* commandList = m_commandList.GetCommandList(); + const uint64_t sortBufferBytes = static_cast(m_gaussianSceneData.splatCount) * sizeof(uint32_t); + const uint32_t threadBlocks = + static_cast((m_gaussianSceneData.splatCount + (kDeviceRadixSortPartitionSize - 1u)) / kDeviceRadixSortPartitionSize); + const uint64_t passHistogramElements = static_cast(threadBlocks) * kDeviceRadixSortRadix; + const uint64_t globalHistogramElements = static_cast(kDeviceRadixSortRadix) * kDeviceRadixSortPassCount; - m_commandAllocator.Reset(); - m_commandList.Reset(); + auto initializeStorageBuffer = [this](D3D12Buffer*& buffer, uint64_t sizeInBytes) -> bool { + BufferDesc bufferDesc = {}; + bufferDesc.size = sizeInBytes; + bufferDesc.stride = sizeof(uint32_t); + bufferDesc.bufferType = static_cast(BufferType::Storage); + bufferDesc.flags = static_cast(BufferFlags::AllowUnorderedAccess); - const D3D12_RESOURCE_STATES shaderResourceState = - D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE | D3D12_RESOURCE_STATE_PIXEL_SHADER_RESOURCE; + buffer = static_cast(m_device.CreateBuffer(bufferDesc)); + if (buffer == nullptr) { + return false; + } + buffer->SetStride(sizeof(uint32_t)); + buffer->SetBufferType(BufferType::Storage); + return true; + }; - m_orderBuffer = new D3D12Buffer(); + BufferDesc orderBufferDesc = {}; + orderBufferDesc.size = sortBufferBytes; + orderBufferDesc.stride = sizeof(uint32_t); + orderBufferDesc.bufferType = static_cast(BufferType::Storage); + orderBufferDesc.flags = static_cast(BufferFlags::AllowUnorderedAccess); + m_orderBuffer = static_cast(m_device.CreateBuffer( + orderBufferDesc, + initialOrder.data(), + static_cast(sortBufferBytes), + ResourceStates::NonPixelShaderResource)); + if (m_orderBuffer == nullptr) { + m_lastErrorMessage = L"Failed to create the primary order buffer."; + return false; + } m_orderBuffer->SetStride(sizeof(uint32_t)); m_orderBuffer->SetBufferType(BufferType::Storage); - if (!m_orderBuffer->InitializeWithData( - device, - commandList, - initialOrder.data(), - static_cast(initialOrder.size() * sizeof(uint32_t)), - shaderResourceState)) { - m_lastErrorMessage = L"Failed to initialize the order buffer."; + + if (!initializeStorageBuffer(m_orderScratchBuffer, sortBufferBytes)) { + m_lastErrorMessage = L"Failed to create the scratch order buffer."; return false; } - m_orderBuffer->SetState(ResourceStates::NonPixelShaderResource); - BufferDesc sortKeyBufferDesc = {}; - sortKeyBufferDesc.size = static_cast(m_gaussianSceneData.splatCount) * sizeof(uint32_t); - sortKeyBufferDesc.stride = sizeof(uint32_t); - sortKeyBufferDesc.bufferType = static_cast(BufferType::Storage); - sortKeyBufferDesc.flags = static_cast(BufferFlags::AllowUnorderedAccess); - - m_sortKeyBuffer = static_cast(m_device.CreateBuffer(sortKeyBufferDesc)); - if (m_sortKeyBuffer == nullptr) { - m_lastErrorMessage = L"Failed to create the sort key buffer."; + if (!initializeStorageBuffer(m_sortKeyBuffer, sortBufferBytes)) { + m_lastErrorMessage = L"Failed to create the primary sort key buffer."; + return false; + } + + if (!initializeStorageBuffer(m_sortKeyScratchBuffer, sortBufferBytes)) { + m_lastErrorMessage = L"Failed to create the scratch sort key buffer."; + return false; + } + + if (!initializeStorageBuffer(m_passHistogramBuffer, passHistogramElements * sizeof(uint32_t))) { + m_lastErrorMessage = L"Failed to create the pass histogram buffer."; + return false; + } + + if (!initializeStorageBuffer(m_globalHistogramBuffer, globalHistogramElements * sizeof(uint32_t))) { + m_lastErrorMessage = L"Failed to create the global histogram buffer."; return false; } - m_sortKeyBuffer->SetStride(sizeof(uint32_t)); - m_sortKeyBuffer->SetBufferType(BufferType::Storage); ResourceViewDesc structuredViewDesc = {}; structuredViewDesc.dimension = ResourceViewDimension::StructuredBuffer; structuredViewDesc.structureByteStride = sizeof(uint32_t); structuredViewDesc.elementCount = m_gaussianSceneData.splatCount; + ResourceViewDesc passHistogramViewDesc = structuredViewDesc; + passHistogramViewDesc.elementCount = static_cast(passHistogramElements); + + ResourceViewDesc globalHistogramViewDesc = structuredViewDesc; + globalHistogramViewDesc.elementCount = static_cast(globalHistogramElements); + m_orderBufferSrv.reset(static_cast(m_device.CreateShaderResourceView(m_orderBuffer, structuredViewDesc))); if (!m_orderBufferSrv) { - m_lastErrorMessage = L"Failed to create the order buffer SRV."; + m_lastErrorMessage = L"Failed to create the primary order SRV."; return false; } - m_sortKeySrv.reset(static_cast(m_device.CreateShaderResourceView(m_sortKeyBuffer, structuredViewDesc))); - if (!m_sortKeySrv) { - m_lastErrorMessage = L"Failed to create the sort key SRV."; + m_orderBufferUav.reset(static_cast(m_device.CreateUnorderedAccessView(m_orderBuffer, structuredViewDesc))); + if (!m_orderBufferUav) { + m_lastErrorMessage = L"Failed to create the primary order UAV."; + return false; + } + + m_orderScratchUav.reset(static_cast(m_device.CreateUnorderedAccessView(m_orderScratchBuffer, structuredViewDesc))); + if (!m_orderScratchUav) { + m_lastErrorMessage = L"Failed to create the scratch order UAV."; return false; } m_sortKeyUav.reset(static_cast(m_device.CreateUnorderedAccessView(m_sortKeyBuffer, structuredViewDesc))); if (!m_sortKeyUav) { - m_lastErrorMessage = L"Failed to create the sort key UAV."; + m_lastErrorMessage = L"Failed to create the primary sort key UAV."; return false; } - DescriptorSetLayoutBinding bindings[4] = {}; - bindings[0].binding = 0; - bindings[0].type = static_cast(DescriptorType::CBV); - bindings[0].count = 1; - bindings[0].visibility = static_cast(ShaderVisibility::All); - - bindings[1].binding = 1; - bindings[1].type = static_cast(DescriptorType::SRV); - bindings[1].count = 1; - bindings[1].visibility = static_cast(ShaderVisibility::All); - bindings[1].resourceDimension = ResourceViewDimension::RawBuffer; - - bindings[2].binding = 2; - bindings[2].type = static_cast(DescriptorType::SRV); - bindings[2].count = 1; - bindings[2].visibility = static_cast(ShaderVisibility::All); - bindings[2].resourceDimension = ResourceViewDimension::StructuredBuffer; - - bindings[3].binding = 3; - bindings[3].type = static_cast(DescriptorType::UAV); - bindings[3].count = 1; - bindings[3].visibility = static_cast(ShaderVisibility::All); - bindings[3].resourceDimension = ResourceViewDimension::StructuredBuffer; - - DescriptorSetLayoutDesc setLayout = {}; - setLayout.bindings = bindings; - setLayout.bindingCount = 4; - - RHIPipelineLayoutDesc pipelineLayoutDesc = {}; - pipelineLayoutDesc.setLayouts = &setLayout; - pipelineLayoutDesc.setLayoutCount = 1; - - m_sortPipelineLayout = m_device.CreatePipelineLayout(pipelineLayoutDesc); - if (m_sortPipelineLayout == nullptr) { - m_lastErrorMessage = L"Failed to create the sort pipeline layout."; + m_sortKeyScratchUav.reset(static_cast(m_device.CreateUnorderedAccessView(m_sortKeyScratchBuffer, structuredViewDesc))); + if (!m_sortKeyScratchUav) { + m_lastErrorMessage = L"Failed to create the scratch sort key UAV."; return false; } - DescriptorPoolDesc poolDesc = {}; - poolDesc.type = DescriptorHeapType::CBV_SRV_UAV; - poolDesc.descriptorCount = 3; - poolDesc.shaderVisible = true; - m_sortDescriptorPool = m_device.CreateDescriptorPool(poolDesc); - if (m_sortDescriptorPool == nullptr) { - m_lastErrorMessage = L"Failed to create the sort descriptor pool."; + m_passHistogramUav.reset(static_cast(m_device.CreateUnorderedAccessView(m_passHistogramBuffer, passHistogramViewDesc))); + if (!m_passHistogramUav) { + m_lastErrorMessage = L"Failed to create the pass histogram UAV."; return false; } - m_sortDescriptorSet = m_sortDescriptorPool->AllocateSet(setLayout); - if (m_sortDescriptorSet == nullptr) { - m_lastErrorMessage = L"Failed to allocate the sort descriptor set."; + m_globalHistogramUav.reset(static_cast(m_device.CreateUnorderedAccessView(m_globalHistogramBuffer, globalHistogramViewDesc))); + if (!m_globalHistogramUav) { + m_lastErrorMessage = L"Failed to create the global histogram UAV."; return false; } - m_sortDescriptorSet->Update(1, m_gaussianPositionView.get()); - m_sortDescriptorSet->Update(2, m_orderBufferSrv.get()); - m_sortDescriptorSet->Update(3, m_sortKeyUav.get()); + DescriptorSetLayoutBinding buildSortKeyBindings[4] = {}; + buildSortKeyBindings[0].binding = 0; + buildSortKeyBindings[0].type = static_cast(DescriptorType::CBV); + buildSortKeyBindings[0].count = 1; + buildSortKeyBindings[0].visibility = static_cast(ShaderVisibility::All); - ShaderCompileDesc computeShaderDesc = {}; - const std::filesystem::path compiledShaderPath = ResolveShaderPath(L"BuildSortKeysCS.dxil"); - const std::vector compiledShaderBinary = LoadBinaryFile(compiledShaderPath); - if (!compiledShaderBinary.empty()) { - computeShaderDesc.profile = L"cs_6_6"; - computeShaderDesc.compiledBinaryBackend = ShaderBinaryBackend::D3D12; - computeShaderDesc.compiledBinary = compiledShaderBinary; - } else { - computeShaderDesc.fileName = ResolveShaderPath(L"BuildSortKeysCS.hlsl").wstring(); - computeShaderDesc.entryPoint = L"MainCS"; - computeShaderDesc.profile = L"cs_5_0"; - } + buildSortKeyBindings[1].binding = 1; + buildSortKeyBindings[1].type = static_cast(DescriptorType::SRV); + buildSortKeyBindings[1].count = 1; + buildSortKeyBindings[1].visibility = static_cast(ShaderVisibility::All); + buildSortKeyBindings[1].resourceDimension = ResourceViewDimension::RawBuffer; - ComputePipelineDesc pipelineDesc = {}; - pipelineDesc.pipelineLayout = m_sortPipelineLayout; - pipelineDesc.computeShader = computeShaderDesc; - m_sortPipelineState = m_device.CreateComputePipelineState(pipelineDesc); - if (m_sortPipelineState == nullptr) { - m_lastErrorMessage = L"Failed to create the sort pipeline state."; + buildSortKeyBindings[2].binding = 2; + buildSortKeyBindings[2].type = static_cast(DescriptorType::SRV); + buildSortKeyBindings[2].count = 1; + buildSortKeyBindings[2].visibility = static_cast(ShaderVisibility::All); + buildSortKeyBindings[2].resourceDimension = ResourceViewDimension::StructuredBuffer; + + buildSortKeyBindings[3].binding = 3; + buildSortKeyBindings[3].type = static_cast(DescriptorType::UAV); + buildSortKeyBindings[3].count = 1; + buildSortKeyBindings[3].visibility = static_cast(ShaderVisibility::All); + buildSortKeyBindings[3].resourceDimension = ResourceViewDimension::StructuredBuffer; + + DescriptorSetLayoutDesc buildSortKeySetLayout = {}; + buildSortKeySetLayout.bindings = buildSortKeyBindings; + buildSortKeySetLayout.bindingCount = 4; + + RHIPipelineLayoutDesc buildSortKeyLayoutDesc = {}; + buildSortKeyLayoutDesc.setLayouts = &buildSortKeySetLayout; + buildSortKeyLayoutDesc.setLayoutCount = 1; + + m_buildSortKeyPipelineLayout = m_device.CreatePipelineLayout(buildSortKeyLayoutDesc); + if (m_buildSortKeyPipelineLayout == nullptr) { + m_lastErrorMessage = L"Failed to create the build-sort-key pipeline layout."; return false; } - m_commandList.Close(); - void* commandLists[] = { &m_commandList }; - m_commandQueue.ExecuteCommandLists(1, commandLists); - m_commandQueue.WaitForIdle(); + DescriptorPoolDesc buildSortKeyPoolDesc = {}; + buildSortKeyPoolDesc.type = DescriptorHeapType::CBV_SRV_UAV; + buildSortKeyPoolDesc.descriptorCount = 3; + buildSortKeyPoolDesc.shaderVisible = true; + m_buildSortKeyDescriptorPool = m_device.CreateDescriptorPool(buildSortKeyPoolDesc); + if (m_buildSortKeyDescriptorPool == nullptr) { + m_lastErrorMessage = L"Failed to create the build-sort-key descriptor pool."; + return false; + } + + m_buildSortKeyDescriptorSet = m_buildSortKeyDescriptorPool->AllocateSet(buildSortKeySetLayout); + if (m_buildSortKeyDescriptorSet == nullptr) { + m_lastErrorMessage = L"Failed to allocate the build-sort-key descriptor set."; + return false; + } + + m_buildSortKeyDescriptorSet->Update(1, m_gaussianPositionView.get()); + m_buildSortKeyDescriptorSet->Update(2, m_orderBufferSrv.get()); + m_buildSortKeyDescriptorSet->Update(3, m_sortKeyUav.get()); + + ShaderCompileDesc buildSortKeyShaderDesc = BuildDxilShaderDesc(ResolveShaderPath(L"BuildSortKeysCS.dxil"), L"cs_6_6"); + if (buildSortKeyShaderDesc.compiledBinary.empty()) { + m_lastErrorMessage = L"Failed to load BuildSortKeysCS.dxil."; + return false; + } + + ComputePipelineDesc buildSortKeyPipelineDesc = {}; + buildSortKeyPipelineDesc.pipelineLayout = m_buildSortKeyPipelineLayout; + buildSortKeyPipelineDesc.computeShader = buildSortKeyShaderDesc; + m_buildSortKeyPipelineState = m_device.CreateComputePipelineState(buildSortKeyPipelineDesc); + if (m_buildSortKeyPipelineState == nullptr) { + m_lastErrorMessage = L"Failed to create the build-sort-key pipeline state."; + return false; + } + + DescriptorSetLayoutBinding radixSortBindings[7] = {}; + radixSortBindings[0].binding = 0; + radixSortBindings[0].type = static_cast(DescriptorType::CBV); + radixSortBindings[0].count = 1; + radixSortBindings[0].visibility = static_cast(ShaderVisibility::All); + + for (uint32_t bindingIndex = 1; bindingIndex <= 6; ++bindingIndex) { + radixSortBindings[bindingIndex].binding = bindingIndex; + radixSortBindings[bindingIndex].type = static_cast(DescriptorType::UAV); + radixSortBindings[bindingIndex].count = 1; + radixSortBindings[bindingIndex].visibility = static_cast(ShaderVisibility::All); + radixSortBindings[bindingIndex].resourceDimension = ResourceViewDimension::StructuredBuffer; + } + + DescriptorSetLayoutDesc radixSortSetLayout = {}; + radixSortSetLayout.bindings = radixSortBindings; + radixSortSetLayout.bindingCount = 7; + + RHIPipelineLayoutDesc radixSortLayoutDesc = {}; + radixSortLayoutDesc.setLayouts = &radixSortSetLayout; + radixSortLayoutDesc.setLayoutCount = 1; + + m_radixSortPipelineLayout = m_device.CreatePipelineLayout(radixSortLayoutDesc); + if (m_radixSortPipelineLayout == nullptr) { + m_lastErrorMessage = L"Failed to create the radix-sort pipeline layout."; + return false; + } + + DescriptorPoolDesc radixSortPoolDesc = {}; + radixSortPoolDesc.type = DescriptorHeapType::CBV_SRV_UAV; + radixSortPoolDesc.descriptorCount = 12; + radixSortPoolDesc.shaderVisible = true; + m_radixSortDescriptorPool = m_device.CreateDescriptorPool(radixSortPoolDesc); + if (m_radixSortDescriptorPool == nullptr) { + m_lastErrorMessage = L"Failed to create the radix-sort descriptor pool."; + return false; + } + + m_radixSortDescriptorSetPrimaryToScratch = m_radixSortDescriptorPool->AllocateSet(radixSortSetLayout); + m_radixSortDescriptorSetScratchToPrimary = m_radixSortDescriptorPool->AllocateSet(radixSortSetLayout); + if (m_radixSortDescriptorSetPrimaryToScratch == nullptr || m_radixSortDescriptorSetScratchToPrimary == nullptr) { + m_lastErrorMessage = L"Failed to allocate the radix-sort descriptor sets."; + return false; + } + + m_radixSortDescriptorSetPrimaryToScratch->Update(1, m_sortKeyUav.get()); + m_radixSortDescriptorSetPrimaryToScratch->Update(2, m_sortKeyScratchUav.get()); + m_radixSortDescriptorSetPrimaryToScratch->Update(3, m_orderBufferUav.get()); + m_radixSortDescriptorSetPrimaryToScratch->Update(4, m_orderScratchUav.get()); + m_radixSortDescriptorSetPrimaryToScratch->Update(5, m_passHistogramUav.get()); + m_radixSortDescriptorSetPrimaryToScratch->Update(6, m_globalHistogramUav.get()); + + m_radixSortDescriptorSetScratchToPrimary->Update(1, m_sortKeyScratchUav.get()); + m_radixSortDescriptorSetScratchToPrimary->Update(2, m_sortKeyUav.get()); + m_radixSortDescriptorSetScratchToPrimary->Update(3, m_orderScratchUav.get()); + m_radixSortDescriptorSetScratchToPrimary->Update(4, m_orderBufferUav.get()); + m_radixSortDescriptorSetScratchToPrimary->Update(5, m_passHistogramUav.get()); + m_radixSortDescriptorSetScratchToPrimary->Update(6, m_globalHistogramUav.get()); + + const ShaderCompileDesc radixInitShaderDesc = BuildDxilShaderDesc(ResolveShaderPath(L"RadixInit.dxil"), L"cs_6_6"); + const ShaderCompileDesc radixUpsweepShaderDesc = BuildDxilShaderDesc(ResolveShaderPath(L"RadixUpsweep.dxil"), L"cs_6_6"); + const ShaderCompileDesc radixGlobalHistogramShaderDesc = BuildDxilShaderDesc(ResolveShaderPath(L"RadixGlobalHistogram.dxil"), L"cs_6_6"); + const ShaderCompileDesc radixScanShaderDesc = BuildDxilShaderDesc(ResolveShaderPath(L"RadixScan.dxil"), L"cs_6_6"); + const ShaderCompileDesc radixDownsweepShaderDesc = BuildDxilShaderDesc(ResolveShaderPath(L"RadixDownsweep.dxil"), L"cs_6_6"); + if (radixInitShaderDesc.compiledBinary.empty() || + radixUpsweepShaderDesc.compiledBinary.empty() || + radixGlobalHistogramShaderDesc.compiledBinary.empty() || + radixScanShaderDesc.compiledBinary.empty() || + radixDownsweepShaderDesc.compiledBinary.empty()) { + m_lastErrorMessage = L"Failed to load one or more radix-sort DXIL shaders."; + return false; + } + + ComputePipelineDesc radixPipelineDesc = {}; + radixPipelineDesc.pipelineLayout = m_radixSortPipelineLayout; + + radixPipelineDesc.computeShader = radixInitShaderDesc; + m_radixSortInitPipelineState = m_device.CreateComputePipelineState(radixPipelineDesc); + if (m_radixSortInitPipelineState == nullptr) { + m_lastErrorMessage = L"Failed to create the radix-sort init pipeline state."; + return false; + } + + radixPipelineDesc.computeShader = radixUpsweepShaderDesc; + m_radixSortUpsweepPipelineState = m_device.CreateComputePipelineState(radixPipelineDesc); + if (m_radixSortUpsweepPipelineState == nullptr) { + m_lastErrorMessage = L"Failed to create the radix-sort upsweep pipeline state."; + return false; + } + + radixPipelineDesc.computeShader = radixGlobalHistogramShaderDesc; + m_radixSortGlobalHistogramPipelineState = m_device.CreateComputePipelineState(radixPipelineDesc); + if (m_radixSortGlobalHistogramPipelineState == nullptr) { + m_lastErrorMessage = L"Failed to create the radix-sort global-histogram pipeline state."; + return false; + } + + radixPipelineDesc.computeShader = radixScanShaderDesc; + m_radixSortScanPipelineState = m_device.CreateComputePipelineState(radixPipelineDesc); + if (m_radixSortScanPipelineState == nullptr) { + m_lastErrorMessage = L"Failed to create the radix-sort scan pipeline state."; + return false; + } + + radixPipelineDesc.computeShader = radixDownsweepShaderDesc; + m_radixSortDownsweepPipelineState = m_device.CreateComputePipelineState(radixPipelineDesc); + if (m_radixSortDownsweepPipelineState == nullptr) { + m_lastErrorMessage = L"Failed to create the radix-sort downsweep pipeline state."; + return false; + } return true; } @@ -928,17 +1166,42 @@ void App::ShutdownPreparePassResources() { } void App::ShutdownSortResources() { - if (m_sortDescriptorSet != nullptr) { - m_sortDescriptorSet->Shutdown(); - delete m_sortDescriptorSet; - m_sortDescriptorSet = nullptr; + if (m_buildSortKeyDescriptorSet != nullptr) { + m_buildSortKeyDescriptorSet->Shutdown(); + delete m_buildSortKeyDescriptorSet; + m_buildSortKeyDescriptorSet = nullptr; } - ShutdownAndDelete(m_sortDescriptorPool); - ShutdownAndDelete(m_sortPipelineState); - ShutdownAndDelete(m_sortPipelineLayout); + ShutdownAndDelete(m_buildSortKeyDescriptorPool); + ShutdownAndDelete(m_buildSortKeyPipelineState); + ShutdownAndDelete(m_buildSortKeyPipelineLayout); + + if (m_radixSortDescriptorSetPrimaryToScratch != nullptr) { + m_radixSortDescriptorSetPrimaryToScratch->Shutdown(); + delete m_radixSortDescriptorSetPrimaryToScratch; + m_radixSortDescriptorSetPrimaryToScratch = nullptr; + } + + if (m_radixSortDescriptorSetScratchToPrimary != nullptr) { + m_radixSortDescriptorSetScratchToPrimary->Shutdown(); + delete m_radixSortDescriptorSetScratchToPrimary; + m_radixSortDescriptorSetScratchToPrimary = nullptr; + } + + ShutdownAndDelete(m_radixSortDescriptorPool); + ShutdownAndDelete(m_radixSortDownsweepPipelineState); + ShutdownAndDelete(m_radixSortScanPipelineState); + ShutdownAndDelete(m_radixSortGlobalHistogramPipelineState); + ShutdownAndDelete(m_radixSortUpsweepPipelineState); + ShutdownAndDelete(m_radixSortInitPipelineState); + ShutdownAndDelete(m_radixSortPipelineLayout); + + m_globalHistogramUav.reset(); + m_passHistogramUav.reset(); + m_orderScratchUav.reset(); + m_orderBufferUav.reset(); m_orderBufferSrv.reset(); + m_sortKeyScratchUav.reset(); m_sortKeyUav.reset(); - m_sortKeySrv.reset(); if (m_orderBuffer != nullptr) { m_orderBuffer->Shutdown(); @@ -946,11 +1209,35 @@ void App::ShutdownSortResources() { m_orderBuffer = nullptr; } + if (m_orderScratchBuffer != nullptr) { + m_orderScratchBuffer->Shutdown(); + delete m_orderScratchBuffer; + m_orderScratchBuffer = nullptr; + } + if (m_sortKeyBuffer != nullptr) { m_sortKeyBuffer->Shutdown(); delete m_sortKeyBuffer; m_sortKeyBuffer = nullptr; } + + if (m_sortKeyScratchBuffer != nullptr) { + m_sortKeyScratchBuffer->Shutdown(); + delete m_sortKeyScratchBuffer; + m_sortKeyScratchBuffer = nullptr; + } + + if (m_passHistogramBuffer != nullptr) { + m_passHistogramBuffer->Shutdown(); + delete m_passHistogramBuffer; + m_passHistogramBuffer = nullptr; + } + + if (m_globalHistogramBuffer != nullptr) { + m_globalHistogramBuffer->Shutdown(); + delete m_globalHistogramBuffer; + m_globalHistogramBuffer = nullptr; + } } void App::ShutdownDebugDrawResources() { @@ -1010,22 +1297,35 @@ void App::Shutdown() { AppendTrace("Shutdown: end"); } -bool App::CaptureSortKeySnapshot() { - if (m_sortKeyBuffer == nullptr || m_gaussianSceneData.splatCount == 0 || m_sortKeySnapshotPath.empty()) { +bool App::CaptureSortSnapshot() { + if (m_sortKeyBuffer == nullptr || m_orderBuffer == nullptr || m_gaussianSceneData.splatCount == 0 || m_sortKeySnapshotPath.empty()) { return true; } const uint32_t sampleCount = std::min(16u, m_gaussianSceneData.splatCount); + const uint64_t keyBufferBytes = static_cast(m_gaussianSceneData.splatCount) * sizeof(uint32_t); const uint64_t sampleBytes = static_cast(sampleCount) * sizeof(uint32_t); - D3D12Buffer readbackBuffer; - if (!readbackBuffer.Initialize( + D3D12Buffer sortKeyReadbackBuffer; + if (!sortKeyReadbackBuffer.Initialize( + m_device.GetDevice(), + keyBufferBytes, + D3D12_RESOURCE_STATE_COPY_DEST, + D3D12_HEAP_TYPE_READBACK, + D3D12_RESOURCE_FLAG_NONE)) { + m_lastErrorMessage = L"Failed to create the sort key readback buffer."; + return false; + } + + D3D12Buffer orderReadbackBuffer; + if (!orderReadbackBuffer.Initialize( m_device.GetDevice(), sampleBytes, D3D12_RESOURCE_STATE_COPY_DEST, D3D12_HEAP_TYPE_READBACK, D3D12_RESOURCE_FLAG_NONE)) { - m_lastErrorMessage = L"Failed to create the sort key readback buffer."; + m_lastErrorMessage = L"Failed to create the order readback buffer."; + sortKeyReadbackBuffer.Shutdown(); return false; } @@ -1040,11 +1340,26 @@ bool App::CaptureSortKeySnapshot() { m_sortKeyBuffer->SetState(ResourceStates::CopySrc); } + if (m_orderBuffer->GetState() != ResourceStates::CopySrc) { + m_commandList.TransitionBarrier( + m_orderBuffer->GetResource(), + m_orderBuffer->GetState(), + ResourceStates::CopySrc); + m_orderBuffer->SetState(ResourceStates::CopySrc); + } + m_commandList.GetCommandList()->CopyBufferRegion( - readbackBuffer.GetResource(), + sortKeyReadbackBuffer.GetResource(), 0, m_sortKeyBuffer->GetResource(), 0, + keyBufferBytes); + + m_commandList.GetCommandList()->CopyBufferRegion( + orderReadbackBuffer.GetResource(), + 0, + m_orderBuffer->GetResource(), + 0, sampleBytes); m_commandList.Close(); @@ -1052,10 +1367,20 @@ bool App::CaptureSortKeySnapshot() { m_commandQueue.ExecuteCommandLists(1, commandLists); m_commandQueue.WaitForIdle(); - const uint32_t* keys = static_cast(readbackBuffer.Map()); + const uint32_t* keys = static_cast(sortKeyReadbackBuffer.Map()); if (keys == nullptr) { m_lastErrorMessage = L"Failed to map the sort key readback buffer."; - readbackBuffer.Shutdown(); + sortKeyReadbackBuffer.Shutdown(); + orderReadbackBuffer.Shutdown(); + return false; + } + + const uint32_t* order = static_cast(orderReadbackBuffer.Map()); + if (order == nullptr) { + sortKeyReadbackBuffer.Unmap(); + sortKeyReadbackBuffer.Shutdown(); + orderReadbackBuffer.Shutdown(); + m_lastErrorMessage = L"Failed to map the order readback buffer."; return false; } @@ -1066,19 +1391,265 @@ bool App::CaptureSortKeySnapshot() { std::ofstream output(snapshotPath, std::ios::binary | std::ios::trunc); if (!output.is_open()) { - readbackBuffer.Unmap(); - readbackBuffer.Shutdown(); + orderReadbackBuffer.Unmap(); + sortKeyReadbackBuffer.Unmap(); + orderReadbackBuffer.Shutdown(); + sortKeyReadbackBuffer.Shutdown(); m_lastErrorMessage = L"Failed to open the sort key snapshot output file."; return false; } output << "sample_count=" << sampleCount << '\n'; - for (uint32_t index = 0; index < sampleCount; ++index) { - output << "key[" << index << "]=" << keys[index] << '\n'; + bool isSorted = true; + uint32_t firstInversionIndex = 0u; + uint32_t firstInversionPrevious = 0u; + uint32_t firstInversionCurrent = 0u; + for (uint32_t index = 1; index < m_gaussianSceneData.splatCount; ++index) { + if (keys[index - 1u] > keys[index]) { + isSorted = false; + firstInversionIndex = index; + firstInversionPrevious = keys[index - 1u]; + firstInversionCurrent = keys[index]; + break; + } } - readbackBuffer.Unmap(); - readbackBuffer.Shutdown(); + std::vector cpuReferenceOrder; + std::vector cpuReferenceKeys; + BuildCpuSortedOrder(m_gaussianSceneData, cpuReferenceOrder, &cpuReferenceKeys); + + uint32_t firstCpuMismatchIndex = 0u; + uint32_t firstCpuMismatchGpu = 0u; + uint32_t firstCpuMismatchCpu = 0u; + bool matchesCpuReference = true; + for (uint32_t index = 0; index < m_gaussianSceneData.splatCount; ++index) { + if (keys[index] != cpuReferenceKeys[index]) { + matchesCpuReference = false; + firstCpuMismatchIndex = index; + firstCpuMismatchGpu = keys[index]; + firstCpuMismatchCpu = cpuReferenceKeys[index]; + break; + } + } + + output << "sorted=" << (isSorted ? 1 : 0) << '\n'; + output << "first_inversion_index=" << firstInversionIndex << '\n'; + output << "first_inversion_prev=" << firstInversionPrevious << '\n'; + output << "first_inversion_curr=" << firstInversionCurrent << '\n'; + output << "matches_cpu_reference=" << (matchesCpuReference ? 1 : 0) << '\n'; + output << "first_cpu_mismatch_index=" << firstCpuMismatchIndex << '\n'; + output << "first_cpu_mismatch_gpu=" << firstCpuMismatchGpu << '\n'; + output << "first_cpu_mismatch_cpu=" << firstCpuMismatchCpu << '\n'; + for (uint32_t index = 0; index < sampleCount; ++index) { + output << "key[" << index << "]=" << keys[index] << '\n'; + output << "order[" << index << "]=" << order[index] << '\n'; + output << "cpu_order[" << index << "]=" << cpuReferenceOrder[index] << '\n'; + output << "cpu_key[" << index << "]=" << cpuReferenceKeys[index] << '\n'; + } + + orderReadbackBuffer.Unmap(); + sortKeyReadbackBuffer.Unmap(); + orderReadbackBuffer.Shutdown(); + sortKeyReadbackBuffer.Shutdown(); + return output.good(); +} + +bool App::CapturePass3HistogramDebug() { + if (m_sortKeyScratchBuffer == nullptr || + m_sortKeyBuffer == nullptr || + m_globalHistogramBuffer == nullptr || + m_gaussianSceneData.splatCount == 0) { + return true; + } + + const uint64_t keyBufferBytes = static_cast(m_gaussianSceneData.splatCount) * sizeof(uint32_t); + const uint32_t histogramElementCount = kDeviceRadixSortRadix * kDeviceRadixSortPassCount; + const uint64_t histogramBytes = static_cast(histogramElementCount) * sizeof(uint32_t); + + D3D12Buffer keyReadbackBuffer; + if (!keyReadbackBuffer.Initialize( + m_device.GetDevice(), + keyBufferBytes, + D3D12_RESOURCE_STATE_COPY_DEST, + D3D12_HEAP_TYPE_READBACK, + D3D12_RESOURCE_FLAG_NONE)) { + m_lastErrorMessage = L"Failed to create the pass3 histogram key readback buffer."; + return false; + } + + D3D12Buffer histogramReadbackBuffer; + if (!histogramReadbackBuffer.Initialize( + m_device.GetDevice(), + histogramBytes, + D3D12_RESOURCE_STATE_COPY_DEST, + D3D12_HEAP_TYPE_READBACK, + D3D12_RESOURCE_FLAG_NONE)) { + keyReadbackBuffer.Shutdown(); + m_lastErrorMessage = L"Failed to create the pass3 histogram readback buffer."; + return false; + } + + D3D12Buffer primaryKeyReadbackBuffer; + if (!primaryKeyReadbackBuffer.Initialize( + m_device.GetDevice(), + keyBufferBytes, + D3D12_RESOURCE_STATE_COPY_DEST, + D3D12_HEAP_TYPE_READBACK, + D3D12_RESOURCE_FLAG_NONE)) { + histogramReadbackBuffer.Shutdown(); + keyReadbackBuffer.Shutdown(); + m_lastErrorMessage = L"Failed to create the pass3 primary-key readback buffer."; + return false; + } + + m_commandAllocator.Reset(); + m_commandList.Reset(); + + if (m_sortKeyScratchBuffer->GetState() != ResourceStates::CopySrc) { + m_commandList.TransitionBarrier( + m_sortKeyScratchBuffer->GetResource(), + m_sortKeyScratchBuffer->GetState(), + ResourceStates::CopySrc); + m_sortKeyScratchBuffer->SetState(ResourceStates::CopySrc); + } + + if (m_sortKeyBuffer->GetState() != ResourceStates::CopySrc) { + m_commandList.TransitionBarrier( + m_sortKeyBuffer->GetResource(), + m_sortKeyBuffer->GetState(), + ResourceStates::CopySrc); + m_sortKeyBuffer->SetState(ResourceStates::CopySrc); + } + + if (m_globalHistogramBuffer->GetState() != ResourceStates::CopySrc) { + m_commandList.TransitionBarrier( + m_globalHistogramBuffer->GetResource(), + m_globalHistogramBuffer->GetState(), + ResourceStates::CopySrc); + m_globalHistogramBuffer->SetState(ResourceStates::CopySrc); + } + + m_commandList.GetCommandList()->CopyBufferRegion( + keyReadbackBuffer.GetResource(), + 0, + m_sortKeyScratchBuffer->GetResource(), + 0, + keyBufferBytes); + m_commandList.GetCommandList()->CopyBufferRegion( + primaryKeyReadbackBuffer.GetResource(), + 0, + m_sortKeyBuffer->GetResource(), + 0, + keyBufferBytes); + m_commandList.GetCommandList()->CopyBufferRegion( + histogramReadbackBuffer.GetResource(), + 0, + m_globalHistogramBuffer->GetResource(), + 0, + histogramBytes); + + m_commandList.Close(); + void* commandLists[] = { &m_commandList }; + m_commandQueue.ExecuteCommandLists(1, commandLists); + m_commandQueue.WaitForIdle(); + + const uint32_t* keys = static_cast(keyReadbackBuffer.Map()); + if (keys == nullptr) { + primaryKeyReadbackBuffer.Shutdown(); + histogramReadbackBuffer.Shutdown(); + keyReadbackBuffer.Shutdown(); + m_lastErrorMessage = L"Failed to map the pass3 key readback buffer."; + return false; + } + + const uint32_t* primaryKeys = static_cast(primaryKeyReadbackBuffer.Map()); + if (primaryKeys == nullptr) { + keyReadbackBuffer.Unmap(); + primaryKeyReadbackBuffer.Shutdown(); + histogramReadbackBuffer.Shutdown(); + keyReadbackBuffer.Shutdown(); + m_lastErrorMessage = L"Failed to map the pass3 primary-key readback buffer."; + return false; + } + + const uint32_t* histogram = static_cast(histogramReadbackBuffer.Map()); + if (histogram == nullptr) { + primaryKeyReadbackBuffer.Unmap(); + keyReadbackBuffer.Unmap(); + primaryKeyReadbackBuffer.Shutdown(); + histogramReadbackBuffer.Shutdown(); + keyReadbackBuffer.Shutdown(); + m_lastErrorMessage = L"Failed to map the pass3 histogram readback buffer."; + return false; + } + + std::vector cpuCounts(kDeviceRadixSortRadix, 0u); + std::vector primaryCpuCounts(kDeviceRadixSortRadix, 0u); + for (uint32_t index = 0; index < m_gaussianSceneData.splatCount; ++index) { + const uint32_t bin = (keys[index] >> 24u) & 0xffu; + ++cpuCounts[bin]; + const uint32_t primaryBin = (primaryKeys[index] >> 24u) & 0xffu; + ++primaryCpuCounts[primaryBin]; + } + + std::vector cpuExclusiveOffsets(kDeviceRadixSortRadix, 0u); + std::vector primaryCpuExclusiveOffsets(kDeviceRadixSortRadix, 0u); + uint32_t runningOffset = 0u; + uint32_t primaryRunningOffset = 0u; + for (uint32_t bin = 0; bin < kDeviceRadixSortRadix; ++bin) { + cpuExclusiveOffsets[bin] = runningOffset; + runningOffset += cpuCounts[bin]; + primaryCpuExclusiveOffsets[bin] = primaryRunningOffset; + primaryRunningOffset += primaryCpuCounts[bin]; + } + + const std::filesystem::path debugPath = ResolveNearExecutable(L"phase3_hist_debug.txt"); + if (!debugPath.parent_path().empty()) { + std::filesystem::create_directories(debugPath.parent_path()); + } + + std::ofstream output(debugPath, std::ios::binary | std::ios::trunc); + if (!output.is_open()) { + histogramReadbackBuffer.Unmap(); + keyReadbackBuffer.Unmap(); + histogramReadbackBuffer.Shutdown(); + keyReadbackBuffer.Shutdown(); + m_lastErrorMessage = L"Failed to open the pass3 histogram debug output file."; + return false; + } + + uint32_t mismatchCount = 0u; + uint32_t primaryMismatchCount = 0u; + constexpr uint32_t kPass3GlobalHistogramBase = kDeviceRadixSortRadix * 3u; + output << "splat_count=" << m_gaussianSceneData.splatCount << '\n'; + output << "final_offset=" << runningOffset << '\n'; + output << "primary_final_offset=" << primaryRunningOffset << '\n'; + for (uint32_t bin = 0; bin < kDeviceRadixSortRadix; ++bin) { + const uint32_t gpuOffset = histogram[kPass3GlobalHistogramBase + bin]; + const uint32_t cpuOffset = cpuExclusiveOffsets[bin]; + const uint32_t primaryCpuOffset = primaryCpuExclusiveOffsets[bin]; + if (gpuOffset != cpuOffset) { + ++mismatchCount; + } + if (gpuOffset != primaryCpuOffset) { + ++primaryMismatchCount; + } + + output << "bin[" << bin << "].count=" << cpuCounts[bin] + << " cpu_offset=" << cpuOffset + << " primary_count=" << primaryCpuCounts[bin] + << " primary_cpu_offset=" << primaryCpuOffset + << " gpu_offset=" << gpuOffset << '\n'; + } + output << "mismatch_count=" << mismatchCount << '\n'; + output << "primary_mismatch_count=" << primaryMismatchCount << '\n'; + + histogramReadbackBuffer.Unmap(); + primaryKeyReadbackBuffer.Unmap(); + keyReadbackBuffer.Unmap(); + primaryKeyReadbackBuffer.Shutdown(); + histogramReadbackBuffer.Shutdown(); + keyReadbackBuffer.Shutdown(); return output.good(); } @@ -1122,16 +1693,104 @@ void App::RenderFrame(bool captureScreenshot) { static_cast(m_width), static_cast(m_height), m_gaussianSceneData.splatCount); + const uint32_t threadBlocks = + static_cast((m_gaussianSceneData.splatCount + (kDeviceRadixSortPartitionSize - 1u)) / kDeviceRadixSortPartitionSize); + const uint32_t passHistogramElementCount = threadBlocks * kDeviceRadixSortRadix; + constexpr uint32_t kSortDebugStageCount = 5u; + const char* const sortDebugStageNames[kSortDebugStageCount] = { + "build_sort_keys", + "radix_pass_0", + "radix_pass_1", + "radix_pass_2", + "radix_pass_3", + }; + const uint32_t sortDebugSampleCount = 0u; + const uint64_t sortDebugSampleBytes = static_cast(sortDebugSampleCount) * sizeof(uint32_t); + D3D12Buffer sortDebugKeyReadbackBuffer; + D3D12Buffer sortDebugOrderReadbackBuffer; + const bool sortDebugEnabled = + sortDebugSampleCount > 0 && + sortDebugKeyReadbackBuffer.Initialize( + m_device.GetDevice(), + sortDebugSampleBytes * kSortDebugStageCount, + D3D12_RESOURCE_STATE_COPY_DEST, + D3D12_HEAP_TYPE_READBACK, + D3D12_RESOURCE_FLAG_NONE) && + sortDebugOrderReadbackBuffer.Initialize( + m_device.GetDevice(), + sortDebugSampleBytes * kSortDebugStageCount, + D3D12_RESOURCE_STATE_COPY_DEST, + D3D12_HEAP_TYPE_READBACK, + D3D12_RESOURCE_FLAG_NONE); + D3D12Buffer pass3PassHistogramReadbackBuffer; + D3D12Buffer pass3PreScanGlobalHistogramReadbackBuffer; + const uint64_t pass3PassHistogramReadbackBytes = + static_cast(passHistogramElementCount) * sizeof(uint32_t); + const uint64_t pass3PreScanGlobalHistogramReadbackBytes = + static_cast(kDeviceRadixSortRadix * kDeviceRadixSortPassCount) * sizeof(uint32_t); + const bool pass3PassHistogramDebugEnabled = + captureScreenshot && + pass3PassHistogramReadbackBuffer.Initialize( + m_device.GetDevice(), + pass3PassHistogramReadbackBytes, + D3D12_RESOURCE_STATE_COPY_DEST, + D3D12_HEAP_TYPE_READBACK, + D3D12_RESOURCE_FLAG_NONE); + const bool pass3PreScanGlobalHistogramDebugEnabled = + captureScreenshot && + pass3PreScanGlobalHistogramReadbackBuffer.Initialize( + m_device.GetDevice(), + pass3PreScanGlobalHistogramReadbackBytes, + D3D12_RESOURCE_STATE_COPY_DEST, + D3D12_HEAP_TYPE_READBACK, + D3D12_RESOURCE_FLAG_NONE); + bool pass3PassHistogramCaptured = false; + bool pass3PreScanGlobalHistogramCaptured = false; + + auto transitionBuffer = [this](D3D12Buffer* buffer, ResourceStates newState) { + if (buffer != nullptr && buffer->GetState() != newState) { + m_commandList.TransitionBarrier( + buffer->GetResource(), + buffer->GetState(), + newState); + buffer->SetState(newState); + } + }; + + auto recordSortDebugStage = [&](uint32_t stageIndex, + D3D12Buffer* keyBuffer, + D3D12Buffer* orderBuffer, + ResourceStates keyRestoreState, + ResourceStates orderRestoreState) { + if (!sortDebugEnabled || + keyBuffer == nullptr || + orderBuffer == nullptr || + sortDebugSampleBytes == 0 || + stageIndex >= kSortDebugStageCount) { + return; + } + + transitionBuffer(keyBuffer, ResourceStates::CopySrc); + transitionBuffer(orderBuffer, ResourceStates::CopySrc); + m_commandList.GetCommandList()->CopyBufferRegion( + sortDebugKeyReadbackBuffer.GetResource(), + static_cast(stageIndex) * sortDebugSampleBytes, + keyBuffer->GetResource(), + 0, + sortDebugSampleBytes); + m_commandList.GetCommandList()->CopyBufferRegion( + sortDebugOrderReadbackBuffer.GetResource(), + static_cast(stageIndex) * sortDebugSampleBytes, + orderBuffer->GetResource(), + 0, + sortDebugSampleBytes); + transitionBuffer(keyBuffer, keyRestoreState); + transitionBuffer(orderBuffer, orderRestoreState); + }; + m_prepareDescriptorSet->WriteConstant(0, &frameConstants, sizeof(frameConstants)); m_debugDescriptorSet->WriteConstant(0, &frameConstants, sizeof(frameConstants)); - - if (m_preparedViewBuffer->GetState() != ResourceStates::UnorderedAccess) { - m_commandList.TransitionBarrier( - m_preparedViewBuffer->GetResource(), - m_preparedViewBuffer->GetState(), - ResourceStates::UnorderedAccess); - m_preparedViewBuffer->SetState(ResourceStates::UnorderedAccess); - } + transitionBuffer(m_preparedViewBuffer, ResourceStates::UnorderedAccess); m_commandList.SetPipelineState(m_preparePipelineState); RHIDescriptorSet* prepareSets[] = { m_prepareDescriptorSet }; @@ -1139,26 +1798,112 @@ void App::RenderFrame(bool captureScreenshot) { m_commandList.Dispatch((m_gaussianSceneData.splatCount + (kPrepareThreadGroupSize - 1)) / kPrepareThreadGroupSize, 1, 1); m_commandList.UAVBarrier(m_preparedViewBuffer->GetResource()); - if (m_sortKeyBuffer->GetState() != ResourceStates::UnorderedAccess) { - m_commandList.TransitionBarrier( - m_sortKeyBuffer->GetResource(), - m_sortKeyBuffer->GetState(), - ResourceStates::UnorderedAccess); - m_sortKeyBuffer->SetState(ResourceStates::UnorderedAccess); - } + transitionBuffer(m_orderBuffer, ResourceStates::NonPixelShaderResource); + transitionBuffer(m_sortKeyBuffer, ResourceStates::UnorderedAccess); - m_sortDescriptorSet->WriteConstant(0, &frameConstants, sizeof(frameConstants)); - m_commandList.SetPipelineState(m_sortPipelineState); - RHIDescriptorSet* sortSets[] = { m_sortDescriptorSet }; - m_commandList.SetComputeDescriptorSets(0, 1, sortSets, m_sortPipelineLayout); + m_buildSortKeyDescriptorSet->WriteConstant(0, &frameConstants, sizeof(frameConstants)); + m_commandList.SetPipelineState(m_buildSortKeyPipelineState); + RHIDescriptorSet* buildSortKeySets[] = { m_buildSortKeyDescriptorSet }; + m_commandList.SetComputeDescriptorSets(0, 1, buildSortKeySets, m_buildSortKeyPipelineLayout); m_commandList.Dispatch((m_gaussianSceneData.splatCount + (kSortThreadGroupSize - 1)) / kSortThreadGroupSize, 1, 1); m_commandList.UAVBarrier(m_sortKeyBuffer->GetResource()); - - m_commandList.TransitionBarrier( - m_preparedViewBuffer->GetResource(), + recordSortDebugStage( + 0u, + m_sortKeyBuffer, + m_orderBuffer, ResourceStates::UnorderedAccess, ResourceStates::NonPixelShaderResource); - m_preparedViewBuffer->SetState(ResourceStates::NonPixelShaderResource); + + if (!kUseCpuSortBaseline) { + transitionBuffer(m_orderBuffer, ResourceStates::UnorderedAccess); + transitionBuffer(m_orderScratchBuffer, ResourceStates::UnorderedAccess); + transitionBuffer(m_sortKeyScratchBuffer, ResourceStates::UnorderedAccess); + transitionBuffer(m_passHistogramBuffer, ResourceStates::UnorderedAccess); + transitionBuffer(m_globalHistogramBuffer, ResourceStates::UnorderedAccess); + + RadixSortConstants radixSortConstants = {}; + radixSortConstants.numKeys = m_gaussianSceneData.splatCount; + radixSortConstants.threadBlocks = threadBlocks; + + m_radixSortDescriptorSetPrimaryToScratch->WriteConstant(0, &radixSortConstants, sizeof(radixSortConstants)); + m_commandList.SetPipelineState(m_radixSortInitPipelineState); + RHIDescriptorSet* radixInitSets[] = { m_radixSortDescriptorSetPrimaryToScratch }; + m_commandList.SetComputeDescriptorSets(0, 1, radixInitSets, m_radixSortPipelineLayout); + m_commandList.Dispatch(1, 1, 1); + m_commandList.UAVBarrier(m_globalHistogramBuffer->GetResource()); + + for (uint32_t passIndex = 0; passIndex < kDeviceRadixSortPassCount; ++passIndex) { + radixSortConstants.radixShift = passIndex * 8u; + RHIDescriptorSet* activeRadixSet = + (passIndex & 1u) == 0u + ? m_radixSortDescriptorSetPrimaryToScratch + : m_radixSortDescriptorSetScratchToPrimary; + D3D12Buffer* destinationKeyBuffer = + (passIndex & 1u) == 0u + ? m_sortKeyScratchBuffer + : m_sortKeyBuffer; + D3D12Buffer* destinationOrderBuffer = + (passIndex & 1u) == 0u + ? m_orderScratchBuffer + : m_orderBuffer; + + activeRadixSet->WriteConstant(0, &radixSortConstants, sizeof(radixSortConstants)); + + m_commandList.SetPipelineState(m_radixSortUpsweepPipelineState); + m_commandList.SetComputeDescriptorSets(0, 1, &activeRadixSet, m_radixSortPipelineLayout); + m_commandList.Dispatch(threadBlocks, 1, 1); + m_commandList.UAVBarrier(m_passHistogramBuffer->GetResource()); + + if (pass3PassHistogramDebugEnabled && passIndex == 3u && !pass3PassHistogramCaptured) { + transitionBuffer(m_passHistogramBuffer, ResourceStates::CopySrc); + m_commandList.GetCommandList()->CopyBufferRegion( + pass3PassHistogramReadbackBuffer.GetResource(), + 0, + m_passHistogramBuffer->GetResource(), + 0, + pass3PassHistogramReadbackBytes); + transitionBuffer(m_passHistogramBuffer, ResourceStates::UnorderedAccess); + pass3PassHistogramCaptured = true; + } + + m_commandList.SetPipelineState(m_radixSortGlobalHistogramPipelineState); + m_commandList.SetComputeDescriptorSets(0, 1, &activeRadixSet, m_radixSortPipelineLayout); + m_commandList.Dispatch(1, 1, 1); + m_commandList.UAVBarrier(m_globalHistogramBuffer->GetResource()); + + if (pass3PreScanGlobalHistogramDebugEnabled && passIndex == 3u && !pass3PreScanGlobalHistogramCaptured) { + transitionBuffer(m_globalHistogramBuffer, ResourceStates::CopySrc); + m_commandList.GetCommandList()->CopyBufferRegion( + pass3PreScanGlobalHistogramReadbackBuffer.GetResource(), + 0, + m_globalHistogramBuffer->GetResource(), + 0, + pass3PreScanGlobalHistogramReadbackBytes); + transitionBuffer(m_globalHistogramBuffer, ResourceStates::UnorderedAccess); + pass3PreScanGlobalHistogramCaptured = true; + } + + m_commandList.SetPipelineState(m_radixSortScanPipelineState); + m_commandList.SetComputeDescriptorSets(0, 1, &activeRadixSet, m_radixSortPipelineLayout); + m_commandList.Dispatch(kDeviceRadixSortRadix, 1, 1); + m_commandList.UAVBarrier(m_passHistogramBuffer->GetResource()); + + m_commandList.SetPipelineState(m_radixSortDownsweepPipelineState); + m_commandList.SetComputeDescriptorSets(0, 1, &activeRadixSet, m_radixSortPipelineLayout); + m_commandList.Dispatch(threadBlocks, 1, 1); + m_commandList.UAVBarrier(destinationKeyBuffer->GetResource()); + m_commandList.UAVBarrier(destinationOrderBuffer->GetResource()); + recordSortDebugStage( + passIndex + 1u, + destinationKeyBuffer, + destinationOrderBuffer, + ResourceStates::UnorderedAccess, + ResourceStates::UnorderedAccess); + } + } + + transitionBuffer(m_preparedViewBuffer, ResourceStates::NonPixelShaderResource); + transitionBuffer(m_orderBuffer, ResourceStates::NonPixelShaderResource); m_commandList.SetPipelineState(m_debugPipelineState); RHIDescriptorSet* debugSets[] = { m_debugDescriptorSet }; @@ -1174,9 +1919,97 @@ void App::RenderFrame(bool captureScreenshot) { AppendTrace("RenderFrame: WaitForIdle before screenshot"); m_commandQueue.WaitForIdle(); - AppendTrace("RenderFrame: Capture sort key snapshot"); - if (!CaptureSortKeySnapshot()) { - AppendTrace(std::string("RenderFrame: Capture sort key snapshot failed: ") + NarrowAscii(m_lastErrorMessage)); + if (sortDebugEnabled) { + const std::filesystem::path debugPath = ResolveNearExecutable(L"phase3_sort_debug.txt"); + std::ofstream debugOutput(debugPath, std::ios::binary | std::ios::trunc); + const uint32_t* stageKeys = static_cast(sortDebugKeyReadbackBuffer.Map()); + const uint32_t* stageOrder = static_cast(sortDebugOrderReadbackBuffer.Map()); + if (debugOutput.is_open() && stageKeys != nullptr && stageOrder != nullptr) { + debugOutput << "sample_count=" << sortDebugSampleCount << '\n'; + for (uint32_t stageIndex = 0; stageIndex < kSortDebugStageCount; ++stageIndex) { + debugOutput << "stage=" << sortDebugStageNames[stageIndex] << '\n'; + const uint32_t* stageKeyBase = stageKeys + stageIndex * sortDebugSampleCount; + const uint32_t* stageOrderBase = stageOrder + stageIndex * sortDebugSampleCount; + for (uint32_t sampleIndex = 0; sampleIndex < sortDebugSampleCount; ++sampleIndex) { + debugOutput << "key[" << sampleIndex << "]=" << stageKeyBase[sampleIndex] << '\n'; + debugOutput << "order[" << sampleIndex << "]=" << stageOrderBase[sampleIndex] << '\n'; + } + } + } + if (stageOrder != nullptr) { + sortDebugOrderReadbackBuffer.Unmap(); + } + if (stageKeys != nullptr) { + sortDebugKeyReadbackBuffer.Unmap(); + } + } + + if (pass3PassHistogramDebugEnabled && pass3PassHistogramCaptured) { + const uint32_t* passHistogram = + static_cast(pass3PassHistogramReadbackBuffer.Map()); + if (passHistogram != nullptr) { + const std::filesystem::path passHistogramDebugPath = + ResolveNearExecutable(L"phase3_passhist_debug.txt"); + if (!passHistogramDebugPath.parent_path().empty()) { + std::filesystem::create_directories(passHistogramDebugPath.parent_path()); + } + + std::ofstream passHistogramOutput( + passHistogramDebugPath, + std::ios::binary | std::ios::trunc); + if (passHistogramOutput.is_open()) { + uint64_t totalCount = 0u; + passHistogramOutput << "thread_blocks=" << threadBlocks << '\n'; + for (uint32_t bin = 0; bin < kDeviceRadixSortRadix; ++bin) { + uint64_t binTotal = 0u; + for (uint32_t blockIndex = 0; blockIndex < threadBlocks; ++blockIndex) { + binTotal += passHistogram[bin * threadBlocks + blockIndex]; + } + totalCount += binTotal; + passHistogramOutput << "bin[" << bin << "]=" << binTotal << '\n'; + } + passHistogramOutput << "total_count=" << totalCount << '\n'; + } + pass3PassHistogramReadbackBuffer.Unmap(); + } + } + + if (pass3PreScanGlobalHistogramDebugEnabled && pass3PreScanGlobalHistogramCaptured) { + const uint32_t* preScanGlobalHistogram = + static_cast(pass3PreScanGlobalHistogramReadbackBuffer.Map()); + if (preScanGlobalHistogram != nullptr) { + const std::filesystem::path preScanGlobalHistogramPath = + ResolveNearExecutable(L"phase3_pre_scan_globalhist.txt"); + if (!preScanGlobalHistogramPath.parent_path().empty()) { + std::filesystem::create_directories(preScanGlobalHistogramPath.parent_path()); + } + + std::ofstream preScanGlobalHistogramOutput( + preScanGlobalHistogramPath, + std::ios::binary | std::ios::trunc); + if (preScanGlobalHistogramOutput.is_open()) { + constexpr uint32_t kPass3GlobalHistogramBase = kDeviceRadixSortRadix * 3u; + for (uint32_t bin = 0; bin < kDeviceRadixSortRadix; ++bin) { + preScanGlobalHistogramOutput + << "bin[" << bin << "]=" + << preScanGlobalHistogram[kPass3GlobalHistogramBase + bin] + << '\n'; + } + } + pass3PreScanGlobalHistogramReadbackBuffer.Unmap(); + } + } + + AppendTrace("RenderFrame: Capture sort snapshot"); + if (!CaptureSortSnapshot()) { + AppendTrace(std::string("RenderFrame: Capture sort snapshot failed: ") + NarrowAscii(m_lastErrorMessage)); + } + + if (!kUseCpuSortBaseline) { + AppendTrace("RenderFrame: Capture pass3 histogram debug"); + if (!CapturePass3HistogramDebug()) { + AppendTrace(std::string("RenderFrame: Capture pass3 histogram debug failed: ") + NarrowAscii(m_lastErrorMessage)); + } } if (!m_screenshotPath.empty()) { @@ -1209,6 +2042,17 @@ void App::RenderFrame(bool captureScreenshot) { m_swapChain.Present(1, 0); } + if (sortDebugEnabled) { + sortDebugOrderReadbackBuffer.Shutdown(); + sortDebugKeyReadbackBuffer.Shutdown(); + } + if (pass3PassHistogramDebugEnabled) { + pass3PassHistogramReadbackBuffer.Shutdown(); + } + if (pass3PreScanGlobalHistogramDebugEnabled) { + pass3PreScanGlobalHistogramReadbackBuffer.Shutdown(); + } + m_hasRenderedAtLeastOneFrame = true; AppendTrace("RenderFrame: end"); }