Files
XCEngine/MVS/3DGS-D3D12/shaders/DeviceRadixSort.hlsl

478 lines
15 KiB
HLSL

/******************************************************************************
* DeviceRadixSort
* Device Level 8-bit LSD Radix Sort using reduce then scan
*
* SPDX-License-Identifier: MIT
* Copyright Thomas Smith 5/17/2024
* https://github.com/b0nes164/GPUSorting
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
******************************************************************************/
#include "SortCommon.hlsl"
#define US_DIM 128U //The number of threads in a Upsweep threadblock
#define SCAN_DIM 128U //The number of threads in a Scan threadblock
RWStructuredBuffer<uint> b_globalHist : register(u5); //buffer holding device level offsets for each binning pass
RWStructuredBuffer<uint> b_passHist : register(u4); //buffer used to store reduced sums of partition tiles
groupshared uint g_us[RADIX * 2]; //Shared memory for upsweep
groupshared uint g_scan[SCAN_DIM]; //Shared memory for the scan
//*****************************************************************************
//INIT KERNEL
//*****************************************************************************
//Clear the global histogram, as we will be adding to it atomically
[numthreads(1024, 1, 1)]
void InitDeviceRadixSort(int3 id : SV_DispatchThreadID)
{
b_globalHist[id.x] = 0;
}
//*****************************************************************************
//UPSWEEP KERNEL
//*****************************************************************************
//histogram, 64 threads to a histogram
inline void HistogramDigitCounts(uint gtid, uint gid)
{
const uint histOffset = gtid / 64 * RADIX;
const uint partitionEnd = gid == e_threadBlocks - 1 ?
e_numKeys : (gid + 1) * PART_SIZE;
for (uint i = gtid + gid * PART_SIZE; i < partitionEnd; i += US_DIM)
{
#if defined(KEY_UINT)
InterlockedAdd(g_us[ExtractDigit(b_sort[i]) + histOffset], 1);
#elif defined(KEY_INT)
InterlockedAdd(g_us[ExtractDigit(IntToUint(b_sort[i])) + histOffset], 1);
#elif defined(KEY_FLOAT)
InterlockedAdd(g_us[ExtractDigit(FloatToUint(b_sort[i])) + histOffset], 1);
#endif
}
}
//reduce and pass to tile histogram
inline void ReduceWriteDigitCounts(uint gtid, uint gid)
{
for (uint i = gtid; i < RADIX; i += US_DIM)
{
g_us[i] += g_us[i + RADIX];
b_passHist[i * e_threadBlocks + gid] = g_us[i];
}
}
//Build the per-pass 256-bin exclusive prefix from the reduced pass histogram.
inline void BuildGlobalHistogramExclusive(uint gtid)
{
uint digitIndices[2];
uint digitTotals[2];
uint digitCount = 0;
for (uint i = gtid; i < RADIX; i += US_DIM)
{
uint total = 0u;
const uint baseOffset = i * e_threadBlocks;
for (uint blockIndex = 0; blockIndex < e_threadBlocks; ++blockIndex)
{
total += b_passHist[baseOffset + blockIndex];
}
g_us[i] = total;
digitIndices[digitCount] = i;
digitTotals[digitCount] = total;
++digitCount;
}
GroupMemoryBarrierWithGroupSync();
for (uint offset = 1; offset < RADIX; offset <<= 1)
{
for (uint i = gtid; i < RADIX; i += US_DIM)
{
g_us[i + RADIX] = g_us[i] + (i >= offset ? g_us[i - offset] : 0u);
}
GroupMemoryBarrierWithGroupSync();
for (uint i = gtid; i < RADIX; i += US_DIM)
{
g_us[i] = g_us[i + RADIX];
}
GroupMemoryBarrierWithGroupSync();
}
const uint globalHistOffset = GlobalHistOffset();
for (uint localIndex = 0; localIndex < digitCount; ++localIndex)
{
const uint digitIndex = digitIndices[localIndex];
b_globalHist[digitIndex + globalHistOffset] = g_us[digitIndex] - digitTotals[localIndex];
}
}
[numthreads(US_DIM, 1, 1)]
void Upsweep(uint3 gtid : SV_GroupThreadID, uint3 gid : SV_GroupID)
{
//get the wave size
const uint waveSize = getWaveSize();
//clear shared memory
const uint histsEnd = RADIX * 2;
for (uint i = gtid.x; i < histsEnd; i += US_DIM)
g_us[i] = 0;
GroupMemoryBarrierWithGroupSync();
HistogramDigitCounts(gtid.x, gid.x);
GroupMemoryBarrierWithGroupSync();
ReduceWriteDigitCounts(gtid.x, gid.x);
}
[numthreads(US_DIM, 1, 1)]
void BuildGlobalHistogram(uint3 gtid : SV_GroupThreadID)
{
const uint histsEnd = RADIX * 2;
for (uint i = gtid.x; i < histsEnd; i += US_DIM)
g_us[i] = 0;
GroupMemoryBarrierWithGroupSync();
BuildGlobalHistogramExclusive(gtid.x);
}
//*****************************************************************************
//SCAN KERNEL
//*****************************************************************************
inline void ExclusiveThreadBlockScanFullWGE16(
uint gtid,
uint laneMask,
uint circularLaneShift,
uint partEnd,
uint deviceOffset,
uint waveSize,
inout uint reduction)
{
for (uint i = gtid; i < partEnd; i += SCAN_DIM)
{
g_scan[gtid] = b_passHist[i + deviceOffset];
g_scan[gtid] += WavePrefixSum(g_scan[gtid]);
GroupMemoryBarrierWithGroupSync();
if (gtid < SCAN_DIM / waveSize)
{
g_scan[(gtid + 1) * waveSize - 1] +=
WavePrefixSum(g_scan[(gtid + 1) * waveSize - 1]);
}
GroupMemoryBarrierWithGroupSync();
uint t = (WaveGetLaneIndex() != laneMask ? g_scan[gtid] : 0) + reduction;
if (gtid >= waveSize)
t += WaveReadLaneAt(g_scan[gtid - 1], 0);
b_passHist[circularLaneShift + (i & ~laneMask) + deviceOffset] = t;
reduction += g_scan[SCAN_DIM - 1];
GroupMemoryBarrierWithGroupSync();
}
}
inline void ExclusiveThreadBlockScanPartialWGE16(
uint gtid,
uint laneMask,
uint circularLaneShift,
uint partEnd,
uint deviceOffset,
uint waveSize,
uint reduction)
{
uint i = gtid + partEnd;
if (i < e_threadBlocks)
g_scan[gtid] = b_passHist[deviceOffset + i];
g_scan[gtid] += WavePrefixSum(g_scan[gtid]);
GroupMemoryBarrierWithGroupSync();
if (gtid < SCAN_DIM / waveSize)
{
g_scan[(gtid + 1) * waveSize - 1] +=
WavePrefixSum(g_scan[(gtid + 1) * waveSize - 1]);
}
GroupMemoryBarrierWithGroupSync();
const uint index = circularLaneShift + (i & ~laneMask);
if (index < e_threadBlocks)
{
uint t = (WaveGetLaneIndex() != laneMask ? g_scan[gtid] : 0) + reduction;
if (gtid >= waveSize)
t += g_scan[(gtid & ~laneMask) - 1];
b_passHist[index + deviceOffset] = t;
}
}
inline void ExclusiveThreadBlockScanWGE16(uint gtid, uint gid, uint waveSize)
{
uint reduction = 0;
const uint laneMask = waveSize - 1;
const uint circularLaneShift = WaveGetLaneIndex() + 1 & laneMask;
const uint partionsEnd = e_threadBlocks / SCAN_DIM * SCAN_DIM;
const uint deviceOffset = gid * e_threadBlocks;
ExclusiveThreadBlockScanFullWGE16(
gtid,
laneMask,
circularLaneShift,
partionsEnd,
deviceOffset,
waveSize,
reduction);
ExclusiveThreadBlockScanPartialWGE16(
gtid,
laneMask,
circularLaneShift,
partionsEnd,
deviceOffset,
waveSize,
reduction);
}
inline void ExclusiveThreadBlockScanFullWLT16(
uint gtid,
uint partitions,
uint deviceOffset,
uint laneLog,
uint circularLaneShift,
uint waveSize,
inout uint reduction)
{
for (uint k = 0; k < partitions; ++k)
{
g_scan[gtid] = b_passHist[gtid + k * SCAN_DIM + deviceOffset];
g_scan[gtid] += WavePrefixSum(g_scan[gtid]);
GroupMemoryBarrierWithGroupSync();
if (gtid < waveSize)
{
b_passHist[circularLaneShift + k * SCAN_DIM + deviceOffset] =
(circularLaneShift ? g_scan[gtid] : 0) + reduction;
}
uint offset = laneLog;
uint j = waveSize;
for (; j < (SCAN_DIM >> 1); j <<= laneLog)
{
if (gtid < (SCAN_DIM >> offset))
{
g_scan[((gtid + 1) << offset) - 1] +=
WavePrefixSum(g_scan[((gtid + 1) << offset) - 1]);
}
GroupMemoryBarrierWithGroupSync();
if ((gtid & ((j << laneLog) - 1)) >= j)
{
if (gtid < (j << laneLog))
{
b_passHist[gtid + k * SCAN_DIM + deviceOffset] =
WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0) +
((gtid & (j - 1)) ? g_scan[gtid - 1] : 0) + reduction;
}
else
{
if ((gtid + 1) & (j - 1))
{
g_scan[gtid] +=
WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0);
}
}
}
offset += laneLog;
}
GroupMemoryBarrierWithGroupSync();
//If SCAN_DIM is not a power of lanecount
for (uint i = gtid + j; i < SCAN_DIM; i += SCAN_DIM)
{
b_passHist[i + k * SCAN_DIM + deviceOffset] =
WaveReadLaneAt(g_scan[((i >> offset) << offset) - 1], 0) +
((i & (j - 1)) ? g_scan[i - 1] : 0) + reduction;
}
reduction += WaveReadLaneAt(g_scan[SCAN_DIM - 1], 0) +
WaveReadLaneAt(g_scan[(((SCAN_DIM - 1) >> offset) << offset) - 1], 0);
GroupMemoryBarrierWithGroupSync();
}
}
inline void ExclusiveThreadBlockScanParitalWLT16(
uint gtid,
uint partitions,
uint deviceOffset,
uint laneLog,
uint circularLaneShift,
uint waveSize,
uint reduction)
{
const uint finalPartSize = e_threadBlocks - partitions * SCAN_DIM;
if (gtid < finalPartSize)
{
g_scan[gtid] = b_passHist[gtid + partitions * SCAN_DIM + deviceOffset];
g_scan[gtid] += WavePrefixSum(g_scan[gtid]);
}
GroupMemoryBarrierWithGroupSync();
if (gtid < waveSize && circularLaneShift < finalPartSize)
{
b_passHist[circularLaneShift + partitions * SCAN_DIM + deviceOffset] =
(circularLaneShift ? g_scan[gtid] : 0) + reduction;
}
uint offset = laneLog;
for (uint j = waveSize; j < finalPartSize; j <<= laneLog)
{
if (gtid < (finalPartSize >> offset))
{
g_scan[((gtid + 1) << offset) - 1] +=
WavePrefixSum(g_scan[((gtid + 1) << offset) - 1]);
}
GroupMemoryBarrierWithGroupSync();
if ((gtid & ((j << laneLog) - 1)) >= j && gtid < finalPartSize)
{
if (gtid < (j << laneLog))
{
b_passHist[gtid + partitions * SCAN_DIM + deviceOffset] =
WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0) +
((gtid & (j - 1)) ? g_scan[gtid - 1] : 0) + reduction;
}
else
{
if ((gtid + 1) & (j - 1))
{
g_scan[gtid] +=
WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0);
}
}
}
offset += laneLog;
}
}
inline void ExclusiveThreadBlockScanWLT16(uint gtid, uint gid, uint waveSize)
{
uint reduction = 0;
const uint partitions = e_threadBlocks / SCAN_DIM;
const uint deviceOffset = gid * e_threadBlocks;
const uint laneLog = countbits(waveSize - 1);
const uint circularLaneShift = WaveGetLaneIndex() + 1 & waveSize - 1;
ExclusiveThreadBlockScanFullWLT16(
gtid,
partitions,
deviceOffset,
laneLog,
circularLaneShift,
waveSize,
reduction);
ExclusiveThreadBlockScanParitalWLT16(
gtid,
partitions,
deviceOffset,
laneLog,
circularLaneShift,
waveSize,
reduction);
}
//Scan does not need flattening of gids
[numthreads(SCAN_DIM, 1, 1)]
void Scan(uint3 gtid : SV_GroupThreadID, uint3 gid : SV_GroupID)
{
if (gtid.x != 0u)
{
return;
}
const uint deviceOffset = gid.x * e_threadBlocks;
uint runningOffset = 0u;
for (uint blockIndex = 0u; blockIndex < e_threadBlocks; ++blockIndex)
{
const uint index = deviceOffset + blockIndex;
const uint count = b_passHist[index];
b_passHist[index] = runningOffset;
runningOffset += count;
}
}
//*****************************************************************************
//DOWNSWEEP KERNEL
//*****************************************************************************
inline void LoadThreadBlockReductions(uint gtid, uint gid, uint exclusiveHistReduction)
{
if (gtid < RADIX)
{
g_d[gtid + PART_SIZE] = b_globalHist[gtid + GlobalHistOffset()] +
b_passHist[gtid * e_threadBlocks + gid] - exclusiveHistReduction;
}
}
[numthreads(D_DIM, 1, 1)]
void Downsweep(uint3 gtid : SV_GroupThreadID, uint3 gid : SV_GroupID)
{
if (gtid.x != 0u)
{
return;
}
const uint partitionStart = gid.x * PART_SIZE;
const uint partitionEnd = min(partitionStart + PART_SIZE, e_numKeys);
uint digitOffsets[RADIX];
const uint globalHistOffset = GlobalHistOffset();
for (uint digit = 0u; digit < RADIX; ++digit)
{
digitOffsets[digit] =
b_globalHist[globalHistOffset + digit] +
b_passHist[digit * e_threadBlocks + gid.x];
}
for (uint index = partitionStart; index < partitionEnd; ++index)
{
uint key;
#if defined(KEY_UINT)
key = b_sort[index];
#elif defined(KEY_INT)
key = IntToUint(b_sort[index]);
#elif defined(KEY_FLOAT)
key = FloatToUint(b_sort[index]);
#endif
const uint digit = ExtractDigit(key);
const uint destinationIndex = digitOffsets[digit]++;
#if defined(KEY_UINT)
b_alt[destinationIndex] = key;
#elif defined(KEY_INT)
b_alt[destinationIndex] = UintToInt(key);
#elif defined(KEY_FLOAT)
b_alt[destinationIndex] = UintToFloat(key);
#endif
#if defined(SORT_PAIRS)
#if defined(PAYLOAD_UINT)
b_altPayload[destinationIndex] = b_sortPayload[index];
#elif defined(PAYLOAD_INT)
b_altPayload[destinationIndex] = b_sortPayload[index];
#elif defined(PAYLOAD_FLOAT)
b_altPayload[destinationIndex] = b_sortPayload[index];
#endif
#endif
}
}