Establish 3DGS D3D12 sorted baseline
This commit is contained in:
@@ -6,6 +6,9 @@ set(CMAKE_CXX_STANDARD 20)
|
|||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
|
||||||
find_program(XC_DXC_EXECUTABLE NAMES dxc)
|
find_program(XC_DXC_EXECUTABLE NAMES dxc)
|
||||||
|
if(NOT XC_DXC_EXECUTABLE)
|
||||||
|
message(FATAL_ERROR "dxc is required to build the 3DGS D3D12 MVS sort shaders.")
|
||||||
|
endif()
|
||||||
|
|
||||||
get_filename_component(XCENGINE_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/../.." ABSOLUTE)
|
get_filename_component(XCENGINE_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/../.." ABSOLUTE)
|
||||||
set(XCENGINE_BUILD_DIR "${XCENGINE_ROOT}/build")
|
set(XCENGINE_BUILD_DIR "${XCENGINE_ROOT}/build")
|
||||||
@@ -35,6 +38,8 @@ add_executable(xc_3dgs_d3d12_mvs
|
|||||||
shaders/PreparedSplatView.hlsli
|
shaders/PreparedSplatView.hlsli
|
||||||
shaders/PrepareGaussiansCS.hlsl
|
shaders/PrepareGaussiansCS.hlsl
|
||||||
shaders/BuildSortKeysCS.hlsl
|
shaders/BuildSortKeysCS.hlsl
|
||||||
|
shaders/SortCommon.hlsl
|
||||||
|
shaders/DeviceRadixSort.hlsl
|
||||||
shaders/DebugPointsVS.hlsl
|
shaders/DebugPointsVS.hlsl
|
||||||
shaders/DebugPointsPS.hlsl
|
shaders/DebugPointsPS.hlsl
|
||||||
)
|
)
|
||||||
@@ -43,6 +48,8 @@ set_source_files_properties(
|
|||||||
shaders/PreparedSplatView.hlsli
|
shaders/PreparedSplatView.hlsli
|
||||||
shaders/PrepareGaussiansCS.hlsl
|
shaders/PrepareGaussiansCS.hlsl
|
||||||
shaders/BuildSortKeysCS.hlsl
|
shaders/BuildSortKeysCS.hlsl
|
||||||
|
shaders/SortCommon.hlsl
|
||||||
|
shaders/DeviceRadixSort.hlsl
|
||||||
shaders/DebugPointsVS.hlsl
|
shaders/DebugPointsVS.hlsl
|
||||||
shaders/DebugPointsPS.hlsl
|
shaders/DebugPointsPS.hlsl
|
||||||
PROPERTIES
|
PROPERTIES
|
||||||
@@ -93,12 +100,55 @@ add_custom_command(TARGET xc_3dgs_d3d12_mvs POST_BUILD
|
|||||||
"$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/shaders"
|
"$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/shaders"
|
||||||
)
|
)
|
||||||
|
|
||||||
if(XC_DXC_EXECUTABLE)
|
|
||||||
add_custom_command(TARGET xc_3dgs_d3d12_mvs POST_BUILD
|
add_custom_command(TARGET xc_3dgs_d3d12_mvs POST_BUILD
|
||||||
COMMAND "${XC_DXC_EXECUTABLE}"
|
COMMAND "${XC_DXC_EXECUTABLE}"
|
||||||
-T cs_6_6
|
-T cs_6_6
|
||||||
-E MainCS
|
-E MainCS
|
||||||
-Fo "$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/shaders/BuildSortKeysCS.dxil"
|
-Fo "$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/shaders/BuildSortKeysCS.dxil"
|
||||||
"${CMAKE_CURRENT_SOURCE_DIR}/shaders/BuildSortKeysCS.hlsl"
|
"${CMAKE_CURRENT_SOURCE_DIR}/shaders/BuildSortKeysCS.hlsl"
|
||||||
|
COMMAND "${XC_DXC_EXECUTABLE}"
|
||||||
|
-T cs_6_6
|
||||||
|
-E InitDeviceRadixSort
|
||||||
|
-D KEY_UINT=1
|
||||||
|
-D PAYLOAD_UINT=1
|
||||||
|
-D SORT_PAIRS=1
|
||||||
|
-D SHOULD_ASCEND=1
|
||||||
|
-Fo "$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/shaders/RadixInit.dxil"
|
||||||
|
"${CMAKE_CURRENT_SOURCE_DIR}/shaders/DeviceRadixSort.hlsl"
|
||||||
|
COMMAND "${XC_DXC_EXECUTABLE}"
|
||||||
|
-T cs_6_6
|
||||||
|
-E Upsweep
|
||||||
|
-D KEY_UINT=1
|
||||||
|
-D PAYLOAD_UINT=1
|
||||||
|
-D SORT_PAIRS=1
|
||||||
|
-D SHOULD_ASCEND=1
|
||||||
|
-Fo "$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/shaders/RadixUpsweep.dxil"
|
||||||
|
"${CMAKE_CURRENT_SOURCE_DIR}/shaders/DeviceRadixSort.hlsl"
|
||||||
|
COMMAND "${XC_DXC_EXECUTABLE}"
|
||||||
|
-T cs_6_6
|
||||||
|
-E BuildGlobalHistogram
|
||||||
|
-D KEY_UINT=1
|
||||||
|
-D PAYLOAD_UINT=1
|
||||||
|
-D SORT_PAIRS=1
|
||||||
|
-D SHOULD_ASCEND=1
|
||||||
|
-Fo "$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/shaders/RadixGlobalHistogram.dxil"
|
||||||
|
"${CMAKE_CURRENT_SOURCE_DIR}/shaders/DeviceRadixSort.hlsl"
|
||||||
|
COMMAND "${XC_DXC_EXECUTABLE}"
|
||||||
|
-T cs_6_6
|
||||||
|
-E Scan
|
||||||
|
-D KEY_UINT=1
|
||||||
|
-D PAYLOAD_UINT=1
|
||||||
|
-D SORT_PAIRS=1
|
||||||
|
-D SHOULD_ASCEND=1
|
||||||
|
-Fo "$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/shaders/RadixScan.dxil"
|
||||||
|
"${CMAKE_CURRENT_SOURCE_DIR}/shaders/DeviceRadixSort.hlsl"
|
||||||
|
COMMAND "${XC_DXC_EXECUTABLE}"
|
||||||
|
-T cs_6_6
|
||||||
|
-E Downsweep
|
||||||
|
-D KEY_UINT=1
|
||||||
|
-D PAYLOAD_UINT=1
|
||||||
|
-D SORT_PAIRS=1
|
||||||
|
-D SHOULD_ASCEND=1
|
||||||
|
-Fo "$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/shaders/RadixDownsweep.dxil"
|
||||||
|
"${CMAKE_CURRENT_SOURCE_DIR}/shaders/DeviceRadixSort.hlsl"
|
||||||
)
|
)
|
||||||
endif()
|
|
||||||
|
|||||||
@@ -72,7 +72,8 @@ private:
|
|||||||
void ShutdownSortResources();
|
void ShutdownSortResources();
|
||||||
void ShutdownDebugDrawResources();
|
void ShutdownDebugDrawResources();
|
||||||
void Shutdown();
|
void Shutdown();
|
||||||
bool CaptureSortKeySnapshot();
|
bool CaptureSortSnapshot();
|
||||||
|
bool CapturePass3HistogramDebug();
|
||||||
void RenderFrame(bool captureScreenshot);
|
void RenderFrame(bool captureScreenshot);
|
||||||
|
|
||||||
HWND m_hwnd = nullptr;
|
HWND m_hwnd = nullptr;
|
||||||
@@ -107,14 +108,31 @@ private:
|
|||||||
XCEngine::RHI::RHIDescriptorPool* m_prepareDescriptorPool = nullptr;
|
XCEngine::RHI::RHIDescriptorPool* m_prepareDescriptorPool = nullptr;
|
||||||
XCEngine::RHI::RHIDescriptorSet* m_prepareDescriptorSet = nullptr;
|
XCEngine::RHI::RHIDescriptorSet* m_prepareDescriptorSet = nullptr;
|
||||||
XCEngine::RHI::D3D12Buffer* m_sortKeyBuffer = nullptr;
|
XCEngine::RHI::D3D12Buffer* m_sortKeyBuffer = nullptr;
|
||||||
|
XCEngine::RHI::D3D12Buffer* m_sortKeyScratchBuffer = nullptr;
|
||||||
XCEngine::RHI::D3D12Buffer* m_orderBuffer = nullptr;
|
XCEngine::RHI::D3D12Buffer* m_orderBuffer = nullptr;
|
||||||
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_sortKeySrv;
|
XCEngine::RHI::D3D12Buffer* m_orderScratchBuffer = nullptr;
|
||||||
|
XCEngine::RHI::D3D12Buffer* m_passHistogramBuffer = nullptr;
|
||||||
|
XCEngine::RHI::D3D12Buffer* m_globalHistogramBuffer = nullptr;
|
||||||
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_sortKeyUav;
|
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_sortKeyUav;
|
||||||
|
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_sortKeyScratchUav;
|
||||||
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_orderBufferSrv;
|
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_orderBufferSrv;
|
||||||
XCEngine::RHI::RHIPipelineLayout* m_sortPipelineLayout = nullptr;
|
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_orderBufferUav;
|
||||||
XCEngine::RHI::RHIPipelineState* m_sortPipelineState = nullptr;
|
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_orderScratchUav;
|
||||||
XCEngine::RHI::RHIDescriptorPool* m_sortDescriptorPool = nullptr;
|
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_passHistogramUav;
|
||||||
XCEngine::RHI::RHIDescriptorSet* m_sortDescriptorSet = nullptr;
|
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_globalHistogramUav;
|
||||||
|
XCEngine::RHI::RHIPipelineLayout* m_buildSortKeyPipelineLayout = nullptr;
|
||||||
|
XCEngine::RHI::RHIPipelineState* m_buildSortKeyPipelineState = nullptr;
|
||||||
|
XCEngine::RHI::RHIDescriptorPool* m_buildSortKeyDescriptorPool = nullptr;
|
||||||
|
XCEngine::RHI::RHIDescriptorSet* m_buildSortKeyDescriptorSet = nullptr;
|
||||||
|
XCEngine::RHI::RHIPipelineLayout* m_radixSortPipelineLayout = nullptr;
|
||||||
|
XCEngine::RHI::RHIPipelineState* m_radixSortInitPipelineState = nullptr;
|
||||||
|
XCEngine::RHI::RHIPipelineState* m_radixSortUpsweepPipelineState = nullptr;
|
||||||
|
XCEngine::RHI::RHIPipelineState* m_radixSortGlobalHistogramPipelineState = nullptr;
|
||||||
|
XCEngine::RHI::RHIPipelineState* m_radixSortScanPipelineState = nullptr;
|
||||||
|
XCEngine::RHI::RHIPipelineState* m_radixSortDownsweepPipelineState = nullptr;
|
||||||
|
XCEngine::RHI::RHIDescriptorPool* m_radixSortDescriptorPool = nullptr;
|
||||||
|
XCEngine::RHI::RHIDescriptorSet* m_radixSortDescriptorSetPrimaryToScratch = nullptr;
|
||||||
|
XCEngine::RHI::RHIDescriptorSet* m_radixSortDescriptorSetScratchToPrimary = nullptr;
|
||||||
XCEngine::RHI::RHIPipelineLayout* m_debugPipelineLayout = nullptr;
|
XCEngine::RHI::RHIPipelineLayout* m_debugPipelineLayout = nullptr;
|
||||||
XCEngine::RHI::RHIPipelineState* m_debugPipelineState = nullptr;
|
XCEngine::RHI::RHIPipelineState* m_debugPipelineState = nullptr;
|
||||||
XCEngine::RHI::RHIDescriptorPool* m_debugDescriptorPool = nullptr;
|
XCEngine::RHI::RHIDescriptorPool* m_debugDescriptorPool = nullptr;
|
||||||
|
|||||||
477
MVS/3DGS-D3D12/shaders/DeviceRadixSort.hlsl
Normal file
477
MVS/3DGS-D3D12/shaders/DeviceRadixSort.hlsl
Normal file
@@ -0,0 +1,477 @@
|
|||||||
|
/******************************************************************************
|
||||||
|
* DeviceRadixSort
|
||||||
|
* Device Level 8-bit LSD Radix Sort using reduce then scan
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: MIT
|
||||||
|
* Copyright Thomas Smith 5/17/2024
|
||||||
|
* https://github.com/b0nes164/GPUSorting
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
* of this software and associated documentation files (the "Software"), to deal
|
||||||
|
* in the Software without restriction, including without limitation the rights
|
||||||
|
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
* copies of the Software, and to permit persons to whom the Software is
|
||||||
|
* furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in all
|
||||||
|
* copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
* SOFTWARE.
|
||||||
|
******************************************************************************/
|
||||||
|
#include "SortCommon.hlsl"
|
||||||
|
|
||||||
|
#define US_DIM 128U //The number of threads in a Upsweep threadblock
|
||||||
|
#define SCAN_DIM 128U //The number of threads in a Scan threadblock
|
||||||
|
|
||||||
|
RWStructuredBuffer<uint> b_globalHist : register(u5); //buffer holding device level offsets for each binning pass
|
||||||
|
RWStructuredBuffer<uint> b_passHist : register(u4); //buffer used to store reduced sums of partition tiles
|
||||||
|
|
||||||
|
groupshared uint g_us[RADIX * 2]; //Shared memory for upsweep
|
||||||
|
groupshared uint g_scan[SCAN_DIM]; //Shared memory for the scan
|
||||||
|
|
||||||
|
//*****************************************************************************
|
||||||
|
//INIT KERNEL
|
||||||
|
//*****************************************************************************
|
||||||
|
//Clear the global histogram, as we will be adding to it atomically
|
||||||
|
[numthreads(1024, 1, 1)]
|
||||||
|
void InitDeviceRadixSort(int3 id : SV_DispatchThreadID)
|
||||||
|
{
|
||||||
|
b_globalHist[id.x] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
//*****************************************************************************
|
||||||
|
//UPSWEEP KERNEL
|
||||||
|
//*****************************************************************************
|
||||||
|
//histogram, 64 threads to a histogram
|
||||||
|
inline void HistogramDigitCounts(uint gtid, uint gid)
|
||||||
|
{
|
||||||
|
const uint histOffset = gtid / 64 * RADIX;
|
||||||
|
const uint partitionEnd = gid == e_threadBlocks - 1 ?
|
||||||
|
e_numKeys : (gid + 1) * PART_SIZE;
|
||||||
|
for (uint i = gtid + gid * PART_SIZE; i < partitionEnd; i += US_DIM)
|
||||||
|
{
|
||||||
|
#if defined(KEY_UINT)
|
||||||
|
InterlockedAdd(g_us[ExtractDigit(b_sort[i]) + histOffset], 1);
|
||||||
|
#elif defined(KEY_INT)
|
||||||
|
InterlockedAdd(g_us[ExtractDigit(IntToUint(b_sort[i])) + histOffset], 1);
|
||||||
|
#elif defined(KEY_FLOAT)
|
||||||
|
InterlockedAdd(g_us[ExtractDigit(FloatToUint(b_sort[i])) + histOffset], 1);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//reduce and pass to tile histogram
|
||||||
|
inline void ReduceWriteDigitCounts(uint gtid, uint gid)
|
||||||
|
{
|
||||||
|
for (uint i = gtid; i < RADIX; i += US_DIM)
|
||||||
|
{
|
||||||
|
g_us[i] += g_us[i + RADIX];
|
||||||
|
b_passHist[i * e_threadBlocks + gid] = g_us[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//Build the per-pass 256-bin exclusive prefix from the reduced pass histogram.
|
||||||
|
inline void BuildGlobalHistogramExclusive(uint gtid)
|
||||||
|
{
|
||||||
|
uint digitIndices[2];
|
||||||
|
uint digitTotals[2];
|
||||||
|
uint digitCount = 0;
|
||||||
|
for (uint i = gtid; i < RADIX; i += US_DIM)
|
||||||
|
{
|
||||||
|
uint total = 0u;
|
||||||
|
const uint baseOffset = i * e_threadBlocks;
|
||||||
|
for (uint blockIndex = 0; blockIndex < e_threadBlocks; ++blockIndex)
|
||||||
|
{
|
||||||
|
total += b_passHist[baseOffset + blockIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
g_us[i] = total;
|
||||||
|
digitIndices[digitCount] = i;
|
||||||
|
digitTotals[digitCount] = total;
|
||||||
|
++digitCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
|
||||||
|
for (uint offset = 1; offset < RADIX; offset <<= 1)
|
||||||
|
{
|
||||||
|
for (uint i = gtid; i < RADIX; i += US_DIM)
|
||||||
|
{
|
||||||
|
g_us[i + RADIX] = g_us[i] + (i >= offset ? g_us[i - offset] : 0u);
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
|
||||||
|
for (uint i = gtid; i < RADIX; i += US_DIM)
|
||||||
|
{
|
||||||
|
g_us[i] = g_us[i + RADIX];
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint globalHistOffset = GlobalHistOffset();
|
||||||
|
for (uint localIndex = 0; localIndex < digitCount; ++localIndex)
|
||||||
|
{
|
||||||
|
const uint digitIndex = digitIndices[localIndex];
|
||||||
|
b_globalHist[digitIndex + globalHistOffset] = g_us[digitIndex] - digitTotals[localIndex];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[numthreads(US_DIM, 1, 1)]
|
||||||
|
void Upsweep(uint3 gtid : SV_GroupThreadID, uint3 gid : SV_GroupID)
|
||||||
|
{
|
||||||
|
//get the wave size
|
||||||
|
const uint waveSize = getWaveSize();
|
||||||
|
|
||||||
|
//clear shared memory
|
||||||
|
const uint histsEnd = RADIX * 2;
|
||||||
|
for (uint i = gtid.x; i < histsEnd; i += US_DIM)
|
||||||
|
g_us[i] = 0;
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
|
||||||
|
HistogramDigitCounts(gtid.x, gid.x);
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
|
||||||
|
ReduceWriteDigitCounts(gtid.x, gid.x);
|
||||||
|
}
|
||||||
|
|
||||||
|
[numthreads(US_DIM, 1, 1)]
|
||||||
|
void BuildGlobalHistogram(uint3 gtid : SV_GroupThreadID)
|
||||||
|
{
|
||||||
|
const uint histsEnd = RADIX * 2;
|
||||||
|
for (uint i = gtid.x; i < histsEnd; i += US_DIM)
|
||||||
|
g_us[i] = 0;
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
|
||||||
|
BuildGlobalHistogramExclusive(gtid.x);
|
||||||
|
}
|
||||||
|
|
||||||
|
//*****************************************************************************
|
||||||
|
//SCAN KERNEL
|
||||||
|
//*****************************************************************************
|
||||||
|
inline void ExclusiveThreadBlockScanFullWGE16(
|
||||||
|
uint gtid,
|
||||||
|
uint laneMask,
|
||||||
|
uint circularLaneShift,
|
||||||
|
uint partEnd,
|
||||||
|
uint deviceOffset,
|
||||||
|
uint waveSize,
|
||||||
|
inout uint reduction)
|
||||||
|
{
|
||||||
|
for (uint i = gtid; i < partEnd; i += SCAN_DIM)
|
||||||
|
{
|
||||||
|
g_scan[gtid] = b_passHist[i + deviceOffset];
|
||||||
|
g_scan[gtid] += WavePrefixSum(g_scan[gtid]);
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
|
||||||
|
if (gtid < SCAN_DIM / waveSize)
|
||||||
|
{
|
||||||
|
g_scan[(gtid + 1) * waveSize - 1] +=
|
||||||
|
WavePrefixSum(g_scan[(gtid + 1) * waveSize - 1]);
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
|
||||||
|
uint t = (WaveGetLaneIndex() != laneMask ? g_scan[gtid] : 0) + reduction;
|
||||||
|
if (gtid >= waveSize)
|
||||||
|
t += WaveReadLaneAt(g_scan[gtid - 1], 0);
|
||||||
|
b_passHist[circularLaneShift + (i & ~laneMask) + deviceOffset] = t;
|
||||||
|
|
||||||
|
reduction += g_scan[SCAN_DIM - 1];
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ExclusiveThreadBlockScanPartialWGE16(
|
||||||
|
uint gtid,
|
||||||
|
uint laneMask,
|
||||||
|
uint circularLaneShift,
|
||||||
|
uint partEnd,
|
||||||
|
uint deviceOffset,
|
||||||
|
uint waveSize,
|
||||||
|
uint reduction)
|
||||||
|
{
|
||||||
|
uint i = gtid + partEnd;
|
||||||
|
if (i < e_threadBlocks)
|
||||||
|
g_scan[gtid] = b_passHist[deviceOffset + i];
|
||||||
|
g_scan[gtid] += WavePrefixSum(g_scan[gtid]);
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
|
||||||
|
if (gtid < SCAN_DIM / waveSize)
|
||||||
|
{
|
||||||
|
g_scan[(gtid + 1) * waveSize - 1] +=
|
||||||
|
WavePrefixSum(g_scan[(gtid + 1) * waveSize - 1]);
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
|
||||||
|
const uint index = circularLaneShift + (i & ~laneMask);
|
||||||
|
if (index < e_threadBlocks)
|
||||||
|
{
|
||||||
|
uint t = (WaveGetLaneIndex() != laneMask ? g_scan[gtid] : 0) + reduction;
|
||||||
|
if (gtid >= waveSize)
|
||||||
|
t += g_scan[(gtid & ~laneMask) - 1];
|
||||||
|
b_passHist[index + deviceOffset] = t;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ExclusiveThreadBlockScanWGE16(uint gtid, uint gid, uint waveSize)
|
||||||
|
{
|
||||||
|
uint reduction = 0;
|
||||||
|
const uint laneMask = waveSize - 1;
|
||||||
|
const uint circularLaneShift = WaveGetLaneIndex() + 1 & laneMask;
|
||||||
|
const uint partionsEnd = e_threadBlocks / SCAN_DIM * SCAN_DIM;
|
||||||
|
const uint deviceOffset = gid * e_threadBlocks;
|
||||||
|
|
||||||
|
ExclusiveThreadBlockScanFullWGE16(
|
||||||
|
gtid,
|
||||||
|
laneMask,
|
||||||
|
circularLaneShift,
|
||||||
|
partionsEnd,
|
||||||
|
deviceOffset,
|
||||||
|
waveSize,
|
||||||
|
reduction);
|
||||||
|
|
||||||
|
ExclusiveThreadBlockScanPartialWGE16(
|
||||||
|
gtid,
|
||||||
|
laneMask,
|
||||||
|
circularLaneShift,
|
||||||
|
partionsEnd,
|
||||||
|
deviceOffset,
|
||||||
|
waveSize,
|
||||||
|
reduction);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ExclusiveThreadBlockScanFullWLT16(
|
||||||
|
uint gtid,
|
||||||
|
uint partitions,
|
||||||
|
uint deviceOffset,
|
||||||
|
uint laneLog,
|
||||||
|
uint circularLaneShift,
|
||||||
|
uint waveSize,
|
||||||
|
inout uint reduction)
|
||||||
|
{
|
||||||
|
for (uint k = 0; k < partitions; ++k)
|
||||||
|
{
|
||||||
|
g_scan[gtid] = b_passHist[gtid + k * SCAN_DIM + deviceOffset];
|
||||||
|
g_scan[gtid] += WavePrefixSum(g_scan[gtid]);
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
if (gtid < waveSize)
|
||||||
|
{
|
||||||
|
b_passHist[circularLaneShift + k * SCAN_DIM + deviceOffset] =
|
||||||
|
(circularLaneShift ? g_scan[gtid] : 0) + reduction;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint offset = laneLog;
|
||||||
|
uint j = waveSize;
|
||||||
|
for (; j < (SCAN_DIM >> 1); j <<= laneLog)
|
||||||
|
{
|
||||||
|
if (gtid < (SCAN_DIM >> offset))
|
||||||
|
{
|
||||||
|
g_scan[((gtid + 1) << offset) - 1] +=
|
||||||
|
WavePrefixSum(g_scan[((gtid + 1) << offset) - 1]);
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
|
||||||
|
if ((gtid & ((j << laneLog) - 1)) >= j)
|
||||||
|
{
|
||||||
|
if (gtid < (j << laneLog))
|
||||||
|
{
|
||||||
|
b_passHist[gtid + k * SCAN_DIM + deviceOffset] =
|
||||||
|
WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0) +
|
||||||
|
((gtid & (j - 1)) ? g_scan[gtid - 1] : 0) + reduction;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if ((gtid + 1) & (j - 1))
|
||||||
|
{
|
||||||
|
g_scan[gtid] +=
|
||||||
|
WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
offset += laneLog;
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
|
||||||
|
//If SCAN_DIM is not a power of lanecount
|
||||||
|
for (uint i = gtid + j; i < SCAN_DIM; i += SCAN_DIM)
|
||||||
|
{
|
||||||
|
b_passHist[i + k * SCAN_DIM + deviceOffset] =
|
||||||
|
WaveReadLaneAt(g_scan[((i >> offset) << offset) - 1], 0) +
|
||||||
|
((i & (j - 1)) ? g_scan[i - 1] : 0) + reduction;
|
||||||
|
}
|
||||||
|
|
||||||
|
reduction += WaveReadLaneAt(g_scan[SCAN_DIM - 1], 0) +
|
||||||
|
WaveReadLaneAt(g_scan[(((SCAN_DIM - 1) >> offset) << offset) - 1], 0);
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ExclusiveThreadBlockScanParitalWLT16(
|
||||||
|
uint gtid,
|
||||||
|
uint partitions,
|
||||||
|
uint deviceOffset,
|
||||||
|
uint laneLog,
|
||||||
|
uint circularLaneShift,
|
||||||
|
uint waveSize,
|
||||||
|
uint reduction)
|
||||||
|
{
|
||||||
|
const uint finalPartSize = e_threadBlocks - partitions * SCAN_DIM;
|
||||||
|
if (gtid < finalPartSize)
|
||||||
|
{
|
||||||
|
g_scan[gtid] = b_passHist[gtid + partitions * SCAN_DIM + deviceOffset];
|
||||||
|
g_scan[gtid] += WavePrefixSum(g_scan[gtid]);
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
if (gtid < waveSize && circularLaneShift < finalPartSize)
|
||||||
|
{
|
||||||
|
b_passHist[circularLaneShift + partitions * SCAN_DIM + deviceOffset] =
|
||||||
|
(circularLaneShift ? g_scan[gtid] : 0) + reduction;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint offset = laneLog;
|
||||||
|
for (uint j = waveSize; j < finalPartSize; j <<= laneLog)
|
||||||
|
{
|
||||||
|
if (gtid < (finalPartSize >> offset))
|
||||||
|
{
|
||||||
|
g_scan[((gtid + 1) << offset) - 1] +=
|
||||||
|
WavePrefixSum(g_scan[((gtid + 1) << offset) - 1]);
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
|
||||||
|
if ((gtid & ((j << laneLog) - 1)) >= j && gtid < finalPartSize)
|
||||||
|
{
|
||||||
|
if (gtid < (j << laneLog))
|
||||||
|
{
|
||||||
|
b_passHist[gtid + partitions * SCAN_DIM + deviceOffset] =
|
||||||
|
WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0) +
|
||||||
|
((gtid & (j - 1)) ? g_scan[gtid - 1] : 0) + reduction;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if ((gtid + 1) & (j - 1))
|
||||||
|
{
|
||||||
|
g_scan[gtid] +=
|
||||||
|
WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
offset += laneLog;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ExclusiveThreadBlockScanWLT16(uint gtid, uint gid, uint waveSize)
|
||||||
|
{
|
||||||
|
uint reduction = 0;
|
||||||
|
const uint partitions = e_threadBlocks / SCAN_DIM;
|
||||||
|
const uint deviceOffset = gid * e_threadBlocks;
|
||||||
|
const uint laneLog = countbits(waveSize - 1);
|
||||||
|
const uint circularLaneShift = WaveGetLaneIndex() + 1 & waveSize - 1;
|
||||||
|
|
||||||
|
ExclusiveThreadBlockScanFullWLT16(
|
||||||
|
gtid,
|
||||||
|
partitions,
|
||||||
|
deviceOffset,
|
||||||
|
laneLog,
|
||||||
|
circularLaneShift,
|
||||||
|
waveSize,
|
||||||
|
reduction);
|
||||||
|
|
||||||
|
ExclusiveThreadBlockScanParitalWLT16(
|
||||||
|
gtid,
|
||||||
|
partitions,
|
||||||
|
deviceOffset,
|
||||||
|
laneLog,
|
||||||
|
circularLaneShift,
|
||||||
|
waveSize,
|
||||||
|
reduction);
|
||||||
|
}
|
||||||
|
|
||||||
|
//Scan does not need flattening of gids
|
||||||
|
[numthreads(SCAN_DIM, 1, 1)]
|
||||||
|
void Scan(uint3 gtid : SV_GroupThreadID, uint3 gid : SV_GroupID)
|
||||||
|
{
|
||||||
|
if (gtid.x != 0u)
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint deviceOffset = gid.x * e_threadBlocks;
|
||||||
|
uint runningOffset = 0u;
|
||||||
|
for (uint blockIndex = 0u; blockIndex < e_threadBlocks; ++blockIndex)
|
||||||
|
{
|
||||||
|
const uint index = deviceOffset + blockIndex;
|
||||||
|
const uint count = b_passHist[index];
|
||||||
|
b_passHist[index] = runningOffset;
|
||||||
|
runningOffset += count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//*****************************************************************************
|
||||||
|
//DOWNSWEEP KERNEL
|
||||||
|
//*****************************************************************************
|
||||||
|
inline void LoadThreadBlockReductions(uint gtid, uint gid, uint exclusiveHistReduction)
|
||||||
|
{
|
||||||
|
if (gtid < RADIX)
|
||||||
|
{
|
||||||
|
g_d[gtid + PART_SIZE] = b_globalHist[gtid + GlobalHistOffset()] +
|
||||||
|
b_passHist[gtid * e_threadBlocks + gid] - exclusiveHistReduction;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[numthreads(D_DIM, 1, 1)]
|
||||||
|
void Downsweep(uint3 gtid : SV_GroupThreadID, uint3 gid : SV_GroupID)
|
||||||
|
{
|
||||||
|
if (gtid.x != 0u)
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint partitionStart = gid.x * PART_SIZE;
|
||||||
|
const uint partitionEnd = min(partitionStart + PART_SIZE, e_numKeys);
|
||||||
|
uint digitOffsets[RADIX];
|
||||||
|
const uint globalHistOffset = GlobalHistOffset();
|
||||||
|
for (uint digit = 0u; digit < RADIX; ++digit)
|
||||||
|
{
|
||||||
|
digitOffsets[digit] =
|
||||||
|
b_globalHist[globalHistOffset + digit] +
|
||||||
|
b_passHist[digit * e_threadBlocks + gid.x];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint index = partitionStart; index < partitionEnd; ++index)
|
||||||
|
{
|
||||||
|
uint key;
|
||||||
|
#if defined(KEY_UINT)
|
||||||
|
key = b_sort[index];
|
||||||
|
#elif defined(KEY_INT)
|
||||||
|
key = IntToUint(b_sort[index]);
|
||||||
|
#elif defined(KEY_FLOAT)
|
||||||
|
key = FloatToUint(b_sort[index]);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
const uint digit = ExtractDigit(key);
|
||||||
|
const uint destinationIndex = digitOffsets[digit]++;
|
||||||
|
|
||||||
|
#if defined(KEY_UINT)
|
||||||
|
b_alt[destinationIndex] = key;
|
||||||
|
#elif defined(KEY_INT)
|
||||||
|
b_alt[destinationIndex] = UintToInt(key);
|
||||||
|
#elif defined(KEY_FLOAT)
|
||||||
|
b_alt[destinationIndex] = UintToFloat(key);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(SORT_PAIRS)
|
||||||
|
#if defined(PAYLOAD_UINT)
|
||||||
|
b_altPayload[destinationIndex] = b_sortPayload[index];
|
||||||
|
#elif defined(PAYLOAD_INT)
|
||||||
|
b_altPayload[destinationIndex] = b_sortPayload[index];
|
||||||
|
#elif defined(PAYLOAD_FLOAT)
|
||||||
|
b_altPayload[destinationIndex] = b_sortPayload[index];
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
959
MVS/3DGS-D3D12/shaders/SortCommon.hlsl
Normal file
959
MVS/3DGS-D3D12/shaders/SortCommon.hlsl
Normal file
@@ -0,0 +1,959 @@
|
|||||||
|
/******************************************************************************
|
||||||
|
* SortCommon
|
||||||
|
* Common functions for GPUSorting
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: MIT
|
||||||
|
* Copyright Thomas Smith 5/17/2024
|
||||||
|
* https://github.com/b0nes164/GPUSorting
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
* of this software and associated documentation files (the "Software"), to deal
|
||||||
|
* in the Software without restriction, including without limitation the rights
|
||||||
|
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
* copies of the Software, and to permit persons to whom the Software is
|
||||||
|
* furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in all
|
||||||
|
* copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
* SOFTWARE.
|
||||||
|
******************************************************************************/
|
||||||
|
#define KEYS_PER_THREAD 15U
|
||||||
|
#define D_DIM 256U
|
||||||
|
#define PART_SIZE 3840U
|
||||||
|
#define D_TOTAL_SMEM 4096U
|
||||||
|
|
||||||
|
#define RADIX 256U //Number of digit bins
|
||||||
|
#define RADIX_MASK 255U //Mask of digit bins
|
||||||
|
#define HALF_RADIX 128U //For smaller waves where bit packing is necessary
|
||||||
|
#define HALF_MASK 127U // ''
|
||||||
|
#define RADIX_LOG 8U //log2(RADIX)
|
||||||
|
#define RADIX_PASSES 4U //(Key width) / RADIX_LOG
|
||||||
|
|
||||||
|
cbuffer cbGpuSorting : register(b0)
|
||||||
|
{
|
||||||
|
uint e_numKeys;
|
||||||
|
uint e_radixShift;
|
||||||
|
uint e_threadBlocks;
|
||||||
|
uint padding;
|
||||||
|
};
|
||||||
|
|
||||||
|
#if defined(KEY_UINT)
|
||||||
|
RWStructuredBuffer<uint> b_sort : register(u0);
|
||||||
|
RWStructuredBuffer<uint> b_alt : register(u1);
|
||||||
|
#elif defined(KEY_INT)
|
||||||
|
RWStructuredBuffer<int> b_sort : register(u0);
|
||||||
|
RWStructuredBuffer<int> b_alt : register(u1);
|
||||||
|
#elif defined(KEY_FLOAT)
|
||||||
|
RWStructuredBuffer<float> b_sort : register(u0);
|
||||||
|
RWStructuredBuffer<float> b_alt : register(u1);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(PAYLOAD_UINT)
|
||||||
|
RWStructuredBuffer<uint> b_sortPayload : register(u2);
|
||||||
|
RWStructuredBuffer<uint> b_altPayload : register(u3);
|
||||||
|
#elif defined(PAYLOAD_INT)
|
||||||
|
RWStructuredBuffer<int> b_sortPayload : register(u2);
|
||||||
|
RWStructuredBuffer<int> b_altPayload : register(u3);
|
||||||
|
#elif defined(PAYLOAD_FLOAT)
|
||||||
|
RWStructuredBuffer<float> b_sortPayload : register(u2);
|
||||||
|
RWStructuredBuffer<float> b_altPayload : register(u3);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
groupshared uint g_d[D_TOTAL_SMEM]; //Shared memory for DigitBinningPass and DownSweep kernels
|
||||||
|
|
||||||
|
struct KeyStruct
|
||||||
|
{
|
||||||
|
uint k[KEYS_PER_THREAD];
|
||||||
|
};
|
||||||
|
|
||||||
|
struct OffsetStruct
|
||||||
|
{
|
||||||
|
#if defined(ENABLE_16_BIT)
|
||||||
|
uint16_t o[KEYS_PER_THREAD];
|
||||||
|
#else
|
||||||
|
uint o[KEYS_PER_THREAD];
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DigitStruct
|
||||||
|
{
|
||||||
|
#if defined(ENABLE_16_BIT)
|
||||||
|
uint16_t d[KEYS_PER_THREAD];
|
||||||
|
#else
|
||||||
|
uint d[KEYS_PER_THREAD];
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
//*****************************************************************************
|
||||||
|
//HELPER FUNCTIONS
|
||||||
|
//*****************************************************************************
|
||||||
|
//Due to a bug with SPIRV pre 1.6, we cannot use WaveGetLaneCount() to get the currently active wavesize
|
||||||
|
inline uint getWaveSize()
|
||||||
|
{
|
||||||
|
#if defined(VULKAN)
|
||||||
|
GroupMemoryBarrierWithGroupSync(); //Make absolutely sure the wave is not diverged here
|
||||||
|
return dot(countbits(WaveActiveBallot(true)), uint4(1, 1, 1, 1));
|
||||||
|
#else
|
||||||
|
return WaveGetLaneCount();
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint getWaveIndex(uint gtid, uint waveSize)
|
||||||
|
{
|
||||||
|
return gtid / waveSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
//Radix Tricks by Michael Herf
|
||||||
|
//http://stereopsis.com/radix.html
|
||||||
|
inline uint FloatToUint(float f)
|
||||||
|
{
|
||||||
|
uint mask = -((int) (asuint(f) >> 31)) | 0x80000000;
|
||||||
|
return asuint(f) ^ mask;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline float UintToFloat(uint u)
|
||||||
|
{
|
||||||
|
uint mask = ((u >> 31) - 1) | 0x80000000;
|
||||||
|
return asfloat(u ^ mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint IntToUint(int i)
|
||||||
|
{
|
||||||
|
return asuint(i ^ 0x80000000);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int UintToInt(uint u)
|
||||||
|
{
|
||||||
|
return asint(u ^ 0x80000000);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint getWaveCountPass(uint waveSize)
|
||||||
|
{
|
||||||
|
return D_DIM / waveSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint ExtractDigit(uint key)
|
||||||
|
{
|
||||||
|
return key >> e_radixShift & RADIX_MASK;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint ExtractDigit(uint key, uint shift)
|
||||||
|
{
|
||||||
|
return key >> shift & RADIX_MASK;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint ExtractPackedIndex(uint key)
|
||||||
|
{
|
||||||
|
return key >> (e_radixShift + 1) & HALF_MASK;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint ExtractPackedShift(uint key)
|
||||||
|
{
|
||||||
|
return (key >> e_radixShift & 1) ? 16 : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint ExtractPackedValue(uint packed, uint key)
|
||||||
|
{
|
||||||
|
return packed >> ExtractPackedShift(key) & 0xffff;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint SubPartSizeWGE16(uint waveSize)
|
||||||
|
{
|
||||||
|
return KEYS_PER_THREAD * waveSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint SharedOffsetWGE16(uint gtid, uint waveSize)
|
||||||
|
{
|
||||||
|
return WaveGetLaneIndex() + getWaveIndex(gtid, waveSize) * SubPartSizeWGE16(waveSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint SubPartSizeWLT16(uint waveSize, uint _serialIterations)
|
||||||
|
{
|
||||||
|
return KEYS_PER_THREAD * waveSize * _serialIterations;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint SharedOffsetWLT16(uint gtid, uint waveSize, uint _serialIterations)
|
||||||
|
{
|
||||||
|
return WaveGetLaneIndex() +
|
||||||
|
(getWaveIndex(gtid, waveSize) / _serialIterations * SubPartSizeWLT16(waveSize, _serialIterations)) +
|
||||||
|
(getWaveIndex(gtid, waveSize) % _serialIterations * waveSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint DeviceOffsetWGE16(uint gtid, uint waveSize, uint partIndex)
|
||||||
|
{
|
||||||
|
return SharedOffsetWGE16(gtid, waveSize) + partIndex * PART_SIZE;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint DeviceOffsetWLT16(uint gtid, uint waveSize, uint partIndex, uint serialIterations)
|
||||||
|
{
|
||||||
|
return SharedOffsetWLT16(gtid, waveSize, serialIterations) + partIndex * PART_SIZE;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint GlobalHistOffset()
|
||||||
|
{
|
||||||
|
return e_radixShift << 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint WaveHistsSizeWGE16(uint waveSize)
|
||||||
|
{
|
||||||
|
return D_DIM / waveSize * RADIX;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint WaveHistsSizeWLT16()
|
||||||
|
{
|
||||||
|
return D_TOTAL_SMEM;
|
||||||
|
}
|
||||||
|
|
||||||
|
//*****************************************************************************
|
||||||
|
//FUNCTIONS COMMON TO THE DOWNSWEEP / DIGIT BINNING PASS
|
||||||
|
//*****************************************************************************
|
||||||
|
//If the size of a wave is too small, we do not have enough space in
|
||||||
|
//shared memory to assign a histogram to each wave, so instead,
|
||||||
|
//some operations are peformed serially.
|
||||||
|
inline uint SerialIterations(uint waveSize)
|
||||||
|
{
|
||||||
|
return (D_DIM / waveSize + 31) >> 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ClearWaveHists(uint gtid, uint waveSize)
|
||||||
|
{
|
||||||
|
const uint histsEnd = waveSize >= 16 ?
|
||||||
|
WaveHistsSizeWGE16(waveSize) : WaveHistsSizeWLT16();
|
||||||
|
for (uint i = gtid; i < histsEnd; i += D_DIM)
|
||||||
|
g_d[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void LoadKey(inout uint key, uint index)
|
||||||
|
{
|
||||||
|
#if defined(KEY_UINT)
|
||||||
|
key = b_sort[index];
|
||||||
|
#elif defined(KEY_INT)
|
||||||
|
key = UintToInt(b_sort[index]);
|
||||||
|
#elif defined(KEY_FLOAT)
|
||||||
|
key = FloatToUint(b_sort[index]);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void LoadDummyKey(inout uint key)
|
||||||
|
{
|
||||||
|
key = 0xffffffff;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline KeyStruct LoadKeysWGE16(uint gtid, uint waveSize, uint partIndex)
|
||||||
|
{
|
||||||
|
KeyStruct keys;
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0, t = DeviceOffsetWGE16(gtid, waveSize, partIndex);
|
||||||
|
i < KEYS_PER_THREAD;
|
||||||
|
++i, t += waveSize)
|
||||||
|
{
|
||||||
|
LoadKey(keys.k[i], t);
|
||||||
|
}
|
||||||
|
return keys;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline KeyStruct LoadKeysWLT16(uint gtid, uint waveSize, uint partIndex, uint serialIterations)
|
||||||
|
{
|
||||||
|
KeyStruct keys;
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0, t = DeviceOffsetWLT16(gtid, waveSize, partIndex, serialIterations);
|
||||||
|
i < KEYS_PER_THREAD;
|
||||||
|
++i, t += waveSize * serialIterations)
|
||||||
|
{
|
||||||
|
LoadKey(keys.k[i], t);
|
||||||
|
}
|
||||||
|
return keys;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline KeyStruct LoadKeysPartialWGE16(uint gtid, uint waveSize, uint partIndex)
|
||||||
|
{
|
||||||
|
KeyStruct keys;
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0, t = DeviceOffsetWGE16(gtid, waveSize, partIndex);
|
||||||
|
i < KEYS_PER_THREAD;
|
||||||
|
++i, t += waveSize)
|
||||||
|
{
|
||||||
|
if (t < e_numKeys)
|
||||||
|
LoadKey(keys.k[i], t);
|
||||||
|
else
|
||||||
|
LoadDummyKey(keys.k[i]);
|
||||||
|
}
|
||||||
|
return keys;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline KeyStruct LoadKeysPartialWLT16(uint gtid, uint waveSize, uint partIndex, uint serialIterations)
|
||||||
|
{
|
||||||
|
KeyStruct keys;
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0, t = DeviceOffsetWLT16(gtid, waveSize, partIndex, serialIterations);
|
||||||
|
i < KEYS_PER_THREAD;
|
||||||
|
++i, t += waveSize * serialIterations)
|
||||||
|
{
|
||||||
|
if (t < e_numKeys)
|
||||||
|
LoadKey(keys.k[i], t);
|
||||||
|
else
|
||||||
|
LoadDummyKey(keys.k[i]);
|
||||||
|
}
|
||||||
|
return keys;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint WaveFlagsWGE16(uint waveSize)
|
||||||
|
{
|
||||||
|
return (waveSize & 31) ? (1U << waveSize) - 1 : 0xffffffff;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint WaveFlagsWLT16(uint waveSize)
|
||||||
|
{
|
||||||
|
return (1U << waveSize) - 1;;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void WarpLevelMultiSplitWGE16(uint key, inout uint4 waveFlags)
|
||||||
|
{
|
||||||
|
[unroll]
|
||||||
|
for (uint k = 0; k < RADIX_LOG; ++k)
|
||||||
|
{
|
||||||
|
const uint currentBit = 1U << (k + e_radixShift);
|
||||||
|
const bool t = (key & currentBit) != 0;
|
||||||
|
GroupMemoryBarrierWithGroupSync(); //Play on the safe side, throw in a barrier for convergence
|
||||||
|
const uint4 ballot = WaveActiveBallot(t);
|
||||||
|
if(t)
|
||||||
|
waveFlags &= ballot;
|
||||||
|
else
|
||||||
|
waveFlags &= (~ballot);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint2 CountBitsWGE16(uint waveSize, uint ltMask, uint4 waveFlags)
|
||||||
|
{
|
||||||
|
uint2 count = uint2(0, 0);
|
||||||
|
|
||||||
|
for(uint wavePart = 0; wavePart < waveSize; wavePart += 32)
|
||||||
|
{
|
||||||
|
uint t = countbits(waveFlags[wavePart >> 5]);
|
||||||
|
if (WaveGetLaneIndex() >= wavePart)
|
||||||
|
{
|
||||||
|
if (WaveGetLaneIndex() >= wavePart + 32)
|
||||||
|
count.x += t;
|
||||||
|
else
|
||||||
|
count.x += countbits(waveFlags[wavePart >> 5] & ltMask);
|
||||||
|
}
|
||||||
|
count.y += t;
|
||||||
|
}
|
||||||
|
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void WarpLevelMultiSplitWLT16(uint key, inout uint waveFlags)
|
||||||
|
{
|
||||||
|
[unroll]
|
||||||
|
for (uint k = 0; k < RADIX_LOG; ++k)
|
||||||
|
{
|
||||||
|
const bool t = key >> (k + e_radixShift) & 1;
|
||||||
|
waveFlags &= (t ? 0 : 0xffffffff) ^ (uint) WaveActiveBallot(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline OffsetStruct RankKeysWGE16(
|
||||||
|
uint waveSize,
|
||||||
|
uint waveOffset,
|
||||||
|
KeyStruct keys)
|
||||||
|
{
|
||||||
|
OffsetStruct offsets;
|
||||||
|
const uint initialFlags = WaveFlagsWGE16(waveSize);
|
||||||
|
const uint ltMask = (1U << (WaveGetLaneIndex() & 31)) - 1;
|
||||||
|
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||||
|
{
|
||||||
|
uint4 waveFlags = initialFlags;
|
||||||
|
WarpLevelMultiSplitWGE16(keys.k[i], waveFlags);
|
||||||
|
|
||||||
|
const uint index = ExtractDigit(keys.k[i]) + waveOffset;
|
||||||
|
const uint2 bitCount = CountBitsWGE16(waveSize, ltMask, waveFlags);
|
||||||
|
|
||||||
|
offsets.o[i] = g_d[index] + bitCount.x;
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
if (bitCount.x == 0)
|
||||||
|
g_d[index] += bitCount.y;
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
}
|
||||||
|
|
||||||
|
return offsets;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline OffsetStruct RankKeysWLT16(uint waveSize, uint waveIndex, KeyStruct keys, uint serialIterations)
|
||||||
|
{
|
||||||
|
OffsetStruct offsets;
|
||||||
|
const uint ltMask = (1U << WaveGetLaneIndex()) - 1;
|
||||||
|
const uint initialFlags = WaveFlagsWLT16(waveSize);
|
||||||
|
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||||
|
{
|
||||||
|
uint waveFlags = initialFlags;
|
||||||
|
WarpLevelMultiSplitWLT16(keys.k[i], waveFlags);
|
||||||
|
|
||||||
|
const uint index = ExtractPackedIndex(keys.k[i]) +
|
||||||
|
(waveIndex / serialIterations * HALF_RADIX);
|
||||||
|
|
||||||
|
const uint peerBits = countbits(waveFlags & ltMask);
|
||||||
|
for (uint k = 0; k < serialIterations; ++k)
|
||||||
|
{
|
||||||
|
if (waveIndex % serialIterations == k)
|
||||||
|
offsets.o[i] = ExtractPackedValue(g_d[index], keys.k[i]) + peerBits;
|
||||||
|
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
if (waveIndex % serialIterations == k && peerBits == 0)
|
||||||
|
{
|
||||||
|
InterlockedAdd(g_d[index],
|
||||||
|
countbits(waveFlags) << ExtractPackedShift(keys.k[i]));
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return offsets;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint WaveHistInclusiveScanCircularShiftWGE16(uint gtid, uint waveSize)
|
||||||
|
{
|
||||||
|
uint histReduction = g_d[gtid];
|
||||||
|
for (uint i = gtid + RADIX; i < WaveHistsSizeWGE16(waveSize); i += RADIX)
|
||||||
|
{
|
||||||
|
histReduction += g_d[i];
|
||||||
|
g_d[i] = histReduction - g_d[i];
|
||||||
|
}
|
||||||
|
return histReduction;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint WaveHistInclusiveScanCircularShiftWLT16(uint gtid)
|
||||||
|
{
|
||||||
|
uint histReduction = g_d[gtid];
|
||||||
|
for (uint i = gtid + HALF_RADIX; i < WaveHistsSizeWLT16(); i += HALF_RADIX)
|
||||||
|
{
|
||||||
|
histReduction += g_d[i];
|
||||||
|
g_d[i] = histReduction - g_d[i];
|
||||||
|
}
|
||||||
|
return histReduction;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void WaveHistReductionExclusiveScanWGE16(uint gtid, uint waveSize, uint histReduction)
|
||||||
|
{
|
||||||
|
if (gtid < RADIX)
|
||||||
|
{
|
||||||
|
const uint laneMask = waveSize - 1;
|
||||||
|
g_d[((WaveGetLaneIndex() + 1) & laneMask) + (gtid & ~laneMask)] = histReduction;
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
|
||||||
|
if (gtid < RADIX / waveSize)
|
||||||
|
{
|
||||||
|
g_d[gtid * waveSize] =
|
||||||
|
WavePrefixSum(g_d[gtid * waveSize]);
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
|
||||||
|
uint t = WaveReadLaneAt(g_d[gtid], 0);
|
||||||
|
if (gtid < RADIX && WaveGetLaneIndex())
|
||||||
|
g_d[gtid] += t;
|
||||||
|
}
|
||||||
|
|
||||||
|
//inclusive/exclusive prefix sum up the histograms,
|
||||||
|
//use a blelloch scan for in place packed exclusive
|
||||||
|
inline void WaveHistReductionExclusiveScanWLT16(uint gtid)
|
||||||
|
{
|
||||||
|
uint shift = 1;
|
||||||
|
for (uint j = RADIX >> 2; j > 0; j >>= 1)
|
||||||
|
{
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
if (gtid < j)
|
||||||
|
{
|
||||||
|
g_d[((((gtid << 1) + 2) << shift) - 1) >> 1] +=
|
||||||
|
g_d[((((gtid << 1) + 1) << shift) - 1) >> 1] & 0xffff0000;
|
||||||
|
}
|
||||||
|
shift++;
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
|
||||||
|
if (gtid == 0)
|
||||||
|
g_d[HALF_RADIX - 1] &= 0xffff;
|
||||||
|
|
||||||
|
for (uint j = 1; j < RADIX >> 1; j <<= 1)
|
||||||
|
{
|
||||||
|
--shift;
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
if (gtid < j)
|
||||||
|
{
|
||||||
|
const uint t = ((((gtid << 1) + 1) << shift) - 1) >> 1;
|
||||||
|
const uint t2 = ((((gtid << 1) + 2) << shift) - 1) >> 1;
|
||||||
|
const uint t3 = g_d[t];
|
||||||
|
g_d[t] = (g_d[t] & 0xffff) | (g_d[t2] & 0xffff0000);
|
||||||
|
g_d[t2] += t3 & 0xffff0000;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
if (gtid < HALF_RADIX)
|
||||||
|
{
|
||||||
|
const uint t = g_d[gtid];
|
||||||
|
g_d[gtid] = (t >> 16) + (t << 16) + (t & 0xffff0000);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void UpdateOffsetsWGE16(
|
||||||
|
uint gtid,
|
||||||
|
uint waveSize,
|
||||||
|
inout OffsetStruct offsets,
|
||||||
|
KeyStruct keys)
|
||||||
|
{
|
||||||
|
if (gtid >= waveSize)
|
||||||
|
{
|
||||||
|
const uint t = getWaveIndex(gtid, waveSize) * RADIX;
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||||
|
{
|
||||||
|
const uint t2 = ExtractDigit(keys.k[i]);
|
||||||
|
offsets.o[i] += g_d[t2 + t] + g_d[t2];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||||
|
offsets.o[i] += g_d[ExtractDigit(keys.k[i])];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void UpdateOffsetsWLT16(
|
||||||
|
uint gtid,
|
||||||
|
uint waveSize,
|
||||||
|
uint serialIterations,
|
||||||
|
inout OffsetStruct offsets,
|
||||||
|
KeyStruct keys)
|
||||||
|
{
|
||||||
|
if (gtid >= waveSize * serialIterations)
|
||||||
|
{
|
||||||
|
const uint t = getWaveIndex(gtid, waveSize) / serialIterations * HALF_RADIX;
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||||
|
{
|
||||||
|
const uint t2 = ExtractPackedIndex(keys.k[i]);
|
||||||
|
offsets.o[i] += ExtractPackedValue(g_d[t2 + t] + g_d[t2], keys.k[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||||
|
offsets.o[i] += ExtractPackedValue(g_d[ExtractPackedIndex(keys.k[i])], keys.k[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ScatterKeysShared(OffsetStruct offsets, KeyStruct keys)
|
||||||
|
{
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||||
|
g_d[offsets.o[i]] = keys.k[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint DescendingIndex(uint deviceIndex)
|
||||||
|
{
|
||||||
|
return e_numKeys - deviceIndex - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void WriteKey(uint deviceIndex, uint groupSharedIndex)
|
||||||
|
{
|
||||||
|
#if defined(KEY_UINT)
|
||||||
|
b_alt[deviceIndex] = g_d[groupSharedIndex];
|
||||||
|
#elif defined(KEY_INT)
|
||||||
|
b_alt[deviceIndex] = UintToInt(g_d[groupSharedIndex]);
|
||||||
|
#elif defined(KEY_FLOAT)
|
||||||
|
b_alt[deviceIndex] = UintToFloat(g_d[groupSharedIndex]);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void LoadPayload(inout uint payload, uint deviceIndex)
|
||||||
|
{
|
||||||
|
#if defined(PAYLOAD_UINT)
|
||||||
|
payload = b_sortPayload[deviceIndex];
|
||||||
|
#elif defined(PAYLOAD_INT) || defined(PAYLOAD_FLOAT)
|
||||||
|
payload = asuint(b_sortPayload[deviceIndex]);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ScatterPayloadsShared(OffsetStruct offsets, KeyStruct payloads)
|
||||||
|
{
|
||||||
|
ScatterKeysShared(offsets, payloads);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void WritePayload(uint deviceIndex, uint groupSharedIndex)
|
||||||
|
{
|
||||||
|
#if defined(PAYLOAD_UINT)
|
||||||
|
b_altPayload[deviceIndex] = g_d[groupSharedIndex];
|
||||||
|
#elif defined(PAYLOAD_INT)
|
||||||
|
b_altPayload[deviceIndex] = asint(g_d[groupSharedIndex]);
|
||||||
|
#elif defined(PAYLOAD_FLOAT)
|
||||||
|
b_altPayload[deviceIndex] = asfloat(g_d[groupSharedIndex]);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
//*****************************************************************************
|
||||||
|
//SCATTERING: FULL PARTITIONS
|
||||||
|
//*****************************************************************************
|
||||||
|
//KEYS ONLY
|
||||||
|
inline void ScatterKeysOnlyDeviceAscending(uint gtid)
|
||||||
|
{
|
||||||
|
for (uint i = gtid; i < PART_SIZE; i += D_DIM)
|
||||||
|
WriteKey(g_d[ExtractDigit(g_d[i]) + PART_SIZE] + i, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ScatterKeysOnlyDeviceDescending(uint gtid)
|
||||||
|
{
|
||||||
|
if (e_radixShift == 24)
|
||||||
|
{
|
||||||
|
for (uint i = gtid; i < PART_SIZE; i += D_DIM)
|
||||||
|
WriteKey(DescendingIndex(g_d[ExtractDigit(g_d[i]) + PART_SIZE] + i), i);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
ScatterKeysOnlyDeviceAscending(gtid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ScatterKeysOnlyDevice(uint gtid)
|
||||||
|
{
|
||||||
|
#if defined(SHOULD_ASCEND)
|
||||||
|
ScatterKeysOnlyDeviceAscending(gtid);
|
||||||
|
#else
|
||||||
|
ScatterKeysOnlyDeviceDescending(gtid);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
//KEY VALUE PAIRS
|
||||||
|
inline void ScatterPairsKeyPhaseAscending(
|
||||||
|
uint gtid,
|
||||||
|
inout DigitStruct digits)
|
||||||
|
{
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||||
|
{
|
||||||
|
digits.d[i] = ExtractDigit(g_d[t]);
|
||||||
|
WriteKey(g_d[digits.d[i] + PART_SIZE] + t, t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ScatterPairsKeyPhaseDescending(
|
||||||
|
uint gtid,
|
||||||
|
inout DigitStruct digits)
|
||||||
|
{
|
||||||
|
if (e_radixShift == 24)
|
||||||
|
{
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||||
|
{
|
||||||
|
digits.d[i] = ExtractDigit(g_d[t]);
|
||||||
|
WriteKey(DescendingIndex(g_d[digits.d[i] + PART_SIZE] + t), t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
ScatterPairsKeyPhaseAscending(gtid, digits);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void LoadPayloadsWGE16(
|
||||||
|
uint gtid,
|
||||||
|
uint waveSize,
|
||||||
|
uint partIndex,
|
||||||
|
inout KeyStruct payloads)
|
||||||
|
{
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0, t = DeviceOffsetWGE16(gtid, waveSize, partIndex);
|
||||||
|
i < KEYS_PER_THREAD;
|
||||||
|
++i, t += waveSize)
|
||||||
|
{
|
||||||
|
LoadPayload(payloads.k[i], t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void LoadPayloadsWLT16(
|
||||||
|
uint gtid,
|
||||||
|
uint waveSize,
|
||||||
|
uint partIndex,
|
||||||
|
uint serialIterations,
|
||||||
|
inout KeyStruct payloads)
|
||||||
|
{
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0, t = DeviceOffsetWLT16(gtid, waveSize, partIndex, serialIterations);
|
||||||
|
i < KEYS_PER_THREAD;
|
||||||
|
++i, t += waveSize * serialIterations)
|
||||||
|
{
|
||||||
|
LoadPayload(payloads.k[i], t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ScatterPayloadsAscending(uint gtid, DigitStruct digits)
|
||||||
|
{
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||||
|
WritePayload(g_d[digits.d[i] + PART_SIZE] + t, t);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ScatterPayloadsDescending(uint gtid, DigitStruct digits)
|
||||||
|
{
|
||||||
|
if (e_radixShift == 24)
|
||||||
|
{
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||||
|
WritePayload(DescendingIndex(g_d[digits.d[i] + PART_SIZE] + t), t);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
ScatterPayloadsAscending(gtid, digits);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ScatterPairsDevice(
|
||||||
|
uint gtid,
|
||||||
|
uint waveSize,
|
||||||
|
uint partIndex,
|
||||||
|
OffsetStruct offsets)
|
||||||
|
{
|
||||||
|
DigitStruct digits;
|
||||||
|
#if defined(SHOULD_ASCEND)
|
||||||
|
ScatterPairsKeyPhaseAscending(gtid, digits);
|
||||||
|
#else
|
||||||
|
ScatterPairsKeyPhaseDescending(gtid, digits);
|
||||||
|
#endif
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
|
||||||
|
KeyStruct payloads;
|
||||||
|
if (waveSize >= 16)
|
||||||
|
LoadPayloadsWGE16(gtid, waveSize, partIndex, payloads);
|
||||||
|
else
|
||||||
|
LoadPayloadsWLT16(gtid, waveSize, partIndex, SerialIterations(waveSize), payloads);
|
||||||
|
ScatterPayloadsShared(offsets, payloads);
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
|
||||||
|
#if defined(SHOULD_ASCEND)
|
||||||
|
ScatterPayloadsAscending(gtid, digits);
|
||||||
|
#else
|
||||||
|
ScatterPayloadsDescending(gtid, digits);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ScatterDevice(
|
||||||
|
uint gtid,
|
||||||
|
uint waveSize,
|
||||||
|
uint partIndex,
|
||||||
|
OffsetStruct offsets)
|
||||||
|
{
|
||||||
|
#if defined(SORT_PAIRS)
|
||||||
|
ScatterPairsDevice(
|
||||||
|
gtid,
|
||||||
|
waveSize,
|
||||||
|
partIndex,
|
||||||
|
offsets);
|
||||||
|
#else
|
||||||
|
ScatterKeysOnlyDevice(gtid);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
//*****************************************************************************
|
||||||
|
//SCATTERING: PARTIAL PARTITIONS
|
||||||
|
//*****************************************************************************
|
||||||
|
//KEYS ONLY
|
||||||
|
inline void ScatterKeysOnlyDevicePartialAscending(uint gtid, uint finalPartSize)
|
||||||
|
{
|
||||||
|
for (uint i = gtid; i < PART_SIZE; i += D_DIM)
|
||||||
|
{
|
||||||
|
if (i < finalPartSize)
|
||||||
|
WriteKey(g_d[ExtractDigit(g_d[i]) + PART_SIZE] + i, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ScatterKeysOnlyDevicePartialDescending(uint gtid, uint finalPartSize)
|
||||||
|
{
|
||||||
|
if (e_radixShift == 24)
|
||||||
|
{
|
||||||
|
for (uint i = gtid; i < PART_SIZE; i += D_DIM)
|
||||||
|
{
|
||||||
|
if (i < finalPartSize)
|
||||||
|
WriteKey(DescendingIndex(g_d[ExtractDigit(g_d[i]) + PART_SIZE] + i), i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
ScatterKeysOnlyDevicePartialAscending(gtid, finalPartSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ScatterKeysOnlyDevicePartial(uint gtid, uint partIndex)
|
||||||
|
{
|
||||||
|
const uint finalPartSize = e_numKeys - partIndex * PART_SIZE;
|
||||||
|
#if defined(SHOULD_ASCEND)
|
||||||
|
ScatterKeysOnlyDevicePartialAscending(gtid, finalPartSize);
|
||||||
|
#else
|
||||||
|
ScatterKeysOnlyDevicePartialDescending(gtid, finalPartSize);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
//KEY VALUE PAIRS
|
||||||
|
inline void ScatterPairsKeyPhaseAscendingPartial(
|
||||||
|
uint gtid,
|
||||||
|
uint finalPartSize,
|
||||||
|
inout DigitStruct digits)
|
||||||
|
{
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||||
|
{
|
||||||
|
if (t < finalPartSize)
|
||||||
|
{
|
||||||
|
digits.d[i] = ExtractDigit(g_d[t]);
|
||||||
|
WriteKey(g_d[digits.d[i] + PART_SIZE] + t, t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ScatterPairsKeyPhaseDescendingPartial(
|
||||||
|
uint gtid,
|
||||||
|
uint finalPartSize,
|
||||||
|
inout DigitStruct digits)
|
||||||
|
{
|
||||||
|
if (e_radixShift == 24)
|
||||||
|
{
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||||
|
{
|
||||||
|
if (t < finalPartSize)
|
||||||
|
{
|
||||||
|
digits.d[i] = ExtractDigit(g_d[t]);
|
||||||
|
WriteKey(DescendingIndex(g_d[digits.d[i] + PART_SIZE] + t), t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
ScatterPairsKeyPhaseAscendingPartial(gtid, finalPartSize, digits);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void LoadPayloadsPartialWGE16(
|
||||||
|
uint gtid,
|
||||||
|
uint waveSize,
|
||||||
|
uint partIndex,
|
||||||
|
inout KeyStruct payloads)
|
||||||
|
{
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0, t = DeviceOffsetWGE16(gtid, waveSize, partIndex);
|
||||||
|
i < KEYS_PER_THREAD;
|
||||||
|
++i, t += waveSize)
|
||||||
|
{
|
||||||
|
if (t < e_numKeys)
|
||||||
|
LoadPayload(payloads.k[i], t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void LoadPayloadsPartialWLT16(
|
||||||
|
uint gtid,
|
||||||
|
uint waveSize,
|
||||||
|
uint partIndex,
|
||||||
|
uint serialIterations,
|
||||||
|
inout KeyStruct payloads)
|
||||||
|
{
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0, t = DeviceOffsetWLT16(gtid, waveSize, partIndex, serialIterations);
|
||||||
|
i < KEYS_PER_THREAD;
|
||||||
|
++i, t += waveSize * serialIterations)
|
||||||
|
{
|
||||||
|
if (t < e_numKeys)
|
||||||
|
LoadPayload(payloads.k[i], t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ScatterPayloadsAscendingPartial(
|
||||||
|
uint gtid,
|
||||||
|
uint finalPartSize,
|
||||||
|
DigitStruct digits)
|
||||||
|
{
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||||
|
{
|
||||||
|
if (t < finalPartSize)
|
||||||
|
WritePayload(g_d[digits.d[i] + PART_SIZE] + t, t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ScatterPayloadsDescendingPartial(
|
||||||
|
uint gtid,
|
||||||
|
uint finalPartSize,
|
||||||
|
DigitStruct digits)
|
||||||
|
{
|
||||||
|
if (e_radixShift == 24)
|
||||||
|
{
|
||||||
|
[unroll]
|
||||||
|
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||||
|
{
|
||||||
|
if (t < finalPartSize)
|
||||||
|
WritePayload(DescendingIndex(g_d[digits.d[i] + PART_SIZE] + t), t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
ScatterPayloadsAscendingPartial(gtid, finalPartSize, digits);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ScatterPairsDevicePartial(
|
||||||
|
uint gtid,
|
||||||
|
uint waveSize,
|
||||||
|
uint partIndex,
|
||||||
|
OffsetStruct offsets)
|
||||||
|
{
|
||||||
|
DigitStruct digits;
|
||||||
|
const uint finalPartSize = e_numKeys - partIndex * PART_SIZE;
|
||||||
|
#if defined(SHOULD_ASCEND)
|
||||||
|
ScatterPairsKeyPhaseAscendingPartial(gtid, finalPartSize, digits);
|
||||||
|
#else
|
||||||
|
ScatterPairsKeyPhaseDescendingPartial(gtid, finalPartSize, digits);
|
||||||
|
#endif
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
|
||||||
|
KeyStruct payloads;
|
||||||
|
if (waveSize >= 16)
|
||||||
|
LoadPayloadsPartialWGE16(gtid, waveSize, partIndex, payloads);
|
||||||
|
else
|
||||||
|
LoadPayloadsPartialWLT16(gtid, waveSize, partIndex, SerialIterations(waveSize), payloads);
|
||||||
|
ScatterPayloadsShared(offsets, payloads);
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
|
||||||
|
#if defined(SHOULD_ASCEND)
|
||||||
|
ScatterPayloadsAscendingPartial(gtid, finalPartSize, digits);
|
||||||
|
#else
|
||||||
|
ScatterPayloadsDescendingPartial(gtid, finalPartSize, digits);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ScatterDevicePartial(
|
||||||
|
uint gtid,
|
||||||
|
uint waveSize,
|
||||||
|
uint partIndex,
|
||||||
|
OffsetStruct offsets)
|
||||||
|
{
|
||||||
|
#if defined(SORT_PAIRS)
|
||||||
|
ScatterPairsDevicePartial(
|
||||||
|
gtid,
|
||||||
|
waveSize,
|
||||||
|
partIndex,
|
||||||
|
offsets);
|
||||||
|
#else
|
||||||
|
ScatterKeysOnlyDevicePartial(gtid, partIndex);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user