using UnityEngine; using UnityEngine.Assertions; using UnityEngine.Rendering; namespace GaussianSplatting.Runtime { // GPU (uint key, uint payload) 8 bit-LSD radix sort, using reduce-then-scan // Copyright Thomas Smith 2024, MIT license // https://github.com/b0nes164/GPUSorting public class GpuSorting { //The size of a threadblock partition in the sort const uint DEVICE_RADIX_SORT_PARTITION_SIZE = 3840; //The size of our radix in bits const uint DEVICE_RADIX_SORT_BITS = 8; //Number of digits in our radix, 1 << DEVICE_RADIX_SORT_BITS const uint DEVICE_RADIX_SORT_RADIX = 256; //Number of sorting passes required to sort a 32bit key, KEY_BITS / DEVICE_RADIX_SORT_BITS const uint DEVICE_RADIX_SORT_PASSES = 4; //Keywords to enable for the shader private LocalKeyword m_keyUintKeyword; private LocalKeyword m_payloadUintKeyword; private LocalKeyword m_ascendKeyword; private LocalKeyword m_sortPairKeyword; private LocalKeyword m_vulkanKeyword; public struct Args { public uint count; public GraphicsBuffer inputKeys; public GraphicsBuffer inputValues; public SupportResources resources; internal int workGroupCount; } public struct SupportResources { public GraphicsBuffer altBuffer; public GraphicsBuffer altPayloadBuffer; public GraphicsBuffer passHistBuffer; public GraphicsBuffer globalHistBuffer; public static SupportResources Load(uint count) { //This is threadBlocks * DEVICE_RADIX_SORT_RADIX uint scratchBufferSize = DivRoundUp(count, DEVICE_RADIX_SORT_PARTITION_SIZE) * DEVICE_RADIX_SORT_RADIX; uint reducedScratchBufferSize = DEVICE_RADIX_SORT_RADIX * DEVICE_RADIX_SORT_PASSES; var target = GraphicsBuffer.Target.Structured; var resources = new SupportResources { altBuffer = new GraphicsBuffer(target, (int)count, 4) { name = "DeviceRadixAlt" }, altPayloadBuffer = new GraphicsBuffer(target, (int)count, 4) { name = "DeviceRadixAltPayload" }, passHistBuffer = new GraphicsBuffer(target, (int)scratchBufferSize, 4) { name = "DeviceRadixPassHistogram" }, globalHistBuffer = new GraphicsBuffer(target, (int)reducedScratchBufferSize, 4) { name = "DeviceRadixGlobalHistogram" }, }; return resources; } public void Dispose() { altBuffer?.Dispose(); altPayloadBuffer?.Dispose(); passHistBuffer?.Dispose(); globalHistBuffer?.Dispose(); altBuffer = null; altPayloadBuffer = null; passHistBuffer = null; globalHistBuffer = null; } } readonly ComputeShader m_CS; readonly int m_kernelInitDeviceRadixSort = -1; readonly int m_kernelUpsweep = -1; readonly int m_kernelScan = -1; readonly int m_kernelDownsweep = -1; readonly bool m_Valid; public bool Valid => m_Valid; public GpuSorting(ComputeShader cs) { m_CS = cs; if (cs) { m_kernelInitDeviceRadixSort = cs.FindKernel("InitDeviceRadixSort"); m_kernelUpsweep = cs.FindKernel("Upsweep"); m_kernelScan = cs.FindKernel("Scan"); m_kernelDownsweep = cs.FindKernel("Downsweep"); } m_Valid = m_kernelInitDeviceRadixSort >= 0 && m_kernelUpsweep >= 0 && m_kernelScan >= 0 && m_kernelDownsweep >= 0; if (m_Valid) { if (!cs.IsSupported(m_kernelInitDeviceRadixSort) || !cs.IsSupported(m_kernelUpsweep) || !cs.IsSupported(m_kernelScan) || !cs.IsSupported(m_kernelDownsweep)) { m_Valid = false; } } m_keyUintKeyword = new LocalKeyword(cs, "KEY_UINT"); m_payloadUintKeyword = new LocalKeyword(cs, "PAYLOAD_UINT"); m_ascendKeyword = new LocalKeyword(cs, "SHOULD_ASCEND"); m_sortPairKeyword = new LocalKeyword(cs, "SORT_PAIRS"); m_vulkanKeyword = new LocalKeyword(cs, "VULKAN"); cs.EnableKeyword(m_keyUintKeyword); cs.EnableKeyword(m_payloadUintKeyword); cs.EnableKeyword(m_ascendKeyword); cs.EnableKeyword(m_sortPairKeyword); if (SystemInfo.graphicsDeviceType == UnityEngine.Rendering.GraphicsDeviceType.Vulkan) cs.EnableKeyword(m_vulkanKeyword); else cs.DisableKeyword(m_vulkanKeyword); } static uint DivRoundUp(uint x, uint y) => (x + y - 1) / y; //Can we remove the last 4 padding without breaking? struct SortConstants { public uint numKeys; // The number of keys to sort public uint radixShift; // The radix shift value for the current pass public uint threadBlocks; // threadBlocks public uint padding0; // Padding - unused } public void Dispatch(CommandBuffer cmd, Args args) { Assert.IsTrue(Valid); GraphicsBuffer srcKeyBuffer = args.inputKeys; GraphicsBuffer srcPayloadBuffer = args.inputValues; GraphicsBuffer dstKeyBuffer = args.resources.altBuffer; GraphicsBuffer dstPayloadBuffer = args.resources.altPayloadBuffer; SortConstants constants = default; constants.numKeys = args.count; constants.threadBlocks = DivRoundUp(args.count, DEVICE_RADIX_SORT_PARTITION_SIZE); // Setup overall constants cmd.SetComputeIntParam(m_CS, "e_numKeys", (int)constants.numKeys); cmd.SetComputeIntParam(m_CS, "e_threadBlocks", (int)constants.threadBlocks); //Set statically located buffers //Upsweep cmd.SetComputeBufferParam(m_CS, m_kernelUpsweep, "b_passHist", args.resources.passHistBuffer); cmd.SetComputeBufferParam(m_CS, m_kernelUpsweep, "b_globalHist", args.resources.globalHistBuffer); //Scan cmd.SetComputeBufferParam(m_CS, m_kernelScan, "b_passHist", args.resources.passHistBuffer); //Downsweep cmd.SetComputeBufferParam(m_CS, m_kernelDownsweep, "b_passHist", args.resources.passHistBuffer); cmd.SetComputeBufferParam(m_CS, m_kernelDownsweep, "b_globalHist", args.resources.globalHistBuffer); //Clear the global histogram cmd.SetComputeBufferParam(m_CS, m_kernelInitDeviceRadixSort, "b_globalHist", args.resources.globalHistBuffer); cmd.DispatchCompute(m_CS, m_kernelInitDeviceRadixSort, 1, 1, 1); // Execute the sort algorithm in 8-bit increments for (constants.radixShift = 0; constants.radixShift < 32; constants.radixShift += DEVICE_RADIX_SORT_BITS) { cmd.SetComputeIntParam(m_CS, "e_radixShift", (int)constants.radixShift); //Upsweep cmd.SetComputeBufferParam(m_CS, m_kernelUpsweep, "b_sort", srcKeyBuffer); cmd.DispatchCompute(m_CS, m_kernelUpsweep, (int)constants.threadBlocks, 1, 1); // Scan cmd.DispatchCompute(m_CS, m_kernelScan, (int)DEVICE_RADIX_SORT_RADIX, 1, 1); // Downsweep cmd.SetComputeBufferParam(m_CS, m_kernelDownsweep, "b_sort", srcKeyBuffer); cmd.SetComputeBufferParam(m_CS, m_kernelDownsweep, "b_sortPayload", srcPayloadBuffer); cmd.SetComputeBufferParam(m_CS, m_kernelDownsweep, "b_alt", dstKeyBuffer); cmd.SetComputeBufferParam(m_CS, m_kernelDownsweep, "b_altPayload", dstPayloadBuffer); cmd.DispatchCompute(m_CS, m_kernelDownsweep, (int)constants.threadBlocks, 1, 1); // Swap (srcKeyBuffer, dstKeyBuffer) = (dstKeyBuffer, srcKeyBuffer); (srcPayloadBuffer, dstPayloadBuffer) = (dstPayloadBuffer, srcPayloadBuffer); } } } }