Establish 3DGS D3D12 sorted baseline
This commit is contained in:
@@ -6,6 +6,9 @@ set(CMAKE_CXX_STANDARD 20)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
find_program(XC_DXC_EXECUTABLE NAMES dxc)
|
||||
if(NOT XC_DXC_EXECUTABLE)
|
||||
message(FATAL_ERROR "dxc is required to build the 3DGS D3D12 MVS sort shaders.")
|
||||
endif()
|
||||
|
||||
get_filename_component(XCENGINE_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/../.." ABSOLUTE)
|
||||
set(XCENGINE_BUILD_DIR "${XCENGINE_ROOT}/build")
|
||||
@@ -35,6 +38,8 @@ add_executable(xc_3dgs_d3d12_mvs
|
||||
shaders/PreparedSplatView.hlsli
|
||||
shaders/PrepareGaussiansCS.hlsl
|
||||
shaders/BuildSortKeysCS.hlsl
|
||||
shaders/SortCommon.hlsl
|
||||
shaders/DeviceRadixSort.hlsl
|
||||
shaders/DebugPointsVS.hlsl
|
||||
shaders/DebugPointsPS.hlsl
|
||||
)
|
||||
@@ -43,6 +48,8 @@ set_source_files_properties(
|
||||
shaders/PreparedSplatView.hlsli
|
||||
shaders/PrepareGaussiansCS.hlsl
|
||||
shaders/BuildSortKeysCS.hlsl
|
||||
shaders/SortCommon.hlsl
|
||||
shaders/DeviceRadixSort.hlsl
|
||||
shaders/DebugPointsVS.hlsl
|
||||
shaders/DebugPointsPS.hlsl
|
||||
PROPERTIES
|
||||
@@ -93,12 +100,55 @@ add_custom_command(TARGET xc_3dgs_d3d12_mvs POST_BUILD
|
||||
"$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/shaders"
|
||||
)
|
||||
|
||||
if(XC_DXC_EXECUTABLE)
|
||||
add_custom_command(TARGET xc_3dgs_d3d12_mvs POST_BUILD
|
||||
COMMAND "${XC_DXC_EXECUTABLE}"
|
||||
-T cs_6_6
|
||||
-E MainCS
|
||||
-Fo "$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/shaders/BuildSortKeysCS.dxil"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/shaders/BuildSortKeysCS.hlsl"
|
||||
)
|
||||
endif()
|
||||
add_custom_command(TARGET xc_3dgs_d3d12_mvs POST_BUILD
|
||||
COMMAND "${XC_DXC_EXECUTABLE}"
|
||||
-T cs_6_6
|
||||
-E MainCS
|
||||
-Fo "$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/shaders/BuildSortKeysCS.dxil"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/shaders/BuildSortKeysCS.hlsl"
|
||||
COMMAND "${XC_DXC_EXECUTABLE}"
|
||||
-T cs_6_6
|
||||
-E InitDeviceRadixSort
|
||||
-D KEY_UINT=1
|
||||
-D PAYLOAD_UINT=1
|
||||
-D SORT_PAIRS=1
|
||||
-D SHOULD_ASCEND=1
|
||||
-Fo "$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/shaders/RadixInit.dxil"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/shaders/DeviceRadixSort.hlsl"
|
||||
COMMAND "${XC_DXC_EXECUTABLE}"
|
||||
-T cs_6_6
|
||||
-E Upsweep
|
||||
-D KEY_UINT=1
|
||||
-D PAYLOAD_UINT=1
|
||||
-D SORT_PAIRS=1
|
||||
-D SHOULD_ASCEND=1
|
||||
-Fo "$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/shaders/RadixUpsweep.dxil"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/shaders/DeviceRadixSort.hlsl"
|
||||
COMMAND "${XC_DXC_EXECUTABLE}"
|
||||
-T cs_6_6
|
||||
-E BuildGlobalHistogram
|
||||
-D KEY_UINT=1
|
||||
-D PAYLOAD_UINT=1
|
||||
-D SORT_PAIRS=1
|
||||
-D SHOULD_ASCEND=1
|
||||
-Fo "$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/shaders/RadixGlobalHistogram.dxil"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/shaders/DeviceRadixSort.hlsl"
|
||||
COMMAND "${XC_DXC_EXECUTABLE}"
|
||||
-T cs_6_6
|
||||
-E Scan
|
||||
-D KEY_UINT=1
|
||||
-D PAYLOAD_UINT=1
|
||||
-D SORT_PAIRS=1
|
||||
-D SHOULD_ASCEND=1
|
||||
-Fo "$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/shaders/RadixScan.dxil"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/shaders/DeviceRadixSort.hlsl"
|
||||
COMMAND "${XC_DXC_EXECUTABLE}"
|
||||
-T cs_6_6
|
||||
-E Downsweep
|
||||
-D KEY_UINT=1
|
||||
-D PAYLOAD_UINT=1
|
||||
-D SORT_PAIRS=1
|
||||
-D SHOULD_ASCEND=1
|
||||
-Fo "$<TARGET_FILE_DIR:xc_3dgs_d3d12_mvs>/shaders/RadixDownsweep.dxil"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/shaders/DeviceRadixSort.hlsl"
|
||||
)
|
||||
|
||||
@@ -72,7 +72,8 @@ private:
|
||||
void ShutdownSortResources();
|
||||
void ShutdownDebugDrawResources();
|
||||
void Shutdown();
|
||||
bool CaptureSortKeySnapshot();
|
||||
bool CaptureSortSnapshot();
|
||||
bool CapturePass3HistogramDebug();
|
||||
void RenderFrame(bool captureScreenshot);
|
||||
|
||||
HWND m_hwnd = nullptr;
|
||||
@@ -107,14 +108,31 @@ private:
|
||||
XCEngine::RHI::RHIDescriptorPool* m_prepareDescriptorPool = nullptr;
|
||||
XCEngine::RHI::RHIDescriptorSet* m_prepareDescriptorSet = nullptr;
|
||||
XCEngine::RHI::D3D12Buffer* m_sortKeyBuffer = nullptr;
|
||||
XCEngine::RHI::D3D12Buffer* m_sortKeyScratchBuffer = nullptr;
|
||||
XCEngine::RHI::D3D12Buffer* m_orderBuffer = nullptr;
|
||||
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_sortKeySrv;
|
||||
XCEngine::RHI::D3D12Buffer* m_orderScratchBuffer = nullptr;
|
||||
XCEngine::RHI::D3D12Buffer* m_passHistogramBuffer = nullptr;
|
||||
XCEngine::RHI::D3D12Buffer* m_globalHistogramBuffer = nullptr;
|
||||
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_sortKeyUav;
|
||||
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_sortKeyScratchUav;
|
||||
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_orderBufferSrv;
|
||||
XCEngine::RHI::RHIPipelineLayout* m_sortPipelineLayout = nullptr;
|
||||
XCEngine::RHI::RHIPipelineState* m_sortPipelineState = nullptr;
|
||||
XCEngine::RHI::RHIDescriptorPool* m_sortDescriptorPool = nullptr;
|
||||
XCEngine::RHI::RHIDescriptorSet* m_sortDescriptorSet = nullptr;
|
||||
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_orderBufferUav;
|
||||
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_orderScratchUav;
|
||||
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_passHistogramUav;
|
||||
std::unique_ptr<XCEngine::RHI::D3D12ResourceView> m_globalHistogramUav;
|
||||
XCEngine::RHI::RHIPipelineLayout* m_buildSortKeyPipelineLayout = nullptr;
|
||||
XCEngine::RHI::RHIPipelineState* m_buildSortKeyPipelineState = nullptr;
|
||||
XCEngine::RHI::RHIDescriptorPool* m_buildSortKeyDescriptorPool = nullptr;
|
||||
XCEngine::RHI::RHIDescriptorSet* m_buildSortKeyDescriptorSet = nullptr;
|
||||
XCEngine::RHI::RHIPipelineLayout* m_radixSortPipelineLayout = nullptr;
|
||||
XCEngine::RHI::RHIPipelineState* m_radixSortInitPipelineState = nullptr;
|
||||
XCEngine::RHI::RHIPipelineState* m_radixSortUpsweepPipelineState = nullptr;
|
||||
XCEngine::RHI::RHIPipelineState* m_radixSortGlobalHistogramPipelineState = nullptr;
|
||||
XCEngine::RHI::RHIPipelineState* m_radixSortScanPipelineState = nullptr;
|
||||
XCEngine::RHI::RHIPipelineState* m_radixSortDownsweepPipelineState = nullptr;
|
||||
XCEngine::RHI::RHIDescriptorPool* m_radixSortDescriptorPool = nullptr;
|
||||
XCEngine::RHI::RHIDescriptorSet* m_radixSortDescriptorSetPrimaryToScratch = nullptr;
|
||||
XCEngine::RHI::RHIDescriptorSet* m_radixSortDescriptorSetScratchToPrimary = nullptr;
|
||||
XCEngine::RHI::RHIPipelineLayout* m_debugPipelineLayout = nullptr;
|
||||
XCEngine::RHI::RHIPipelineState* m_debugPipelineState = nullptr;
|
||||
XCEngine::RHI::RHIDescriptorPool* m_debugDescriptorPool = nullptr;
|
||||
|
||||
477
MVS/3DGS-D3D12/shaders/DeviceRadixSort.hlsl
Normal file
477
MVS/3DGS-D3D12/shaders/DeviceRadixSort.hlsl
Normal file
@@ -0,0 +1,477 @@
|
||||
/******************************************************************************
|
||||
* DeviceRadixSort
|
||||
* Device Level 8-bit LSD Radix Sort using reduce then scan
|
||||
*
|
||||
* SPDX-License-Identifier: MIT
|
||||
* Copyright Thomas Smith 5/17/2024
|
||||
* https://github.com/b0nes164/GPUSorting
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
******************************************************************************/
|
||||
#include "SortCommon.hlsl"
|
||||
|
||||
#define US_DIM 128U //The number of threads in a Upsweep threadblock
|
||||
#define SCAN_DIM 128U //The number of threads in a Scan threadblock
|
||||
|
||||
RWStructuredBuffer<uint> b_globalHist : register(u5); //buffer holding device level offsets for each binning pass
|
||||
RWStructuredBuffer<uint> b_passHist : register(u4); //buffer used to store reduced sums of partition tiles
|
||||
|
||||
groupshared uint g_us[RADIX * 2]; //Shared memory for upsweep
|
||||
groupshared uint g_scan[SCAN_DIM]; //Shared memory for the scan
|
||||
|
||||
//*****************************************************************************
|
||||
//INIT KERNEL
|
||||
//*****************************************************************************
|
||||
//Clear the global histogram, as we will be adding to it atomically
|
||||
[numthreads(1024, 1, 1)]
|
||||
void InitDeviceRadixSort(int3 id : SV_DispatchThreadID)
|
||||
{
|
||||
b_globalHist[id.x] = 0;
|
||||
}
|
||||
|
||||
//*****************************************************************************
|
||||
//UPSWEEP KERNEL
|
||||
//*****************************************************************************
|
||||
//histogram, 64 threads to a histogram
|
||||
inline void HistogramDigitCounts(uint gtid, uint gid)
|
||||
{
|
||||
const uint histOffset = gtid / 64 * RADIX;
|
||||
const uint partitionEnd = gid == e_threadBlocks - 1 ?
|
||||
e_numKeys : (gid + 1) * PART_SIZE;
|
||||
for (uint i = gtid + gid * PART_SIZE; i < partitionEnd; i += US_DIM)
|
||||
{
|
||||
#if defined(KEY_UINT)
|
||||
InterlockedAdd(g_us[ExtractDigit(b_sort[i]) + histOffset], 1);
|
||||
#elif defined(KEY_INT)
|
||||
InterlockedAdd(g_us[ExtractDigit(IntToUint(b_sort[i])) + histOffset], 1);
|
||||
#elif defined(KEY_FLOAT)
|
||||
InterlockedAdd(g_us[ExtractDigit(FloatToUint(b_sort[i])) + histOffset], 1);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
//reduce and pass to tile histogram
|
||||
inline void ReduceWriteDigitCounts(uint gtid, uint gid)
|
||||
{
|
||||
for (uint i = gtid; i < RADIX; i += US_DIM)
|
||||
{
|
||||
g_us[i] += g_us[i + RADIX];
|
||||
b_passHist[i * e_threadBlocks + gid] = g_us[i];
|
||||
}
|
||||
}
|
||||
|
||||
//Build the per-pass 256-bin exclusive prefix from the reduced pass histogram.
|
||||
inline void BuildGlobalHistogramExclusive(uint gtid)
|
||||
{
|
||||
uint digitIndices[2];
|
||||
uint digitTotals[2];
|
||||
uint digitCount = 0;
|
||||
for (uint i = gtid; i < RADIX; i += US_DIM)
|
||||
{
|
||||
uint total = 0u;
|
||||
const uint baseOffset = i * e_threadBlocks;
|
||||
for (uint blockIndex = 0; blockIndex < e_threadBlocks; ++blockIndex)
|
||||
{
|
||||
total += b_passHist[baseOffset + blockIndex];
|
||||
}
|
||||
|
||||
g_us[i] = total;
|
||||
digitIndices[digitCount] = i;
|
||||
digitTotals[digitCount] = total;
|
||||
++digitCount;
|
||||
}
|
||||
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
for (uint offset = 1; offset < RADIX; offset <<= 1)
|
||||
{
|
||||
for (uint i = gtid; i < RADIX; i += US_DIM)
|
||||
{
|
||||
g_us[i + RADIX] = g_us[i] + (i >= offset ? g_us[i - offset] : 0u);
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
for (uint i = gtid; i < RADIX; i += US_DIM)
|
||||
{
|
||||
g_us[i] = g_us[i + RADIX];
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
}
|
||||
|
||||
const uint globalHistOffset = GlobalHistOffset();
|
||||
for (uint localIndex = 0; localIndex < digitCount; ++localIndex)
|
||||
{
|
||||
const uint digitIndex = digitIndices[localIndex];
|
||||
b_globalHist[digitIndex + globalHistOffset] = g_us[digitIndex] - digitTotals[localIndex];
|
||||
}
|
||||
}
|
||||
|
||||
[numthreads(US_DIM, 1, 1)]
|
||||
void Upsweep(uint3 gtid : SV_GroupThreadID, uint3 gid : SV_GroupID)
|
||||
{
|
||||
//get the wave size
|
||||
const uint waveSize = getWaveSize();
|
||||
|
||||
//clear shared memory
|
||||
const uint histsEnd = RADIX * 2;
|
||||
for (uint i = gtid.x; i < histsEnd; i += US_DIM)
|
||||
g_us[i] = 0;
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
HistogramDigitCounts(gtid.x, gid.x);
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
ReduceWriteDigitCounts(gtid.x, gid.x);
|
||||
}
|
||||
|
||||
[numthreads(US_DIM, 1, 1)]
|
||||
void BuildGlobalHistogram(uint3 gtid : SV_GroupThreadID)
|
||||
{
|
||||
const uint histsEnd = RADIX * 2;
|
||||
for (uint i = gtid.x; i < histsEnd; i += US_DIM)
|
||||
g_us[i] = 0;
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
BuildGlobalHistogramExclusive(gtid.x);
|
||||
}
|
||||
|
||||
//*****************************************************************************
|
||||
//SCAN KERNEL
|
||||
//*****************************************************************************
|
||||
inline void ExclusiveThreadBlockScanFullWGE16(
|
||||
uint gtid,
|
||||
uint laneMask,
|
||||
uint circularLaneShift,
|
||||
uint partEnd,
|
||||
uint deviceOffset,
|
||||
uint waveSize,
|
||||
inout uint reduction)
|
||||
{
|
||||
for (uint i = gtid; i < partEnd; i += SCAN_DIM)
|
||||
{
|
||||
g_scan[gtid] = b_passHist[i + deviceOffset];
|
||||
g_scan[gtid] += WavePrefixSum(g_scan[gtid]);
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
if (gtid < SCAN_DIM / waveSize)
|
||||
{
|
||||
g_scan[(gtid + 1) * waveSize - 1] +=
|
||||
WavePrefixSum(g_scan[(gtid + 1) * waveSize - 1]);
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
uint t = (WaveGetLaneIndex() != laneMask ? g_scan[gtid] : 0) + reduction;
|
||||
if (gtid >= waveSize)
|
||||
t += WaveReadLaneAt(g_scan[gtid - 1], 0);
|
||||
b_passHist[circularLaneShift + (i & ~laneMask) + deviceOffset] = t;
|
||||
|
||||
reduction += g_scan[SCAN_DIM - 1];
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
}
|
||||
}
|
||||
|
||||
inline void ExclusiveThreadBlockScanPartialWGE16(
|
||||
uint gtid,
|
||||
uint laneMask,
|
||||
uint circularLaneShift,
|
||||
uint partEnd,
|
||||
uint deviceOffset,
|
||||
uint waveSize,
|
||||
uint reduction)
|
||||
{
|
||||
uint i = gtid + partEnd;
|
||||
if (i < e_threadBlocks)
|
||||
g_scan[gtid] = b_passHist[deviceOffset + i];
|
||||
g_scan[gtid] += WavePrefixSum(g_scan[gtid]);
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
if (gtid < SCAN_DIM / waveSize)
|
||||
{
|
||||
g_scan[(gtid + 1) * waveSize - 1] +=
|
||||
WavePrefixSum(g_scan[(gtid + 1) * waveSize - 1]);
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
const uint index = circularLaneShift + (i & ~laneMask);
|
||||
if (index < e_threadBlocks)
|
||||
{
|
||||
uint t = (WaveGetLaneIndex() != laneMask ? g_scan[gtid] : 0) + reduction;
|
||||
if (gtid >= waveSize)
|
||||
t += g_scan[(gtid & ~laneMask) - 1];
|
||||
b_passHist[index + deviceOffset] = t;
|
||||
}
|
||||
}
|
||||
|
||||
inline void ExclusiveThreadBlockScanWGE16(uint gtid, uint gid, uint waveSize)
|
||||
{
|
||||
uint reduction = 0;
|
||||
const uint laneMask = waveSize - 1;
|
||||
const uint circularLaneShift = WaveGetLaneIndex() + 1 & laneMask;
|
||||
const uint partionsEnd = e_threadBlocks / SCAN_DIM * SCAN_DIM;
|
||||
const uint deviceOffset = gid * e_threadBlocks;
|
||||
|
||||
ExclusiveThreadBlockScanFullWGE16(
|
||||
gtid,
|
||||
laneMask,
|
||||
circularLaneShift,
|
||||
partionsEnd,
|
||||
deviceOffset,
|
||||
waveSize,
|
||||
reduction);
|
||||
|
||||
ExclusiveThreadBlockScanPartialWGE16(
|
||||
gtid,
|
||||
laneMask,
|
||||
circularLaneShift,
|
||||
partionsEnd,
|
||||
deviceOffset,
|
||||
waveSize,
|
||||
reduction);
|
||||
}
|
||||
|
||||
inline void ExclusiveThreadBlockScanFullWLT16(
|
||||
uint gtid,
|
||||
uint partitions,
|
||||
uint deviceOffset,
|
||||
uint laneLog,
|
||||
uint circularLaneShift,
|
||||
uint waveSize,
|
||||
inout uint reduction)
|
||||
{
|
||||
for (uint k = 0; k < partitions; ++k)
|
||||
{
|
||||
g_scan[gtid] = b_passHist[gtid + k * SCAN_DIM + deviceOffset];
|
||||
g_scan[gtid] += WavePrefixSum(g_scan[gtid]);
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
if (gtid < waveSize)
|
||||
{
|
||||
b_passHist[circularLaneShift + k * SCAN_DIM + deviceOffset] =
|
||||
(circularLaneShift ? g_scan[gtid] : 0) + reduction;
|
||||
}
|
||||
|
||||
uint offset = laneLog;
|
||||
uint j = waveSize;
|
||||
for (; j < (SCAN_DIM >> 1); j <<= laneLog)
|
||||
{
|
||||
if (gtid < (SCAN_DIM >> offset))
|
||||
{
|
||||
g_scan[((gtid + 1) << offset) - 1] +=
|
||||
WavePrefixSum(g_scan[((gtid + 1) << offset) - 1]);
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
if ((gtid & ((j << laneLog) - 1)) >= j)
|
||||
{
|
||||
if (gtid < (j << laneLog))
|
||||
{
|
||||
b_passHist[gtid + k * SCAN_DIM + deviceOffset] =
|
||||
WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0) +
|
||||
((gtid & (j - 1)) ? g_scan[gtid - 1] : 0) + reduction;
|
||||
}
|
||||
else
|
||||
{
|
||||
if ((gtid + 1) & (j - 1))
|
||||
{
|
||||
g_scan[gtid] +=
|
||||
WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
offset += laneLog;
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
//If SCAN_DIM is not a power of lanecount
|
||||
for (uint i = gtid + j; i < SCAN_DIM; i += SCAN_DIM)
|
||||
{
|
||||
b_passHist[i + k * SCAN_DIM + deviceOffset] =
|
||||
WaveReadLaneAt(g_scan[((i >> offset) << offset) - 1], 0) +
|
||||
((i & (j - 1)) ? g_scan[i - 1] : 0) + reduction;
|
||||
}
|
||||
|
||||
reduction += WaveReadLaneAt(g_scan[SCAN_DIM - 1], 0) +
|
||||
WaveReadLaneAt(g_scan[(((SCAN_DIM - 1) >> offset) << offset) - 1], 0);
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
}
|
||||
}
|
||||
|
||||
inline void ExclusiveThreadBlockScanParitalWLT16(
|
||||
uint gtid,
|
||||
uint partitions,
|
||||
uint deviceOffset,
|
||||
uint laneLog,
|
||||
uint circularLaneShift,
|
||||
uint waveSize,
|
||||
uint reduction)
|
||||
{
|
||||
const uint finalPartSize = e_threadBlocks - partitions * SCAN_DIM;
|
||||
if (gtid < finalPartSize)
|
||||
{
|
||||
g_scan[gtid] = b_passHist[gtid + partitions * SCAN_DIM + deviceOffset];
|
||||
g_scan[gtid] += WavePrefixSum(g_scan[gtid]);
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
if (gtid < waveSize && circularLaneShift < finalPartSize)
|
||||
{
|
||||
b_passHist[circularLaneShift + partitions * SCAN_DIM + deviceOffset] =
|
||||
(circularLaneShift ? g_scan[gtid] : 0) + reduction;
|
||||
}
|
||||
|
||||
uint offset = laneLog;
|
||||
for (uint j = waveSize; j < finalPartSize; j <<= laneLog)
|
||||
{
|
||||
if (gtid < (finalPartSize >> offset))
|
||||
{
|
||||
g_scan[((gtid + 1) << offset) - 1] +=
|
||||
WavePrefixSum(g_scan[((gtid + 1) << offset) - 1]);
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
if ((gtid & ((j << laneLog) - 1)) >= j && gtid < finalPartSize)
|
||||
{
|
||||
if (gtid < (j << laneLog))
|
||||
{
|
||||
b_passHist[gtid + partitions * SCAN_DIM + deviceOffset] =
|
||||
WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0) +
|
||||
((gtid & (j - 1)) ? g_scan[gtid - 1] : 0) + reduction;
|
||||
}
|
||||
else
|
||||
{
|
||||
if ((gtid + 1) & (j - 1))
|
||||
{
|
||||
g_scan[gtid] +=
|
||||
WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
offset += laneLog;
|
||||
}
|
||||
}
|
||||
|
||||
inline void ExclusiveThreadBlockScanWLT16(uint gtid, uint gid, uint waveSize)
|
||||
{
|
||||
uint reduction = 0;
|
||||
const uint partitions = e_threadBlocks / SCAN_DIM;
|
||||
const uint deviceOffset = gid * e_threadBlocks;
|
||||
const uint laneLog = countbits(waveSize - 1);
|
||||
const uint circularLaneShift = WaveGetLaneIndex() + 1 & waveSize - 1;
|
||||
|
||||
ExclusiveThreadBlockScanFullWLT16(
|
||||
gtid,
|
||||
partitions,
|
||||
deviceOffset,
|
||||
laneLog,
|
||||
circularLaneShift,
|
||||
waveSize,
|
||||
reduction);
|
||||
|
||||
ExclusiveThreadBlockScanParitalWLT16(
|
||||
gtid,
|
||||
partitions,
|
||||
deviceOffset,
|
||||
laneLog,
|
||||
circularLaneShift,
|
||||
waveSize,
|
||||
reduction);
|
||||
}
|
||||
|
||||
//Scan does not need flattening of gids
|
||||
[numthreads(SCAN_DIM, 1, 1)]
|
||||
void Scan(uint3 gtid : SV_GroupThreadID, uint3 gid : SV_GroupID)
|
||||
{
|
||||
if (gtid.x != 0u)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const uint deviceOffset = gid.x * e_threadBlocks;
|
||||
uint runningOffset = 0u;
|
||||
for (uint blockIndex = 0u; blockIndex < e_threadBlocks; ++blockIndex)
|
||||
{
|
||||
const uint index = deviceOffset + blockIndex;
|
||||
const uint count = b_passHist[index];
|
||||
b_passHist[index] = runningOffset;
|
||||
runningOffset += count;
|
||||
}
|
||||
}
|
||||
|
||||
//*****************************************************************************
|
||||
//DOWNSWEEP KERNEL
|
||||
//*****************************************************************************
|
||||
inline void LoadThreadBlockReductions(uint gtid, uint gid, uint exclusiveHistReduction)
|
||||
{
|
||||
if (gtid < RADIX)
|
||||
{
|
||||
g_d[gtid + PART_SIZE] = b_globalHist[gtid + GlobalHistOffset()] +
|
||||
b_passHist[gtid * e_threadBlocks + gid] - exclusiveHistReduction;
|
||||
}
|
||||
}
|
||||
|
||||
[numthreads(D_DIM, 1, 1)]
|
||||
void Downsweep(uint3 gtid : SV_GroupThreadID, uint3 gid : SV_GroupID)
|
||||
{
|
||||
if (gtid.x != 0u)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const uint partitionStart = gid.x * PART_SIZE;
|
||||
const uint partitionEnd = min(partitionStart + PART_SIZE, e_numKeys);
|
||||
uint digitOffsets[RADIX];
|
||||
const uint globalHistOffset = GlobalHistOffset();
|
||||
for (uint digit = 0u; digit < RADIX; ++digit)
|
||||
{
|
||||
digitOffsets[digit] =
|
||||
b_globalHist[globalHistOffset + digit] +
|
||||
b_passHist[digit * e_threadBlocks + gid.x];
|
||||
}
|
||||
|
||||
for (uint index = partitionStart; index < partitionEnd; ++index)
|
||||
{
|
||||
uint key;
|
||||
#if defined(KEY_UINT)
|
||||
key = b_sort[index];
|
||||
#elif defined(KEY_INT)
|
||||
key = IntToUint(b_sort[index]);
|
||||
#elif defined(KEY_FLOAT)
|
||||
key = FloatToUint(b_sort[index]);
|
||||
#endif
|
||||
|
||||
const uint digit = ExtractDigit(key);
|
||||
const uint destinationIndex = digitOffsets[digit]++;
|
||||
|
||||
#if defined(KEY_UINT)
|
||||
b_alt[destinationIndex] = key;
|
||||
#elif defined(KEY_INT)
|
||||
b_alt[destinationIndex] = UintToInt(key);
|
||||
#elif defined(KEY_FLOAT)
|
||||
b_alt[destinationIndex] = UintToFloat(key);
|
||||
#endif
|
||||
|
||||
#if defined(SORT_PAIRS)
|
||||
#if defined(PAYLOAD_UINT)
|
||||
b_altPayload[destinationIndex] = b_sortPayload[index];
|
||||
#elif defined(PAYLOAD_INT)
|
||||
b_altPayload[destinationIndex] = b_sortPayload[index];
|
||||
#elif defined(PAYLOAD_FLOAT)
|
||||
b_altPayload[destinationIndex] = b_sortPayload[index];
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
}
|
||||
959
MVS/3DGS-D3D12/shaders/SortCommon.hlsl
Normal file
959
MVS/3DGS-D3D12/shaders/SortCommon.hlsl
Normal file
@@ -0,0 +1,959 @@
|
||||
/******************************************************************************
|
||||
* SortCommon
|
||||
* Common functions for GPUSorting
|
||||
*
|
||||
* SPDX-License-Identifier: MIT
|
||||
* Copyright Thomas Smith 5/17/2024
|
||||
* https://github.com/b0nes164/GPUSorting
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
******************************************************************************/
|
||||
#define KEYS_PER_THREAD 15U
|
||||
#define D_DIM 256U
|
||||
#define PART_SIZE 3840U
|
||||
#define D_TOTAL_SMEM 4096U
|
||||
|
||||
#define RADIX 256U //Number of digit bins
|
||||
#define RADIX_MASK 255U //Mask of digit bins
|
||||
#define HALF_RADIX 128U //For smaller waves where bit packing is necessary
|
||||
#define HALF_MASK 127U // ''
|
||||
#define RADIX_LOG 8U //log2(RADIX)
|
||||
#define RADIX_PASSES 4U //(Key width) / RADIX_LOG
|
||||
|
||||
cbuffer cbGpuSorting : register(b0)
|
||||
{
|
||||
uint e_numKeys;
|
||||
uint e_radixShift;
|
||||
uint e_threadBlocks;
|
||||
uint padding;
|
||||
};
|
||||
|
||||
#if defined(KEY_UINT)
|
||||
RWStructuredBuffer<uint> b_sort : register(u0);
|
||||
RWStructuredBuffer<uint> b_alt : register(u1);
|
||||
#elif defined(KEY_INT)
|
||||
RWStructuredBuffer<int> b_sort : register(u0);
|
||||
RWStructuredBuffer<int> b_alt : register(u1);
|
||||
#elif defined(KEY_FLOAT)
|
||||
RWStructuredBuffer<float> b_sort : register(u0);
|
||||
RWStructuredBuffer<float> b_alt : register(u1);
|
||||
#endif
|
||||
|
||||
#if defined(PAYLOAD_UINT)
|
||||
RWStructuredBuffer<uint> b_sortPayload : register(u2);
|
||||
RWStructuredBuffer<uint> b_altPayload : register(u3);
|
||||
#elif defined(PAYLOAD_INT)
|
||||
RWStructuredBuffer<int> b_sortPayload : register(u2);
|
||||
RWStructuredBuffer<int> b_altPayload : register(u3);
|
||||
#elif defined(PAYLOAD_FLOAT)
|
||||
RWStructuredBuffer<float> b_sortPayload : register(u2);
|
||||
RWStructuredBuffer<float> b_altPayload : register(u3);
|
||||
#endif
|
||||
|
||||
groupshared uint g_d[D_TOTAL_SMEM]; //Shared memory for DigitBinningPass and DownSweep kernels
|
||||
|
||||
struct KeyStruct
|
||||
{
|
||||
uint k[KEYS_PER_THREAD];
|
||||
};
|
||||
|
||||
struct OffsetStruct
|
||||
{
|
||||
#if defined(ENABLE_16_BIT)
|
||||
uint16_t o[KEYS_PER_THREAD];
|
||||
#else
|
||||
uint o[KEYS_PER_THREAD];
|
||||
#endif
|
||||
};
|
||||
|
||||
struct DigitStruct
|
||||
{
|
||||
#if defined(ENABLE_16_BIT)
|
||||
uint16_t d[KEYS_PER_THREAD];
|
||||
#else
|
||||
uint d[KEYS_PER_THREAD];
|
||||
#endif
|
||||
};
|
||||
|
||||
//*****************************************************************************
|
||||
//HELPER FUNCTIONS
|
||||
//*****************************************************************************
|
||||
//Due to a bug with SPIRV pre 1.6, we cannot use WaveGetLaneCount() to get the currently active wavesize
|
||||
inline uint getWaveSize()
|
||||
{
|
||||
#if defined(VULKAN)
|
||||
GroupMemoryBarrierWithGroupSync(); //Make absolutely sure the wave is not diverged here
|
||||
return dot(countbits(WaveActiveBallot(true)), uint4(1, 1, 1, 1));
|
||||
#else
|
||||
return WaveGetLaneCount();
|
||||
#endif
|
||||
}
|
||||
|
||||
inline uint getWaveIndex(uint gtid, uint waveSize)
|
||||
{
|
||||
return gtid / waveSize;
|
||||
}
|
||||
|
||||
//Radix Tricks by Michael Herf
|
||||
//http://stereopsis.com/radix.html
|
||||
inline uint FloatToUint(float f)
|
||||
{
|
||||
uint mask = -((int) (asuint(f) >> 31)) | 0x80000000;
|
||||
return asuint(f) ^ mask;
|
||||
}
|
||||
|
||||
inline float UintToFloat(uint u)
|
||||
{
|
||||
uint mask = ((u >> 31) - 1) | 0x80000000;
|
||||
return asfloat(u ^ mask);
|
||||
}
|
||||
|
||||
inline uint IntToUint(int i)
|
||||
{
|
||||
return asuint(i ^ 0x80000000);
|
||||
}
|
||||
|
||||
inline int UintToInt(uint u)
|
||||
{
|
||||
return asint(u ^ 0x80000000);
|
||||
}
|
||||
|
||||
inline uint getWaveCountPass(uint waveSize)
|
||||
{
|
||||
return D_DIM / waveSize;
|
||||
}
|
||||
|
||||
inline uint ExtractDigit(uint key)
|
||||
{
|
||||
return key >> e_radixShift & RADIX_MASK;
|
||||
}
|
||||
|
||||
inline uint ExtractDigit(uint key, uint shift)
|
||||
{
|
||||
return key >> shift & RADIX_MASK;
|
||||
}
|
||||
|
||||
inline uint ExtractPackedIndex(uint key)
|
||||
{
|
||||
return key >> (e_radixShift + 1) & HALF_MASK;
|
||||
}
|
||||
|
||||
inline uint ExtractPackedShift(uint key)
|
||||
{
|
||||
return (key >> e_radixShift & 1) ? 16 : 0;
|
||||
}
|
||||
|
||||
inline uint ExtractPackedValue(uint packed, uint key)
|
||||
{
|
||||
return packed >> ExtractPackedShift(key) & 0xffff;
|
||||
}
|
||||
|
||||
inline uint SubPartSizeWGE16(uint waveSize)
|
||||
{
|
||||
return KEYS_PER_THREAD * waveSize;
|
||||
}
|
||||
|
||||
inline uint SharedOffsetWGE16(uint gtid, uint waveSize)
|
||||
{
|
||||
return WaveGetLaneIndex() + getWaveIndex(gtid, waveSize) * SubPartSizeWGE16(waveSize);
|
||||
}
|
||||
|
||||
inline uint SubPartSizeWLT16(uint waveSize, uint _serialIterations)
|
||||
{
|
||||
return KEYS_PER_THREAD * waveSize * _serialIterations;
|
||||
}
|
||||
|
||||
inline uint SharedOffsetWLT16(uint gtid, uint waveSize, uint _serialIterations)
|
||||
{
|
||||
return WaveGetLaneIndex() +
|
||||
(getWaveIndex(gtid, waveSize) / _serialIterations * SubPartSizeWLT16(waveSize, _serialIterations)) +
|
||||
(getWaveIndex(gtid, waveSize) % _serialIterations * waveSize);
|
||||
}
|
||||
|
||||
inline uint DeviceOffsetWGE16(uint gtid, uint waveSize, uint partIndex)
|
||||
{
|
||||
return SharedOffsetWGE16(gtid, waveSize) + partIndex * PART_SIZE;
|
||||
}
|
||||
|
||||
inline uint DeviceOffsetWLT16(uint gtid, uint waveSize, uint partIndex, uint serialIterations)
|
||||
{
|
||||
return SharedOffsetWLT16(gtid, waveSize, serialIterations) + partIndex * PART_SIZE;
|
||||
}
|
||||
|
||||
inline uint GlobalHistOffset()
|
||||
{
|
||||
return e_radixShift << 5;
|
||||
}
|
||||
|
||||
inline uint WaveHistsSizeWGE16(uint waveSize)
|
||||
{
|
||||
return D_DIM / waveSize * RADIX;
|
||||
}
|
||||
|
||||
inline uint WaveHistsSizeWLT16()
|
||||
{
|
||||
return D_TOTAL_SMEM;
|
||||
}
|
||||
|
||||
//*****************************************************************************
|
||||
//FUNCTIONS COMMON TO THE DOWNSWEEP / DIGIT BINNING PASS
|
||||
//*****************************************************************************
|
||||
//If the size of a wave is too small, we do not have enough space in
|
||||
//shared memory to assign a histogram to each wave, so instead,
|
||||
//some operations are peformed serially.
|
||||
inline uint SerialIterations(uint waveSize)
|
||||
{
|
||||
return (D_DIM / waveSize + 31) >> 5;
|
||||
}
|
||||
|
||||
inline void ClearWaveHists(uint gtid, uint waveSize)
|
||||
{
|
||||
const uint histsEnd = waveSize >= 16 ?
|
||||
WaveHistsSizeWGE16(waveSize) : WaveHistsSizeWLT16();
|
||||
for (uint i = gtid; i < histsEnd; i += D_DIM)
|
||||
g_d[i] = 0;
|
||||
}
|
||||
|
||||
inline void LoadKey(inout uint key, uint index)
|
||||
{
|
||||
#if defined(KEY_UINT)
|
||||
key = b_sort[index];
|
||||
#elif defined(KEY_INT)
|
||||
key = UintToInt(b_sort[index]);
|
||||
#elif defined(KEY_FLOAT)
|
||||
key = FloatToUint(b_sort[index]);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void LoadDummyKey(inout uint key)
|
||||
{
|
||||
key = 0xffffffff;
|
||||
}
|
||||
|
||||
inline KeyStruct LoadKeysWGE16(uint gtid, uint waveSize, uint partIndex)
|
||||
{
|
||||
KeyStruct keys;
|
||||
[unroll]
|
||||
for (uint i = 0, t = DeviceOffsetWGE16(gtid, waveSize, partIndex);
|
||||
i < KEYS_PER_THREAD;
|
||||
++i, t += waveSize)
|
||||
{
|
||||
LoadKey(keys.k[i], t);
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
inline KeyStruct LoadKeysWLT16(uint gtid, uint waveSize, uint partIndex, uint serialIterations)
|
||||
{
|
||||
KeyStruct keys;
|
||||
[unroll]
|
||||
for (uint i = 0, t = DeviceOffsetWLT16(gtid, waveSize, partIndex, serialIterations);
|
||||
i < KEYS_PER_THREAD;
|
||||
++i, t += waveSize * serialIterations)
|
||||
{
|
||||
LoadKey(keys.k[i], t);
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
inline KeyStruct LoadKeysPartialWGE16(uint gtid, uint waveSize, uint partIndex)
|
||||
{
|
||||
KeyStruct keys;
|
||||
[unroll]
|
||||
for (uint i = 0, t = DeviceOffsetWGE16(gtid, waveSize, partIndex);
|
||||
i < KEYS_PER_THREAD;
|
||||
++i, t += waveSize)
|
||||
{
|
||||
if (t < e_numKeys)
|
||||
LoadKey(keys.k[i], t);
|
||||
else
|
||||
LoadDummyKey(keys.k[i]);
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
inline KeyStruct LoadKeysPartialWLT16(uint gtid, uint waveSize, uint partIndex, uint serialIterations)
|
||||
{
|
||||
KeyStruct keys;
|
||||
[unroll]
|
||||
for (uint i = 0, t = DeviceOffsetWLT16(gtid, waveSize, partIndex, serialIterations);
|
||||
i < KEYS_PER_THREAD;
|
||||
++i, t += waveSize * serialIterations)
|
||||
{
|
||||
if (t < e_numKeys)
|
||||
LoadKey(keys.k[i], t);
|
||||
else
|
||||
LoadDummyKey(keys.k[i]);
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
inline uint WaveFlagsWGE16(uint waveSize)
|
||||
{
|
||||
return (waveSize & 31) ? (1U << waveSize) - 1 : 0xffffffff;
|
||||
}
|
||||
|
||||
inline uint WaveFlagsWLT16(uint waveSize)
|
||||
{
|
||||
return (1U << waveSize) - 1;;
|
||||
}
|
||||
|
||||
inline void WarpLevelMultiSplitWGE16(uint key, inout uint4 waveFlags)
|
||||
{
|
||||
[unroll]
|
||||
for (uint k = 0; k < RADIX_LOG; ++k)
|
||||
{
|
||||
const uint currentBit = 1U << (k + e_radixShift);
|
||||
const bool t = (key & currentBit) != 0;
|
||||
GroupMemoryBarrierWithGroupSync(); //Play on the safe side, throw in a barrier for convergence
|
||||
const uint4 ballot = WaveActiveBallot(t);
|
||||
if(t)
|
||||
waveFlags &= ballot;
|
||||
else
|
||||
waveFlags &= (~ballot);
|
||||
}
|
||||
}
|
||||
|
||||
inline uint2 CountBitsWGE16(uint waveSize, uint ltMask, uint4 waveFlags)
|
||||
{
|
||||
uint2 count = uint2(0, 0);
|
||||
|
||||
for(uint wavePart = 0; wavePart < waveSize; wavePart += 32)
|
||||
{
|
||||
uint t = countbits(waveFlags[wavePart >> 5]);
|
||||
if (WaveGetLaneIndex() >= wavePart)
|
||||
{
|
||||
if (WaveGetLaneIndex() >= wavePart + 32)
|
||||
count.x += t;
|
||||
else
|
||||
count.x += countbits(waveFlags[wavePart >> 5] & ltMask);
|
||||
}
|
||||
count.y += t;
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
inline void WarpLevelMultiSplitWLT16(uint key, inout uint waveFlags)
|
||||
{
|
||||
[unroll]
|
||||
for (uint k = 0; k < RADIX_LOG; ++k)
|
||||
{
|
||||
const bool t = key >> (k + e_radixShift) & 1;
|
||||
waveFlags &= (t ? 0 : 0xffffffff) ^ (uint) WaveActiveBallot(t);
|
||||
}
|
||||
}
|
||||
|
||||
inline OffsetStruct RankKeysWGE16(
|
||||
uint waveSize,
|
||||
uint waveOffset,
|
||||
KeyStruct keys)
|
||||
{
|
||||
OffsetStruct offsets;
|
||||
const uint initialFlags = WaveFlagsWGE16(waveSize);
|
||||
const uint ltMask = (1U << (WaveGetLaneIndex() & 31)) - 1;
|
||||
|
||||
[unroll]
|
||||
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||
{
|
||||
uint4 waveFlags = initialFlags;
|
||||
WarpLevelMultiSplitWGE16(keys.k[i], waveFlags);
|
||||
|
||||
const uint index = ExtractDigit(keys.k[i]) + waveOffset;
|
||||
const uint2 bitCount = CountBitsWGE16(waveSize, ltMask, waveFlags);
|
||||
|
||||
offsets.o[i] = g_d[index] + bitCount.x;
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
if (bitCount.x == 0)
|
||||
g_d[index] += bitCount.y;
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
}
|
||||
|
||||
return offsets;
|
||||
}
|
||||
|
||||
inline OffsetStruct RankKeysWLT16(uint waveSize, uint waveIndex, KeyStruct keys, uint serialIterations)
|
||||
{
|
||||
OffsetStruct offsets;
|
||||
const uint ltMask = (1U << WaveGetLaneIndex()) - 1;
|
||||
const uint initialFlags = WaveFlagsWLT16(waveSize);
|
||||
|
||||
[unroll]
|
||||
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||
{
|
||||
uint waveFlags = initialFlags;
|
||||
WarpLevelMultiSplitWLT16(keys.k[i], waveFlags);
|
||||
|
||||
const uint index = ExtractPackedIndex(keys.k[i]) +
|
||||
(waveIndex / serialIterations * HALF_RADIX);
|
||||
|
||||
const uint peerBits = countbits(waveFlags & ltMask);
|
||||
for (uint k = 0; k < serialIterations; ++k)
|
||||
{
|
||||
if (waveIndex % serialIterations == k)
|
||||
offsets.o[i] = ExtractPackedValue(g_d[index], keys.k[i]) + peerBits;
|
||||
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
if (waveIndex % serialIterations == k && peerBits == 0)
|
||||
{
|
||||
InterlockedAdd(g_d[index],
|
||||
countbits(waveFlags) << ExtractPackedShift(keys.k[i]));
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
}
|
||||
}
|
||||
|
||||
return offsets;
|
||||
}
|
||||
|
||||
inline uint WaveHistInclusiveScanCircularShiftWGE16(uint gtid, uint waveSize)
|
||||
{
|
||||
uint histReduction = g_d[gtid];
|
||||
for (uint i = gtid + RADIX; i < WaveHistsSizeWGE16(waveSize); i += RADIX)
|
||||
{
|
||||
histReduction += g_d[i];
|
||||
g_d[i] = histReduction - g_d[i];
|
||||
}
|
||||
return histReduction;
|
||||
}
|
||||
|
||||
inline uint WaveHistInclusiveScanCircularShiftWLT16(uint gtid)
|
||||
{
|
||||
uint histReduction = g_d[gtid];
|
||||
for (uint i = gtid + HALF_RADIX; i < WaveHistsSizeWLT16(); i += HALF_RADIX)
|
||||
{
|
||||
histReduction += g_d[i];
|
||||
g_d[i] = histReduction - g_d[i];
|
||||
}
|
||||
return histReduction;
|
||||
}
|
||||
|
||||
inline void WaveHistReductionExclusiveScanWGE16(uint gtid, uint waveSize, uint histReduction)
|
||||
{
|
||||
if (gtid < RADIX)
|
||||
{
|
||||
const uint laneMask = waveSize - 1;
|
||||
g_d[((WaveGetLaneIndex() + 1) & laneMask) + (gtid & ~laneMask)] = histReduction;
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
if (gtid < RADIX / waveSize)
|
||||
{
|
||||
g_d[gtid * waveSize] =
|
||||
WavePrefixSum(g_d[gtid * waveSize]);
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
uint t = WaveReadLaneAt(g_d[gtid], 0);
|
||||
if (gtid < RADIX && WaveGetLaneIndex())
|
||||
g_d[gtid] += t;
|
||||
}
|
||||
|
||||
//inclusive/exclusive prefix sum up the histograms,
|
||||
//use a blelloch scan for in place packed exclusive
|
||||
inline void WaveHistReductionExclusiveScanWLT16(uint gtid)
|
||||
{
|
||||
uint shift = 1;
|
||||
for (uint j = RADIX >> 2; j > 0; j >>= 1)
|
||||
{
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
if (gtid < j)
|
||||
{
|
||||
g_d[((((gtid << 1) + 2) << shift) - 1) >> 1] +=
|
||||
g_d[((((gtid << 1) + 1) << shift) - 1) >> 1] & 0xffff0000;
|
||||
}
|
||||
shift++;
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
if (gtid == 0)
|
||||
g_d[HALF_RADIX - 1] &= 0xffff;
|
||||
|
||||
for (uint j = 1; j < RADIX >> 1; j <<= 1)
|
||||
{
|
||||
--shift;
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
if (gtid < j)
|
||||
{
|
||||
const uint t = ((((gtid << 1) + 1) << shift) - 1) >> 1;
|
||||
const uint t2 = ((((gtid << 1) + 2) << shift) - 1) >> 1;
|
||||
const uint t3 = g_d[t];
|
||||
g_d[t] = (g_d[t] & 0xffff) | (g_d[t2] & 0xffff0000);
|
||||
g_d[t2] += t3 & 0xffff0000;
|
||||
}
|
||||
}
|
||||
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
if (gtid < HALF_RADIX)
|
||||
{
|
||||
const uint t = g_d[gtid];
|
||||
g_d[gtid] = (t >> 16) + (t << 16) + (t & 0xffff0000);
|
||||
}
|
||||
}
|
||||
|
||||
inline void UpdateOffsetsWGE16(
|
||||
uint gtid,
|
||||
uint waveSize,
|
||||
inout OffsetStruct offsets,
|
||||
KeyStruct keys)
|
||||
{
|
||||
if (gtid >= waveSize)
|
||||
{
|
||||
const uint t = getWaveIndex(gtid, waveSize) * RADIX;
|
||||
[unroll]
|
||||
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||
{
|
||||
const uint t2 = ExtractDigit(keys.k[i]);
|
||||
offsets.o[i] += g_d[t2 + t] + g_d[t2];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||
offsets.o[i] += g_d[ExtractDigit(keys.k[i])];
|
||||
}
|
||||
}
|
||||
|
||||
inline void UpdateOffsetsWLT16(
|
||||
uint gtid,
|
||||
uint waveSize,
|
||||
uint serialIterations,
|
||||
inout OffsetStruct offsets,
|
||||
KeyStruct keys)
|
||||
{
|
||||
if (gtid >= waveSize * serialIterations)
|
||||
{
|
||||
const uint t = getWaveIndex(gtid, waveSize) / serialIterations * HALF_RADIX;
|
||||
[unroll]
|
||||
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||
{
|
||||
const uint t2 = ExtractPackedIndex(keys.k[i]);
|
||||
offsets.o[i] += ExtractPackedValue(g_d[t2 + t] + g_d[t2], keys.k[i]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||
offsets.o[i] += ExtractPackedValue(g_d[ExtractPackedIndex(keys.k[i])], keys.k[i]);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterKeysShared(OffsetStruct offsets, KeyStruct keys)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||
g_d[offsets.o[i]] = keys.k[i];
|
||||
}
|
||||
|
||||
inline uint DescendingIndex(uint deviceIndex)
|
||||
{
|
||||
return e_numKeys - deviceIndex - 1;
|
||||
}
|
||||
|
||||
inline void WriteKey(uint deviceIndex, uint groupSharedIndex)
|
||||
{
|
||||
#if defined(KEY_UINT)
|
||||
b_alt[deviceIndex] = g_d[groupSharedIndex];
|
||||
#elif defined(KEY_INT)
|
||||
b_alt[deviceIndex] = UintToInt(g_d[groupSharedIndex]);
|
||||
#elif defined(KEY_FLOAT)
|
||||
b_alt[deviceIndex] = UintToFloat(g_d[groupSharedIndex]);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void LoadPayload(inout uint payload, uint deviceIndex)
|
||||
{
|
||||
#if defined(PAYLOAD_UINT)
|
||||
payload = b_sortPayload[deviceIndex];
|
||||
#elif defined(PAYLOAD_INT) || defined(PAYLOAD_FLOAT)
|
||||
payload = asuint(b_sortPayload[deviceIndex]);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void ScatterPayloadsShared(OffsetStruct offsets, KeyStruct payloads)
|
||||
{
|
||||
ScatterKeysShared(offsets, payloads);
|
||||
}
|
||||
|
||||
inline void WritePayload(uint deviceIndex, uint groupSharedIndex)
|
||||
{
|
||||
#if defined(PAYLOAD_UINT)
|
||||
b_altPayload[deviceIndex] = g_d[groupSharedIndex];
|
||||
#elif defined(PAYLOAD_INT)
|
||||
b_altPayload[deviceIndex] = asint(g_d[groupSharedIndex]);
|
||||
#elif defined(PAYLOAD_FLOAT)
|
||||
b_altPayload[deviceIndex] = asfloat(g_d[groupSharedIndex]);
|
||||
#endif
|
||||
}
|
||||
|
||||
//*****************************************************************************
|
||||
//SCATTERING: FULL PARTITIONS
|
||||
//*****************************************************************************
|
||||
//KEYS ONLY
|
||||
inline void ScatterKeysOnlyDeviceAscending(uint gtid)
|
||||
{
|
||||
for (uint i = gtid; i < PART_SIZE; i += D_DIM)
|
||||
WriteKey(g_d[ExtractDigit(g_d[i]) + PART_SIZE] + i, i);
|
||||
}
|
||||
|
||||
inline void ScatterKeysOnlyDeviceDescending(uint gtid)
|
||||
{
|
||||
if (e_radixShift == 24)
|
||||
{
|
||||
for (uint i = gtid; i < PART_SIZE; i += D_DIM)
|
||||
WriteKey(DescendingIndex(g_d[ExtractDigit(g_d[i]) + PART_SIZE] + i), i);
|
||||
}
|
||||
else
|
||||
{
|
||||
ScatterKeysOnlyDeviceAscending(gtid);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterKeysOnlyDevice(uint gtid)
|
||||
{
|
||||
#if defined(SHOULD_ASCEND)
|
||||
ScatterKeysOnlyDeviceAscending(gtid);
|
||||
#else
|
||||
ScatterKeysOnlyDeviceDescending(gtid);
|
||||
#endif
|
||||
}
|
||||
|
||||
//KEY VALUE PAIRS
|
||||
inline void ScatterPairsKeyPhaseAscending(
|
||||
uint gtid,
|
||||
inout DigitStruct digits)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||
{
|
||||
digits.d[i] = ExtractDigit(g_d[t]);
|
||||
WriteKey(g_d[digits.d[i] + PART_SIZE] + t, t);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterPairsKeyPhaseDescending(
|
||||
uint gtid,
|
||||
inout DigitStruct digits)
|
||||
{
|
||||
if (e_radixShift == 24)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||
{
|
||||
digits.d[i] = ExtractDigit(g_d[t]);
|
||||
WriteKey(DescendingIndex(g_d[digits.d[i] + PART_SIZE] + t), t);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
ScatterPairsKeyPhaseAscending(gtid, digits);
|
||||
}
|
||||
}
|
||||
|
||||
inline void LoadPayloadsWGE16(
|
||||
uint gtid,
|
||||
uint waveSize,
|
||||
uint partIndex,
|
||||
inout KeyStruct payloads)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = DeviceOffsetWGE16(gtid, waveSize, partIndex);
|
||||
i < KEYS_PER_THREAD;
|
||||
++i, t += waveSize)
|
||||
{
|
||||
LoadPayload(payloads.k[i], t);
|
||||
}
|
||||
}
|
||||
|
||||
inline void LoadPayloadsWLT16(
|
||||
uint gtid,
|
||||
uint waveSize,
|
||||
uint partIndex,
|
||||
uint serialIterations,
|
||||
inout KeyStruct payloads)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = DeviceOffsetWLT16(gtid, waveSize, partIndex, serialIterations);
|
||||
i < KEYS_PER_THREAD;
|
||||
++i, t += waveSize * serialIterations)
|
||||
{
|
||||
LoadPayload(payloads.k[i], t);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterPayloadsAscending(uint gtid, DigitStruct digits)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||
WritePayload(g_d[digits.d[i] + PART_SIZE] + t, t);
|
||||
}
|
||||
|
||||
inline void ScatterPayloadsDescending(uint gtid, DigitStruct digits)
|
||||
{
|
||||
if (e_radixShift == 24)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||
WritePayload(DescendingIndex(g_d[digits.d[i] + PART_SIZE] + t), t);
|
||||
}
|
||||
else
|
||||
{
|
||||
ScatterPayloadsAscending(gtid, digits);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterPairsDevice(
|
||||
uint gtid,
|
||||
uint waveSize,
|
||||
uint partIndex,
|
||||
OffsetStruct offsets)
|
||||
{
|
||||
DigitStruct digits;
|
||||
#if defined(SHOULD_ASCEND)
|
||||
ScatterPairsKeyPhaseAscending(gtid, digits);
|
||||
#else
|
||||
ScatterPairsKeyPhaseDescending(gtid, digits);
|
||||
#endif
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
KeyStruct payloads;
|
||||
if (waveSize >= 16)
|
||||
LoadPayloadsWGE16(gtid, waveSize, partIndex, payloads);
|
||||
else
|
||||
LoadPayloadsWLT16(gtid, waveSize, partIndex, SerialIterations(waveSize), payloads);
|
||||
ScatterPayloadsShared(offsets, payloads);
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
#if defined(SHOULD_ASCEND)
|
||||
ScatterPayloadsAscending(gtid, digits);
|
||||
#else
|
||||
ScatterPayloadsDescending(gtid, digits);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void ScatterDevice(
|
||||
uint gtid,
|
||||
uint waveSize,
|
||||
uint partIndex,
|
||||
OffsetStruct offsets)
|
||||
{
|
||||
#if defined(SORT_PAIRS)
|
||||
ScatterPairsDevice(
|
||||
gtid,
|
||||
waveSize,
|
||||
partIndex,
|
||||
offsets);
|
||||
#else
|
||||
ScatterKeysOnlyDevice(gtid);
|
||||
#endif
|
||||
}
|
||||
|
||||
//*****************************************************************************
|
||||
//SCATTERING: PARTIAL PARTITIONS
|
||||
//*****************************************************************************
|
||||
//KEYS ONLY
|
||||
inline void ScatterKeysOnlyDevicePartialAscending(uint gtid, uint finalPartSize)
|
||||
{
|
||||
for (uint i = gtid; i < PART_SIZE; i += D_DIM)
|
||||
{
|
||||
if (i < finalPartSize)
|
||||
WriteKey(g_d[ExtractDigit(g_d[i]) + PART_SIZE] + i, i);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterKeysOnlyDevicePartialDescending(uint gtid, uint finalPartSize)
|
||||
{
|
||||
if (e_radixShift == 24)
|
||||
{
|
||||
for (uint i = gtid; i < PART_SIZE; i += D_DIM)
|
||||
{
|
||||
if (i < finalPartSize)
|
||||
WriteKey(DescendingIndex(g_d[ExtractDigit(g_d[i]) + PART_SIZE] + i), i);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
ScatterKeysOnlyDevicePartialAscending(gtid, finalPartSize);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterKeysOnlyDevicePartial(uint gtid, uint partIndex)
|
||||
{
|
||||
const uint finalPartSize = e_numKeys - partIndex * PART_SIZE;
|
||||
#if defined(SHOULD_ASCEND)
|
||||
ScatterKeysOnlyDevicePartialAscending(gtid, finalPartSize);
|
||||
#else
|
||||
ScatterKeysOnlyDevicePartialDescending(gtid, finalPartSize);
|
||||
#endif
|
||||
}
|
||||
|
||||
//KEY VALUE PAIRS
|
||||
inline void ScatterPairsKeyPhaseAscendingPartial(
|
||||
uint gtid,
|
||||
uint finalPartSize,
|
||||
inout DigitStruct digits)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||
{
|
||||
if (t < finalPartSize)
|
||||
{
|
||||
digits.d[i] = ExtractDigit(g_d[t]);
|
||||
WriteKey(g_d[digits.d[i] + PART_SIZE] + t, t);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterPairsKeyPhaseDescendingPartial(
|
||||
uint gtid,
|
||||
uint finalPartSize,
|
||||
inout DigitStruct digits)
|
||||
{
|
||||
if (e_radixShift == 24)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||
{
|
||||
if (t < finalPartSize)
|
||||
{
|
||||
digits.d[i] = ExtractDigit(g_d[t]);
|
||||
WriteKey(DescendingIndex(g_d[digits.d[i] + PART_SIZE] + t), t);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
ScatterPairsKeyPhaseAscendingPartial(gtid, finalPartSize, digits);
|
||||
}
|
||||
}
|
||||
|
||||
inline void LoadPayloadsPartialWGE16(
|
||||
uint gtid,
|
||||
uint waveSize,
|
||||
uint partIndex,
|
||||
inout KeyStruct payloads)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = DeviceOffsetWGE16(gtid, waveSize, partIndex);
|
||||
i < KEYS_PER_THREAD;
|
||||
++i, t += waveSize)
|
||||
{
|
||||
if (t < e_numKeys)
|
||||
LoadPayload(payloads.k[i], t);
|
||||
}
|
||||
}
|
||||
|
||||
inline void LoadPayloadsPartialWLT16(
|
||||
uint gtid,
|
||||
uint waveSize,
|
||||
uint partIndex,
|
||||
uint serialIterations,
|
||||
inout KeyStruct payloads)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = DeviceOffsetWLT16(gtid, waveSize, partIndex, serialIterations);
|
||||
i < KEYS_PER_THREAD;
|
||||
++i, t += waveSize * serialIterations)
|
||||
{
|
||||
if (t < e_numKeys)
|
||||
LoadPayload(payloads.k[i], t);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterPayloadsAscendingPartial(
|
||||
uint gtid,
|
||||
uint finalPartSize,
|
||||
DigitStruct digits)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||
{
|
||||
if (t < finalPartSize)
|
||||
WritePayload(g_d[digits.d[i] + PART_SIZE] + t, t);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterPayloadsDescendingPartial(
|
||||
uint gtid,
|
||||
uint finalPartSize,
|
||||
DigitStruct digits)
|
||||
{
|
||||
if (e_radixShift == 24)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||
{
|
||||
if (t < finalPartSize)
|
||||
WritePayload(DescendingIndex(g_d[digits.d[i] + PART_SIZE] + t), t);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
ScatterPayloadsAscendingPartial(gtid, finalPartSize, digits);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterPairsDevicePartial(
|
||||
uint gtid,
|
||||
uint waveSize,
|
||||
uint partIndex,
|
||||
OffsetStruct offsets)
|
||||
{
|
||||
DigitStruct digits;
|
||||
const uint finalPartSize = e_numKeys - partIndex * PART_SIZE;
|
||||
#if defined(SHOULD_ASCEND)
|
||||
ScatterPairsKeyPhaseAscendingPartial(gtid, finalPartSize, digits);
|
||||
#else
|
||||
ScatterPairsKeyPhaseDescendingPartial(gtid, finalPartSize, digits);
|
||||
#endif
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
KeyStruct payloads;
|
||||
if (waveSize >= 16)
|
||||
LoadPayloadsPartialWGE16(gtid, waveSize, partIndex, payloads);
|
||||
else
|
||||
LoadPayloadsPartialWLT16(gtid, waveSize, partIndex, SerialIterations(waveSize), payloads);
|
||||
ScatterPayloadsShared(offsets, payloads);
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
#if defined(SHOULD_ASCEND)
|
||||
ScatterPayloadsAscendingPartial(gtid, finalPartSize, digits);
|
||||
#else
|
||||
ScatterPayloadsDescendingPartial(gtid, finalPartSize, digits);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void ScatterDevicePartial(
|
||||
uint gtid,
|
||||
uint waveSize,
|
||||
uint partIndex,
|
||||
OffsetStruct offsets)
|
||||
{
|
||||
#if defined(SORT_PAIRS)
|
||||
ScatterPairsDevicePartial(
|
||||
gtid,
|
||||
waveSize,
|
||||
partIndex,
|
||||
offsets);
|
||||
#else
|
||||
ScatterKeysOnlyDevicePartial(gtid, partIndex);
|
||||
#endif
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user