chore: sync workspace state
This commit is contained in:
44
MVS/3DGS-Unity/Shaders/BlackSkybox.shader
Normal file
44
MVS/3DGS-Unity/Shaders/BlackSkybox.shader
Normal file
@@ -0,0 +1,44 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
Shader "Unlit/BlackSkybox"
|
||||
{
|
||||
Properties
|
||||
{
|
||||
_Color ("Color", Color) = (0,0,0,0)
|
||||
}
|
||||
SubShader
|
||||
{
|
||||
Pass
|
||||
{
|
||||
CGPROGRAM
|
||||
#pragma vertex vert
|
||||
#pragma fragment frag
|
||||
|
||||
#include "UnityCG.cginc"
|
||||
|
||||
struct appdata
|
||||
{
|
||||
float4 vertex : POSITION;
|
||||
};
|
||||
|
||||
struct v2f
|
||||
{
|
||||
float4 vertex : SV_POSITION;
|
||||
};
|
||||
|
||||
v2f vert (appdata v)
|
||||
{
|
||||
v2f o;
|
||||
o.vertex = UnityObjectToClipPos(v.vertex);
|
||||
return o;
|
||||
}
|
||||
|
||||
half4 _Color;
|
||||
|
||||
half4 frag (v2f i) : SV_Target
|
||||
{
|
||||
return _Color;
|
||||
}
|
||||
ENDCG
|
||||
}
|
||||
}
|
||||
}
|
||||
9
MVS/3DGS-Unity/Shaders/BlackSkybox.shader.meta
Normal file
9
MVS/3DGS-Unity/Shaders/BlackSkybox.shader.meta
Normal file
@@ -0,0 +1,9 @@
|
||||
fileFormatVersion: 2
|
||||
guid: a4867e5be68354ccda78062a92c74391
|
||||
ShaderImporter:
|
||||
externalObjects: {}
|
||||
defaultTextures: []
|
||||
nonModifiableTextures: []
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
||||
531
MVS/3DGS-Unity/Shaders/DeviceRadixSort.hlsl
Normal file
531
MVS/3DGS-Unity/Shaders/DeviceRadixSort.hlsl
Normal file
@@ -0,0 +1,531 @@
|
||||
/******************************************************************************
|
||||
* DeviceRadixSort
|
||||
* Device Level 8-bit LSD Radix Sort using reduce then scan
|
||||
*
|
||||
* SPDX-License-Identifier: MIT
|
||||
* Copyright Thomas Smith 5/17/2024
|
||||
* https://github.com/b0nes164/GPUSorting
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
******************************************************************************/
|
||||
#include "SortCommon.hlsl"
|
||||
|
||||
#define US_DIM 128U //The number of threads in a Upsweep threadblock
|
||||
#define SCAN_DIM 128U //The number of threads in a Scan threadblock
|
||||
|
||||
RWStructuredBuffer<uint> b_globalHist; //buffer holding device level offsets for each binning pass
|
||||
RWStructuredBuffer<uint> b_passHist; //buffer used to store reduced sums of partition tiles
|
||||
|
||||
groupshared uint g_us[RADIX * 2]; //Shared memory for upsweep
|
||||
groupshared uint g_scan[SCAN_DIM]; //Shared memory for the scan
|
||||
|
||||
//*****************************************************************************
|
||||
//INIT KERNEL
|
||||
//*****************************************************************************
|
||||
//Clear the global histogram, as we will be adding to it atomically
|
||||
[numthreads(1024, 1, 1)]
|
||||
void InitDeviceRadixSort(int3 id : SV_DispatchThreadID)
|
||||
{
|
||||
b_globalHist[id.x] = 0;
|
||||
}
|
||||
|
||||
//*****************************************************************************
|
||||
//UPSWEEP KERNEL
|
||||
//*****************************************************************************
|
||||
//histogram, 64 threads to a histogram
|
||||
inline void HistogramDigitCounts(uint gtid, uint gid)
|
||||
{
|
||||
const uint histOffset = gtid / 64 * RADIX;
|
||||
const uint partitionEnd = gid == e_threadBlocks - 1 ?
|
||||
e_numKeys : (gid + 1) * PART_SIZE;
|
||||
for (uint i = gtid + gid * PART_SIZE; i < partitionEnd; i += US_DIM)
|
||||
{
|
||||
#if defined(KEY_UINT)
|
||||
InterlockedAdd(g_us[ExtractDigit(b_sort[i]) + histOffset], 1);
|
||||
#elif defined(KEY_INT)
|
||||
InterlockedAdd(g_us[ExtractDigit(IntToUint(b_sort[i])) + histOffset], 1);
|
||||
#elif defined(KEY_FLOAT)
|
||||
InterlockedAdd(g_us[ExtractDigit(FloatToUint(b_sort[i])) + histOffset], 1);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
//reduce and pass to tile histogram
|
||||
inline void ReduceWriteDigitCounts(uint gtid, uint gid)
|
||||
{
|
||||
for (uint i = gtid; i < RADIX; i += US_DIM)
|
||||
{
|
||||
g_us[i] += g_us[i + RADIX];
|
||||
b_passHist[i * e_threadBlocks + gid] = g_us[i];
|
||||
g_us[i] += WavePrefixSum(g_us[i]);
|
||||
}
|
||||
}
|
||||
|
||||
//Exclusive scan over digit counts, then atomically add to global hist
|
||||
inline void GlobalHistExclusiveScanWGE16(uint gtid, uint waveSize)
|
||||
{
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
if (gtid < (RADIX / waveSize))
|
||||
{
|
||||
g_us[(gtid + 1) * waveSize - 1] +=
|
||||
WavePrefixSum(g_us[(gtid + 1) * waveSize - 1]);
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
//atomically add to global histogram
|
||||
const uint globalHistOffset = GlobalHistOffset();
|
||||
const uint laneMask = waveSize - 1;
|
||||
const uint circularLaneShift = WaveGetLaneIndex() + 1 & laneMask;
|
||||
for (uint i = gtid; i < RADIX; i += US_DIM)
|
||||
{
|
||||
const uint index = circularLaneShift + (i & ~laneMask);
|
||||
uint t = WaveGetLaneIndex() != laneMask ? g_us[i] : 0;
|
||||
if (i >= waveSize)
|
||||
t += WaveReadLaneAt(g_us[i - 1], 0);
|
||||
InterlockedAdd(b_globalHist[index + globalHistOffset], t);
|
||||
}
|
||||
}
|
||||
|
||||
inline void GlobalHistExclusiveScanWLT16(uint gtid, uint waveSize)
|
||||
{
|
||||
const uint globalHistOffset = GlobalHistOffset();
|
||||
if (gtid < waveSize)
|
||||
{
|
||||
const uint circularLaneShift = WaveGetLaneIndex() + 1 &
|
||||
waveSize - 1;
|
||||
InterlockedAdd(b_globalHist[circularLaneShift + globalHistOffset],
|
||||
circularLaneShift ? g_us[gtid] : 0);
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
const uint laneLog = countbits(waveSize - 1);
|
||||
uint offset = laneLog;
|
||||
uint j = waveSize;
|
||||
for (; j < (RADIX >> 1); j <<= laneLog)
|
||||
{
|
||||
if (gtid < (RADIX >> offset))
|
||||
{
|
||||
g_us[((gtid + 1) << offset) - 1] +=
|
||||
WavePrefixSum(g_us[((gtid + 1) << offset) - 1]);
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
for (uint i = gtid + j; i < RADIX; i += US_DIM)
|
||||
{
|
||||
if ((i & ((j << laneLog) - 1)) >= j)
|
||||
{
|
||||
if (i < (j << laneLog))
|
||||
{
|
||||
InterlockedAdd(b_globalHist[i + globalHistOffset],
|
||||
WaveReadLaneAt(g_us[((i >> offset) << offset) - 1], 0) +
|
||||
((i & (j - 1)) ? g_us[i - 1] : 0));
|
||||
}
|
||||
else
|
||||
{
|
||||
if ((i + 1) & (j - 1))
|
||||
{
|
||||
g_us[i] +=
|
||||
WaveReadLaneAt(g_us[((i >> offset) << offset) - 1], 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
offset += laneLog;
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
//If RADIX is not a power of lanecount
|
||||
for (uint i = gtid + j; i < RADIX; i += US_DIM)
|
||||
{
|
||||
InterlockedAdd(b_globalHist[i + globalHistOffset],
|
||||
WaveReadLaneAt(g_us[((i >> offset) << offset) - 1], 0) +
|
||||
((i & (j - 1)) ? g_us[i - 1] : 0));
|
||||
}
|
||||
}
|
||||
|
||||
[numthreads(US_DIM, 1, 1)]
|
||||
void Upsweep(uint3 gtid : SV_GroupThreadID, uint3 gid : SV_GroupID)
|
||||
{
|
||||
//get the wave size
|
||||
const uint waveSize = getWaveSize();
|
||||
|
||||
//clear shared memory
|
||||
const uint histsEnd = RADIX * 2;
|
||||
for (uint i = gtid.x; i < histsEnd; i += US_DIM)
|
||||
g_us[i] = 0;
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
HistogramDigitCounts(gtid.x, gid.x);
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
ReduceWriteDigitCounts(gtid.x, gid.x);
|
||||
|
||||
if (waveSize >= 16)
|
||||
GlobalHistExclusiveScanWGE16(gtid.x, waveSize);
|
||||
|
||||
if (waveSize < 16)
|
||||
GlobalHistExclusiveScanWLT16(gtid.x, waveSize);
|
||||
}
|
||||
|
||||
//*****************************************************************************
|
||||
//SCAN KERNEL
|
||||
//*****************************************************************************
|
||||
inline void ExclusiveThreadBlockScanFullWGE16(
|
||||
uint gtid,
|
||||
uint laneMask,
|
||||
uint circularLaneShift,
|
||||
uint partEnd,
|
||||
uint deviceOffset,
|
||||
uint waveSize,
|
||||
inout uint reduction)
|
||||
{
|
||||
for (uint i = gtid; i < partEnd; i += SCAN_DIM)
|
||||
{
|
||||
g_scan[gtid] = b_passHist[i + deviceOffset];
|
||||
g_scan[gtid] += WavePrefixSum(g_scan[gtid]);
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
if (gtid < SCAN_DIM / waveSize)
|
||||
{
|
||||
g_scan[(gtid + 1) * waveSize - 1] +=
|
||||
WavePrefixSum(g_scan[(gtid + 1) * waveSize - 1]);
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
uint t = (WaveGetLaneIndex() != laneMask ? g_scan[gtid] : 0) + reduction;
|
||||
if (gtid >= waveSize)
|
||||
t += WaveReadLaneAt(g_scan[gtid - 1], 0);
|
||||
b_passHist[circularLaneShift + (i & ~laneMask) + deviceOffset] = t;
|
||||
|
||||
reduction += g_scan[SCAN_DIM - 1];
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
}
|
||||
}
|
||||
|
||||
inline void ExclusiveThreadBlockScanPartialWGE16(
|
||||
uint gtid,
|
||||
uint laneMask,
|
||||
uint circularLaneShift,
|
||||
uint partEnd,
|
||||
uint deviceOffset,
|
||||
uint waveSize,
|
||||
uint reduction)
|
||||
{
|
||||
uint i = gtid + partEnd;
|
||||
if (i < e_threadBlocks)
|
||||
g_scan[gtid] = b_passHist[deviceOffset + i];
|
||||
g_scan[gtid] += WavePrefixSum(g_scan[gtid]);
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
if (gtid < SCAN_DIM / waveSize)
|
||||
{
|
||||
g_scan[(gtid + 1) * waveSize - 1] +=
|
||||
WavePrefixSum(g_scan[(gtid + 1) * waveSize - 1]);
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
const uint index = circularLaneShift + (i & ~laneMask);
|
||||
if (index < e_threadBlocks)
|
||||
{
|
||||
uint t = (WaveGetLaneIndex() != laneMask ? g_scan[gtid] : 0) + reduction;
|
||||
if (gtid >= waveSize)
|
||||
t += g_scan[(gtid & ~laneMask) - 1];
|
||||
b_passHist[index + deviceOffset] = t;
|
||||
}
|
||||
}
|
||||
|
||||
inline void ExclusiveThreadBlockScanWGE16(uint gtid, uint gid, uint waveSize)
|
||||
{
|
||||
uint reduction = 0;
|
||||
const uint laneMask = waveSize - 1;
|
||||
const uint circularLaneShift = WaveGetLaneIndex() + 1 & laneMask;
|
||||
const uint partionsEnd = e_threadBlocks / SCAN_DIM * SCAN_DIM;
|
||||
const uint deviceOffset = gid * e_threadBlocks;
|
||||
|
||||
ExclusiveThreadBlockScanFullWGE16(
|
||||
gtid,
|
||||
laneMask,
|
||||
circularLaneShift,
|
||||
partionsEnd,
|
||||
deviceOffset,
|
||||
waveSize,
|
||||
reduction);
|
||||
|
||||
ExclusiveThreadBlockScanPartialWGE16(
|
||||
gtid,
|
||||
laneMask,
|
||||
circularLaneShift,
|
||||
partionsEnd,
|
||||
deviceOffset,
|
||||
waveSize,
|
||||
reduction);
|
||||
}
|
||||
|
||||
inline void ExclusiveThreadBlockScanFullWLT16(
|
||||
uint gtid,
|
||||
uint partitions,
|
||||
uint deviceOffset,
|
||||
uint laneLog,
|
||||
uint circularLaneShift,
|
||||
uint waveSize,
|
||||
inout uint reduction)
|
||||
{
|
||||
for (uint k = 0; k < partitions; ++k)
|
||||
{
|
||||
g_scan[gtid] = b_passHist[gtid + k * SCAN_DIM + deviceOffset];
|
||||
g_scan[gtid] += WavePrefixSum(g_scan[gtid]);
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
if (gtid < waveSize)
|
||||
{
|
||||
b_passHist[circularLaneShift + k * SCAN_DIM + deviceOffset] =
|
||||
(circularLaneShift ? g_scan[gtid] : 0) + reduction;
|
||||
}
|
||||
|
||||
uint offset = laneLog;
|
||||
uint j = waveSize;
|
||||
for (; j < (SCAN_DIM >> 1); j <<= laneLog)
|
||||
{
|
||||
if (gtid < (SCAN_DIM >> offset))
|
||||
{
|
||||
g_scan[((gtid + 1) << offset) - 1] +=
|
||||
WavePrefixSum(g_scan[((gtid + 1) << offset) - 1]);
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
if ((gtid & ((j << laneLog) - 1)) >= j)
|
||||
{
|
||||
if (gtid < (j << laneLog))
|
||||
{
|
||||
b_passHist[gtid + k * SCAN_DIM + deviceOffset] =
|
||||
WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0) +
|
||||
((gtid & (j - 1)) ? g_scan[gtid - 1] : 0) + reduction;
|
||||
}
|
||||
else
|
||||
{
|
||||
if ((gtid + 1) & (j - 1))
|
||||
{
|
||||
g_scan[gtid] +=
|
||||
WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
offset += laneLog;
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
//If SCAN_DIM is not a power of lanecount
|
||||
for (uint i = gtid + j; i < SCAN_DIM; i += SCAN_DIM)
|
||||
{
|
||||
b_passHist[i + k * SCAN_DIM + deviceOffset] =
|
||||
WaveReadLaneAt(g_scan[((i >> offset) << offset) - 1], 0) +
|
||||
((i & (j - 1)) ? g_scan[i - 1] : 0) + reduction;
|
||||
}
|
||||
|
||||
reduction += WaveReadLaneAt(g_scan[SCAN_DIM - 1], 0) +
|
||||
WaveReadLaneAt(g_scan[(((SCAN_DIM - 1) >> offset) << offset) - 1], 0);
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
}
|
||||
}
|
||||
|
||||
inline void ExclusiveThreadBlockScanParitalWLT16(
|
||||
uint gtid,
|
||||
uint partitions,
|
||||
uint deviceOffset,
|
||||
uint laneLog,
|
||||
uint circularLaneShift,
|
||||
uint waveSize,
|
||||
uint reduction)
|
||||
{
|
||||
const uint finalPartSize = e_threadBlocks - partitions * SCAN_DIM;
|
||||
if (gtid < finalPartSize)
|
||||
{
|
||||
g_scan[gtid] = b_passHist[gtid + partitions * SCAN_DIM + deviceOffset];
|
||||
g_scan[gtid] += WavePrefixSum(g_scan[gtid]);
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
if (gtid < waveSize && circularLaneShift < finalPartSize)
|
||||
{
|
||||
b_passHist[circularLaneShift + partitions * SCAN_DIM + deviceOffset] =
|
||||
(circularLaneShift ? g_scan[gtid] : 0) + reduction;
|
||||
}
|
||||
|
||||
uint offset = laneLog;
|
||||
for (uint j = waveSize; j < finalPartSize; j <<= laneLog)
|
||||
{
|
||||
if (gtid < (finalPartSize >> offset))
|
||||
{
|
||||
g_scan[((gtid + 1) << offset) - 1] +=
|
||||
WavePrefixSum(g_scan[((gtid + 1) << offset) - 1]);
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
if ((gtid & ((j << laneLog) - 1)) >= j && gtid < finalPartSize)
|
||||
{
|
||||
if (gtid < (j << laneLog))
|
||||
{
|
||||
b_passHist[gtid + partitions * SCAN_DIM + deviceOffset] =
|
||||
WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0) +
|
||||
((gtid & (j - 1)) ? g_scan[gtid - 1] : 0) + reduction;
|
||||
}
|
||||
else
|
||||
{
|
||||
if ((gtid + 1) & (j - 1))
|
||||
{
|
||||
g_scan[gtid] +=
|
||||
WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
offset += laneLog;
|
||||
}
|
||||
}
|
||||
|
||||
inline void ExclusiveThreadBlockScanWLT16(uint gtid, uint gid, uint waveSize)
|
||||
{
|
||||
uint reduction = 0;
|
||||
const uint partitions = e_threadBlocks / SCAN_DIM;
|
||||
const uint deviceOffset = gid * e_threadBlocks;
|
||||
const uint laneLog = countbits(waveSize - 1);
|
||||
const uint circularLaneShift = WaveGetLaneIndex() + 1 & waveSize - 1;
|
||||
|
||||
ExclusiveThreadBlockScanFullWLT16(
|
||||
gtid,
|
||||
partitions,
|
||||
deviceOffset,
|
||||
laneLog,
|
||||
circularLaneShift,
|
||||
waveSize,
|
||||
reduction);
|
||||
|
||||
ExclusiveThreadBlockScanParitalWLT16(
|
||||
gtid,
|
||||
partitions,
|
||||
deviceOffset,
|
||||
laneLog,
|
||||
circularLaneShift,
|
||||
waveSize,
|
||||
reduction);
|
||||
}
|
||||
|
||||
//Scan does not need flattening of gids
|
||||
[numthreads(SCAN_DIM, 1, 1)]
|
||||
void Scan(uint3 gtid : SV_GroupThreadID, uint3 gid : SV_GroupID)
|
||||
{
|
||||
const uint waveSize = getWaveSize();
|
||||
if (waveSize >= 16)
|
||||
ExclusiveThreadBlockScanWGE16(gtid.x, gid.x, waveSize);
|
||||
|
||||
if (waveSize < 16)
|
||||
ExclusiveThreadBlockScanWLT16(gtid.x, gid.x, waveSize);
|
||||
}
|
||||
|
||||
//*****************************************************************************
|
||||
//DOWNSWEEP KERNEL
|
||||
//*****************************************************************************
|
||||
inline void LoadThreadBlockReductions(uint gtid, uint gid, uint exclusiveHistReduction)
|
||||
{
|
||||
if (gtid < RADIX)
|
||||
{
|
||||
g_d[gtid + PART_SIZE] = b_globalHist[gtid + GlobalHistOffset()] +
|
||||
b_passHist[gtid * e_threadBlocks + gid] - exclusiveHistReduction;
|
||||
}
|
||||
}
|
||||
|
||||
[numthreads(D_DIM, 1, 1)]
|
||||
void Downsweep(uint3 gtid : SV_GroupThreadID, uint3 gid : SV_GroupID)
|
||||
{
|
||||
KeyStruct keys;
|
||||
OffsetStruct offsets;
|
||||
const uint waveSize = getWaveSize();
|
||||
|
||||
ClearWaveHists(gtid.x, waveSize);
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
if (gid.x < e_threadBlocks - 1)
|
||||
{
|
||||
if (waveSize >= 16)
|
||||
keys = LoadKeysWGE16(gtid.x, waveSize, gid.x);
|
||||
|
||||
if (waveSize < 16)
|
||||
keys = LoadKeysWLT16(gtid.x, waveSize, gid.x, SerialIterations(waveSize));
|
||||
}
|
||||
|
||||
if (gid.x == e_threadBlocks - 1)
|
||||
{
|
||||
if (waveSize >= 16)
|
||||
keys = LoadKeysPartialWGE16(gtid.x, waveSize, gid.x);
|
||||
|
||||
if (waveSize < 16)
|
||||
keys = LoadKeysPartialWLT16(gtid.x, waveSize, gid.x, SerialIterations(waveSize));
|
||||
}
|
||||
|
||||
uint exclusiveHistReduction;
|
||||
|
||||
if (waveSize >= 16)
|
||||
{
|
||||
offsets = RankKeysWGE16(waveSize, getWaveIndex(gtid.x, waveSize) * RADIX, keys);
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
uint histReduction;
|
||||
if (gtid.x < RADIX)
|
||||
{
|
||||
histReduction = WaveHistInclusiveScanCircularShiftWGE16(gtid.x, waveSize);
|
||||
histReduction += WavePrefixSum(histReduction); //take advantage of barrier to begin scan
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
WaveHistReductionExclusiveScanWGE16(gtid.x, waveSize, histReduction);
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
UpdateOffsetsWGE16(gtid.x, waveSize, offsets, keys);
|
||||
if (gtid.x < RADIX)
|
||||
exclusiveHistReduction = g_d[gtid.x]; //take advantage of barrier to grab value
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
}
|
||||
|
||||
if (waveSize < 16)
|
||||
{
|
||||
offsets = RankKeysWLT16(waveSize, getWaveIndex(gtid.x, waveSize), keys, SerialIterations(waveSize));
|
||||
|
||||
if (gtid.x < HALF_RADIX)
|
||||
{
|
||||
uint histReduction = WaveHistInclusiveScanCircularShiftWLT16(gtid.x);
|
||||
g_d[gtid.x] = histReduction + (histReduction << 16); //take advantage of barrier to begin scan
|
||||
}
|
||||
|
||||
WaveHistReductionExclusiveScanWLT16(gtid.x);
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
UpdateOffsetsWLT16(gtid.x, waveSize, SerialIterations(waveSize), offsets, keys);
|
||||
if (gtid.x < RADIX) //take advantage of barrier to grab value
|
||||
exclusiveHistReduction = g_d[gtid.x >> 1] >> ((gtid.x & 1) ? 16 : 0) & 0xffff;
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
}
|
||||
|
||||
ScatterKeysShared(offsets, keys);
|
||||
LoadThreadBlockReductions(gtid.x, gid.x, exclusiveHistReduction);
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
if (gid.x < e_threadBlocks - 1)
|
||||
ScatterDevice(gtid.x, waveSize, gid.x, offsets);
|
||||
|
||||
if (gid.x == e_threadBlocks - 1)
|
||||
ScatterDevicePartial(gtid.x, waveSize, gid.x, offsets);
|
||||
}
|
||||
7
MVS/3DGS-Unity/Shaders/DeviceRadixSort.hlsl.meta
Normal file
7
MVS/3DGS-Unity/Shaders/DeviceRadixSort.hlsl.meta
Normal file
@@ -0,0 +1,7 @@
|
||||
fileFormatVersion: 2
|
||||
guid: 02209b8d952e7fc418492b88139826fd
|
||||
ShaderIncludeImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
||||
45
MVS/3DGS-Unity/Shaders/GaussianComposite.shader
Normal file
45
MVS/3DGS-Unity/Shaders/GaussianComposite.shader
Normal file
@@ -0,0 +1,45 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
Shader "Hidden/Gaussian Splatting/Composite"
|
||||
{
|
||||
SubShader
|
||||
{
|
||||
Pass
|
||||
{
|
||||
ZWrite Off
|
||||
ZTest Always
|
||||
Cull Off
|
||||
Blend SrcAlpha OneMinusSrcAlpha
|
||||
|
||||
CGPROGRAM
|
||||
#pragma vertex vert
|
||||
#pragma fragment frag
|
||||
#pragma require compute
|
||||
#pragma use_dxc
|
||||
#include "UnityCG.cginc"
|
||||
|
||||
struct v2f
|
||||
{
|
||||
float4 vertex : SV_POSITION;
|
||||
};
|
||||
|
||||
v2f vert (uint vtxID : SV_VertexID)
|
||||
{
|
||||
v2f o;
|
||||
float2 quadPos = float2(vtxID&1, (vtxID>>1)&1) * 4.0 - 1.0;
|
||||
o.vertex = float4(quadPos, 1, 1);
|
||||
return o;
|
||||
}
|
||||
|
||||
Texture2D _GaussianSplatRT;
|
||||
|
||||
half4 frag (v2f i) : SV_Target
|
||||
{
|
||||
half4 col = _GaussianSplatRT.Load(int3(i.vertex.xy, 0));
|
||||
col.rgb = GammaToLinearSpace(col.rgb);
|
||||
col.a = saturate(col.a * 1.5);
|
||||
return col;
|
||||
}
|
||||
ENDCG
|
||||
}
|
||||
}
|
||||
}
|
||||
9
MVS/3DGS-Unity/Shaders/GaussianComposite.shader.meta
Normal file
9
MVS/3DGS-Unity/Shaders/GaussianComposite.shader.meta
Normal file
@@ -0,0 +1,9 @@
|
||||
fileFormatVersion: 2
|
||||
guid: 7e184af7d01193a408eb916d8acafff9
|
||||
ShaderImporter:
|
||||
externalObjects: {}
|
||||
defaultTextures: []
|
||||
nonModifiableTextures: []
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
||||
102
MVS/3DGS-Unity/Shaders/GaussianDebugRenderBoxes.shader
Normal file
102
MVS/3DGS-Unity/Shaders/GaussianDebugRenderBoxes.shader
Normal file
@@ -0,0 +1,102 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
Shader "Gaussian Splatting/Debug/Render Boxes"
|
||||
{
|
||||
SubShader
|
||||
{
|
||||
Tags { "RenderType"="Transparent" "Queue"="Transparent" }
|
||||
|
||||
Pass
|
||||
{
|
||||
ZWrite Off
|
||||
Blend OneMinusDstAlpha One
|
||||
Cull Front
|
||||
|
||||
CGPROGRAM
|
||||
#pragma vertex vert
|
||||
#pragma fragment frag
|
||||
#pragma require compute
|
||||
#pragma use_dxc
|
||||
|
||||
#include "UnityCG.cginc"
|
||||
#include "GaussianSplatting.hlsl"
|
||||
|
||||
StructuredBuffer<uint> _OrderBuffer;
|
||||
|
||||
bool _DisplayChunks;
|
||||
|
||||
struct v2f
|
||||
{
|
||||
half4 col : COLOR0;
|
||||
float4 vertex : SV_POSITION;
|
||||
};
|
||||
|
||||
float _SplatScale;
|
||||
float _SplatOpacityScale;
|
||||
|
||||
// based on https://iquilezles.org/articles/palettes/
|
||||
// cosine based palette, 4 vec3 params
|
||||
half3 palette(float t, half3 a, half3 b, half3 c, half3 d)
|
||||
{
|
||||
return a + b*cos(6.28318*(c*t+d));
|
||||
}
|
||||
|
||||
v2f vert (uint vtxID : SV_VertexID, uint instID : SV_InstanceID)
|
||||
{
|
||||
v2f o;
|
||||
bool chunks = _DisplayChunks;
|
||||
uint idx = vtxID;
|
||||
float3 localPos = float3(idx&1, (idx>>1)&1, (idx>>2)&1) * 2.0 - 1.0;
|
||||
|
||||
float3 centerWorldPos = 0;
|
||||
|
||||
if (!chunks)
|
||||
{
|
||||
// display splat boxes
|
||||
instID = _OrderBuffer[instID];
|
||||
SplatData splat = LoadSplatData(instID);
|
||||
|
||||
float4 boxRot = splat.rot;
|
||||
float3 boxSize = splat.scale;
|
||||
boxSize *= _SplatScale;
|
||||
|
||||
float3x3 splatRotScaleMat = CalcMatrixFromRotationScale(boxRot, boxSize);
|
||||
splatRotScaleMat = mul((float3x3)unity_ObjectToWorld, splatRotScaleMat);
|
||||
|
||||
centerWorldPos = splat.pos;
|
||||
centerWorldPos = mul(unity_ObjectToWorld, float4(centerWorldPos,1)).xyz;
|
||||
|
||||
o.col.rgb = saturate(splat.sh.col);
|
||||
o.col.a = saturate(splat.opacity * _SplatOpacityScale);
|
||||
|
||||
localPos = mul(splatRotScaleMat, localPos) * 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
// display chunk boxes
|
||||
localPos = localPos * 0.5 + 0.5;
|
||||
SplatChunkInfo chunk = _SplatChunks[instID];
|
||||
float3 posMin = float3(chunk.posX.x, chunk.posY.x, chunk.posZ.x);
|
||||
float3 posMax = float3(chunk.posX.y, chunk.posY.y, chunk.posZ.y);
|
||||
|
||||
localPos = lerp(posMin, posMax, localPos);
|
||||
localPos = mul(unity_ObjectToWorld, float4(localPos,1)).xyz;
|
||||
|
||||
o.col.rgb = palette((float)instID / (float)_SplatChunkCount, half3(0.5,0.5,0.5), half3(0.5,0.5,0.5), half3(1,1,1), half3(0.0, 0.33, 0.67));
|
||||
o.col.a = 0.1;
|
||||
}
|
||||
|
||||
float3 worldPos = centerWorldPos + localPos;
|
||||
o.vertex = UnityWorldToClipPos(worldPos);
|
||||
|
||||
return o;
|
||||
}
|
||||
|
||||
half4 frag (v2f i) : SV_Target
|
||||
{
|
||||
half4 res = half4(i.col.rgb * i.col.a, i.col.a);
|
||||
return res;
|
||||
}
|
||||
ENDCG
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
fileFormatVersion: 2
|
||||
guid: 4006f2680fd7c8b4cbcb881454c782be
|
||||
ShaderImporter:
|
||||
externalObjects: {}
|
||||
defaultTextures: []
|
||||
nonModifiableTextures: []
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
||||
65
MVS/3DGS-Unity/Shaders/GaussianDebugRenderPoints.shader
Normal file
65
MVS/3DGS-Unity/Shaders/GaussianDebugRenderPoints.shader
Normal file
@@ -0,0 +1,65 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
Shader "Gaussian Splatting/Debug/Render Points"
|
||||
{
|
||||
SubShader
|
||||
{
|
||||
Tags { "RenderType"="Transparent" "Queue"="Transparent" }
|
||||
|
||||
Pass
|
||||
{
|
||||
ZWrite On
|
||||
Cull Off
|
||||
|
||||
CGPROGRAM
|
||||
#pragma vertex vert
|
||||
#pragma fragment frag
|
||||
#pragma require compute
|
||||
#pragma use_dxc
|
||||
|
||||
#include "GaussianSplatting.hlsl"
|
||||
|
||||
struct v2f
|
||||
{
|
||||
half3 color : TEXCOORD0;
|
||||
float4 vertex : SV_POSITION;
|
||||
};
|
||||
|
||||
float _SplatSize;
|
||||
bool _DisplayIndex;
|
||||
int _SplatCount;
|
||||
|
||||
v2f vert (uint vtxID : SV_VertexID, uint instID : SV_InstanceID)
|
||||
{
|
||||
v2f o;
|
||||
uint splatIndex = instID;
|
||||
|
||||
SplatData splat = LoadSplatData(splatIndex);
|
||||
|
||||
float3 centerWorldPos = splat.pos;
|
||||
centerWorldPos = mul(unity_ObjectToWorld, float4(centerWorldPos,1)).xyz;
|
||||
|
||||
float4 centerClipPos = mul(UNITY_MATRIX_VP, float4(centerWorldPos, 1));
|
||||
|
||||
o.vertex = centerClipPos;
|
||||
uint idx = vtxID;
|
||||
float2 quadPos = float2(idx&1, (idx>>1)&1) * 2.0 - 1.0;
|
||||
o.vertex.xy += (quadPos * _SplatSize / _ScreenParams.xy) * o.vertex.w;
|
||||
|
||||
o.color.rgb = saturate(splat.sh.col);
|
||||
if (_DisplayIndex)
|
||||
{
|
||||
o.color.r = frac((float)splatIndex / (float)_SplatCount * 100);
|
||||
o.color.g = frac((float)splatIndex / (float)_SplatCount * 10);
|
||||
o.color.b = (float)splatIndex / (float)_SplatCount;
|
||||
}
|
||||
return o;
|
||||
}
|
||||
|
||||
half4 frag (v2f i) : SV_Target
|
||||
{
|
||||
return half4(i.color.rgb, 1);
|
||||
}
|
||||
ENDCG
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
fileFormatVersion: 2
|
||||
guid: b44409fc67214394f8f47e4e2648425e
|
||||
ShaderImporter:
|
||||
externalObjects: {}
|
||||
defaultTextures: []
|
||||
nonModifiableTextures: []
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
||||
617
MVS/3DGS-Unity/Shaders/GaussianSplatting.hlsl
Normal file
617
MVS/3DGS-Unity/Shaders/GaussianSplatting.hlsl
Normal file
@@ -0,0 +1,617 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
#ifndef GAUSSIAN_SPLATTING_HLSL
|
||||
#define GAUSSIAN_SPLATTING_HLSL
|
||||
|
||||
float InvSquareCentered01(float x)
|
||||
{
|
||||
x -= 0.5;
|
||||
x *= 0.5;
|
||||
x = sqrt(abs(x)) * sign(x);
|
||||
return x + 0.5;
|
||||
}
|
||||
|
||||
float3 QuatRotateVector(float3 v, float4 r)
|
||||
{
|
||||
float3 t = 2 * cross(r.xyz, v);
|
||||
return v + r.w * t + cross(r.xyz, t);
|
||||
}
|
||||
|
||||
float4 QuatMul(float4 a, float4 b)
|
||||
{
|
||||
return float4(a.wwww * b + (a.xyzx * b.wwwx + a.yzxy * b.zxyy) * float4(1,1,1,-1) - a.zxyz * b.yzxz);
|
||||
}
|
||||
|
||||
float4 QuatInverse(float4 q)
|
||||
{
|
||||
return rcp(dot(q, q)) * q * float4(-1,-1,-1,1);
|
||||
}
|
||||
|
||||
float3x3 CalcMatrixFromRotationScale(float4 rot, float3 scale)
|
||||
{
|
||||
float3x3 ms = float3x3(
|
||||
scale.x, 0, 0,
|
||||
0, scale.y, 0,
|
||||
0, 0, scale.z
|
||||
);
|
||||
float x = rot.x;
|
||||
float y = rot.y;
|
||||
float z = rot.z;
|
||||
float w = rot.w;
|
||||
float3x3 mr = float3x3(
|
||||
1-2*(y*y + z*z), 2*(x*y - w*z), 2*(x*z + w*y),
|
||||
2*(x*y + w*z), 1-2*(x*x + z*z), 2*(y*z - w*x),
|
||||
2*(x*z - w*y), 2*(y*z + w*x), 1-2*(x*x + y*y)
|
||||
);
|
||||
return mul(mr, ms);
|
||||
}
|
||||
|
||||
void CalcCovariance3D(float3x3 rotMat, out float3 sigma0, out float3 sigma1)
|
||||
{
|
||||
float3x3 sig = mul(rotMat, transpose(rotMat));
|
||||
sigma0 = float3(sig._m00, sig._m01, sig._m02);
|
||||
sigma1 = float3(sig._m11, sig._m12, sig._m22);
|
||||
}
|
||||
|
||||
// from "EWA Splatting" (Zwicker et al 2002) eq. 31
|
||||
float3 CalcCovariance2D(float3 worldPos, float3 cov3d0, float3 cov3d1, float4x4 matrixV, float4x4 matrixP, float4 screenParams)
|
||||
{
|
||||
float4x4 viewMatrix = matrixV;
|
||||
float3 viewPos = mul(viewMatrix, float4(worldPos, 1)).xyz;
|
||||
|
||||
// this is needed in order for splats that are visible in view but clipped "quite a lot" to work
|
||||
float aspect = matrixP._m00 / matrixP._m11;
|
||||
float tanFovX = rcp(matrixP._m00);
|
||||
float tanFovY = rcp(matrixP._m11 * aspect);
|
||||
float limX = 1.3 * tanFovX;
|
||||
float limY = 1.3 * tanFovY;
|
||||
viewPos.x = clamp(viewPos.x / viewPos.z, -limX, limX) * viewPos.z;
|
||||
viewPos.y = clamp(viewPos.y / viewPos.z, -limY, limY) * viewPos.z;
|
||||
|
||||
float focal = screenParams.x * matrixP._m00 / 2;
|
||||
|
||||
float3x3 J = float3x3(
|
||||
focal / viewPos.z, 0, -(focal * viewPos.x) / (viewPos.z * viewPos.z),
|
||||
0, focal / viewPos.z, -(focal * viewPos.y) / (viewPos.z * viewPos.z),
|
||||
0, 0, 0
|
||||
);
|
||||
float3x3 W = (float3x3)viewMatrix;
|
||||
float3x3 T = mul(J, W);
|
||||
float3x3 V = float3x3(
|
||||
cov3d0.x, cov3d0.y, cov3d0.z,
|
||||
cov3d0.y, cov3d1.x, cov3d1.y,
|
||||
cov3d0.z, cov3d1.y, cov3d1.z
|
||||
);
|
||||
float3x3 cov = mul(T, mul(V, transpose(T)));
|
||||
|
||||
// Low pass filter to make each splat at least 1px size.
|
||||
cov._m00 += 0.3;
|
||||
cov._m11 += 0.3;
|
||||
return float3(cov._m00, cov._m01, cov._m11);
|
||||
}
|
||||
|
||||
float3 CalcConic(float3 cov2d)
|
||||
{
|
||||
float det = cov2d.x * cov2d.z - cov2d.y * cov2d.y;
|
||||
return float3(cov2d.z, -cov2d.y, cov2d.x) * rcp(det);
|
||||
}
|
||||
|
||||
float2 CalcScreenSpaceDelta(float2 svPositionXY, float2 centerXY, float4 projectionParams)
|
||||
{
|
||||
float2 d = svPositionXY - centerXY;
|
||||
d.y *= projectionParams.x;
|
||||
return d;
|
||||
}
|
||||
|
||||
float CalcPowerFromConic(float3 conic, float2 d)
|
||||
{
|
||||
return -0.5 * (conic.x * d.x*d.x + conic.z * d.y*d.y) + conic.y * d.x*d.y;
|
||||
}
|
||||
|
||||
// Morton interleaving 16x16 group i.e. by 4 bits of coordinates, based on this thread:
|
||||
// https://twitter.com/rygorous/status/986715358852608000
|
||||
// which is simplified version of https://fgiesen.wordpress.com/2009/12/13/decoding-morton-codes/
|
||||
uint EncodeMorton2D_16x16(uint2 c)
|
||||
{
|
||||
uint t = ((c.y & 0xF) << 8) | (c.x & 0xF); // ----EFGH----ABCD
|
||||
t = (t ^ (t << 2)) & 0x3333; // --EF--GH--AB--CD
|
||||
t = (t ^ (t << 1)) & 0x5555; // -E-F-G-H-A-B-C-D
|
||||
return (t | (t >> 7)) & 0xFF; // --------EAFBGCHD
|
||||
}
|
||||
uint2 DecodeMorton2D_16x16(uint t) // --------EAFBGCHD
|
||||
{
|
||||
t = (t & 0xFF) | ((t & 0xFE) << 7); // -EAFBGCHEAFBGCHD
|
||||
t &= 0x5555; // -E-F-G-H-A-B-C-D
|
||||
t = (t ^ (t >> 1)) & 0x3333; // --EF--GH--AB--CD
|
||||
t = (t ^ (t >> 2)) & 0x0f0f; // ----EFGH----ABCD
|
||||
return uint2(t & 0xF, t >> 8); // --------EFGHABCD
|
||||
}
|
||||
|
||||
|
||||
static const float SH_C1 = 0.4886025;
|
||||
static const float SH_C2[] = { 1.0925484, -1.0925484, 0.3153916, -1.0925484, 0.5462742 };
|
||||
static const float SH_C3[] = { -0.5900436, 2.8906114, -0.4570458, 0.3731763, -0.4570458, 1.4453057, -0.5900436 };
|
||||
|
||||
struct SplatSHData
|
||||
{
|
||||
half3 col, sh1, sh2, sh3, sh4, sh5, sh6, sh7, sh8, sh9, sh10, sh11, sh12, sh13, sh14, sh15;
|
||||
};
|
||||
|
||||
half3 ShadeSH(SplatSHData splat, half3 dir, int shOrder, bool onlySH)
|
||||
{
|
||||
dir *= -1;
|
||||
|
||||
half x = dir.x, y = dir.y, z = dir.z;
|
||||
|
||||
// ambient band
|
||||
half3 res = splat.col; // col = sh0 * SH_C0 + 0.5 is already precomputed
|
||||
if (onlySH)
|
||||
res = 0.5;
|
||||
// 1st degree
|
||||
if (shOrder >= 1)
|
||||
{
|
||||
res += SH_C1 * (-splat.sh1 * y + splat.sh2 * z - splat.sh3 * x);
|
||||
// 2nd degree
|
||||
if (shOrder >= 2)
|
||||
{
|
||||
half xx = x * x, yy = y * y, zz = z * z;
|
||||
half xy = x * y, yz = y * z, xz = x * z;
|
||||
res +=
|
||||
(SH_C2[0] * xy) * splat.sh4 +
|
||||
(SH_C2[1] * yz) * splat.sh5 +
|
||||
(SH_C2[2] * (2 * zz - xx - yy)) * splat.sh6 +
|
||||
(SH_C2[3] * xz) * splat.sh7 +
|
||||
(SH_C2[4] * (xx - yy)) * splat.sh8;
|
||||
// 3rd degree
|
||||
if (shOrder >= 3)
|
||||
{
|
||||
res +=
|
||||
(SH_C3[0] * y * (3 * xx - yy)) * splat.sh9 +
|
||||
(SH_C3[1] * xy * z) * splat.sh10 +
|
||||
(SH_C3[2] * y * (4 * zz - xx - yy)) * splat.sh11 +
|
||||
(SH_C3[3] * z * (2 * zz - 3 * xx - 3 * yy)) * splat.sh12 +
|
||||
(SH_C3[4] * x * (4 * zz - xx - yy)) * splat.sh13 +
|
||||
(SH_C3[5] * z * (xx - yy)) * splat.sh14 +
|
||||
(SH_C3[6] * x * (xx - 3 * yy)) * splat.sh15;
|
||||
}
|
||||
}
|
||||
}
|
||||
return max(res, 0);
|
||||
}
|
||||
|
||||
static const uint kTexWidth = 2048;
|
||||
|
||||
uint3 SplatIndexToPixelIndex(uint idx)
|
||||
{
|
||||
uint3 res;
|
||||
|
||||
uint2 xy = DecodeMorton2D_16x16(idx);
|
||||
uint width = kTexWidth / 16;
|
||||
idx >>= 8;
|
||||
res.x = (idx % width) * 16 + xy.x;
|
||||
res.y = (idx / width) * 16 + xy.y;
|
||||
res.z = 0;
|
||||
return res;
|
||||
}
|
||||
|
||||
struct SplatChunkInfo
|
||||
{
|
||||
uint colR, colG, colB, colA;
|
||||
float2 posX, posY, posZ;
|
||||
uint sclX, sclY, sclZ;
|
||||
uint shR, shG, shB;
|
||||
};
|
||||
|
||||
StructuredBuffer<SplatChunkInfo> _SplatChunks;
|
||||
uint _SplatChunkCount;
|
||||
|
||||
static const uint kChunkSize = 256;
|
||||
|
||||
struct SplatData
|
||||
{
|
||||
float3 pos;
|
||||
float4 rot;
|
||||
float3 scale;
|
||||
half opacity;
|
||||
SplatSHData sh;
|
||||
};
|
||||
|
||||
// Decode quaternion from a "smallest 3" e.g. 10.10.10.2 format
|
||||
float4 DecodeRotation(float4 pq)
|
||||
{
|
||||
uint idx = (uint)round(pq.w * 3.0); // note: need to round or index might come out wrong in some formats (e.g. fp16.fp16.fp16.fp16)
|
||||
float4 q;
|
||||
q.xyz = pq.xyz * sqrt(2.0) - (1.0 / sqrt(2.0));
|
||||
q.w = sqrt(1.0 - saturate(dot(q.xyz, q.xyz)));
|
||||
if (idx == 0) q = q.wxyz;
|
||||
if (idx == 1) q = q.xwyz;
|
||||
if (idx == 2) q = q.xywz;
|
||||
return q;
|
||||
}
|
||||
float4 PackSmallest3Rotation(float4 q)
|
||||
{
|
||||
// find biggest component
|
||||
float4 absQ = abs(q);
|
||||
int index = 0;
|
||||
float maxV = absQ.x;
|
||||
if (absQ.y > maxV)
|
||||
{
|
||||
index = 1;
|
||||
maxV = absQ.y;
|
||||
}
|
||||
if (absQ.z > maxV)
|
||||
{
|
||||
index = 2;
|
||||
maxV = absQ.z;
|
||||
}
|
||||
if (absQ.w > maxV)
|
||||
{
|
||||
index = 3;
|
||||
maxV = absQ.w;
|
||||
}
|
||||
|
||||
if (index == 0) q = q.yzwx;
|
||||
if (index == 1) q = q.xzwy;
|
||||
if (index == 2) q = q.xywz;
|
||||
|
||||
float3 three = q.xyz * (q.w >= 0 ? 1 : -1); // -1/sqrt2..+1/sqrt2 range
|
||||
three = (three * sqrt(2.0)) * 0.5 + 0.5; // 0..1 range
|
||||
return float4(three, index / 3.0);
|
||||
}
|
||||
|
||||
half3 DecodePacked_6_5_5(uint enc)
|
||||
{
|
||||
return half3(
|
||||
(enc & 63) / 63.0,
|
||||
((enc >> 6) & 31) / 31.0,
|
||||
((enc >> 11) & 31) / 31.0);
|
||||
}
|
||||
|
||||
half3 DecodePacked_5_6_5(uint enc)
|
||||
{
|
||||
return half3(
|
||||
(enc & 31) / 31.0,
|
||||
((enc >> 5) & 63) / 63.0,
|
||||
((enc >> 11) & 31) / 31.0);
|
||||
}
|
||||
|
||||
half3 DecodePacked_11_10_11(uint enc)
|
||||
{
|
||||
return half3(
|
||||
(enc & 2047) / 2047.0,
|
||||
((enc >> 11) & 1023) / 1023.0,
|
||||
((enc >> 21) & 2047) / 2047.0);
|
||||
}
|
||||
|
||||
float3 DecodePacked_16_16_16(uint2 enc)
|
||||
{
|
||||
return float3(
|
||||
(enc.x & 65535) / 65535.0,
|
||||
((enc.x >> 16) & 65535) / 65535.0,
|
||||
(enc.y & 65535) / 65535.0);
|
||||
}
|
||||
|
||||
float4 DecodePacked_10_10_10_2(uint enc)
|
||||
{
|
||||
return float4(
|
||||
(enc & 1023) / 1023.0,
|
||||
((enc >> 10) & 1023) / 1023.0,
|
||||
((enc >> 20) & 1023) / 1023.0,
|
||||
((enc >> 30) & 3) / 3.0);
|
||||
}
|
||||
uint EncodeQuatToNorm10(float4 v) // 32 bits: 10.10.10.2
|
||||
{
|
||||
return (uint) (v.x * 1023.5f) | ((uint) (v.y * 1023.5f) << 10) | ((uint) (v.z * 1023.5f) << 20) | ((uint) (v.w * 3.5f) << 30);
|
||||
}
|
||||
|
||||
|
||||
#ifdef SHADER_STAGE_COMPUTE
|
||||
#define SplatBufferDataType RWByteAddressBuffer
|
||||
#else
|
||||
#define SplatBufferDataType ByteAddressBuffer
|
||||
#endif
|
||||
|
||||
SplatBufferDataType _SplatPos;
|
||||
SplatBufferDataType _SplatOther;
|
||||
SplatBufferDataType _SplatSH;
|
||||
Texture2D _SplatColor;
|
||||
uint _SplatFormat;
|
||||
|
||||
// Match GaussianSplatAsset.VectorFormat
|
||||
#define VECTOR_FMT_32F 0
|
||||
#define VECTOR_FMT_16 1
|
||||
#define VECTOR_FMT_11 2
|
||||
#define VECTOR_FMT_6 3
|
||||
|
||||
uint LoadUShort(SplatBufferDataType dataBuffer, uint addrU)
|
||||
{
|
||||
uint addrA = addrU & ~0x3;
|
||||
uint val = dataBuffer.Load(addrA);
|
||||
if (addrU != addrA)
|
||||
val >>= 16;
|
||||
return val & 0xFFFF;
|
||||
}
|
||||
|
||||
uint LoadUInt(SplatBufferDataType dataBuffer, uint addrU)
|
||||
{
|
||||
uint addrA = addrU & ~0x3;
|
||||
uint val = dataBuffer.Load(addrA);
|
||||
if (addrU != addrA)
|
||||
{
|
||||
uint val1 = dataBuffer.Load(addrA + 4);
|
||||
val = (val >> 16) | ((val1 & 0xFFFF) << 16);
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
float3 LoadAndDecodeVector(SplatBufferDataType dataBuffer, uint addrU, uint fmt)
|
||||
{
|
||||
uint addrA = addrU & ~0x3;
|
||||
|
||||
uint val0 = dataBuffer.Load(addrA);
|
||||
|
||||
float3 res = 0;
|
||||
if (fmt == VECTOR_FMT_32F)
|
||||
{
|
||||
uint val1 = dataBuffer.Load(addrA + 4);
|
||||
uint val2 = dataBuffer.Load(addrA + 8);
|
||||
if (addrU != addrA)
|
||||
{
|
||||
uint val3 = dataBuffer.Load(addrA + 12);
|
||||
val0 = (val0 >> 16) | ((val1 & 0xFFFF) << 16);
|
||||
val1 = (val1 >> 16) | ((val2 & 0xFFFF) << 16);
|
||||
val2 = (val2 >> 16) | ((val3 & 0xFFFF) << 16);
|
||||
}
|
||||
res = float3(asfloat(val0), asfloat(val1), asfloat(val2));
|
||||
}
|
||||
else if (fmt == VECTOR_FMT_16)
|
||||
{
|
||||
uint val1 = dataBuffer.Load(addrA + 4);
|
||||
if (addrU != addrA)
|
||||
{
|
||||
val0 = (val0 >> 16) | ((val1 & 0xFFFF) << 16);
|
||||
val1 >>= 16;
|
||||
}
|
||||
res = DecodePacked_16_16_16(uint2(val0, val1));
|
||||
}
|
||||
else if (fmt == VECTOR_FMT_11)
|
||||
{
|
||||
uint val1 = dataBuffer.Load(addrA + 4);
|
||||
if (addrU != addrA)
|
||||
{
|
||||
val0 = (val0 >> 16) | ((val1 & 0xFFFF) << 16);
|
||||
}
|
||||
res = DecodePacked_11_10_11(val0);
|
||||
}
|
||||
else if (fmt == VECTOR_FMT_6)
|
||||
{
|
||||
if (addrU != addrA)
|
||||
val0 >>= 16;
|
||||
res = DecodePacked_6_5_5(val0);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
float3 LoadSplatPosValue(uint index)
|
||||
{
|
||||
uint fmt = _SplatFormat & 0xFF;
|
||||
uint stride = 0;
|
||||
if (fmt == VECTOR_FMT_32F)
|
||||
stride = 12;
|
||||
else if (fmt == VECTOR_FMT_16)
|
||||
stride = 6;
|
||||
else if (fmt == VECTOR_FMT_11)
|
||||
stride = 4;
|
||||
else if (fmt == VECTOR_FMT_6)
|
||||
stride = 2;
|
||||
return LoadAndDecodeVector(_SplatPos, index * stride, fmt);
|
||||
}
|
||||
|
||||
float3 LoadSplatPos(uint idx)
|
||||
{
|
||||
float3 pos = LoadSplatPosValue(idx);
|
||||
uint chunkIdx = idx / kChunkSize;
|
||||
if (chunkIdx < _SplatChunkCount)
|
||||
{
|
||||
SplatChunkInfo chunk = _SplatChunks[chunkIdx];
|
||||
float3 posMin = float3(chunk.posX.x, chunk.posY.x, chunk.posZ.x);
|
||||
float3 posMax = float3(chunk.posX.y, chunk.posY.y, chunk.posZ.y);
|
||||
pos = lerp(posMin, posMax, pos);
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
half4 LoadSplatColTex(uint3 coord)
|
||||
{
|
||||
return _SplatColor.Load(coord);
|
||||
}
|
||||
|
||||
SplatData LoadSplatData(uint idx)
|
||||
{
|
||||
SplatData s = (SplatData)0;
|
||||
|
||||
// figure out raw data offsets / locations
|
||||
uint3 coord = SplatIndexToPixelIndex(idx);
|
||||
|
||||
uint scaleFmt = (_SplatFormat >> 8) & 0xFF;
|
||||
uint shFormat = (_SplatFormat >> 16) & 0xFF;
|
||||
|
||||
uint otherStride = 4; // rotation is 10.10.10.2
|
||||
if (scaleFmt == VECTOR_FMT_32F)
|
||||
otherStride += 12;
|
||||
else if (scaleFmt == VECTOR_FMT_16)
|
||||
otherStride += 6;
|
||||
else if (scaleFmt == VECTOR_FMT_11)
|
||||
otherStride += 4;
|
||||
else if (scaleFmt == VECTOR_FMT_6)
|
||||
otherStride += 2;
|
||||
if (shFormat > VECTOR_FMT_6)
|
||||
otherStride += 2;
|
||||
uint otherAddr = idx * otherStride;
|
||||
|
||||
uint shStride = 0;
|
||||
if (shFormat == VECTOR_FMT_32F)
|
||||
shStride = 192; // 15*3 fp32, rounded up to multiple of 16
|
||||
else if (shFormat == VECTOR_FMT_16 || shFormat > VECTOR_FMT_6)
|
||||
shStride = 96; // 15*3 fp16, rounded up to multiple of 16
|
||||
else if (shFormat == VECTOR_FMT_11)
|
||||
shStride = 60; // 15x uint
|
||||
else if (shFormat == VECTOR_FMT_6)
|
||||
shStride = 32; // 15x ushort, rounded up to multiple of 4
|
||||
|
||||
|
||||
// load raw splat data, which might be chunk-relative
|
||||
s.pos = LoadSplatPosValue(idx);
|
||||
s.rot = DecodeRotation(DecodePacked_10_10_10_2(LoadUInt(_SplatOther, otherAddr)));
|
||||
s.scale = LoadAndDecodeVector(_SplatOther, otherAddr + 4, scaleFmt);
|
||||
half4 col = LoadSplatColTex(coord);
|
||||
|
||||
uint shIndex = idx;
|
||||
if (shFormat > VECTOR_FMT_6)
|
||||
shIndex = LoadUShort(_SplatOther, otherAddr + otherStride - 2);
|
||||
|
||||
uint shOffset = shIndex * shStride;
|
||||
uint4 shRaw0 = _SplatSH.Load4(shOffset);
|
||||
uint4 shRaw1 = _SplatSH.Load4(shOffset + 16);
|
||||
if (shFormat == VECTOR_FMT_32F)
|
||||
{
|
||||
uint4 shRaw2 = _SplatSH.Load4(shOffset + 32);
|
||||
uint4 shRaw3 = _SplatSH.Load4(shOffset + 48);
|
||||
uint4 shRaw4 = _SplatSH.Load4(shOffset + 64);
|
||||
uint4 shRaw5 = _SplatSH.Load4(shOffset + 80);
|
||||
uint4 shRaw6 = _SplatSH.Load4(shOffset + 96);
|
||||
uint4 shRaw7 = _SplatSH.Load4(shOffset + 112);
|
||||
uint4 shRaw8 = _SplatSH.Load4(shOffset + 128);
|
||||
uint4 shRaw9 = _SplatSH.Load4(shOffset + 144);
|
||||
uint4 shRawA = _SplatSH.Load4(shOffset + 160);
|
||||
uint shRawB = _SplatSH.Load(shOffset + 176);
|
||||
s.sh.sh1.r = asfloat(shRaw0.x); s.sh.sh1.g = asfloat(shRaw0.y); s.sh.sh1.b = asfloat(shRaw0.z);
|
||||
s.sh.sh2.r = asfloat(shRaw0.w); s.sh.sh2.g = asfloat(shRaw1.x); s.sh.sh2.b = asfloat(shRaw1.y);
|
||||
s.sh.sh3.r = asfloat(shRaw1.z); s.sh.sh3.g = asfloat(shRaw1.w); s.sh.sh3.b = asfloat(shRaw2.x);
|
||||
s.sh.sh4.r = asfloat(shRaw2.y); s.sh.sh4.g = asfloat(shRaw2.z); s.sh.sh4.b = asfloat(shRaw2.w);
|
||||
s.sh.sh5.r = asfloat(shRaw3.x); s.sh.sh5.g = asfloat(shRaw3.y); s.sh.sh5.b = asfloat(shRaw3.z);
|
||||
s.sh.sh6.r = asfloat(shRaw3.w); s.sh.sh6.g = asfloat(shRaw4.x); s.sh.sh6.b = asfloat(shRaw4.y);
|
||||
s.sh.sh7.r = asfloat(shRaw4.z); s.sh.sh7.g = asfloat(shRaw4.w); s.sh.sh7.b = asfloat(shRaw5.x);
|
||||
s.sh.sh8.r = asfloat(shRaw5.y); s.sh.sh8.g = asfloat(shRaw5.z); s.sh.sh8.b = asfloat(shRaw5.w);
|
||||
s.sh.sh9.r = asfloat(shRaw6.x); s.sh.sh9.g = asfloat(shRaw6.y); s.sh.sh9.b = asfloat(shRaw6.z);
|
||||
s.sh.sh10.r = asfloat(shRaw6.w); s.sh.sh10.g = asfloat(shRaw7.x); s.sh.sh10.b = asfloat(shRaw7.y);
|
||||
s.sh.sh11.r = asfloat(shRaw7.z); s.sh.sh11.g = asfloat(shRaw7.w); s.sh.sh11.b = asfloat(shRaw8.x);
|
||||
s.sh.sh12.r = asfloat(shRaw8.y); s.sh.sh12.g = asfloat(shRaw8.z); s.sh.sh12.b = asfloat(shRaw8.w);
|
||||
s.sh.sh13.r = asfloat(shRaw9.x); s.sh.sh13.g = asfloat(shRaw9.y); s.sh.sh13.b = asfloat(shRaw9.z);
|
||||
s.sh.sh14.r = asfloat(shRaw9.w); s.sh.sh14.g = asfloat(shRawA.x); s.sh.sh14.b = asfloat(shRawA.y);
|
||||
s.sh.sh15.r = asfloat(shRawA.z); s.sh.sh15.g = asfloat(shRawA.w); s.sh.sh15.b = asfloat(shRawB);
|
||||
}
|
||||
else if (shFormat == VECTOR_FMT_16 || shFormat > VECTOR_FMT_6)
|
||||
{
|
||||
uint4 shRaw2 = _SplatSH.Load4(shOffset + 32);
|
||||
uint4 shRaw3 = _SplatSH.Load4(shOffset + 48);
|
||||
uint4 shRaw4 = _SplatSH.Load4(shOffset + 64);
|
||||
uint3 shRaw5 = _SplatSH.Load3(shOffset + 80);
|
||||
s.sh.sh1.r = f16tof32(shRaw0.x ); s.sh.sh1.g = f16tof32(shRaw0.x >> 16); s.sh.sh1.b = f16tof32(shRaw0.y );
|
||||
s.sh.sh2.r = f16tof32(shRaw0.y >> 16); s.sh.sh2.g = f16tof32(shRaw0.z ); s.sh.sh2.b = f16tof32(shRaw0.z >> 16);
|
||||
s.sh.sh3.r = f16tof32(shRaw0.w ); s.sh.sh3.g = f16tof32(shRaw0.w >> 16); s.sh.sh3.b = f16tof32(shRaw1.x );
|
||||
s.sh.sh4.r = f16tof32(shRaw1.x >> 16); s.sh.sh4.g = f16tof32(shRaw1.y ); s.sh.sh4.b = f16tof32(shRaw1.y >> 16);
|
||||
s.sh.sh5.r = f16tof32(shRaw1.z ); s.sh.sh5.g = f16tof32(shRaw1.z >> 16); s.sh.sh5.b = f16tof32(shRaw1.w );
|
||||
s.sh.sh6.r = f16tof32(shRaw1.w >> 16); s.sh.sh6.g = f16tof32(shRaw2.x ); s.sh.sh6.b = f16tof32(shRaw2.x >> 16);
|
||||
s.sh.sh7.r = f16tof32(shRaw2.y ); s.sh.sh7.g = f16tof32(shRaw2.y >> 16); s.sh.sh7.b = f16tof32(shRaw2.z );
|
||||
s.sh.sh8.r = f16tof32(shRaw2.z >> 16); s.sh.sh8.g = f16tof32(shRaw2.w ); s.sh.sh8.b = f16tof32(shRaw2.w >> 16);
|
||||
s.sh.sh9.r = f16tof32(shRaw3.x ); s.sh.sh9.g = f16tof32(shRaw3.x >> 16); s.sh.sh9.b = f16tof32(shRaw3.y );
|
||||
s.sh.sh10.r = f16tof32(shRaw3.y >> 16); s.sh.sh10.g = f16tof32(shRaw3.z ); s.sh.sh10.b = f16tof32(shRaw3.z >> 16);
|
||||
s.sh.sh11.r = f16tof32(shRaw3.w ); s.sh.sh11.g = f16tof32(shRaw3.w >> 16); s.sh.sh11.b = f16tof32(shRaw4.x );
|
||||
s.sh.sh12.r = f16tof32(shRaw4.x >> 16); s.sh.sh12.g = f16tof32(shRaw4.y ); s.sh.sh12.b = f16tof32(shRaw4.y >> 16);
|
||||
s.sh.sh13.r = f16tof32(shRaw4.z ); s.sh.sh13.g = f16tof32(shRaw4.z >> 16); s.sh.sh13.b = f16tof32(shRaw4.w );
|
||||
s.sh.sh14.r = f16tof32(shRaw4.w >> 16); s.sh.sh14.g = f16tof32(shRaw5.x ); s.sh.sh14.b = f16tof32(shRaw5.x >> 16);
|
||||
s.sh.sh15.r = f16tof32(shRaw5.y ); s.sh.sh15.g = f16tof32(shRaw5.y >> 16); s.sh.sh15.b = f16tof32(shRaw5.z );
|
||||
}
|
||||
else if (shFormat == VECTOR_FMT_11)
|
||||
{
|
||||
uint4 shRaw2 = _SplatSH.Load4(shOffset + 32);
|
||||
uint3 shRaw3 = _SplatSH.Load3(shOffset + 48);
|
||||
s.sh.sh1 = DecodePacked_11_10_11(shRaw0.x);
|
||||
s.sh.sh2 = DecodePacked_11_10_11(shRaw0.y);
|
||||
s.sh.sh3 = DecodePacked_11_10_11(shRaw0.z);
|
||||
s.sh.sh4 = DecodePacked_11_10_11(shRaw0.w);
|
||||
s.sh.sh5 = DecodePacked_11_10_11(shRaw1.x);
|
||||
s.sh.sh6 = DecodePacked_11_10_11(shRaw1.y);
|
||||
s.sh.sh7 = DecodePacked_11_10_11(shRaw1.z);
|
||||
s.sh.sh8 = DecodePacked_11_10_11(shRaw1.w);
|
||||
s.sh.sh9 = DecodePacked_11_10_11(shRaw2.x);
|
||||
s.sh.sh10 = DecodePacked_11_10_11(shRaw2.y);
|
||||
s.sh.sh11 = DecodePacked_11_10_11(shRaw2.z);
|
||||
s.sh.sh12 = DecodePacked_11_10_11(shRaw2.w);
|
||||
s.sh.sh13 = DecodePacked_11_10_11(shRaw3.x);
|
||||
s.sh.sh14 = DecodePacked_11_10_11(shRaw3.y);
|
||||
s.sh.sh15 = DecodePacked_11_10_11(shRaw3.z);
|
||||
}
|
||||
else if (shFormat == VECTOR_FMT_6)
|
||||
{
|
||||
s.sh.sh1 = DecodePacked_5_6_5(shRaw0.x);
|
||||
s.sh.sh2 = DecodePacked_5_6_5(shRaw0.x >> 16);
|
||||
s.sh.sh3 = DecodePacked_5_6_5(shRaw0.y);
|
||||
s.sh.sh4 = DecodePacked_5_6_5(shRaw0.y >> 16);
|
||||
s.sh.sh5 = DecodePacked_5_6_5(shRaw0.z);
|
||||
s.sh.sh6 = DecodePacked_5_6_5(shRaw0.z >> 16);
|
||||
s.sh.sh7 = DecodePacked_5_6_5(shRaw0.w);
|
||||
s.sh.sh8 = DecodePacked_5_6_5(shRaw0.w >> 16);
|
||||
s.sh.sh9 = DecodePacked_5_6_5(shRaw1.x);
|
||||
s.sh.sh10 = DecodePacked_5_6_5(shRaw1.x >> 16);
|
||||
s.sh.sh11 = DecodePacked_5_6_5(shRaw1.y);
|
||||
s.sh.sh12 = DecodePacked_5_6_5(shRaw1.y >> 16);
|
||||
s.sh.sh13 = DecodePacked_5_6_5(shRaw1.z);
|
||||
s.sh.sh14 = DecodePacked_5_6_5(shRaw1.z >> 16);
|
||||
s.sh.sh15 = DecodePacked_5_6_5(shRaw1.w);
|
||||
}
|
||||
|
||||
// if raw data is chunk-relative, convert to final values by interpolating between chunk min/max
|
||||
uint chunkIdx = idx / kChunkSize;
|
||||
if (chunkIdx < _SplatChunkCount)
|
||||
{
|
||||
SplatChunkInfo chunk = _SplatChunks[chunkIdx];
|
||||
float3 posMin = float3(chunk.posX.x, chunk.posY.x, chunk.posZ.x);
|
||||
float3 posMax = float3(chunk.posX.y, chunk.posY.y, chunk.posZ.y);
|
||||
half3 sclMin = half3(f16tof32(chunk.sclX ), f16tof32(chunk.sclY ), f16tof32(chunk.sclZ ));
|
||||
half3 sclMax = half3(f16tof32(chunk.sclX>>16), f16tof32(chunk.sclY>>16), f16tof32(chunk.sclZ>>16));
|
||||
half4 colMin = half4(f16tof32(chunk.colR ), f16tof32(chunk.colG ), f16tof32(chunk.colB ), f16tof32(chunk.colA ));
|
||||
half4 colMax = half4(f16tof32(chunk.colR>>16), f16tof32(chunk.colG>>16), f16tof32(chunk.colB>>16), f16tof32(chunk.colA>>16));
|
||||
half3 shMin = half3(f16tof32(chunk.shR ), f16tof32(chunk.shG ), f16tof32(chunk.shB ));
|
||||
half3 shMax = half3(f16tof32(chunk.shR>>16), f16tof32(chunk.shG>>16), f16tof32(chunk.shB>>16));
|
||||
s.pos = lerp(posMin, posMax, s.pos);
|
||||
s.scale = lerp(sclMin, sclMax, s.scale);
|
||||
s.scale *= s.scale;
|
||||
s.scale *= s.scale;
|
||||
s.scale *= s.scale;
|
||||
col = lerp(colMin, colMax, col);
|
||||
col.a = InvSquareCentered01(col.a);
|
||||
|
||||
if (shFormat > VECTOR_FMT_32F && shFormat <= VECTOR_FMT_6)
|
||||
{
|
||||
s.sh.sh1 = lerp(shMin, shMax, s.sh.sh1 );
|
||||
s.sh.sh2 = lerp(shMin, shMax, s.sh.sh2 );
|
||||
s.sh.sh3 = lerp(shMin, shMax, s.sh.sh3 );
|
||||
s.sh.sh4 = lerp(shMin, shMax, s.sh.sh4 );
|
||||
s.sh.sh5 = lerp(shMin, shMax, s.sh.sh5 );
|
||||
s.sh.sh6 = lerp(shMin, shMax, s.sh.sh6 );
|
||||
s.sh.sh7 = lerp(shMin, shMax, s.sh.sh7 );
|
||||
s.sh.sh8 = lerp(shMin, shMax, s.sh.sh8 );
|
||||
s.sh.sh9 = lerp(shMin, shMax, s.sh.sh9 );
|
||||
s.sh.sh10 = lerp(shMin, shMax, s.sh.sh10);
|
||||
s.sh.sh11 = lerp(shMin, shMax, s.sh.sh11);
|
||||
s.sh.sh12 = lerp(shMin, shMax, s.sh.sh12);
|
||||
s.sh.sh13 = lerp(shMin, shMax, s.sh.sh13);
|
||||
s.sh.sh14 = lerp(shMin, shMax, s.sh.sh14);
|
||||
s.sh.sh15 = lerp(shMin, shMax, s.sh.sh15);
|
||||
}
|
||||
}
|
||||
s.opacity = col.a;
|
||||
s.sh.col = col.rgb;
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
struct SplatViewData
|
||||
{
|
||||
float4 pos;
|
||||
float2 axis1, axis2;
|
||||
uint2 color; // 4xFP16
|
||||
};
|
||||
|
||||
#endif // GAUSSIAN_SPLATTING_HLSL
|
||||
7
MVS/3DGS-Unity/Shaders/GaussianSplatting.hlsl.meta
Normal file
7
MVS/3DGS-Unity/Shaders/GaussianSplatting.hlsl.meta
Normal file
@@ -0,0 +1,7 @@
|
||||
fileFormatVersion: 2
|
||||
guid: d4087e54957693c48a7be32de91c99e2
|
||||
ShaderIncludeImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
||||
111
MVS/3DGS-Unity/Shaders/RenderGaussianSplats.shader
Normal file
111
MVS/3DGS-Unity/Shaders/RenderGaussianSplats.shader
Normal file
@@ -0,0 +1,111 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
Shader "Gaussian Splatting/Render Splats"
|
||||
{
|
||||
SubShader
|
||||
{
|
||||
Tags { "RenderType"="Transparent" "Queue"="Transparent" }
|
||||
|
||||
Pass
|
||||
{
|
||||
ZWrite Off
|
||||
Blend OneMinusDstAlpha One
|
||||
Cull Off
|
||||
|
||||
CGPROGRAM
|
||||
#pragma vertex vert
|
||||
#pragma fragment frag
|
||||
#pragma require compute
|
||||
#pragma use_dxc
|
||||
|
||||
#include "GaussianSplatting.hlsl"
|
||||
|
||||
StructuredBuffer<uint> _OrderBuffer;
|
||||
|
||||
struct v2f
|
||||
{
|
||||
half4 col : COLOR0;
|
||||
float2 pos : TEXCOORD0;
|
||||
float4 vertex : SV_POSITION;
|
||||
};
|
||||
|
||||
StructuredBuffer<SplatViewData> _SplatViewData;
|
||||
ByteAddressBuffer _SplatSelectedBits;
|
||||
uint _SplatBitsValid;
|
||||
|
||||
v2f vert (uint vtxID : SV_VertexID, uint instID : SV_InstanceID)
|
||||
{
|
||||
v2f o = (v2f)0;
|
||||
instID = _OrderBuffer[instID];
|
||||
SplatViewData view = _SplatViewData[instID];
|
||||
float4 centerClipPos = view.pos;
|
||||
bool behindCam = centerClipPos.w <= 0;
|
||||
if (behindCam)
|
||||
{
|
||||
o.vertex = asfloat(0x7fc00000); // NaN discards the primitive
|
||||
}
|
||||
else
|
||||
{
|
||||
o.col.r = f16tof32(view.color.x >> 16);
|
||||
o.col.g = f16tof32(view.color.x);
|
||||
o.col.b = f16tof32(view.color.y >> 16);
|
||||
o.col.a = f16tof32(view.color.y);
|
||||
|
||||
uint idx = vtxID;
|
||||
float2 quadPos = float2(idx&1, (idx>>1)&1) * 2.0 - 1.0;
|
||||
quadPos *= 2;
|
||||
|
||||
o.pos = quadPos;
|
||||
|
||||
float2 deltaScreenPos = (quadPos.x * view.axis1 + quadPos.y * view.axis2) * 2 / _ScreenParams.xy;
|
||||
o.vertex = centerClipPos;
|
||||
o.vertex.xy += deltaScreenPos * centerClipPos.w;
|
||||
|
||||
// is this splat selected?
|
||||
if (_SplatBitsValid)
|
||||
{
|
||||
uint wordIdx = instID / 32;
|
||||
uint bitIdx = instID & 31;
|
||||
uint selVal = _SplatSelectedBits.Load(wordIdx * 4);
|
||||
if (selVal & (1 << bitIdx))
|
||||
{
|
||||
o.col.a = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
return o;
|
||||
}
|
||||
|
||||
half4 frag (v2f i) : SV_Target
|
||||
{
|
||||
float power = -dot(i.pos, i.pos);
|
||||
half alpha = exp(power);
|
||||
if (i.col.a >= 0)
|
||||
{
|
||||
alpha = saturate(alpha * i.col.a);
|
||||
}
|
||||
else
|
||||
{
|
||||
// "selected" splat: magenta outline, increase opacity, magenta tint
|
||||
half3 selectedColor = half3(1,0,1);
|
||||
if (alpha > 7.0/255.0)
|
||||
{
|
||||
if (alpha < 10.0/255.0)
|
||||
{
|
||||
alpha = 1;
|
||||
i.col.rgb = selectedColor;
|
||||
}
|
||||
alpha = saturate(alpha + 0.3);
|
||||
}
|
||||
i.col.rgb = lerp(i.col.rgb, selectedColor, 0.5);
|
||||
}
|
||||
|
||||
if (alpha < 1.0/255.0)
|
||||
discard;
|
||||
|
||||
half4 res = half4(i.col.rgb * alpha, alpha);
|
||||
return res;
|
||||
}
|
||||
ENDCG
|
||||
}
|
||||
}
|
||||
}
|
||||
9
MVS/3DGS-Unity/Shaders/RenderGaussianSplats.shader.meta
Normal file
9
MVS/3DGS-Unity/Shaders/RenderGaussianSplats.shader.meta
Normal file
@@ -0,0 +1,9 @@
|
||||
fileFormatVersion: 2
|
||||
guid: ed800126ae8844a67aad1974ddddd59c
|
||||
ShaderImporter:
|
||||
externalObjects: {}
|
||||
defaultTextures: []
|
||||
nonModifiableTextures: []
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
||||
959
MVS/3DGS-Unity/Shaders/SortCommon.hlsl
Normal file
959
MVS/3DGS-Unity/Shaders/SortCommon.hlsl
Normal file
@@ -0,0 +1,959 @@
|
||||
/******************************************************************************
|
||||
* SortCommon
|
||||
* Common functions for GPUSorting
|
||||
*
|
||||
* SPDX-License-Identifier: MIT
|
||||
* Copyright Thomas Smith 5/17/2024
|
||||
* https://github.com/b0nes164/GPUSorting
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
******************************************************************************/
|
||||
#define KEYS_PER_THREAD 15U
|
||||
#define D_DIM 256U
|
||||
#define PART_SIZE 3840U
|
||||
#define D_TOTAL_SMEM 4096U
|
||||
|
||||
#define RADIX 256U //Number of digit bins
|
||||
#define RADIX_MASK 255U //Mask of digit bins
|
||||
#define HALF_RADIX 128U //For smaller waves where bit packing is necessary
|
||||
#define HALF_MASK 127U // ''
|
||||
#define RADIX_LOG 8U //log2(RADIX)
|
||||
#define RADIX_PASSES 4U //(Key width) / RADIX_LOG
|
||||
|
||||
cbuffer cbGpuSorting : register(b0)
|
||||
{
|
||||
uint e_numKeys;
|
||||
uint e_radixShift;
|
||||
uint e_threadBlocks;
|
||||
uint padding;
|
||||
};
|
||||
|
||||
#if defined(KEY_UINT)
|
||||
RWStructuredBuffer<uint> b_sort;
|
||||
RWStructuredBuffer<uint> b_alt;
|
||||
#elif defined(KEY_INT)
|
||||
RWStructuredBuffer<int> b_sort;
|
||||
RWStructuredBuffer<int> b_alt;
|
||||
#elif defined(KEY_FLOAT)
|
||||
RWStructuredBuffer<float> b_sort;
|
||||
RWStructuredBuffer<float> b_alt;
|
||||
#endif
|
||||
|
||||
#if defined(PAYLOAD_UINT)
|
||||
RWStructuredBuffer<uint> b_sortPayload;
|
||||
RWStructuredBuffer<uint> b_altPayload;
|
||||
#elif defined(PAYLOAD_INT)
|
||||
RWStructuredBuffer<int> b_sortPayload;
|
||||
RWStructuredBuffer<int> b_altPayload;
|
||||
#elif defined(PAYLOAD_FLOAT)
|
||||
RWStructuredBuffer<float> b_sortPayload;
|
||||
RWStructuredBuffer<float> b_altPayload;
|
||||
#endif
|
||||
|
||||
groupshared uint g_d[D_TOTAL_SMEM]; //Shared memory for DigitBinningPass and DownSweep kernels
|
||||
|
||||
struct KeyStruct
|
||||
{
|
||||
uint k[KEYS_PER_THREAD];
|
||||
};
|
||||
|
||||
struct OffsetStruct
|
||||
{
|
||||
#if defined(ENABLE_16_BIT)
|
||||
uint16_t o[KEYS_PER_THREAD];
|
||||
#else
|
||||
uint o[KEYS_PER_THREAD];
|
||||
#endif
|
||||
};
|
||||
|
||||
struct DigitStruct
|
||||
{
|
||||
#if defined(ENABLE_16_BIT)
|
||||
uint16_t d[KEYS_PER_THREAD];
|
||||
#else
|
||||
uint d[KEYS_PER_THREAD];
|
||||
#endif
|
||||
};
|
||||
|
||||
//*****************************************************************************
|
||||
//HELPER FUNCTIONS
|
||||
//*****************************************************************************
|
||||
//Due to a bug with SPIRV pre 1.6, we cannot use WaveGetLaneCount() to get the currently active wavesize
|
||||
inline uint getWaveSize()
|
||||
{
|
||||
#if defined(VULKAN)
|
||||
GroupMemoryBarrierWithGroupSync(); //Make absolutely sure the wave is not diverged here
|
||||
return dot(countbits(WaveActiveBallot(true)), uint4(1, 1, 1, 1));
|
||||
#else
|
||||
return WaveGetLaneCount();
|
||||
#endif
|
||||
}
|
||||
|
||||
inline uint getWaveIndex(uint gtid, uint waveSize)
|
||||
{
|
||||
return gtid / waveSize;
|
||||
}
|
||||
|
||||
//Radix Tricks by Michael Herf
|
||||
//http://stereopsis.com/radix.html
|
||||
inline uint FloatToUint(float f)
|
||||
{
|
||||
uint mask = -((int) (asuint(f) >> 31)) | 0x80000000;
|
||||
return asuint(f) ^ mask;
|
||||
}
|
||||
|
||||
inline float UintToFloat(uint u)
|
||||
{
|
||||
uint mask = ((u >> 31) - 1) | 0x80000000;
|
||||
return asfloat(u ^ mask);
|
||||
}
|
||||
|
||||
inline uint IntToUint(int i)
|
||||
{
|
||||
return asuint(i ^ 0x80000000);
|
||||
}
|
||||
|
||||
inline int UintToInt(uint u)
|
||||
{
|
||||
return asint(u ^ 0x80000000);
|
||||
}
|
||||
|
||||
inline uint getWaveCountPass(uint waveSize)
|
||||
{
|
||||
return D_DIM / waveSize;
|
||||
}
|
||||
|
||||
inline uint ExtractDigit(uint key)
|
||||
{
|
||||
return key >> e_radixShift & RADIX_MASK;
|
||||
}
|
||||
|
||||
inline uint ExtractDigit(uint key, uint shift)
|
||||
{
|
||||
return key >> shift & RADIX_MASK;
|
||||
}
|
||||
|
||||
inline uint ExtractPackedIndex(uint key)
|
||||
{
|
||||
return key >> (e_radixShift + 1) & HALF_MASK;
|
||||
}
|
||||
|
||||
inline uint ExtractPackedShift(uint key)
|
||||
{
|
||||
return (key >> e_radixShift & 1) ? 16 : 0;
|
||||
}
|
||||
|
||||
inline uint ExtractPackedValue(uint packed, uint key)
|
||||
{
|
||||
return packed >> ExtractPackedShift(key) & 0xffff;
|
||||
}
|
||||
|
||||
inline uint SubPartSizeWGE16(uint waveSize)
|
||||
{
|
||||
return KEYS_PER_THREAD * waveSize;
|
||||
}
|
||||
|
||||
inline uint SharedOffsetWGE16(uint gtid, uint waveSize)
|
||||
{
|
||||
return WaveGetLaneIndex() + getWaveIndex(gtid, waveSize) * SubPartSizeWGE16(waveSize);
|
||||
}
|
||||
|
||||
inline uint SubPartSizeWLT16(uint waveSize, uint _serialIterations)
|
||||
{
|
||||
return KEYS_PER_THREAD * waveSize * _serialIterations;
|
||||
}
|
||||
|
||||
inline uint SharedOffsetWLT16(uint gtid, uint waveSize, uint _serialIterations)
|
||||
{
|
||||
return WaveGetLaneIndex() +
|
||||
(getWaveIndex(gtid, waveSize) / _serialIterations * SubPartSizeWLT16(waveSize, _serialIterations)) +
|
||||
(getWaveIndex(gtid, waveSize) % _serialIterations * waveSize);
|
||||
}
|
||||
|
||||
inline uint DeviceOffsetWGE16(uint gtid, uint waveSize, uint partIndex)
|
||||
{
|
||||
return SharedOffsetWGE16(gtid, waveSize) + partIndex * PART_SIZE;
|
||||
}
|
||||
|
||||
inline uint DeviceOffsetWLT16(uint gtid, uint waveSize, uint partIndex, uint serialIterations)
|
||||
{
|
||||
return SharedOffsetWLT16(gtid, waveSize, serialIterations) + partIndex * PART_SIZE;
|
||||
}
|
||||
|
||||
inline uint GlobalHistOffset()
|
||||
{
|
||||
return e_radixShift << 5;
|
||||
}
|
||||
|
||||
inline uint WaveHistsSizeWGE16(uint waveSize)
|
||||
{
|
||||
return D_DIM / waveSize * RADIX;
|
||||
}
|
||||
|
||||
inline uint WaveHistsSizeWLT16()
|
||||
{
|
||||
return D_TOTAL_SMEM;
|
||||
}
|
||||
|
||||
//*****************************************************************************
|
||||
//FUNCTIONS COMMON TO THE DOWNSWEEP / DIGIT BINNING PASS
|
||||
//*****************************************************************************
|
||||
//If the size of a wave is too small, we do not have enough space in
|
||||
//shared memory to assign a histogram to each wave, so instead,
|
||||
//some operations are peformed serially.
|
||||
inline uint SerialIterations(uint waveSize)
|
||||
{
|
||||
return (D_DIM / waveSize + 31) >> 5;
|
||||
}
|
||||
|
||||
inline void ClearWaveHists(uint gtid, uint waveSize)
|
||||
{
|
||||
const uint histsEnd = waveSize >= 16 ?
|
||||
WaveHistsSizeWGE16(waveSize) : WaveHistsSizeWLT16();
|
||||
for (uint i = gtid; i < histsEnd; i += D_DIM)
|
||||
g_d[i] = 0;
|
||||
}
|
||||
|
||||
inline void LoadKey(inout uint key, uint index)
|
||||
{
|
||||
#if defined(KEY_UINT)
|
||||
key = b_sort[index];
|
||||
#elif defined(KEY_INT)
|
||||
key = UintToInt(b_sort[index]);
|
||||
#elif defined(KEY_FLOAT)
|
||||
key = FloatToUint(b_sort[index]);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void LoadDummyKey(inout uint key)
|
||||
{
|
||||
key = 0xffffffff;
|
||||
}
|
||||
|
||||
inline KeyStruct LoadKeysWGE16(uint gtid, uint waveSize, uint partIndex)
|
||||
{
|
||||
KeyStruct keys;
|
||||
[unroll]
|
||||
for (uint i = 0, t = DeviceOffsetWGE16(gtid, waveSize, partIndex);
|
||||
i < KEYS_PER_THREAD;
|
||||
++i, t += waveSize)
|
||||
{
|
||||
LoadKey(keys.k[i], t);
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
inline KeyStruct LoadKeysWLT16(uint gtid, uint waveSize, uint partIndex, uint serialIterations)
|
||||
{
|
||||
KeyStruct keys;
|
||||
[unroll]
|
||||
for (uint i = 0, t = DeviceOffsetWLT16(gtid, waveSize, partIndex, serialIterations);
|
||||
i < KEYS_PER_THREAD;
|
||||
++i, t += waveSize * serialIterations)
|
||||
{
|
||||
LoadKey(keys.k[i], t);
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
inline KeyStruct LoadKeysPartialWGE16(uint gtid, uint waveSize, uint partIndex)
|
||||
{
|
||||
KeyStruct keys;
|
||||
[unroll]
|
||||
for (uint i = 0, t = DeviceOffsetWGE16(gtid, waveSize, partIndex);
|
||||
i < KEYS_PER_THREAD;
|
||||
++i, t += waveSize)
|
||||
{
|
||||
if (t < e_numKeys)
|
||||
LoadKey(keys.k[i], t);
|
||||
else
|
||||
LoadDummyKey(keys.k[i]);
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
inline KeyStruct LoadKeysPartialWLT16(uint gtid, uint waveSize, uint partIndex, uint serialIterations)
|
||||
{
|
||||
KeyStruct keys;
|
||||
[unroll]
|
||||
for (uint i = 0, t = DeviceOffsetWLT16(gtid, waveSize, partIndex, serialIterations);
|
||||
i < KEYS_PER_THREAD;
|
||||
++i, t += waveSize * serialIterations)
|
||||
{
|
||||
if (t < e_numKeys)
|
||||
LoadKey(keys.k[i], t);
|
||||
else
|
||||
LoadDummyKey(keys.k[i]);
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
inline uint WaveFlagsWGE16(uint waveSize)
|
||||
{
|
||||
return (waveSize & 31) ? (1U << waveSize) - 1 : 0xffffffff;
|
||||
}
|
||||
|
||||
inline uint WaveFlagsWLT16(uint waveSize)
|
||||
{
|
||||
return (1U << waveSize) - 1;;
|
||||
}
|
||||
|
||||
inline void WarpLevelMultiSplitWGE16(uint key, inout uint4 waveFlags)
|
||||
{
|
||||
[unroll]
|
||||
for (uint k = 0; k < RADIX_LOG; ++k)
|
||||
{
|
||||
const uint currentBit = 1 << k + e_radixShift;
|
||||
const bool t = (key & currentBit) != 0;
|
||||
GroupMemoryBarrierWithGroupSync(); //Play on the safe side, throw in a barrier for convergence
|
||||
const uint4 ballot = WaveActiveBallot(t);
|
||||
if(t)
|
||||
waveFlags &= ballot;
|
||||
else
|
||||
waveFlags &= (~ballot);
|
||||
}
|
||||
}
|
||||
|
||||
inline uint2 CountBitsWGE16(uint waveSize, uint ltMask, uint4 waveFlags)
|
||||
{
|
||||
uint2 count = uint2(0, 0);
|
||||
|
||||
for(uint wavePart = 0; wavePart < waveSize; wavePart += 32)
|
||||
{
|
||||
uint t = countbits(waveFlags[wavePart >> 5]);
|
||||
if (WaveGetLaneIndex() >= wavePart)
|
||||
{
|
||||
if (WaveGetLaneIndex() >= wavePart + 32)
|
||||
count.x += t;
|
||||
else
|
||||
count.x += countbits(waveFlags[wavePart >> 5] & ltMask);
|
||||
}
|
||||
count.y += t;
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
inline void WarpLevelMultiSplitWLT16(uint key, inout uint waveFlags)
|
||||
{
|
||||
[unroll]
|
||||
for (uint k = 0; k < RADIX_LOG; ++k)
|
||||
{
|
||||
const bool t = key >> (k + e_radixShift) & 1;
|
||||
waveFlags &= (t ? 0 : 0xffffffff) ^ (uint) WaveActiveBallot(t);
|
||||
}
|
||||
}
|
||||
|
||||
inline OffsetStruct RankKeysWGE16(
|
||||
uint waveSize,
|
||||
uint waveOffset,
|
||||
KeyStruct keys)
|
||||
{
|
||||
OffsetStruct offsets;
|
||||
const uint initialFlags = WaveFlagsWGE16(waveSize);
|
||||
const uint ltMask = (1U << (WaveGetLaneIndex() & 31)) - 1;
|
||||
|
||||
[unroll]
|
||||
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||
{
|
||||
uint4 waveFlags = initialFlags;
|
||||
WarpLevelMultiSplitWGE16(keys.k[i], waveFlags);
|
||||
|
||||
const uint index = ExtractDigit(keys.k[i]) + waveOffset;
|
||||
const uint2 bitCount = CountBitsWGE16(waveSize, ltMask, waveFlags);
|
||||
|
||||
offsets.o[i] = g_d[index] + bitCount.x;
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
if (bitCount.x == 0)
|
||||
g_d[index] += bitCount.y;
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
}
|
||||
|
||||
return offsets;
|
||||
}
|
||||
|
||||
inline OffsetStruct RankKeysWLT16(uint waveSize, uint waveIndex, KeyStruct keys, uint serialIterations)
|
||||
{
|
||||
OffsetStruct offsets;
|
||||
const uint ltMask = (1U << WaveGetLaneIndex()) - 1;
|
||||
const uint initialFlags = WaveFlagsWLT16(waveSize);
|
||||
|
||||
[unroll]
|
||||
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||
{
|
||||
uint waveFlags = initialFlags;
|
||||
WarpLevelMultiSplitWLT16(keys.k[i], waveFlags);
|
||||
|
||||
const uint index = ExtractPackedIndex(keys.k[i]) +
|
||||
(waveIndex / serialIterations * HALF_RADIX);
|
||||
|
||||
const uint peerBits = countbits(waveFlags & ltMask);
|
||||
for (uint k = 0; k < serialIterations; ++k)
|
||||
{
|
||||
if (waveIndex % serialIterations == k)
|
||||
offsets.o[i] = ExtractPackedValue(g_d[index], keys.k[i]) + peerBits;
|
||||
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
if (waveIndex % serialIterations == k && peerBits == 0)
|
||||
{
|
||||
InterlockedAdd(g_d[index],
|
||||
countbits(waveFlags) << ExtractPackedShift(keys.k[i]));
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
}
|
||||
}
|
||||
|
||||
return offsets;
|
||||
}
|
||||
|
||||
inline uint WaveHistInclusiveScanCircularShiftWGE16(uint gtid, uint waveSize)
|
||||
{
|
||||
uint histReduction = g_d[gtid];
|
||||
for (uint i = gtid + RADIX; i < WaveHistsSizeWGE16(waveSize); i += RADIX)
|
||||
{
|
||||
histReduction += g_d[i];
|
||||
g_d[i] = histReduction - g_d[i];
|
||||
}
|
||||
return histReduction;
|
||||
}
|
||||
|
||||
inline uint WaveHistInclusiveScanCircularShiftWLT16(uint gtid)
|
||||
{
|
||||
uint histReduction = g_d[gtid];
|
||||
for (uint i = gtid + HALF_RADIX; i < WaveHistsSizeWLT16(); i += HALF_RADIX)
|
||||
{
|
||||
histReduction += g_d[i];
|
||||
g_d[i] = histReduction - g_d[i];
|
||||
}
|
||||
return histReduction;
|
||||
}
|
||||
|
||||
inline void WaveHistReductionExclusiveScanWGE16(uint gtid, uint waveSize, uint histReduction)
|
||||
{
|
||||
if (gtid < RADIX)
|
||||
{
|
||||
const uint laneMask = waveSize - 1;
|
||||
g_d[((WaveGetLaneIndex() + 1) & laneMask) + (gtid & ~laneMask)] = histReduction;
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
if (gtid < RADIX / waveSize)
|
||||
{
|
||||
g_d[gtid * waveSize] =
|
||||
WavePrefixSum(g_d[gtid * waveSize]);
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
uint t = WaveReadLaneAt(g_d[gtid], 0);
|
||||
if (gtid < RADIX && WaveGetLaneIndex())
|
||||
g_d[gtid] += t;
|
||||
}
|
||||
|
||||
//inclusive/exclusive prefix sum up the histograms,
|
||||
//use a blelloch scan for in place packed exclusive
|
||||
inline void WaveHistReductionExclusiveScanWLT16(uint gtid)
|
||||
{
|
||||
uint shift = 1;
|
||||
for (uint j = RADIX >> 2; j > 0; j >>= 1)
|
||||
{
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
if (gtid < j)
|
||||
{
|
||||
g_d[((((gtid << 1) + 2) << shift) - 1) >> 1] +=
|
||||
g_d[((((gtid << 1) + 1) << shift) - 1) >> 1] & 0xffff0000;
|
||||
}
|
||||
shift++;
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
if (gtid == 0)
|
||||
g_d[HALF_RADIX - 1] &= 0xffff;
|
||||
|
||||
for (uint j = 1; j < RADIX >> 1; j <<= 1)
|
||||
{
|
||||
--shift;
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
if (gtid < j)
|
||||
{
|
||||
const uint t = ((((gtid << 1) + 1) << shift) - 1) >> 1;
|
||||
const uint t2 = ((((gtid << 1) + 2) << shift) - 1) >> 1;
|
||||
const uint t3 = g_d[t];
|
||||
g_d[t] = (g_d[t] & 0xffff) | (g_d[t2] & 0xffff0000);
|
||||
g_d[t2] += t3 & 0xffff0000;
|
||||
}
|
||||
}
|
||||
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
if (gtid < HALF_RADIX)
|
||||
{
|
||||
const uint t = g_d[gtid];
|
||||
g_d[gtid] = (t >> 16) + (t << 16) + (t & 0xffff0000);
|
||||
}
|
||||
}
|
||||
|
||||
inline void UpdateOffsetsWGE16(
|
||||
uint gtid,
|
||||
uint waveSize,
|
||||
inout OffsetStruct offsets,
|
||||
KeyStruct keys)
|
||||
{
|
||||
if (gtid >= waveSize)
|
||||
{
|
||||
const uint t = getWaveIndex(gtid, waveSize) * RADIX;
|
||||
[unroll]
|
||||
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||
{
|
||||
const uint t2 = ExtractDigit(keys.k[i]);
|
||||
offsets.o[i] += g_d[t2 + t] + g_d[t2];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||
offsets.o[i] += g_d[ExtractDigit(keys.k[i])];
|
||||
}
|
||||
}
|
||||
|
||||
inline void UpdateOffsetsWLT16(
|
||||
uint gtid,
|
||||
uint waveSize,
|
||||
uint serialIterations,
|
||||
inout OffsetStruct offsets,
|
||||
KeyStruct keys)
|
||||
{
|
||||
if (gtid >= waveSize * serialIterations)
|
||||
{
|
||||
const uint t = getWaveIndex(gtid, waveSize) / serialIterations * HALF_RADIX;
|
||||
[unroll]
|
||||
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||
{
|
||||
const uint t2 = ExtractPackedIndex(keys.k[i]);
|
||||
offsets.o[i] += ExtractPackedValue(g_d[t2 + t] + g_d[t2], keys.k[i]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||
offsets.o[i] += ExtractPackedValue(g_d[ExtractPackedIndex(keys.k[i])], keys.k[i]);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterKeysShared(OffsetStruct offsets, KeyStruct keys)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
||||
g_d[offsets.o[i]] = keys.k[i];
|
||||
}
|
||||
|
||||
inline uint DescendingIndex(uint deviceIndex)
|
||||
{
|
||||
return e_numKeys - deviceIndex - 1;
|
||||
}
|
||||
|
||||
inline void WriteKey(uint deviceIndex, uint groupSharedIndex)
|
||||
{
|
||||
#if defined(KEY_UINT)
|
||||
b_alt[deviceIndex] = g_d[groupSharedIndex];
|
||||
#elif defined(KEY_INT)
|
||||
b_alt[deviceIndex] = UintToInt(g_d[groupSharedIndex]);
|
||||
#elif defined(KEY_FLOAT)
|
||||
b_alt[deviceIndex] = UintToFloat(g_d[groupSharedIndex]);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void LoadPayload(inout uint payload, uint deviceIndex)
|
||||
{
|
||||
#if defined(PAYLOAD_UINT)
|
||||
payload = b_sortPayload[deviceIndex];
|
||||
#elif defined(PAYLOAD_INT) || defined(PAYLOAD_FLOAT)
|
||||
payload = asuint(b_sortPayload[deviceIndex]);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void ScatterPayloadsShared(OffsetStruct offsets, KeyStruct payloads)
|
||||
{
|
||||
ScatterKeysShared(offsets, payloads);
|
||||
}
|
||||
|
||||
inline void WritePayload(uint deviceIndex, uint groupSharedIndex)
|
||||
{
|
||||
#if defined(PAYLOAD_UINT)
|
||||
b_altPayload[deviceIndex] = g_d[groupSharedIndex];
|
||||
#elif defined(PAYLOAD_INT)
|
||||
b_altPayload[deviceIndex] = asint(g_d[groupSharedIndex]);
|
||||
#elif defined(PAYLOAD_FLOAT)
|
||||
b_altPayload[deviceIndex] = asfloat(g_d[groupSharedIndex]);
|
||||
#endif
|
||||
}
|
||||
|
||||
//*****************************************************************************
|
||||
//SCATTERING: FULL PARTITIONS
|
||||
//*****************************************************************************
|
||||
//KEYS ONLY
|
||||
inline void ScatterKeysOnlyDeviceAscending(uint gtid)
|
||||
{
|
||||
for (uint i = gtid; i < PART_SIZE; i += D_DIM)
|
||||
WriteKey(g_d[ExtractDigit(g_d[i]) + PART_SIZE] + i, i);
|
||||
}
|
||||
|
||||
inline void ScatterKeysOnlyDeviceDescending(uint gtid)
|
||||
{
|
||||
if (e_radixShift == 24)
|
||||
{
|
||||
for (uint i = gtid; i < PART_SIZE; i += D_DIM)
|
||||
WriteKey(DescendingIndex(g_d[ExtractDigit(g_d[i]) + PART_SIZE] + i), i);
|
||||
}
|
||||
else
|
||||
{
|
||||
ScatterKeysOnlyDeviceAscending(gtid);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterKeysOnlyDevice(uint gtid)
|
||||
{
|
||||
#if defined(SHOULD_ASCEND)
|
||||
ScatterKeysOnlyDeviceAscending(gtid);
|
||||
#else
|
||||
ScatterKeysOnlyDeviceDescending(gtid);
|
||||
#endif
|
||||
}
|
||||
|
||||
//KEY VALUE PAIRS
|
||||
inline void ScatterPairsKeyPhaseAscending(
|
||||
uint gtid,
|
||||
inout DigitStruct digits)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||
{
|
||||
digits.d[i] = ExtractDigit(g_d[t]);
|
||||
WriteKey(g_d[digits.d[i] + PART_SIZE] + t, t);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterPairsKeyPhaseDescending(
|
||||
uint gtid,
|
||||
inout DigitStruct digits)
|
||||
{
|
||||
if (e_radixShift == 24)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||
{
|
||||
digits.d[i] = ExtractDigit(g_d[t]);
|
||||
WriteKey(DescendingIndex(g_d[digits.d[i] + PART_SIZE] + t), t);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
ScatterPairsKeyPhaseAscending(gtid, digits);
|
||||
}
|
||||
}
|
||||
|
||||
inline void LoadPayloadsWGE16(
|
||||
uint gtid,
|
||||
uint waveSize,
|
||||
uint partIndex,
|
||||
inout KeyStruct payloads)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = DeviceOffsetWGE16(gtid, waveSize, partIndex);
|
||||
i < KEYS_PER_THREAD;
|
||||
++i, t += waveSize)
|
||||
{
|
||||
LoadPayload(payloads.k[i], t);
|
||||
}
|
||||
}
|
||||
|
||||
inline void LoadPayloadsWLT16(
|
||||
uint gtid,
|
||||
uint waveSize,
|
||||
uint partIndex,
|
||||
uint serialIterations,
|
||||
inout KeyStruct payloads)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = DeviceOffsetWLT16(gtid, waveSize, partIndex, serialIterations);
|
||||
i < KEYS_PER_THREAD;
|
||||
++i, t += waveSize * serialIterations)
|
||||
{
|
||||
LoadPayload(payloads.k[i], t);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterPayloadsAscending(uint gtid, DigitStruct digits)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||
WritePayload(g_d[digits.d[i] + PART_SIZE] + t, t);
|
||||
}
|
||||
|
||||
inline void ScatterPayloadsDescending(uint gtid, DigitStruct digits)
|
||||
{
|
||||
if (e_radixShift == 24)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||
WritePayload(DescendingIndex(g_d[digits.d[i] + PART_SIZE] + t), t);
|
||||
}
|
||||
else
|
||||
{
|
||||
ScatterPayloadsAscending(gtid, digits);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterPairsDevice(
|
||||
uint gtid,
|
||||
uint waveSize,
|
||||
uint partIndex,
|
||||
OffsetStruct offsets)
|
||||
{
|
||||
DigitStruct digits;
|
||||
#if defined(SHOULD_ASCEND)
|
||||
ScatterPairsKeyPhaseAscending(gtid, digits);
|
||||
#else
|
||||
ScatterPairsKeyPhaseDescending(gtid, digits);
|
||||
#endif
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
KeyStruct payloads;
|
||||
if (waveSize >= 16)
|
||||
LoadPayloadsWGE16(gtid, waveSize, partIndex, payloads);
|
||||
else
|
||||
LoadPayloadsWLT16(gtid, waveSize, partIndex, SerialIterations(waveSize), payloads);
|
||||
ScatterPayloadsShared(offsets, payloads);
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
#if defined(SHOULD_ASCEND)
|
||||
ScatterPayloadsAscending(gtid, digits);
|
||||
#else
|
||||
ScatterPayloadsDescending(gtid, digits);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void ScatterDevice(
|
||||
uint gtid,
|
||||
uint waveSize,
|
||||
uint partIndex,
|
||||
OffsetStruct offsets)
|
||||
{
|
||||
#if defined(SORT_PAIRS)
|
||||
ScatterPairsDevice(
|
||||
gtid,
|
||||
waveSize,
|
||||
partIndex,
|
||||
offsets);
|
||||
#else
|
||||
ScatterKeysOnlyDevice(gtid);
|
||||
#endif
|
||||
}
|
||||
|
||||
//*****************************************************************************
|
||||
//SCATTERING: PARTIAL PARTITIONS
|
||||
//*****************************************************************************
|
||||
//KEYS ONLY
|
||||
inline void ScatterKeysOnlyDevicePartialAscending(uint gtid, uint finalPartSize)
|
||||
{
|
||||
for (uint i = gtid; i < PART_SIZE; i += D_DIM)
|
||||
{
|
||||
if (i < finalPartSize)
|
||||
WriteKey(g_d[ExtractDigit(g_d[i]) + PART_SIZE] + i, i);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterKeysOnlyDevicePartialDescending(uint gtid, uint finalPartSize)
|
||||
{
|
||||
if (e_radixShift == 24)
|
||||
{
|
||||
for (uint i = gtid; i < PART_SIZE; i += D_DIM)
|
||||
{
|
||||
if (i < finalPartSize)
|
||||
WriteKey(DescendingIndex(g_d[ExtractDigit(g_d[i]) + PART_SIZE] + i), i);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
ScatterKeysOnlyDevicePartialAscending(gtid, finalPartSize);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterKeysOnlyDevicePartial(uint gtid, uint partIndex)
|
||||
{
|
||||
const uint finalPartSize = e_numKeys - partIndex * PART_SIZE;
|
||||
#if defined(SHOULD_ASCEND)
|
||||
ScatterKeysOnlyDevicePartialAscending(gtid, finalPartSize);
|
||||
#else
|
||||
ScatterKeysOnlyDevicePartialDescending(gtid, finalPartSize);
|
||||
#endif
|
||||
}
|
||||
|
||||
//KEY VALUE PAIRS
|
||||
inline void ScatterPairsKeyPhaseAscendingPartial(
|
||||
uint gtid,
|
||||
uint finalPartSize,
|
||||
inout DigitStruct digits)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||
{
|
||||
if (t < finalPartSize)
|
||||
{
|
||||
digits.d[i] = ExtractDigit(g_d[t]);
|
||||
WriteKey(g_d[digits.d[i] + PART_SIZE] + t, t);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterPairsKeyPhaseDescendingPartial(
|
||||
uint gtid,
|
||||
uint finalPartSize,
|
||||
inout DigitStruct digits)
|
||||
{
|
||||
if (e_radixShift == 24)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||
{
|
||||
if (t < finalPartSize)
|
||||
{
|
||||
digits.d[i] = ExtractDigit(g_d[t]);
|
||||
WriteKey(DescendingIndex(g_d[digits.d[i] + PART_SIZE] + t), t);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
ScatterPairsKeyPhaseAscendingPartial(gtid, finalPartSize, digits);
|
||||
}
|
||||
}
|
||||
|
||||
inline void LoadPayloadsPartialWGE16(
|
||||
uint gtid,
|
||||
uint waveSize,
|
||||
uint partIndex,
|
||||
inout KeyStruct payloads)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = DeviceOffsetWGE16(gtid, waveSize, partIndex);
|
||||
i < KEYS_PER_THREAD;
|
||||
++i, t += waveSize)
|
||||
{
|
||||
if (t < e_numKeys)
|
||||
LoadPayload(payloads.k[i], t);
|
||||
}
|
||||
}
|
||||
|
||||
inline void LoadPayloadsPartialWLT16(
|
||||
uint gtid,
|
||||
uint waveSize,
|
||||
uint partIndex,
|
||||
uint serialIterations,
|
||||
inout KeyStruct payloads)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = DeviceOffsetWLT16(gtid, waveSize, partIndex, serialIterations);
|
||||
i < KEYS_PER_THREAD;
|
||||
++i, t += waveSize * serialIterations)
|
||||
{
|
||||
if (t < e_numKeys)
|
||||
LoadPayload(payloads.k[i], t);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterPayloadsAscendingPartial(
|
||||
uint gtid,
|
||||
uint finalPartSize,
|
||||
DigitStruct digits)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||
{
|
||||
if (t < finalPartSize)
|
||||
WritePayload(g_d[digits.d[i] + PART_SIZE] + t, t);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterPayloadsDescendingPartial(
|
||||
uint gtid,
|
||||
uint finalPartSize,
|
||||
DigitStruct digits)
|
||||
{
|
||||
if (e_radixShift == 24)
|
||||
{
|
||||
[unroll]
|
||||
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
|
||||
{
|
||||
if (t < finalPartSize)
|
||||
WritePayload(DescendingIndex(g_d[digits.d[i] + PART_SIZE] + t), t);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
ScatterPayloadsAscendingPartial(gtid, finalPartSize, digits);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ScatterPairsDevicePartial(
|
||||
uint gtid,
|
||||
uint waveSize,
|
||||
uint partIndex,
|
||||
OffsetStruct offsets)
|
||||
{
|
||||
DigitStruct digits;
|
||||
const uint finalPartSize = e_numKeys - partIndex * PART_SIZE;
|
||||
#if defined(SHOULD_ASCEND)
|
||||
ScatterPairsKeyPhaseAscendingPartial(gtid, finalPartSize, digits);
|
||||
#else
|
||||
ScatterPairsKeyPhaseDescendingPartial(gtid, finalPartSize, digits);
|
||||
#endif
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
KeyStruct payloads;
|
||||
if (waveSize >= 16)
|
||||
LoadPayloadsPartialWGE16(gtid, waveSize, partIndex, payloads);
|
||||
else
|
||||
LoadPayloadsPartialWLT16(gtid, waveSize, partIndex, SerialIterations(waveSize), payloads);
|
||||
ScatterPayloadsShared(offsets, payloads);
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
|
||||
#if defined(SHOULD_ASCEND)
|
||||
ScatterPayloadsAscendingPartial(gtid, finalPartSize, digits);
|
||||
#else
|
||||
ScatterPayloadsDescendingPartial(gtid, finalPartSize, digits);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void ScatterDevicePartial(
|
||||
uint gtid,
|
||||
uint waveSize,
|
||||
uint partIndex,
|
||||
OffsetStruct offsets)
|
||||
{
|
||||
#if defined(SORT_PAIRS)
|
||||
ScatterPairsDevicePartial(
|
||||
gtid,
|
||||
waveSize,
|
||||
partIndex,
|
||||
offsets);
|
||||
#else
|
||||
ScatterKeysOnlyDevicePartial(gtid, partIndex);
|
||||
#endif
|
||||
}
|
||||
7
MVS/3DGS-Unity/Shaders/SortCommon.hlsl.meta
Normal file
7
MVS/3DGS-Unity/Shaders/SortCommon.hlsl.meta
Normal file
@@ -0,0 +1,7 @@
|
||||
fileFormatVersion: 2
|
||||
guid: 268e5936ab6d79f4b8aeef8f5d14e7ee
|
||||
ShaderIncludeImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
||||
212
MVS/3DGS-Unity/Shaders/SphericalHarmonics.hlsl
Normal file
212
MVS/3DGS-Unity/Shaders/SphericalHarmonics.hlsl
Normal file
@@ -0,0 +1,212 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#ifndef SPHERICAL_HARMONICS_HLSL
|
||||
#define SPHERICAL_HARMONICS_HLSL
|
||||
|
||||
// SH rotation based on https://github.com/andrewwillmott/sh-lib (Unlicense / public domain)
|
||||
|
||||
#define SH_MAX_ORDER 4
|
||||
#define SH_MAX_COEFFS_COUNT (SH_MAX_ORDER*SH_MAX_ORDER)
|
||||
|
||||
float3 Dot3(int vidx, float3 v[SH_MAX_COEFFS_COUNT], float f[3])
|
||||
{
|
||||
return v[vidx+0] * f[0] + v[vidx+1] * f[1] + v[vidx+2] * f[2];
|
||||
}
|
||||
float3 Dot5(int vidx, float3 v[SH_MAX_COEFFS_COUNT], float f[5])
|
||||
{
|
||||
return v[vidx+0] * f[0] + v[vidx+1] * f[1] + v[vidx+2] * f[2] + v[vidx+3] * f[3] + v[vidx+4] * f[4];
|
||||
}
|
||||
float3 Dot7(int vidx, float3 v[SH_MAX_COEFFS_COUNT], float f[7])
|
||||
{
|
||||
return v[vidx+0] * f[0] + v[vidx+1] * f[1] + v[vidx+2] * f[2] + v[vidx+3] * f[3] + v[vidx+4] * f[4] + v[vidx+5] * f[5] + v[vidx+6] * f[6];
|
||||
}
|
||||
|
||||
void RotateSH(float3x3 orient, int n, float3 coeffsIn[SH_MAX_COEFFS_COUNT], out float3 coeffs[SH_MAX_COEFFS_COUNT])
|
||||
{
|
||||
const float kSqrt03_02 = sqrt( 3.0 / 2.0);
|
||||
const float kSqrt01_03 = sqrt( 1.0 / 3.0);
|
||||
const float kSqrt02_03 = sqrt( 2.0 / 3.0);
|
||||
const float kSqrt04_03 = sqrt( 4.0 / 3.0);
|
||||
const float kSqrt01_04 = sqrt( 1.0 / 4.0);
|
||||
const float kSqrt03_04 = sqrt( 3.0 / 4.0);
|
||||
const float kSqrt01_05 = sqrt( 1.0 / 5.0);
|
||||
const float kSqrt03_05 = sqrt( 3.0 / 5.0);
|
||||
const float kSqrt06_05 = sqrt( 6.0 / 5.0);
|
||||
const float kSqrt08_05 = sqrt( 8.0 / 5.0);
|
||||
const float kSqrt09_05 = sqrt( 9.0 / 5.0);
|
||||
const float kSqrt05_06 = sqrt( 5.0 / 6.0);
|
||||
const float kSqrt01_06 = sqrt( 1.0 / 6.0);
|
||||
const float kSqrt03_08 = sqrt( 3.0 / 8.0);
|
||||
const float kSqrt05_08 = sqrt( 5.0 / 8.0);
|
||||
const float kSqrt07_08 = sqrt( 7.0 / 8.0);
|
||||
const float kSqrt09_08 = sqrt( 9.0 / 8.0);
|
||||
const float kSqrt05_09 = sqrt( 5.0 / 9.0);
|
||||
const float kSqrt08_09 = sqrt( 8.0 / 9.0);
|
||||
|
||||
const float kSqrt01_10 = sqrt( 1.0 / 10.0);
|
||||
const float kSqrt03_10 = sqrt( 3.0 / 10.0);
|
||||
const float kSqrt01_12 = sqrt( 1.0 / 12.0);
|
||||
const float kSqrt04_15 = sqrt( 4.0 / 15.0);
|
||||
const float kSqrt01_16 = sqrt( 1.0 / 16.0);
|
||||
const float kSqrt07_16 = sqrt( 7.0 / 16.0);
|
||||
const float kSqrt15_16 = sqrt(15.0 / 16.0);
|
||||
const float kSqrt01_18 = sqrt( 1.0 / 18.0);
|
||||
const float kSqrt03_25 = sqrt( 3.0 / 25.0);
|
||||
const float kSqrt14_25 = sqrt(14.0 / 25.0);
|
||||
const float kSqrt15_25 = sqrt(15.0 / 25.0);
|
||||
const float kSqrt18_25 = sqrt(18.0 / 25.0);
|
||||
const float kSqrt01_32 = sqrt( 1.0 / 32.0);
|
||||
const float kSqrt03_32 = sqrt( 3.0 / 32.0);
|
||||
const float kSqrt15_32 = sqrt(15.0 / 32.0);
|
||||
const float kSqrt21_32 = sqrt(21.0 / 32.0);
|
||||
const float kSqrt01_50 = sqrt( 1.0 / 50.0);
|
||||
const float kSqrt03_50 = sqrt( 3.0 / 50.0);
|
||||
const float kSqrt21_50 = sqrt(21.0 / 50.0);
|
||||
|
||||
int srcIdx = 0;
|
||||
int dstIdx = 0;
|
||||
|
||||
// band 0
|
||||
coeffs[dstIdx++] = coeffsIn[0];
|
||||
if (n < 2)
|
||||
return;
|
||||
|
||||
// band 1
|
||||
srcIdx += 1;
|
||||
float sh1[3][3] =
|
||||
{
|
||||
// NOTE: change from upstream code at https://github.com/andrewwillmott/sh-lib, some of the
|
||||
// values need to have "-" in front of them.
|
||||
orient._22, -orient._23, orient._21,
|
||||
-orient._32, orient._33, -orient._31,
|
||||
orient._12, -orient._13, orient._11
|
||||
};
|
||||
coeffs[dstIdx++] = Dot3(srcIdx, coeffsIn, sh1[0]);
|
||||
coeffs[dstIdx++] = Dot3(srcIdx, coeffsIn, sh1[1]);
|
||||
coeffs[dstIdx++] = Dot3(srcIdx, coeffsIn, sh1[2]);
|
||||
if (n < 3)
|
||||
return;
|
||||
|
||||
// band 2
|
||||
srcIdx += 3;
|
||||
float sh2[5][5];
|
||||
|
||||
sh2[0][0] = kSqrt01_04 * ((sh1[2][2] * sh1[0][0] + sh1[2][0] * sh1[0][2]) + (sh1[0][2] * sh1[2][0] + sh1[0][0] * sh1[2][2]));
|
||||
sh2[0][1] = (sh1[2][1] * sh1[0][0] + sh1[0][1] * sh1[2][0]);
|
||||
sh2[0][2] = kSqrt03_04 * (sh1[2][1] * sh1[0][1] + sh1[0][1] * sh1[2][1]);
|
||||
sh2[0][3] = (sh1[2][1] * sh1[0][2] + sh1[0][1] * sh1[2][2]);
|
||||
sh2[0][4] = kSqrt01_04 * ((sh1[2][2] * sh1[0][2] - sh1[2][0] * sh1[0][0]) + (sh1[0][2] * sh1[2][2] - sh1[0][0] * sh1[2][0]));
|
||||
|
||||
coeffs[dstIdx++] = Dot5(srcIdx, coeffsIn, sh2[0]);
|
||||
|
||||
sh2[1][0] = kSqrt01_04 * ((sh1[1][2] * sh1[0][0] + sh1[1][0] * sh1[0][2]) + (sh1[0][2] * sh1[1][0] + sh1[0][0] * sh1[1][2]));
|
||||
sh2[1][1] = sh1[1][1] * sh1[0][0] + sh1[0][1] * sh1[1][0];
|
||||
sh2[1][2] = kSqrt03_04 * (sh1[1][1] * sh1[0][1] + sh1[0][1] * sh1[1][1]);
|
||||
sh2[1][3] = sh1[1][1] * sh1[0][2] + sh1[0][1] * sh1[1][2];
|
||||
sh2[1][4] = kSqrt01_04 * ((sh1[1][2] * sh1[0][2] - sh1[1][0] * sh1[0][0]) + (sh1[0][2] * sh1[1][2] - sh1[0][0] * sh1[1][0]));
|
||||
|
||||
coeffs[dstIdx++] = Dot5(srcIdx, coeffsIn, sh2[1]);
|
||||
|
||||
sh2[2][0] = kSqrt01_03 * (sh1[1][2] * sh1[1][0] + sh1[1][0] * sh1[1][2]) + -kSqrt01_12 * ((sh1[2][2] * sh1[2][0] + sh1[2][0] * sh1[2][2]) + (sh1[0][2] * sh1[0][0] + sh1[0][0] * sh1[0][2]));
|
||||
sh2[2][1] = kSqrt04_03 * sh1[1][1] * sh1[1][0] + -kSqrt01_03 * (sh1[2][1] * sh1[2][0] + sh1[0][1] * sh1[0][0]);
|
||||
sh2[2][2] = sh1[1][1] * sh1[1][1] + -kSqrt01_04 * (sh1[2][1] * sh1[2][1] + sh1[0][1] * sh1[0][1]);
|
||||
sh2[2][3] = kSqrt04_03 * sh1[1][1] * sh1[1][2] + -kSqrt01_03 * (sh1[2][1] * sh1[2][2] + sh1[0][1] * sh1[0][2]);
|
||||
sh2[2][4] = kSqrt01_03 * (sh1[1][2] * sh1[1][2] - sh1[1][0] * sh1[1][0]) + -kSqrt01_12 * ((sh1[2][2] * sh1[2][2] - sh1[2][0] * sh1[2][0]) + (sh1[0][2] * sh1[0][2] - sh1[0][0] * sh1[0][0]));
|
||||
|
||||
coeffs[dstIdx++] = Dot5(srcIdx, coeffsIn, sh2[2]);
|
||||
|
||||
sh2[3][0] = kSqrt01_04 * ((sh1[1][2] * sh1[2][0] + sh1[1][0] * sh1[2][2]) + (sh1[2][2] * sh1[1][0] + sh1[2][0] * sh1[1][2]));
|
||||
sh2[3][1] = sh1[1][1] * sh1[2][0] + sh1[2][1] * sh1[1][0];
|
||||
sh2[3][2] = kSqrt03_04 * (sh1[1][1] * sh1[2][1] + sh1[2][1] * sh1[1][1]);
|
||||
sh2[3][3] = sh1[1][1] * sh1[2][2] + sh1[2][1] * sh1[1][2];
|
||||
sh2[3][4] = kSqrt01_04 * ((sh1[1][2] * sh1[2][2] - sh1[1][0] * sh1[2][0]) + (sh1[2][2] * sh1[1][2] - sh1[2][0] * sh1[1][0]));
|
||||
|
||||
coeffs[dstIdx++] = Dot5(srcIdx, coeffsIn, sh2[3]);
|
||||
|
||||
sh2[4][0] = kSqrt01_04 * ((sh1[2][2] * sh1[2][0] + sh1[2][0] * sh1[2][2]) - (sh1[0][2] * sh1[0][0] + sh1[0][0] * sh1[0][2]));
|
||||
sh2[4][1] = (sh1[2][1] * sh1[2][0] - sh1[0][1] * sh1[0][0]);
|
||||
sh2[4][2] = kSqrt03_04 * (sh1[2][1] * sh1[2][1] - sh1[0][1] * sh1[0][1]);
|
||||
sh2[4][3] = (sh1[2][1] * sh1[2][2] - sh1[0][1] * sh1[0][2]);
|
||||
sh2[4][4] = kSqrt01_04 * ((sh1[2][2] * sh1[2][2] - sh1[2][0] * sh1[2][0]) - (sh1[0][2] * sh1[0][2] - sh1[0][0] * sh1[0][0]));
|
||||
|
||||
coeffs[dstIdx++] = Dot5(srcIdx, coeffsIn, sh2[4]);
|
||||
|
||||
if (n < 4)
|
||||
return;
|
||||
|
||||
// band 3
|
||||
srcIdx += 5;
|
||||
float sh3[7][7];
|
||||
|
||||
sh3[0][0] = kSqrt01_04 * ((sh1[2][2] * sh2[0][0] + sh1[2][0] * sh2[0][4]) + (sh1[0][2] * sh2[4][0] + sh1[0][0] * sh2[4][4]));
|
||||
sh3[0][1] = kSqrt03_02 * (sh1[2][1] * sh2[0][0] + sh1[0][1] * sh2[4][0]);
|
||||
sh3[0][2] = kSqrt15_16 * (sh1[2][1] * sh2[0][1] + sh1[0][1] * sh2[4][1]);
|
||||
sh3[0][3] = kSqrt05_06 * (sh1[2][1] * sh2[0][2] + sh1[0][1] * sh2[4][2]);
|
||||
sh3[0][4] = kSqrt15_16 * (sh1[2][1] * sh2[0][3] + sh1[0][1] * sh2[4][3]);
|
||||
sh3[0][5] = kSqrt03_02 * (sh1[2][1] * sh2[0][4] + sh1[0][1] * sh2[4][4]);
|
||||
sh3[0][6] = kSqrt01_04 * ((sh1[2][2] * sh2[0][4] - sh1[2][0] * sh2[0][0]) + (sh1[0][2] * sh2[4][4] - sh1[0][0] * sh2[4][0]));
|
||||
|
||||
coeffs[dstIdx++] = Dot7(srcIdx, coeffsIn, sh3[0]);
|
||||
|
||||
sh3[1][0] = kSqrt01_06 * (sh1[1][2] * sh2[0][0] + sh1[1][0] * sh2[0][4]) + kSqrt01_06 * ((sh1[2][2] * sh2[1][0] + sh1[2][0] * sh2[1][4]) + (sh1[0][2] * sh2[3][0] + sh1[0][0] * sh2[3][4]));
|
||||
sh3[1][1] = sh1[1][1] * sh2[0][0] + (sh1[2][1] * sh2[1][0] + sh1[0][1] * sh2[3][0]);
|
||||
sh3[1][2] = kSqrt05_08 * sh1[1][1] * sh2[0][1] + kSqrt05_08 * (sh1[2][1] * sh2[1][1] + sh1[0][1] * sh2[3][1]);
|
||||
sh3[1][3] = kSqrt05_09 * sh1[1][1] * sh2[0][2] + kSqrt05_09 * (sh1[2][1] * sh2[1][2] + sh1[0][1] * sh2[3][2]);
|
||||
sh3[1][4] = kSqrt05_08 * sh1[1][1] * sh2[0][3] + kSqrt05_08 * (sh1[2][1] * sh2[1][3] + sh1[0][1] * sh2[3][3]);
|
||||
sh3[1][5] = sh1[1][1] * sh2[0][4] + (sh1[2][1] * sh2[1][4] + sh1[0][1] * sh2[3][4]);
|
||||
sh3[1][6] = kSqrt01_06 * (sh1[1][2] * sh2[0][4] - sh1[1][0] * sh2[0][0]) + kSqrt01_06 * ((sh1[2][2] * sh2[1][4] - sh1[2][0] * sh2[1][0]) + (sh1[0][2] * sh2[3][4] - sh1[0][0] * sh2[3][0]));
|
||||
|
||||
coeffs[dstIdx++] = Dot7(srcIdx, coeffsIn, sh3[1]);
|
||||
|
||||
sh3[2][0] = kSqrt04_15 * (sh1[1][2] * sh2[1][0] + sh1[1][0] * sh2[1][4]) + kSqrt01_05 * (sh1[0][2] * sh2[2][0] + sh1[0][0] * sh2[2][4]) + -sqrt(1.0 / 60.0) * ((sh1[2][2] * sh2[0][0] + sh1[2][0] * sh2[0][4]) - (sh1[0][2] * sh2[4][0] + sh1[0][0] * sh2[4][4]));
|
||||
sh3[2][1] = kSqrt08_05 * sh1[1][1] * sh2[1][0] + kSqrt06_05 * sh1[0][1] * sh2[2][0] + -kSqrt01_10 * (sh1[2][1] * sh2[0][0] - sh1[0][1] * sh2[4][0]);
|
||||
sh3[2][2] = sh1[1][1] * sh2[1][1] + kSqrt03_04 * sh1[0][1] * sh2[2][1] + -kSqrt01_16 * (sh1[2][1] * sh2[0][1] - sh1[0][1] * sh2[4][1]);
|
||||
sh3[2][3] = kSqrt08_09 * sh1[1][1] * sh2[1][2] + kSqrt02_03 * sh1[0][1] * sh2[2][2] + -kSqrt01_18 * (sh1[2][1] * sh2[0][2] - sh1[0][1] * sh2[4][2]);
|
||||
sh3[2][4] = sh1[1][1] * sh2[1][3] + kSqrt03_04 * sh1[0][1] * sh2[2][3] + -kSqrt01_16 * (sh1[2][1] * sh2[0][3] - sh1[0][1] * sh2[4][3]);
|
||||
sh3[2][5] = kSqrt08_05 * sh1[1][1] * sh2[1][4] + kSqrt06_05 * sh1[0][1] * sh2[2][4] + -kSqrt01_10 * (sh1[2][1] * sh2[0][4] - sh1[0][1] * sh2[4][4]);
|
||||
sh3[2][6] = kSqrt04_15 * (sh1[1][2] * sh2[1][4] - sh1[1][0] * sh2[1][0]) + kSqrt01_05 * (sh1[0][2] * sh2[2][4] - sh1[0][0] * sh2[2][0]) + -sqrt(1.0 / 60.0) * ((sh1[2][2] * sh2[0][4] - sh1[2][0] * sh2[0][0]) - (sh1[0][2] * sh2[4][4] - sh1[0][0] * sh2[4][0]));
|
||||
|
||||
coeffs[dstIdx++] = Dot7(srcIdx, coeffsIn, sh3[2]);
|
||||
|
||||
sh3[3][0] = kSqrt03_10 * (sh1[1][2] * sh2[2][0] + sh1[1][0] * sh2[2][4]) + -kSqrt01_10 * ((sh1[2][2] * sh2[3][0] + sh1[2][0] * sh2[3][4]) + (sh1[0][2] * sh2[1][0] + sh1[0][0] * sh2[1][4]));
|
||||
sh3[3][1] = kSqrt09_05 * sh1[1][1] * sh2[2][0] + -kSqrt03_05 * (sh1[2][1] * sh2[3][0] + sh1[0][1] * sh2[1][0]);
|
||||
sh3[3][2] = kSqrt09_08 * sh1[1][1] * sh2[2][1] + -kSqrt03_08 * (sh1[2][1] * sh2[3][1] + sh1[0][1] * sh2[1][1]);
|
||||
sh3[3][3] = sh1[1][1] * sh2[2][2] + -kSqrt01_03 * (sh1[2][1] * sh2[3][2] + sh1[0][1] * sh2[1][2]);
|
||||
sh3[3][4] = kSqrt09_08 * sh1[1][1] * sh2[2][3] + -kSqrt03_08 * (sh1[2][1] * sh2[3][3] + sh1[0][1] * sh2[1][3]);
|
||||
sh3[3][5] = kSqrt09_05 * sh1[1][1] * sh2[2][4] + -kSqrt03_05 * (sh1[2][1] * sh2[3][4] + sh1[0][1] * sh2[1][4]);
|
||||
sh3[3][6] = kSqrt03_10 * (sh1[1][2] * sh2[2][4] - sh1[1][0] * sh2[2][0]) + -kSqrt01_10 * ((sh1[2][2] * sh2[3][4] - sh1[2][0] * sh2[3][0]) + (sh1[0][2] * sh2[1][4] - sh1[0][0] * sh2[1][0]));
|
||||
|
||||
coeffs[dstIdx++] = Dot7(srcIdx, coeffsIn, sh3[3]);
|
||||
|
||||
sh3[4][0] = kSqrt04_15 * (sh1[1][2] * sh2[3][0] + sh1[1][0] * sh2[3][4]) + kSqrt01_05 * (sh1[2][2] * sh2[2][0] + sh1[2][0] * sh2[2][4]) + -sqrt(1.0 / 60.0) * ((sh1[2][2] * sh2[4][0] + sh1[2][0] * sh2[4][4]) + (sh1[0][2] * sh2[0][0] + sh1[0][0] * sh2[0][4]));
|
||||
sh3[4][1] = kSqrt08_05 * sh1[1][1] * sh2[3][0] + kSqrt06_05 * sh1[2][1] * sh2[2][0] + -kSqrt01_10 * (sh1[2][1] * sh2[4][0] + sh1[0][1] * sh2[0][0]);
|
||||
sh3[4][2] = sh1[1][1] * sh2[3][1] + kSqrt03_04 * sh1[2][1] * sh2[2][1] + -kSqrt01_16 * (sh1[2][1] * sh2[4][1] + sh1[0][1] * sh2[0][1]);
|
||||
sh3[4][3] = kSqrt08_09 * sh1[1][1] * sh2[3][2] + kSqrt02_03 * sh1[2][1] * sh2[2][2] + -kSqrt01_18 * (sh1[2][1] * sh2[4][2] + sh1[0][1] * sh2[0][2]);
|
||||
sh3[4][4] = sh1[1][1] * sh2[3][3] + kSqrt03_04 * sh1[2][1] * sh2[2][3] + -kSqrt01_16 * (sh1[2][1] * sh2[4][3] + sh1[0][1] * sh2[0][3]);
|
||||
sh3[4][5] = kSqrt08_05 * sh1[1][1] * sh2[3][4] + kSqrt06_05 * sh1[2][1] * sh2[2][4] + -kSqrt01_10 * (sh1[2][1] * sh2[4][4] + sh1[0][1] * sh2[0][4]);
|
||||
sh3[4][6] = kSqrt04_15 * (sh1[1][2] * sh2[3][4] - sh1[1][0] * sh2[3][0]) + kSqrt01_05 * (sh1[2][2] * sh2[2][4] - sh1[2][0] * sh2[2][0]) + -sqrt(1.0 / 60.0) * ((sh1[2][2] * sh2[4][4] - sh1[2][0] * sh2[4][0]) + (sh1[0][2] * sh2[0][4] - sh1[0][0] * sh2[0][0]));
|
||||
|
||||
coeffs[dstIdx++] = Dot7(srcIdx, coeffsIn, sh3[4]);
|
||||
|
||||
sh3[5][0] = kSqrt01_06 * (sh1[1][2] * sh2[4][0] + sh1[1][0] * sh2[4][4]) + kSqrt01_06 * ((sh1[2][2] * sh2[3][0] + sh1[2][0] * sh2[3][4]) - (sh1[0][2] * sh2[1][0] + sh1[0][0] * sh2[1][4]));
|
||||
sh3[5][1] = sh1[1][1] * sh2[4][0] + (sh1[2][1] * sh2[3][0] - sh1[0][1] * sh2[1][0]);
|
||||
sh3[5][2] = kSqrt05_08 * sh1[1][1] * sh2[4][1] + kSqrt05_08 * (sh1[2][1] * sh2[3][1] - sh1[0][1] * sh2[1][1]);
|
||||
sh3[5][3] = kSqrt05_09 * sh1[1][1] * sh2[4][2] + kSqrt05_09 * (sh1[2][1] * sh2[3][2] - sh1[0][1] * sh2[1][2]);
|
||||
sh3[5][4] = kSqrt05_08 * sh1[1][1] * sh2[4][3] + kSqrt05_08 * (sh1[2][1] * sh2[3][3] - sh1[0][1] * sh2[1][3]);
|
||||
sh3[5][5] = sh1[1][1] * sh2[4][4] + (sh1[2][1] * sh2[3][4] - sh1[0][1] * sh2[1][4]);
|
||||
sh3[5][6] = kSqrt01_06 * (sh1[1][2] * sh2[4][4] - sh1[1][0] * sh2[4][0]) + kSqrt01_06 * ((sh1[2][2] * sh2[3][4] - sh1[2][0] * sh2[3][0]) - (sh1[0][2] * sh2[1][4] - sh1[0][0] * sh2[1][0]));
|
||||
|
||||
coeffs[dstIdx++] = Dot7(srcIdx, coeffsIn, sh3[5]);
|
||||
|
||||
sh3[6][0] = kSqrt01_04 * ((sh1[2][2] * sh2[4][0] + sh1[2][0] * sh2[4][4]) - (sh1[0][2] * sh2[0][0] + sh1[0][0] * sh2[0][4]));
|
||||
sh3[6][1] = kSqrt03_02 * (sh1[2][1] * sh2[4][0] - sh1[0][1] * sh2[0][0]);
|
||||
sh3[6][2] = kSqrt15_16 * (sh1[2][1] * sh2[4][1] - sh1[0][1] * sh2[0][1]);
|
||||
sh3[6][3] = kSqrt05_06 * (sh1[2][1] * sh2[4][2] - sh1[0][1] * sh2[0][2]);
|
||||
sh3[6][4] = kSqrt15_16 * (sh1[2][1] * sh2[4][3] - sh1[0][1] * sh2[0][3]);
|
||||
sh3[6][5] = kSqrt03_02 * (sh1[2][1] * sh2[4][4] - sh1[0][1] * sh2[0][4]);
|
||||
sh3[6][6] = kSqrt01_04 * ((sh1[2][2] * sh2[4][4] - sh1[2][0] * sh2[4][0]) - (sh1[0][2] * sh2[0][4] - sh1[0][0] * sh2[0][0]));
|
||||
|
||||
coeffs[dstIdx++] = Dot7(srcIdx, coeffsIn, sh3[6]);
|
||||
}
|
||||
|
||||
#endif // SPHERICAL_HARMONICS_HLSL
|
||||
7
MVS/3DGS-Unity/Shaders/SphericalHarmonics.hlsl.meta
Normal file
7
MVS/3DGS-Unity/Shaders/SphericalHarmonics.hlsl.meta
Normal file
@@ -0,0 +1,7 @@
|
||||
fileFormatVersion: 2
|
||||
guid: 0e45617a7c5ba4b4eb55e897761dcb31
|
||||
ShaderIncludeImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
||||
757
MVS/3DGS-Unity/Shaders/SplatUtilities.compute
Normal file
757
MVS/3DGS-Unity/Shaders/SplatUtilities.compute
Normal file
@@ -0,0 +1,757 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
#define GROUP_SIZE 1024
|
||||
|
||||
#pragma kernel CSSetIndices
|
||||
#pragma kernel CSCalcDistances
|
||||
#pragma kernel CSCalcViewData
|
||||
#pragma kernel CSUpdateEditData
|
||||
#pragma kernel CSInitEditData
|
||||
#pragma kernel CSClearBuffer
|
||||
#pragma kernel CSInvertSelection
|
||||
#pragma kernel CSSelectAll
|
||||
#pragma kernel CSOrBuffers
|
||||
#pragma kernel CSSelectionUpdate
|
||||
#pragma kernel CSTranslateSelection
|
||||
#pragma kernel CSRotateSelection
|
||||
#pragma kernel CSScaleSelection
|
||||
#pragma kernel CSExportData
|
||||
#pragma kernel CSCopySplats
|
||||
|
||||
// DeviceRadixSort
|
||||
#pragma multi_compile __ KEY_UINT KEY_INT KEY_FLOAT
|
||||
#pragma multi_compile __ PAYLOAD_UINT PAYLOAD_INT PAYLOAD_FLOAT
|
||||
#pragma multi_compile __ SHOULD_ASCEND
|
||||
#pragma multi_compile __ SORT_PAIRS
|
||||
#pragma multi_compile __ VULKAN
|
||||
#pragma kernel InitDeviceRadixSort
|
||||
#pragma kernel Upsweep
|
||||
#pragma kernel Scan
|
||||
#pragma kernel Downsweep
|
||||
|
||||
// GPU sorting needs wave ops
|
||||
#pragma require wavebasic
|
||||
#pragma require waveballot
|
||||
#pragma use_dxc
|
||||
|
||||
#include "DeviceRadixSort.hlsl"
|
||||
#include "GaussianSplatting.hlsl"
|
||||
#include "UnityCG.cginc"
|
||||
|
||||
float4x4 _MatrixObjectToWorld;
|
||||
float4x4 _MatrixWorldToObject;
|
||||
float4x4 _MatrixMV;
|
||||
float4 _VecScreenParams;
|
||||
float4 _VecWorldSpaceCameraPos;
|
||||
int _SelectionMode;
|
||||
|
||||
RWStructuredBuffer<uint> _SplatSortDistances;
|
||||
RWStructuredBuffer<uint> _SplatSortKeys;
|
||||
uint _SplatCount;
|
||||
|
||||
// radix sort etc. friendly, see http://stereopsis.com/radix.html
|
||||
uint FloatToSortableUint(float f)
|
||||
{
|
||||
uint fu = asuint(f);
|
||||
uint mask = -((int)(fu >> 31)) | 0x80000000;
|
||||
return fu ^ mask;
|
||||
}
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSSetIndices (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _SplatCount)
|
||||
return;
|
||||
|
||||
_SplatSortKeys[idx] = idx;
|
||||
}
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSCalcDistances (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _SplatCount)
|
||||
return;
|
||||
|
||||
uint origIdx = _SplatSortKeys[idx];
|
||||
|
||||
float3 pos = LoadSplatPos(origIdx);
|
||||
pos = mul(_MatrixMV, float4(pos.xyz, 1)).xyz;
|
||||
|
||||
_SplatSortDistances[idx] = FloatToSortableUint(pos.z);
|
||||
}
|
||||
|
||||
RWStructuredBuffer<SplatViewData> _SplatViewData;
|
||||
|
||||
float _SplatScale;
|
||||
float _SplatOpacityScale;
|
||||
uint _SHOrder;
|
||||
uint _SHOnly;
|
||||
|
||||
uint _SplatCutoutsCount;
|
||||
|
||||
#define SPLAT_CUTOUT_TYPE_ELLIPSOID 0
|
||||
#define SPLAT_CUTOUT_TYPE_BOX 1
|
||||
|
||||
struct GaussianCutoutShaderData // match GaussianCutout.ShaderData in C#
|
||||
{
|
||||
float4x4 mat;
|
||||
uint typeAndFlags;
|
||||
};
|
||||
StructuredBuffer<GaussianCutoutShaderData> _SplatCutouts;
|
||||
|
||||
RWByteAddressBuffer _SplatSelectedBits;
|
||||
ByteAddressBuffer _SplatDeletedBits;
|
||||
uint _SplatBitsValid;
|
||||
|
||||
void DecomposeCovariance(float3 cov2d, out float2 v1, out float2 v2)
|
||||
{
|
||||
#if 0 // does not quite give the correct results?
|
||||
|
||||
// https://jsfiddle.net/mattrossman/ehxmtgw6/
|
||||
// References:
|
||||
// - https://www.youtube.com/watch?v=e50Bj7jn9IQ
|
||||
// - https://en.wikipedia.org/wiki/Eigenvalue_algorithm#2%C3%972_matrices
|
||||
// - https://people.math.harvard.edu/~knill/teaching/math21b2004/exhibits/2dmatrices/index.html
|
||||
float a = cov2d.x;
|
||||
float b = cov2d.y;
|
||||
float d = cov2d.z;
|
||||
float det = a * d - b * b; // matrix is symmetric, so "c" is same as "b"
|
||||
float trace = a + d;
|
||||
|
||||
float mean = 0.5 * trace;
|
||||
float dist = sqrt(mean * mean - det);
|
||||
|
||||
float lambda1 = mean + dist; // 1st eigenvalue
|
||||
float lambda2 = mean - dist; // 2nd eigenvalue
|
||||
|
||||
if (b == 0) {
|
||||
// https://twitter.com/the_ross_man/status/1706342719776551360
|
||||
if (a > d) v1 = float2(1, 0);
|
||||
else v1 = float2(0, 1);
|
||||
} else
|
||||
v1 = normalize(float2(b, d - lambda2));
|
||||
|
||||
v1.y = -v1.y;
|
||||
// The 2nd eigenvector is just a 90 degree rotation of the first since Gaussian axes are orthogonal
|
||||
v2 = float2(v1.y, -v1.x);
|
||||
|
||||
// scaling components
|
||||
v1 *= sqrt(lambda1);
|
||||
v2 *= sqrt(lambda2);
|
||||
|
||||
float radius = 1.5;
|
||||
v1 *= radius;
|
||||
v2 *= radius;
|
||||
|
||||
#else
|
||||
|
||||
// same as in antimatter15/splat
|
||||
float diag1 = cov2d.x, diag2 = cov2d.z, offDiag = cov2d.y;
|
||||
float mid = 0.5f * (diag1 + diag2);
|
||||
float radius = length(float2((diag1 - diag2) / 2.0, offDiag));
|
||||
float lambda1 = mid + radius;
|
||||
float lambda2 = max(mid - radius, 0.1);
|
||||
float2 diagVec = normalize(float2(offDiag, lambda1 - diag1));
|
||||
diagVec.y = -diagVec.y;
|
||||
float maxSize = 4096.0;
|
||||
v1 = min(sqrt(2.0 * lambda1), maxSize) * diagVec;
|
||||
v2 = min(sqrt(2.0 * lambda2), maxSize) * float2(diagVec.y, -diagVec.x);
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
bool IsSplatCut(float3 pos)
|
||||
{
|
||||
bool finalCut = false;
|
||||
for (uint i = 0; i < _SplatCutoutsCount; ++i)
|
||||
{
|
||||
GaussianCutoutShaderData cutData = _SplatCutouts[i];
|
||||
uint type = cutData.typeAndFlags & 0xFF;
|
||||
if (type == 0xFF) // invalid/null cutout, ignore
|
||||
continue;
|
||||
bool invert = (cutData.typeAndFlags & 0xFF00) != 0;
|
||||
|
||||
float3 cutoutPos = mul(cutData.mat, float4(pos, 1)).xyz;
|
||||
if (type == SPLAT_CUTOUT_TYPE_ELLIPSOID)
|
||||
{
|
||||
if (dot(cutoutPos, cutoutPos) <= 1) return invert;
|
||||
}
|
||||
if (type == SPLAT_CUTOUT_TYPE_BOX)
|
||||
{
|
||||
if (all(abs(cutoutPos) <= 1)) return invert;
|
||||
}
|
||||
finalCut |= !invert;
|
||||
}
|
||||
return finalCut;
|
||||
}
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSCalcViewData (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _SplatCount)
|
||||
return;
|
||||
|
||||
SplatData splat = LoadSplatData(idx);
|
||||
SplatViewData view = (SplatViewData)0;
|
||||
|
||||
float3 centerWorldPos = mul(_MatrixObjectToWorld, float4(splat.pos,1)).xyz;
|
||||
float4 centerClipPos = mul(UNITY_MATRIX_VP, float4(centerWorldPos, 1));
|
||||
half opacityScale = _SplatOpacityScale;
|
||||
float splatScale = _SplatScale;
|
||||
|
||||
// deleted?
|
||||
if (_SplatBitsValid)
|
||||
{
|
||||
uint wordIdx = idx / 32;
|
||||
uint bitIdx = idx & 31;
|
||||
uint wordVal = _SplatDeletedBits.Load(wordIdx * 4);
|
||||
if (wordVal & (1 << bitIdx))
|
||||
{
|
||||
centerClipPos.w = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// cutouts
|
||||
if (IsSplatCut(splat.pos))
|
||||
{
|
||||
centerClipPos.w = 0;
|
||||
}
|
||||
|
||||
view.pos = centerClipPos;
|
||||
bool behindCam = centerClipPos.w <= 0;
|
||||
if (!behindCam)
|
||||
{
|
||||
float4 boxRot = splat.rot;
|
||||
float3 boxSize = splat.scale;
|
||||
|
||||
float3x3 splatRotScaleMat = CalcMatrixFromRotationScale(boxRot, boxSize);
|
||||
|
||||
float3 cov3d0, cov3d1;
|
||||
CalcCovariance3D(splatRotScaleMat, cov3d0, cov3d1);
|
||||
float splatScale2 = splatScale * splatScale;
|
||||
cov3d0 *= splatScale2;
|
||||
cov3d1 *= splatScale2;
|
||||
float3 cov2d = CalcCovariance2D(splat.pos, cov3d0, cov3d1, _MatrixMV, UNITY_MATRIX_P, _VecScreenParams);
|
||||
|
||||
DecomposeCovariance(cov2d, view.axis1, view.axis2);
|
||||
|
||||
float3 worldViewDir = _VecWorldSpaceCameraPos.xyz - centerWorldPos;
|
||||
float3 objViewDir = mul((float3x3)_MatrixWorldToObject, worldViewDir);
|
||||
objViewDir = normalize(objViewDir);
|
||||
|
||||
half4 col;
|
||||
col.rgb = ShadeSH(splat.sh, objViewDir, _SHOrder, _SHOnly != 0);
|
||||
col.a = min(splat.opacity * opacityScale, 65000);
|
||||
view.color.x = (f32tof16(col.r) << 16) | f32tof16(col.g);
|
||||
view.color.y = (f32tof16(col.b) << 16) | f32tof16(col.a);
|
||||
}
|
||||
|
||||
_SplatViewData[idx] = view;
|
||||
}
|
||||
|
||||
|
||||
RWByteAddressBuffer _DstBuffer;
|
||||
ByteAddressBuffer _SrcBuffer;
|
||||
uint _BufferSize;
|
||||
|
||||
uint2 GetSplatIndicesFromWord(uint idx)
|
||||
{
|
||||
uint idxStart = idx * 32;
|
||||
uint idxEnd = min(idxStart + 32, _SplatCount);
|
||||
return uint2(idxStart, idxEnd);
|
||||
}
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSUpdateEditData (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _BufferSize)
|
||||
return;
|
||||
|
||||
uint valSel = _SplatSelectedBits.Load(idx * 4);
|
||||
uint valDel = _SplatDeletedBits.Load(idx * 4);
|
||||
valSel &= ~valDel; // don't count deleted splats as selected
|
||||
uint2 splatIndices = GetSplatIndicesFromWord(idx);
|
||||
|
||||
// update selection bounds
|
||||
float3 bmin = 1.0e38;
|
||||
float3 bmax = -1.0e38;
|
||||
uint mask = 1;
|
||||
uint valCut = 0;
|
||||
for (uint sidx = splatIndices.x; sidx < splatIndices.y; ++sidx, mask <<= 1)
|
||||
{
|
||||
float3 spos = LoadSplatPos(sidx);
|
||||
// don't count cut splats as selected
|
||||
if (IsSplatCut(spos))
|
||||
{
|
||||
valSel &= ~mask;
|
||||
valCut |= mask;
|
||||
}
|
||||
if (valSel & mask)
|
||||
{
|
||||
bmin = min(bmin, spos);
|
||||
bmax = max(bmax, spos);
|
||||
}
|
||||
}
|
||||
valCut &= ~valDel; // don't count deleted splats as cut
|
||||
|
||||
if (valSel != 0)
|
||||
{
|
||||
_DstBuffer.InterlockedMin(12, FloatToSortableUint(bmin.x));
|
||||
_DstBuffer.InterlockedMin(16, FloatToSortableUint(bmin.y));
|
||||
_DstBuffer.InterlockedMin(20, FloatToSortableUint(bmin.z));
|
||||
_DstBuffer.InterlockedMax(24, FloatToSortableUint(bmax.x));
|
||||
_DstBuffer.InterlockedMax(28, FloatToSortableUint(bmax.y));
|
||||
_DstBuffer.InterlockedMax(32, FloatToSortableUint(bmax.z));
|
||||
}
|
||||
uint sumSel = countbits(valSel);
|
||||
uint sumDel = countbits(valDel);
|
||||
uint sumCut = countbits(valCut);
|
||||
_DstBuffer.InterlockedAdd(0, sumSel);
|
||||
_DstBuffer.InterlockedAdd(4, sumDel);
|
||||
_DstBuffer.InterlockedAdd(8, sumCut);
|
||||
}
|
||||
|
||||
[numthreads(1,1,1)]
|
||||
void CSInitEditData (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
_DstBuffer.Store3(0, uint3(0,0,0)); // selected, deleted, cut counts
|
||||
uint initMin = FloatToSortableUint(1.0e38);
|
||||
uint initMax = FloatToSortableUint(-1.0e38);
|
||||
_DstBuffer.Store3(12, uint3(initMin, initMin, initMin));
|
||||
_DstBuffer.Store3(24, uint3(initMax, initMax, initMax));
|
||||
}
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSClearBuffer (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _BufferSize)
|
||||
return;
|
||||
_DstBuffer.Store(idx * 4, 0);
|
||||
}
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSInvertSelection (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _BufferSize)
|
||||
return;
|
||||
uint v = _DstBuffer.Load(idx * 4);
|
||||
v = ~v;
|
||||
|
||||
// do not select splats that are cut
|
||||
uint2 splatIndices = GetSplatIndicesFromWord(idx);
|
||||
uint mask = 1;
|
||||
for (uint sidx = splatIndices.x; sidx < splatIndices.y; ++sidx, mask <<= 1)
|
||||
{
|
||||
float3 spos = LoadSplatPos(sidx);
|
||||
if (IsSplatCut(spos))
|
||||
v &= ~mask;
|
||||
}
|
||||
|
||||
_DstBuffer.Store(idx * 4, v);
|
||||
}
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSSelectAll (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _BufferSize)
|
||||
return;
|
||||
uint v = ~0;
|
||||
|
||||
// do not select splats that are cut
|
||||
uint2 splatIndices = GetSplatIndicesFromWord(idx);
|
||||
uint mask = 1;
|
||||
for (uint sidx = splatIndices.x; sidx < splatIndices.y; ++sidx, mask <<= 1)
|
||||
{
|
||||
float3 spos = LoadSplatPos(sidx);
|
||||
if (IsSplatCut(spos))
|
||||
v &= ~mask;
|
||||
}
|
||||
|
||||
_DstBuffer.Store(idx * 4, v);
|
||||
}
|
||||
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSOrBuffers (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _BufferSize)
|
||||
return;
|
||||
uint a = _SrcBuffer.Load(idx * 4);
|
||||
uint b = _DstBuffer.Load(idx * 4);
|
||||
_DstBuffer.Store(idx * 4, a | b);
|
||||
}
|
||||
|
||||
float4 _SelectionRect;
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSSelectionUpdate (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _SplatCount)
|
||||
return;
|
||||
|
||||
float3 pos = LoadSplatPos(idx);
|
||||
if (IsSplatCut(pos))
|
||||
return;
|
||||
|
||||
float3 centerWorldPos = mul(_MatrixObjectToWorld, float4(pos,1)).xyz;
|
||||
float4 centerClipPos = mul(UNITY_MATRIX_VP, float4(centerWorldPos, 1));
|
||||
bool behindCam = centerClipPos.w <= 0;
|
||||
if (behindCam)
|
||||
return;
|
||||
|
||||
float2 pixelPos = (centerClipPos.xy / centerClipPos.w * float2(0.5, -0.5) + 0.5) * _VecScreenParams.xy;
|
||||
if (pixelPos.x < _SelectionRect.x || pixelPos.x > _SelectionRect.z ||
|
||||
pixelPos.y < _SelectionRect.y || pixelPos.y > _SelectionRect.w)
|
||||
{
|
||||
return;
|
||||
}
|
||||
uint wordIdx = idx / 32;
|
||||
uint bitIdx = idx & 31;
|
||||
if (_SelectionMode)
|
||||
_SplatSelectedBits.InterlockedOr(wordIdx * 4, 1u << bitIdx); // +
|
||||
else
|
||||
_SplatSelectedBits.InterlockedAnd(wordIdx * 4, ~(1u << bitIdx)); // -
|
||||
}
|
||||
|
||||
float3 _SelectionDelta;
|
||||
|
||||
bool IsSplatSelected(uint idx)
|
||||
{
|
||||
uint wordIdx = idx / 32;
|
||||
uint bitIdx = idx & 31;
|
||||
uint selVal = _SplatSelectedBits.Load(wordIdx * 4);
|
||||
return (selVal & (1 << bitIdx)) != 0;
|
||||
}
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSTranslateSelection (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _SplatCount)
|
||||
return;
|
||||
if (!IsSplatSelected(idx))
|
||||
return;
|
||||
|
||||
uint fmt = _SplatFormat & 0xFF;
|
||||
if (_SplatChunkCount == 0 && fmt == VECTOR_FMT_32F)
|
||||
{
|
||||
uint stride = 12;
|
||||
float3 pos = asfloat(_SplatPos.Load3(idx * stride));
|
||||
pos += _SelectionDelta;
|
||||
_SplatPos.Store3(idx * stride, asuint(pos));
|
||||
}
|
||||
}
|
||||
|
||||
float3 _SelectionCenter;
|
||||
float4 _SelectionDeltaRot;
|
||||
ByteAddressBuffer _SplatPosMouseDown;
|
||||
ByteAddressBuffer _SplatOtherMouseDown;
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSRotateSelection (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _SplatCount)
|
||||
return;
|
||||
if (!IsSplatSelected(idx))
|
||||
return;
|
||||
|
||||
uint posFmt = _SplatFormat & 0xFF;
|
||||
if (_SplatChunkCount == 0 && posFmt == VECTOR_FMT_32F)
|
||||
{
|
||||
uint posStride = 12;
|
||||
float3 pos = asfloat(_SplatPosMouseDown.Load3(idx * posStride));
|
||||
pos -= _SelectionCenter;
|
||||
pos = mul(_MatrixObjectToWorld, float4(pos,1)).xyz;
|
||||
pos = QuatRotateVector(pos, _SelectionDeltaRot);
|
||||
pos = mul(_MatrixWorldToObject, float4(pos,1)).xyz;
|
||||
pos += _SelectionCenter;
|
||||
_SplatPos.Store3(idx * posStride, asuint(pos));
|
||||
}
|
||||
|
||||
uint scaleFmt = (_SplatFormat >> 8) & 0xFF;
|
||||
uint shFormat = (_SplatFormat >> 16) & 0xFF;
|
||||
if (_SplatChunkCount == 0 && scaleFmt == VECTOR_FMT_32F && shFormat == VECTOR_FMT_32F)
|
||||
{
|
||||
uint otherStride = 4 + 12;
|
||||
uint rotVal = _SplatOtherMouseDown.Load(idx * otherStride);
|
||||
float4 rot = DecodeRotation(DecodePacked_10_10_10_2(rotVal));
|
||||
|
||||
//@TODO: correct rotation
|
||||
rot = QuatMul(rot, _SelectionDeltaRot);
|
||||
|
||||
rotVal = EncodeQuatToNorm10(PackSmallest3Rotation(rot));
|
||||
_SplatOther.Store(idx * otherStride, rotVal);
|
||||
}
|
||||
|
||||
//@TODO: rotate SHs
|
||||
}
|
||||
|
||||
//@TODO: maybe scale the splat scale itself too?
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSScaleSelection (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _SplatCount)
|
||||
return;
|
||||
if (!IsSplatSelected(idx))
|
||||
return;
|
||||
|
||||
uint fmt = _SplatFormat & 0xFF;
|
||||
if (_SplatChunkCount == 0 && fmt == VECTOR_FMT_32F)
|
||||
{
|
||||
uint stride = 12;
|
||||
float3 pos = asfloat(_SplatPosMouseDown.Load3(idx * stride));
|
||||
pos -= _SelectionCenter;
|
||||
pos = mul(_MatrixObjectToWorld, float4(pos,1)).xyz;
|
||||
pos *= _SelectionDelta;
|
||||
pos = mul(_MatrixWorldToObject, float4(pos,1)).xyz;
|
||||
pos += _SelectionCenter;
|
||||
_SplatPos.Store3(idx * stride, asuint(pos));
|
||||
}
|
||||
}
|
||||
|
||||
struct ExportSplatData
|
||||
{
|
||||
float3 pos;
|
||||
float3 nor;
|
||||
float3 dc0;
|
||||
float4 shR14; float4 shR58; float4 shR9C; float3 shRDF;
|
||||
float4 shG14; float4 shG58; float4 shG9C; float3 shGDF;
|
||||
float4 shB14; float4 shB58; float4 shB9C; float3 shBDF;
|
||||
float opacity;
|
||||
float3 scale;
|
||||
float4 rot;
|
||||
};
|
||||
RWStructuredBuffer<ExportSplatData> _ExportBuffer;
|
||||
|
||||
float3 ColorToSH0(float3 col)
|
||||
{
|
||||
return (col - 0.5) / 0.2820948;
|
||||
}
|
||||
float InvSigmoid(float v)
|
||||
{
|
||||
return log(v / max(1 - v, 1.0e-6));
|
||||
}
|
||||
|
||||
// SH rotation
|
||||
#include "SphericalHarmonics.hlsl"
|
||||
|
||||
void RotateSH(inout SplatSHData sh, float3x3 rot)
|
||||
{
|
||||
float3 shin[16];
|
||||
float3 shout[16];
|
||||
shin[0] = sh.col;
|
||||
shin[1] = sh.sh1;
|
||||
shin[2] = sh.sh2;
|
||||
shin[3] = sh.sh3;
|
||||
shin[4] = sh.sh4;
|
||||
shin[5] = sh.sh5;
|
||||
shin[6] = sh.sh6;
|
||||
shin[7] = sh.sh7;
|
||||
shin[8] = sh.sh8;
|
||||
shin[9] = sh.sh9;
|
||||
shin[10] = sh.sh10;
|
||||
shin[11] = sh.sh11;
|
||||
shin[12] = sh.sh12;
|
||||
shin[13] = sh.sh13;
|
||||
shin[14] = sh.sh14;
|
||||
shin[15] = sh.sh15;
|
||||
RotateSH(rot, 4, shin, shout);
|
||||
sh.col = shout[0];
|
||||
sh.sh1 = shout[1];
|
||||
sh.sh2 = shout[2];
|
||||
sh.sh3 = shout[3];
|
||||
sh.sh4 = shout[4];
|
||||
sh.sh5 = shout[5];
|
||||
sh.sh6 = shout[6];
|
||||
sh.sh7 = shout[7];
|
||||
sh.sh8 = shout[8];
|
||||
sh.sh9 = shout[9];
|
||||
sh.sh10 = shout[10];
|
||||
sh.sh11 = shout[11];
|
||||
sh.sh12 = shout[12];
|
||||
sh.sh13 = shout[13];
|
||||
sh.sh14 = shout[14];
|
||||
sh.sh15 = shout[15];
|
||||
}
|
||||
|
||||
float3x3 CalcSHRotMatrix(float4x4 objToWorld)
|
||||
{
|
||||
float3x3 m = (float3x3)objToWorld;
|
||||
float sx = length(float3(m[0][0], m[0][1], m[0][2]));
|
||||
float sy = length(float3(m[1][0], m[1][1], m[1][2]));
|
||||
float sz = length(float3(m[2][0], m[2][1], m[2][2]));
|
||||
|
||||
float invSX = 1.0 / sx;
|
||||
float invSY = 1.0 / sy;
|
||||
float invSZ = 1.0 / sz;
|
||||
|
||||
m[0][0] *= invSX;
|
||||
m[0][1] *= invSX;
|
||||
m[0][2] *= invSX;
|
||||
m[1][0] *= invSY;
|
||||
m[1][1] *= invSY;
|
||||
m[1][2] *= invSY;
|
||||
m[2][0] *= invSZ;
|
||||
m[2][1] *= invSZ;
|
||||
m[2][2] *= invSZ;
|
||||
return m;
|
||||
}
|
||||
|
||||
|
||||
float4 _ExportTransformRotation;
|
||||
float3 _ExportTransformScale;
|
||||
uint _ExportTransformFlags;
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSExportData (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _SplatCount)
|
||||
return;
|
||||
SplatData src = LoadSplatData(idx);
|
||||
|
||||
bool isCut = IsSplatCut(src.pos);
|
||||
|
||||
// transform splat by matrix, if needed
|
||||
if (_ExportTransformFlags != 0)
|
||||
{
|
||||
src.pos = mul(_MatrixObjectToWorld, float4(src.pos,1)).xyz;
|
||||
|
||||
// note: this only handles axis flips from scale, not any arbitrary scaling
|
||||
if (_ExportTransformScale.x < 0)
|
||||
src.rot.yz = -src.rot.yz;
|
||||
if (_ExportTransformScale.y < 0)
|
||||
src.rot.xz = -src.rot.xz;
|
||||
if (_ExportTransformScale.z < 0)
|
||||
src.rot.xy = -src.rot.xy;
|
||||
src.rot = QuatMul(_ExportTransformRotation, src.rot);
|
||||
src.scale *= abs(_ExportTransformScale);
|
||||
|
||||
float3x3 shRot = CalcSHRotMatrix(_MatrixObjectToWorld);
|
||||
RotateSH(src.sh, shRot);
|
||||
}
|
||||
|
||||
ExportSplatData dst;
|
||||
dst.pos = src.pos;
|
||||
dst.nor = 0;
|
||||
dst.dc0 = ColorToSH0(src.sh.col);
|
||||
|
||||
dst.shR14 = float4(src.sh.sh1.r, src.sh.sh2.r, src.sh.sh3.r, src.sh.sh4.r);
|
||||
dst.shR58 = float4(src.sh.sh5.r, src.sh.sh6.r, src.sh.sh7.r, src.sh.sh8.r);
|
||||
dst.shR9C = float4(src.sh.sh9.r, src.sh.sh10.r, src.sh.sh11.r, src.sh.sh12.r);
|
||||
dst.shRDF = float3(src.sh.sh13.r, src.sh.sh14.r, src.sh.sh15.r);
|
||||
|
||||
dst.shG14 = float4(src.sh.sh1.g, src.sh.sh2.g, src.sh.sh3.g, src.sh.sh4.g);
|
||||
dst.shG58 = float4(src.sh.sh5.g, src.sh.sh6.g, src.sh.sh7.g, src.sh.sh8.g);
|
||||
dst.shG9C = float4(src.sh.sh9.g, src.sh.sh10.g, src.sh.sh11.g, src.sh.sh12.g);
|
||||
dst.shGDF = float3(src.sh.sh13.g, src.sh.sh14.g, src.sh.sh15.g);
|
||||
|
||||
dst.shB14 = float4(src.sh.sh1.b, src.sh.sh2.b, src.sh.sh3.b, src.sh.sh4.b);
|
||||
dst.shB58 = float4(src.sh.sh5.b, src.sh.sh6.b, src.sh.sh7.b, src.sh.sh8.b);
|
||||
dst.shB9C = float4(src.sh.sh9.b, src.sh.sh10.b, src.sh.sh11.b, src.sh.sh12.b);
|
||||
dst.shBDF = float3(src.sh.sh13.b, src.sh.sh14.b, src.sh.sh15.b);
|
||||
|
||||
dst.opacity = InvSigmoid(src.opacity);
|
||||
dst.scale = log(src.scale);
|
||||
dst.rot = src.rot.wxyz;
|
||||
|
||||
if (isCut)
|
||||
dst.nor = 1; // mark as skipped for export
|
||||
|
||||
_ExportBuffer[idx] = dst;
|
||||
}
|
||||
|
||||
RWByteAddressBuffer _CopyDstPos;
|
||||
RWByteAddressBuffer _CopyDstOther;
|
||||
RWByteAddressBuffer _CopyDstSH;
|
||||
RWByteAddressBuffer _CopyDstEditDeleted;
|
||||
RWTexture2D<float4> _CopyDstColor;
|
||||
uint _CopyDstSize, _CopySrcStartIndex, _CopyDstStartIndex, _CopyCount;
|
||||
|
||||
float4x4 _CopyTransformMatrix;
|
||||
float4 _CopyTransformRotation;
|
||||
float3 _CopyTransformScale;
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSCopySplats (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _CopyCount)
|
||||
return;
|
||||
uint srcIdx = _CopySrcStartIndex + idx;
|
||||
uint dstIdx = _CopyDstStartIndex + idx;
|
||||
if (srcIdx >= _SplatCount || dstIdx >= _CopyDstSize)
|
||||
return;
|
||||
|
||||
SplatData src = LoadSplatData(idx);
|
||||
|
||||
// transform the splat
|
||||
src.pos = mul(_CopyTransformMatrix, float4(src.pos,1)).xyz;
|
||||
// note: this only handles axis flips from scale, not any arbitrary scaling
|
||||
if (_CopyTransformScale.x < 0)
|
||||
src.rot.yz = -src.rot.yz;
|
||||
if (_CopyTransformScale.y < 0)
|
||||
src.rot.xz = -src.rot.xz;
|
||||
if (_CopyTransformScale.z < 0)
|
||||
src.rot.xy = -src.rot.xy;
|
||||
src.rot = QuatMul(_CopyTransformRotation, src.rot);
|
||||
src.scale *= abs(_CopyTransformScale);
|
||||
|
||||
float3x3 shRot = CalcSHRotMatrix(_CopyTransformMatrix);
|
||||
RotateSH(src.sh, shRot);
|
||||
|
||||
// output data into destination:
|
||||
// pos
|
||||
uint posStride = 12;
|
||||
_CopyDstPos.Store3(dstIdx * posStride, asuint(src.pos));
|
||||
// rot + scale
|
||||
uint otherStride = 4 + 12;
|
||||
uint rotVal = EncodeQuatToNorm10(PackSmallest3Rotation(src.rot));
|
||||
_CopyDstOther.Store4(dstIdx * otherStride, uint4(
|
||||
rotVal,
|
||||
asuint(src.scale.x),
|
||||
asuint(src.scale.y),
|
||||
asuint(src.scale.z)));
|
||||
// color
|
||||
uint3 pixelIndex = SplatIndexToPixelIndex(dstIdx);
|
||||
_CopyDstColor[pixelIndex.xy] = float4(src.sh.col, src.opacity);
|
||||
|
||||
// SH
|
||||
uint shStride = 192; // 15*3 fp32, rounded up to multiple of 16
|
||||
uint shOffset = dstIdx * shStride;
|
||||
_CopyDstSH.Store3(shOffset + 12 * 0, asuint(src.sh.sh1));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 1, asuint(src.sh.sh2));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 2, asuint(src.sh.sh3));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 3, asuint(src.sh.sh4));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 4, asuint(src.sh.sh5));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 5, asuint(src.sh.sh6));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 6, asuint(src.sh.sh7));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 7, asuint(src.sh.sh8));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 8, asuint(src.sh.sh9));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 9, asuint(src.sh.sh10));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 10, asuint(src.sh.sh11));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 11, asuint(src.sh.sh12));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 12, asuint(src.sh.sh13));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 13, asuint(src.sh.sh14));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 14, asuint(src.sh.sh15));
|
||||
|
||||
// deleted bits
|
||||
uint srcWordIdx = srcIdx / 32;
|
||||
uint srcBitIdx = srcIdx & 31;
|
||||
if (_SplatDeletedBits.Load(srcWordIdx * 4) & (1u << srcBitIdx))
|
||||
{
|
||||
uint dstWordIdx = dstIdx / 32;
|
||||
uint dstBitIdx = dstIdx & 31;
|
||||
_CopyDstEditDeleted.InterlockedOr(dstWordIdx * 4, 1u << dstBitIdx);
|
||||
}
|
||||
}
|
||||
7
MVS/3DGS-Unity/Shaders/SplatUtilities.compute.meta
Normal file
7
MVS/3DGS-Unity/Shaders/SplatUtilities.compute.meta
Normal file
@@ -0,0 +1,7 @@
|
||||
fileFormatVersion: 2
|
||||
guid: ec84f78b836bd4f96a105d6b804f08bd
|
||||
ComputeShaderImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
||||
Reference in New Issue
Block a user