chore: sync workspace state

This commit is contained in:
2026-03-29 01:36:53 +08:00
parent eb5de3e3d4
commit e5cb79f3ce
4935 changed files with 35593 additions and 360696 deletions

View File

@@ -0,0 +1,44 @@
// SPDX-License-Identifier: MIT
Shader "Unlit/BlackSkybox"
{
Properties
{
_Color ("Color", Color) = (0,0,0,0)
}
SubShader
{
Pass
{
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#include "UnityCG.cginc"
struct appdata
{
float4 vertex : POSITION;
};
struct v2f
{
float4 vertex : SV_POSITION;
};
v2f vert (appdata v)
{
v2f o;
o.vertex = UnityObjectToClipPos(v.vertex);
return o;
}
half4 _Color;
half4 frag (v2f i) : SV_Target
{
return _Color;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: a4867e5be68354ccda78062a92c74391
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,531 @@
/******************************************************************************
* DeviceRadixSort
* Device Level 8-bit LSD Radix Sort using reduce then scan
*
* SPDX-License-Identifier: MIT
* Copyright Thomas Smith 5/17/2024
* https://github.com/b0nes164/GPUSorting
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
******************************************************************************/
#include "SortCommon.hlsl"
#define US_DIM 128U //The number of threads in a Upsweep threadblock
#define SCAN_DIM 128U //The number of threads in a Scan threadblock
RWStructuredBuffer<uint> b_globalHist; //buffer holding device level offsets for each binning pass
RWStructuredBuffer<uint> b_passHist; //buffer used to store reduced sums of partition tiles
groupshared uint g_us[RADIX * 2]; //Shared memory for upsweep
groupshared uint g_scan[SCAN_DIM]; //Shared memory for the scan
//*****************************************************************************
//INIT KERNEL
//*****************************************************************************
//Clear the global histogram, as we will be adding to it atomically
[numthreads(1024, 1, 1)]
void InitDeviceRadixSort(int3 id : SV_DispatchThreadID)
{
b_globalHist[id.x] = 0;
}
//*****************************************************************************
//UPSWEEP KERNEL
//*****************************************************************************
//histogram, 64 threads to a histogram
inline void HistogramDigitCounts(uint gtid, uint gid)
{
const uint histOffset = gtid / 64 * RADIX;
const uint partitionEnd = gid == e_threadBlocks - 1 ?
e_numKeys : (gid + 1) * PART_SIZE;
for (uint i = gtid + gid * PART_SIZE; i < partitionEnd; i += US_DIM)
{
#if defined(KEY_UINT)
InterlockedAdd(g_us[ExtractDigit(b_sort[i]) + histOffset], 1);
#elif defined(KEY_INT)
InterlockedAdd(g_us[ExtractDigit(IntToUint(b_sort[i])) + histOffset], 1);
#elif defined(KEY_FLOAT)
InterlockedAdd(g_us[ExtractDigit(FloatToUint(b_sort[i])) + histOffset], 1);
#endif
}
}
//reduce and pass to tile histogram
inline void ReduceWriteDigitCounts(uint gtid, uint gid)
{
for (uint i = gtid; i < RADIX; i += US_DIM)
{
g_us[i] += g_us[i + RADIX];
b_passHist[i * e_threadBlocks + gid] = g_us[i];
g_us[i] += WavePrefixSum(g_us[i]);
}
}
//Exclusive scan over digit counts, then atomically add to global hist
inline void GlobalHistExclusiveScanWGE16(uint gtid, uint waveSize)
{
GroupMemoryBarrierWithGroupSync();
if (gtid < (RADIX / waveSize))
{
g_us[(gtid + 1) * waveSize - 1] +=
WavePrefixSum(g_us[(gtid + 1) * waveSize - 1]);
}
GroupMemoryBarrierWithGroupSync();
//atomically add to global histogram
const uint globalHistOffset = GlobalHistOffset();
const uint laneMask = waveSize - 1;
const uint circularLaneShift = WaveGetLaneIndex() + 1 & laneMask;
for (uint i = gtid; i < RADIX; i += US_DIM)
{
const uint index = circularLaneShift + (i & ~laneMask);
uint t = WaveGetLaneIndex() != laneMask ? g_us[i] : 0;
if (i >= waveSize)
t += WaveReadLaneAt(g_us[i - 1], 0);
InterlockedAdd(b_globalHist[index + globalHistOffset], t);
}
}
inline void GlobalHistExclusiveScanWLT16(uint gtid, uint waveSize)
{
const uint globalHistOffset = GlobalHistOffset();
if (gtid < waveSize)
{
const uint circularLaneShift = WaveGetLaneIndex() + 1 &
waveSize - 1;
InterlockedAdd(b_globalHist[circularLaneShift + globalHistOffset],
circularLaneShift ? g_us[gtid] : 0);
}
GroupMemoryBarrierWithGroupSync();
const uint laneLog = countbits(waveSize - 1);
uint offset = laneLog;
uint j = waveSize;
for (; j < (RADIX >> 1); j <<= laneLog)
{
if (gtid < (RADIX >> offset))
{
g_us[((gtid + 1) << offset) - 1] +=
WavePrefixSum(g_us[((gtid + 1) << offset) - 1]);
}
GroupMemoryBarrierWithGroupSync();
for (uint i = gtid + j; i < RADIX; i += US_DIM)
{
if ((i & ((j << laneLog) - 1)) >= j)
{
if (i < (j << laneLog))
{
InterlockedAdd(b_globalHist[i + globalHistOffset],
WaveReadLaneAt(g_us[((i >> offset) << offset) - 1], 0) +
((i & (j - 1)) ? g_us[i - 1] : 0));
}
else
{
if ((i + 1) & (j - 1))
{
g_us[i] +=
WaveReadLaneAt(g_us[((i >> offset) << offset) - 1], 0);
}
}
}
}
offset += laneLog;
}
GroupMemoryBarrierWithGroupSync();
//If RADIX is not a power of lanecount
for (uint i = gtid + j; i < RADIX; i += US_DIM)
{
InterlockedAdd(b_globalHist[i + globalHistOffset],
WaveReadLaneAt(g_us[((i >> offset) << offset) - 1], 0) +
((i & (j - 1)) ? g_us[i - 1] : 0));
}
}
[numthreads(US_DIM, 1, 1)]
void Upsweep(uint3 gtid : SV_GroupThreadID, uint3 gid : SV_GroupID)
{
//get the wave size
const uint waveSize = getWaveSize();
//clear shared memory
const uint histsEnd = RADIX * 2;
for (uint i = gtid.x; i < histsEnd; i += US_DIM)
g_us[i] = 0;
GroupMemoryBarrierWithGroupSync();
HistogramDigitCounts(gtid.x, gid.x);
GroupMemoryBarrierWithGroupSync();
ReduceWriteDigitCounts(gtid.x, gid.x);
if (waveSize >= 16)
GlobalHistExclusiveScanWGE16(gtid.x, waveSize);
if (waveSize < 16)
GlobalHistExclusiveScanWLT16(gtid.x, waveSize);
}
//*****************************************************************************
//SCAN KERNEL
//*****************************************************************************
inline void ExclusiveThreadBlockScanFullWGE16(
uint gtid,
uint laneMask,
uint circularLaneShift,
uint partEnd,
uint deviceOffset,
uint waveSize,
inout uint reduction)
{
for (uint i = gtid; i < partEnd; i += SCAN_DIM)
{
g_scan[gtid] = b_passHist[i + deviceOffset];
g_scan[gtid] += WavePrefixSum(g_scan[gtid]);
GroupMemoryBarrierWithGroupSync();
if (gtid < SCAN_DIM / waveSize)
{
g_scan[(gtid + 1) * waveSize - 1] +=
WavePrefixSum(g_scan[(gtid + 1) * waveSize - 1]);
}
GroupMemoryBarrierWithGroupSync();
uint t = (WaveGetLaneIndex() != laneMask ? g_scan[gtid] : 0) + reduction;
if (gtid >= waveSize)
t += WaveReadLaneAt(g_scan[gtid - 1], 0);
b_passHist[circularLaneShift + (i & ~laneMask) + deviceOffset] = t;
reduction += g_scan[SCAN_DIM - 1];
GroupMemoryBarrierWithGroupSync();
}
}
inline void ExclusiveThreadBlockScanPartialWGE16(
uint gtid,
uint laneMask,
uint circularLaneShift,
uint partEnd,
uint deviceOffset,
uint waveSize,
uint reduction)
{
uint i = gtid + partEnd;
if (i < e_threadBlocks)
g_scan[gtid] = b_passHist[deviceOffset + i];
g_scan[gtid] += WavePrefixSum(g_scan[gtid]);
GroupMemoryBarrierWithGroupSync();
if (gtid < SCAN_DIM / waveSize)
{
g_scan[(gtid + 1) * waveSize - 1] +=
WavePrefixSum(g_scan[(gtid + 1) * waveSize - 1]);
}
GroupMemoryBarrierWithGroupSync();
const uint index = circularLaneShift + (i & ~laneMask);
if (index < e_threadBlocks)
{
uint t = (WaveGetLaneIndex() != laneMask ? g_scan[gtid] : 0) + reduction;
if (gtid >= waveSize)
t += g_scan[(gtid & ~laneMask) - 1];
b_passHist[index + deviceOffset] = t;
}
}
inline void ExclusiveThreadBlockScanWGE16(uint gtid, uint gid, uint waveSize)
{
uint reduction = 0;
const uint laneMask = waveSize - 1;
const uint circularLaneShift = WaveGetLaneIndex() + 1 & laneMask;
const uint partionsEnd = e_threadBlocks / SCAN_DIM * SCAN_DIM;
const uint deviceOffset = gid * e_threadBlocks;
ExclusiveThreadBlockScanFullWGE16(
gtid,
laneMask,
circularLaneShift,
partionsEnd,
deviceOffset,
waveSize,
reduction);
ExclusiveThreadBlockScanPartialWGE16(
gtid,
laneMask,
circularLaneShift,
partionsEnd,
deviceOffset,
waveSize,
reduction);
}
inline void ExclusiveThreadBlockScanFullWLT16(
uint gtid,
uint partitions,
uint deviceOffset,
uint laneLog,
uint circularLaneShift,
uint waveSize,
inout uint reduction)
{
for (uint k = 0; k < partitions; ++k)
{
g_scan[gtid] = b_passHist[gtid + k * SCAN_DIM + deviceOffset];
g_scan[gtid] += WavePrefixSum(g_scan[gtid]);
GroupMemoryBarrierWithGroupSync();
if (gtid < waveSize)
{
b_passHist[circularLaneShift + k * SCAN_DIM + deviceOffset] =
(circularLaneShift ? g_scan[gtid] : 0) + reduction;
}
uint offset = laneLog;
uint j = waveSize;
for (; j < (SCAN_DIM >> 1); j <<= laneLog)
{
if (gtid < (SCAN_DIM >> offset))
{
g_scan[((gtid + 1) << offset) - 1] +=
WavePrefixSum(g_scan[((gtid + 1) << offset) - 1]);
}
GroupMemoryBarrierWithGroupSync();
if ((gtid & ((j << laneLog) - 1)) >= j)
{
if (gtid < (j << laneLog))
{
b_passHist[gtid + k * SCAN_DIM + deviceOffset] =
WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0) +
((gtid & (j - 1)) ? g_scan[gtid - 1] : 0) + reduction;
}
else
{
if ((gtid + 1) & (j - 1))
{
g_scan[gtid] +=
WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0);
}
}
}
offset += laneLog;
}
GroupMemoryBarrierWithGroupSync();
//If SCAN_DIM is not a power of lanecount
for (uint i = gtid + j; i < SCAN_DIM; i += SCAN_DIM)
{
b_passHist[i + k * SCAN_DIM + deviceOffset] =
WaveReadLaneAt(g_scan[((i >> offset) << offset) - 1], 0) +
((i & (j - 1)) ? g_scan[i - 1] : 0) + reduction;
}
reduction += WaveReadLaneAt(g_scan[SCAN_DIM - 1], 0) +
WaveReadLaneAt(g_scan[(((SCAN_DIM - 1) >> offset) << offset) - 1], 0);
GroupMemoryBarrierWithGroupSync();
}
}
inline void ExclusiveThreadBlockScanParitalWLT16(
uint gtid,
uint partitions,
uint deviceOffset,
uint laneLog,
uint circularLaneShift,
uint waveSize,
uint reduction)
{
const uint finalPartSize = e_threadBlocks - partitions * SCAN_DIM;
if (gtid < finalPartSize)
{
g_scan[gtid] = b_passHist[gtid + partitions * SCAN_DIM + deviceOffset];
g_scan[gtid] += WavePrefixSum(g_scan[gtid]);
}
GroupMemoryBarrierWithGroupSync();
if (gtid < waveSize && circularLaneShift < finalPartSize)
{
b_passHist[circularLaneShift + partitions * SCAN_DIM + deviceOffset] =
(circularLaneShift ? g_scan[gtid] : 0) + reduction;
}
uint offset = laneLog;
for (uint j = waveSize; j < finalPartSize; j <<= laneLog)
{
if (gtid < (finalPartSize >> offset))
{
g_scan[((gtid + 1) << offset) - 1] +=
WavePrefixSum(g_scan[((gtid + 1) << offset) - 1]);
}
GroupMemoryBarrierWithGroupSync();
if ((gtid & ((j << laneLog) - 1)) >= j && gtid < finalPartSize)
{
if (gtid < (j << laneLog))
{
b_passHist[gtid + partitions * SCAN_DIM + deviceOffset] =
WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0) +
((gtid & (j - 1)) ? g_scan[gtid - 1] : 0) + reduction;
}
else
{
if ((gtid + 1) & (j - 1))
{
g_scan[gtid] +=
WaveReadLaneAt(g_scan[((gtid >> offset) << offset) - 1], 0);
}
}
}
offset += laneLog;
}
}
inline void ExclusiveThreadBlockScanWLT16(uint gtid, uint gid, uint waveSize)
{
uint reduction = 0;
const uint partitions = e_threadBlocks / SCAN_DIM;
const uint deviceOffset = gid * e_threadBlocks;
const uint laneLog = countbits(waveSize - 1);
const uint circularLaneShift = WaveGetLaneIndex() + 1 & waveSize - 1;
ExclusiveThreadBlockScanFullWLT16(
gtid,
partitions,
deviceOffset,
laneLog,
circularLaneShift,
waveSize,
reduction);
ExclusiveThreadBlockScanParitalWLT16(
gtid,
partitions,
deviceOffset,
laneLog,
circularLaneShift,
waveSize,
reduction);
}
//Scan does not need flattening of gids
[numthreads(SCAN_DIM, 1, 1)]
void Scan(uint3 gtid : SV_GroupThreadID, uint3 gid : SV_GroupID)
{
const uint waveSize = getWaveSize();
if (waveSize >= 16)
ExclusiveThreadBlockScanWGE16(gtid.x, gid.x, waveSize);
if (waveSize < 16)
ExclusiveThreadBlockScanWLT16(gtid.x, gid.x, waveSize);
}
//*****************************************************************************
//DOWNSWEEP KERNEL
//*****************************************************************************
inline void LoadThreadBlockReductions(uint gtid, uint gid, uint exclusiveHistReduction)
{
if (gtid < RADIX)
{
g_d[gtid + PART_SIZE] = b_globalHist[gtid + GlobalHistOffset()] +
b_passHist[gtid * e_threadBlocks + gid] - exclusiveHistReduction;
}
}
[numthreads(D_DIM, 1, 1)]
void Downsweep(uint3 gtid : SV_GroupThreadID, uint3 gid : SV_GroupID)
{
KeyStruct keys;
OffsetStruct offsets;
const uint waveSize = getWaveSize();
ClearWaveHists(gtid.x, waveSize);
GroupMemoryBarrierWithGroupSync();
if (gid.x < e_threadBlocks - 1)
{
if (waveSize >= 16)
keys = LoadKeysWGE16(gtid.x, waveSize, gid.x);
if (waveSize < 16)
keys = LoadKeysWLT16(gtid.x, waveSize, gid.x, SerialIterations(waveSize));
}
if (gid.x == e_threadBlocks - 1)
{
if (waveSize >= 16)
keys = LoadKeysPartialWGE16(gtid.x, waveSize, gid.x);
if (waveSize < 16)
keys = LoadKeysPartialWLT16(gtid.x, waveSize, gid.x, SerialIterations(waveSize));
}
uint exclusiveHistReduction;
if (waveSize >= 16)
{
offsets = RankKeysWGE16(waveSize, getWaveIndex(gtid.x, waveSize) * RADIX, keys);
GroupMemoryBarrierWithGroupSync();
uint histReduction;
if (gtid.x < RADIX)
{
histReduction = WaveHistInclusiveScanCircularShiftWGE16(gtid.x, waveSize);
histReduction += WavePrefixSum(histReduction); //take advantage of barrier to begin scan
}
GroupMemoryBarrierWithGroupSync();
WaveHistReductionExclusiveScanWGE16(gtid.x, waveSize, histReduction);
GroupMemoryBarrierWithGroupSync();
UpdateOffsetsWGE16(gtid.x, waveSize, offsets, keys);
if (gtid.x < RADIX)
exclusiveHistReduction = g_d[gtid.x]; //take advantage of barrier to grab value
GroupMemoryBarrierWithGroupSync();
}
if (waveSize < 16)
{
offsets = RankKeysWLT16(waveSize, getWaveIndex(gtid.x, waveSize), keys, SerialIterations(waveSize));
if (gtid.x < HALF_RADIX)
{
uint histReduction = WaveHistInclusiveScanCircularShiftWLT16(gtid.x);
g_d[gtid.x] = histReduction + (histReduction << 16); //take advantage of barrier to begin scan
}
WaveHistReductionExclusiveScanWLT16(gtid.x);
GroupMemoryBarrierWithGroupSync();
UpdateOffsetsWLT16(gtid.x, waveSize, SerialIterations(waveSize), offsets, keys);
if (gtid.x < RADIX) //take advantage of barrier to grab value
exclusiveHistReduction = g_d[gtid.x >> 1] >> ((gtid.x & 1) ? 16 : 0) & 0xffff;
GroupMemoryBarrierWithGroupSync();
}
ScatterKeysShared(offsets, keys);
LoadThreadBlockReductions(gtid.x, gid.x, exclusiveHistReduction);
GroupMemoryBarrierWithGroupSync();
if (gid.x < e_threadBlocks - 1)
ScatterDevice(gtid.x, waveSize, gid.x, offsets);
if (gid.x == e_threadBlocks - 1)
ScatterDevicePartial(gtid.x, waveSize, gid.x, offsets);
}

View File

@@ -0,0 +1,7 @@
fileFormatVersion: 2
guid: 02209b8d952e7fc418492b88139826fd
ShaderIncludeImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,45 @@
// SPDX-License-Identifier: MIT
Shader "Hidden/Gaussian Splatting/Composite"
{
SubShader
{
Pass
{
ZWrite Off
ZTest Always
Cull Off
Blend SrcAlpha OneMinusSrcAlpha
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#pragma require compute
#pragma use_dxc
#include "UnityCG.cginc"
struct v2f
{
float4 vertex : SV_POSITION;
};
v2f vert (uint vtxID : SV_VertexID)
{
v2f o;
float2 quadPos = float2(vtxID&1, (vtxID>>1)&1) * 4.0 - 1.0;
o.vertex = float4(quadPos, 1, 1);
return o;
}
Texture2D _GaussianSplatRT;
half4 frag (v2f i) : SV_Target
{
half4 col = _GaussianSplatRT.Load(int3(i.vertex.xy, 0));
col.rgb = GammaToLinearSpace(col.rgb);
col.a = saturate(col.a * 1.5);
return col;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 7e184af7d01193a408eb916d8acafff9
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,102 @@
// SPDX-License-Identifier: MIT
Shader "Gaussian Splatting/Debug/Render Boxes"
{
SubShader
{
Tags { "RenderType"="Transparent" "Queue"="Transparent" }
Pass
{
ZWrite Off
Blend OneMinusDstAlpha One
Cull Front
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#pragma require compute
#pragma use_dxc
#include "UnityCG.cginc"
#include "GaussianSplatting.hlsl"
StructuredBuffer<uint> _OrderBuffer;
bool _DisplayChunks;
struct v2f
{
half4 col : COLOR0;
float4 vertex : SV_POSITION;
};
float _SplatScale;
float _SplatOpacityScale;
// based on https://iquilezles.org/articles/palettes/
// cosine based palette, 4 vec3 params
half3 palette(float t, half3 a, half3 b, half3 c, half3 d)
{
return a + b*cos(6.28318*(c*t+d));
}
v2f vert (uint vtxID : SV_VertexID, uint instID : SV_InstanceID)
{
v2f o;
bool chunks = _DisplayChunks;
uint idx = vtxID;
float3 localPos = float3(idx&1, (idx>>1)&1, (idx>>2)&1) * 2.0 - 1.0;
float3 centerWorldPos = 0;
if (!chunks)
{
// display splat boxes
instID = _OrderBuffer[instID];
SplatData splat = LoadSplatData(instID);
float4 boxRot = splat.rot;
float3 boxSize = splat.scale;
boxSize *= _SplatScale;
float3x3 splatRotScaleMat = CalcMatrixFromRotationScale(boxRot, boxSize);
splatRotScaleMat = mul((float3x3)unity_ObjectToWorld, splatRotScaleMat);
centerWorldPos = splat.pos;
centerWorldPos = mul(unity_ObjectToWorld, float4(centerWorldPos,1)).xyz;
o.col.rgb = saturate(splat.sh.col);
o.col.a = saturate(splat.opacity * _SplatOpacityScale);
localPos = mul(splatRotScaleMat, localPos) * 2;
}
else
{
// display chunk boxes
localPos = localPos * 0.5 + 0.5;
SplatChunkInfo chunk = _SplatChunks[instID];
float3 posMin = float3(chunk.posX.x, chunk.posY.x, chunk.posZ.x);
float3 posMax = float3(chunk.posX.y, chunk.posY.y, chunk.posZ.y);
localPos = lerp(posMin, posMax, localPos);
localPos = mul(unity_ObjectToWorld, float4(localPos,1)).xyz;
o.col.rgb = palette((float)instID / (float)_SplatChunkCount, half3(0.5,0.5,0.5), half3(0.5,0.5,0.5), half3(1,1,1), half3(0.0, 0.33, 0.67));
o.col.a = 0.1;
}
float3 worldPos = centerWorldPos + localPos;
o.vertex = UnityWorldToClipPos(worldPos);
return o;
}
half4 frag (v2f i) : SV_Target
{
half4 res = half4(i.col.rgb * i.col.a, i.col.a);
return res;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: 4006f2680fd7c8b4cbcb881454c782be
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,65 @@
// SPDX-License-Identifier: MIT
Shader "Gaussian Splatting/Debug/Render Points"
{
SubShader
{
Tags { "RenderType"="Transparent" "Queue"="Transparent" }
Pass
{
ZWrite On
Cull Off
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#pragma require compute
#pragma use_dxc
#include "GaussianSplatting.hlsl"
struct v2f
{
half3 color : TEXCOORD0;
float4 vertex : SV_POSITION;
};
float _SplatSize;
bool _DisplayIndex;
int _SplatCount;
v2f vert (uint vtxID : SV_VertexID, uint instID : SV_InstanceID)
{
v2f o;
uint splatIndex = instID;
SplatData splat = LoadSplatData(splatIndex);
float3 centerWorldPos = splat.pos;
centerWorldPos = mul(unity_ObjectToWorld, float4(centerWorldPos,1)).xyz;
float4 centerClipPos = mul(UNITY_MATRIX_VP, float4(centerWorldPos, 1));
o.vertex = centerClipPos;
uint idx = vtxID;
float2 quadPos = float2(idx&1, (idx>>1)&1) * 2.0 - 1.0;
o.vertex.xy += (quadPos * _SplatSize / _ScreenParams.xy) * o.vertex.w;
o.color.rgb = saturate(splat.sh.col);
if (_DisplayIndex)
{
o.color.r = frac((float)splatIndex / (float)_SplatCount * 100);
o.color.g = frac((float)splatIndex / (float)_SplatCount * 10);
o.color.b = (float)splatIndex / (float)_SplatCount;
}
return o;
}
half4 frag (v2f i) : SV_Target
{
return half4(i.color.rgb, 1);
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: b44409fc67214394f8f47e4e2648425e
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,617 @@
// SPDX-License-Identifier: MIT
#ifndef GAUSSIAN_SPLATTING_HLSL
#define GAUSSIAN_SPLATTING_HLSL
float InvSquareCentered01(float x)
{
x -= 0.5;
x *= 0.5;
x = sqrt(abs(x)) * sign(x);
return x + 0.5;
}
float3 QuatRotateVector(float3 v, float4 r)
{
float3 t = 2 * cross(r.xyz, v);
return v + r.w * t + cross(r.xyz, t);
}
float4 QuatMul(float4 a, float4 b)
{
return float4(a.wwww * b + (a.xyzx * b.wwwx + a.yzxy * b.zxyy) * float4(1,1,1,-1) - a.zxyz * b.yzxz);
}
float4 QuatInverse(float4 q)
{
return rcp(dot(q, q)) * q * float4(-1,-1,-1,1);
}
float3x3 CalcMatrixFromRotationScale(float4 rot, float3 scale)
{
float3x3 ms = float3x3(
scale.x, 0, 0,
0, scale.y, 0,
0, 0, scale.z
);
float x = rot.x;
float y = rot.y;
float z = rot.z;
float w = rot.w;
float3x3 mr = float3x3(
1-2*(y*y + z*z), 2*(x*y - w*z), 2*(x*z + w*y),
2*(x*y + w*z), 1-2*(x*x + z*z), 2*(y*z - w*x),
2*(x*z - w*y), 2*(y*z + w*x), 1-2*(x*x + y*y)
);
return mul(mr, ms);
}
void CalcCovariance3D(float3x3 rotMat, out float3 sigma0, out float3 sigma1)
{
float3x3 sig = mul(rotMat, transpose(rotMat));
sigma0 = float3(sig._m00, sig._m01, sig._m02);
sigma1 = float3(sig._m11, sig._m12, sig._m22);
}
// from "EWA Splatting" (Zwicker et al 2002) eq. 31
float3 CalcCovariance2D(float3 worldPos, float3 cov3d0, float3 cov3d1, float4x4 matrixV, float4x4 matrixP, float4 screenParams)
{
float4x4 viewMatrix = matrixV;
float3 viewPos = mul(viewMatrix, float4(worldPos, 1)).xyz;
// this is needed in order for splats that are visible in view but clipped "quite a lot" to work
float aspect = matrixP._m00 / matrixP._m11;
float tanFovX = rcp(matrixP._m00);
float tanFovY = rcp(matrixP._m11 * aspect);
float limX = 1.3 * tanFovX;
float limY = 1.3 * tanFovY;
viewPos.x = clamp(viewPos.x / viewPos.z, -limX, limX) * viewPos.z;
viewPos.y = clamp(viewPos.y / viewPos.z, -limY, limY) * viewPos.z;
float focal = screenParams.x * matrixP._m00 / 2;
float3x3 J = float3x3(
focal / viewPos.z, 0, -(focal * viewPos.x) / (viewPos.z * viewPos.z),
0, focal / viewPos.z, -(focal * viewPos.y) / (viewPos.z * viewPos.z),
0, 0, 0
);
float3x3 W = (float3x3)viewMatrix;
float3x3 T = mul(J, W);
float3x3 V = float3x3(
cov3d0.x, cov3d0.y, cov3d0.z,
cov3d0.y, cov3d1.x, cov3d1.y,
cov3d0.z, cov3d1.y, cov3d1.z
);
float3x3 cov = mul(T, mul(V, transpose(T)));
// Low pass filter to make each splat at least 1px size.
cov._m00 += 0.3;
cov._m11 += 0.3;
return float3(cov._m00, cov._m01, cov._m11);
}
float3 CalcConic(float3 cov2d)
{
float det = cov2d.x * cov2d.z - cov2d.y * cov2d.y;
return float3(cov2d.z, -cov2d.y, cov2d.x) * rcp(det);
}
float2 CalcScreenSpaceDelta(float2 svPositionXY, float2 centerXY, float4 projectionParams)
{
float2 d = svPositionXY - centerXY;
d.y *= projectionParams.x;
return d;
}
float CalcPowerFromConic(float3 conic, float2 d)
{
return -0.5 * (conic.x * d.x*d.x + conic.z * d.y*d.y) + conic.y * d.x*d.y;
}
// Morton interleaving 16x16 group i.e. by 4 bits of coordinates, based on this thread:
// https://twitter.com/rygorous/status/986715358852608000
// which is simplified version of https://fgiesen.wordpress.com/2009/12/13/decoding-morton-codes/
uint EncodeMorton2D_16x16(uint2 c)
{
uint t = ((c.y & 0xF) << 8) | (c.x & 0xF); // ----EFGH----ABCD
t = (t ^ (t << 2)) & 0x3333; // --EF--GH--AB--CD
t = (t ^ (t << 1)) & 0x5555; // -E-F-G-H-A-B-C-D
return (t | (t >> 7)) & 0xFF; // --------EAFBGCHD
}
uint2 DecodeMorton2D_16x16(uint t) // --------EAFBGCHD
{
t = (t & 0xFF) | ((t & 0xFE) << 7); // -EAFBGCHEAFBGCHD
t &= 0x5555; // -E-F-G-H-A-B-C-D
t = (t ^ (t >> 1)) & 0x3333; // --EF--GH--AB--CD
t = (t ^ (t >> 2)) & 0x0f0f; // ----EFGH----ABCD
return uint2(t & 0xF, t >> 8); // --------EFGHABCD
}
static const float SH_C1 = 0.4886025;
static const float SH_C2[] = { 1.0925484, -1.0925484, 0.3153916, -1.0925484, 0.5462742 };
static const float SH_C3[] = { -0.5900436, 2.8906114, -0.4570458, 0.3731763, -0.4570458, 1.4453057, -0.5900436 };
struct SplatSHData
{
half3 col, sh1, sh2, sh3, sh4, sh5, sh6, sh7, sh8, sh9, sh10, sh11, sh12, sh13, sh14, sh15;
};
half3 ShadeSH(SplatSHData splat, half3 dir, int shOrder, bool onlySH)
{
dir *= -1;
half x = dir.x, y = dir.y, z = dir.z;
// ambient band
half3 res = splat.col; // col = sh0 * SH_C0 + 0.5 is already precomputed
if (onlySH)
res = 0.5;
// 1st degree
if (shOrder >= 1)
{
res += SH_C1 * (-splat.sh1 * y + splat.sh2 * z - splat.sh3 * x);
// 2nd degree
if (shOrder >= 2)
{
half xx = x * x, yy = y * y, zz = z * z;
half xy = x * y, yz = y * z, xz = x * z;
res +=
(SH_C2[0] * xy) * splat.sh4 +
(SH_C2[1] * yz) * splat.sh5 +
(SH_C2[2] * (2 * zz - xx - yy)) * splat.sh6 +
(SH_C2[3] * xz) * splat.sh7 +
(SH_C2[4] * (xx - yy)) * splat.sh8;
// 3rd degree
if (shOrder >= 3)
{
res +=
(SH_C3[0] * y * (3 * xx - yy)) * splat.sh9 +
(SH_C3[1] * xy * z) * splat.sh10 +
(SH_C3[2] * y * (4 * zz - xx - yy)) * splat.sh11 +
(SH_C3[3] * z * (2 * zz - 3 * xx - 3 * yy)) * splat.sh12 +
(SH_C3[4] * x * (4 * zz - xx - yy)) * splat.sh13 +
(SH_C3[5] * z * (xx - yy)) * splat.sh14 +
(SH_C3[6] * x * (xx - 3 * yy)) * splat.sh15;
}
}
}
return max(res, 0);
}
static const uint kTexWidth = 2048;
uint3 SplatIndexToPixelIndex(uint idx)
{
uint3 res;
uint2 xy = DecodeMorton2D_16x16(idx);
uint width = kTexWidth / 16;
idx >>= 8;
res.x = (idx % width) * 16 + xy.x;
res.y = (idx / width) * 16 + xy.y;
res.z = 0;
return res;
}
struct SplatChunkInfo
{
uint colR, colG, colB, colA;
float2 posX, posY, posZ;
uint sclX, sclY, sclZ;
uint shR, shG, shB;
};
StructuredBuffer<SplatChunkInfo> _SplatChunks;
uint _SplatChunkCount;
static const uint kChunkSize = 256;
struct SplatData
{
float3 pos;
float4 rot;
float3 scale;
half opacity;
SplatSHData sh;
};
// Decode quaternion from a "smallest 3" e.g. 10.10.10.2 format
float4 DecodeRotation(float4 pq)
{
uint idx = (uint)round(pq.w * 3.0); // note: need to round or index might come out wrong in some formats (e.g. fp16.fp16.fp16.fp16)
float4 q;
q.xyz = pq.xyz * sqrt(2.0) - (1.0 / sqrt(2.0));
q.w = sqrt(1.0 - saturate(dot(q.xyz, q.xyz)));
if (idx == 0) q = q.wxyz;
if (idx == 1) q = q.xwyz;
if (idx == 2) q = q.xywz;
return q;
}
float4 PackSmallest3Rotation(float4 q)
{
// find biggest component
float4 absQ = abs(q);
int index = 0;
float maxV = absQ.x;
if (absQ.y > maxV)
{
index = 1;
maxV = absQ.y;
}
if (absQ.z > maxV)
{
index = 2;
maxV = absQ.z;
}
if (absQ.w > maxV)
{
index = 3;
maxV = absQ.w;
}
if (index == 0) q = q.yzwx;
if (index == 1) q = q.xzwy;
if (index == 2) q = q.xywz;
float3 three = q.xyz * (q.w >= 0 ? 1 : -1); // -1/sqrt2..+1/sqrt2 range
three = (three * sqrt(2.0)) * 0.5 + 0.5; // 0..1 range
return float4(three, index / 3.0);
}
half3 DecodePacked_6_5_5(uint enc)
{
return half3(
(enc & 63) / 63.0,
((enc >> 6) & 31) / 31.0,
((enc >> 11) & 31) / 31.0);
}
half3 DecodePacked_5_6_5(uint enc)
{
return half3(
(enc & 31) / 31.0,
((enc >> 5) & 63) / 63.0,
((enc >> 11) & 31) / 31.0);
}
half3 DecodePacked_11_10_11(uint enc)
{
return half3(
(enc & 2047) / 2047.0,
((enc >> 11) & 1023) / 1023.0,
((enc >> 21) & 2047) / 2047.0);
}
float3 DecodePacked_16_16_16(uint2 enc)
{
return float3(
(enc.x & 65535) / 65535.0,
((enc.x >> 16) & 65535) / 65535.0,
(enc.y & 65535) / 65535.0);
}
float4 DecodePacked_10_10_10_2(uint enc)
{
return float4(
(enc & 1023) / 1023.0,
((enc >> 10) & 1023) / 1023.0,
((enc >> 20) & 1023) / 1023.0,
((enc >> 30) & 3) / 3.0);
}
uint EncodeQuatToNorm10(float4 v) // 32 bits: 10.10.10.2
{
return (uint) (v.x * 1023.5f) | ((uint) (v.y * 1023.5f) << 10) | ((uint) (v.z * 1023.5f) << 20) | ((uint) (v.w * 3.5f) << 30);
}
#ifdef SHADER_STAGE_COMPUTE
#define SplatBufferDataType RWByteAddressBuffer
#else
#define SplatBufferDataType ByteAddressBuffer
#endif
SplatBufferDataType _SplatPos;
SplatBufferDataType _SplatOther;
SplatBufferDataType _SplatSH;
Texture2D _SplatColor;
uint _SplatFormat;
// Match GaussianSplatAsset.VectorFormat
#define VECTOR_FMT_32F 0
#define VECTOR_FMT_16 1
#define VECTOR_FMT_11 2
#define VECTOR_FMT_6 3
uint LoadUShort(SplatBufferDataType dataBuffer, uint addrU)
{
uint addrA = addrU & ~0x3;
uint val = dataBuffer.Load(addrA);
if (addrU != addrA)
val >>= 16;
return val & 0xFFFF;
}
uint LoadUInt(SplatBufferDataType dataBuffer, uint addrU)
{
uint addrA = addrU & ~0x3;
uint val = dataBuffer.Load(addrA);
if (addrU != addrA)
{
uint val1 = dataBuffer.Load(addrA + 4);
val = (val >> 16) | ((val1 & 0xFFFF) << 16);
}
return val;
}
float3 LoadAndDecodeVector(SplatBufferDataType dataBuffer, uint addrU, uint fmt)
{
uint addrA = addrU & ~0x3;
uint val0 = dataBuffer.Load(addrA);
float3 res = 0;
if (fmt == VECTOR_FMT_32F)
{
uint val1 = dataBuffer.Load(addrA + 4);
uint val2 = dataBuffer.Load(addrA + 8);
if (addrU != addrA)
{
uint val3 = dataBuffer.Load(addrA + 12);
val0 = (val0 >> 16) | ((val1 & 0xFFFF) << 16);
val1 = (val1 >> 16) | ((val2 & 0xFFFF) << 16);
val2 = (val2 >> 16) | ((val3 & 0xFFFF) << 16);
}
res = float3(asfloat(val0), asfloat(val1), asfloat(val2));
}
else if (fmt == VECTOR_FMT_16)
{
uint val1 = dataBuffer.Load(addrA + 4);
if (addrU != addrA)
{
val0 = (val0 >> 16) | ((val1 & 0xFFFF) << 16);
val1 >>= 16;
}
res = DecodePacked_16_16_16(uint2(val0, val1));
}
else if (fmt == VECTOR_FMT_11)
{
uint val1 = dataBuffer.Load(addrA + 4);
if (addrU != addrA)
{
val0 = (val0 >> 16) | ((val1 & 0xFFFF) << 16);
}
res = DecodePacked_11_10_11(val0);
}
else if (fmt == VECTOR_FMT_6)
{
if (addrU != addrA)
val0 >>= 16;
res = DecodePacked_6_5_5(val0);
}
return res;
}
float3 LoadSplatPosValue(uint index)
{
uint fmt = _SplatFormat & 0xFF;
uint stride = 0;
if (fmt == VECTOR_FMT_32F)
stride = 12;
else if (fmt == VECTOR_FMT_16)
stride = 6;
else if (fmt == VECTOR_FMT_11)
stride = 4;
else if (fmt == VECTOR_FMT_6)
stride = 2;
return LoadAndDecodeVector(_SplatPos, index * stride, fmt);
}
float3 LoadSplatPos(uint idx)
{
float3 pos = LoadSplatPosValue(idx);
uint chunkIdx = idx / kChunkSize;
if (chunkIdx < _SplatChunkCount)
{
SplatChunkInfo chunk = _SplatChunks[chunkIdx];
float3 posMin = float3(chunk.posX.x, chunk.posY.x, chunk.posZ.x);
float3 posMax = float3(chunk.posX.y, chunk.posY.y, chunk.posZ.y);
pos = lerp(posMin, posMax, pos);
}
return pos;
}
half4 LoadSplatColTex(uint3 coord)
{
return _SplatColor.Load(coord);
}
SplatData LoadSplatData(uint idx)
{
SplatData s = (SplatData)0;
// figure out raw data offsets / locations
uint3 coord = SplatIndexToPixelIndex(idx);
uint scaleFmt = (_SplatFormat >> 8) & 0xFF;
uint shFormat = (_SplatFormat >> 16) & 0xFF;
uint otherStride = 4; // rotation is 10.10.10.2
if (scaleFmt == VECTOR_FMT_32F)
otherStride += 12;
else if (scaleFmt == VECTOR_FMT_16)
otherStride += 6;
else if (scaleFmt == VECTOR_FMT_11)
otherStride += 4;
else if (scaleFmt == VECTOR_FMT_6)
otherStride += 2;
if (shFormat > VECTOR_FMT_6)
otherStride += 2;
uint otherAddr = idx * otherStride;
uint shStride = 0;
if (shFormat == VECTOR_FMT_32F)
shStride = 192; // 15*3 fp32, rounded up to multiple of 16
else if (shFormat == VECTOR_FMT_16 || shFormat > VECTOR_FMT_6)
shStride = 96; // 15*3 fp16, rounded up to multiple of 16
else if (shFormat == VECTOR_FMT_11)
shStride = 60; // 15x uint
else if (shFormat == VECTOR_FMT_6)
shStride = 32; // 15x ushort, rounded up to multiple of 4
// load raw splat data, which might be chunk-relative
s.pos = LoadSplatPosValue(idx);
s.rot = DecodeRotation(DecodePacked_10_10_10_2(LoadUInt(_SplatOther, otherAddr)));
s.scale = LoadAndDecodeVector(_SplatOther, otherAddr + 4, scaleFmt);
half4 col = LoadSplatColTex(coord);
uint shIndex = idx;
if (shFormat > VECTOR_FMT_6)
shIndex = LoadUShort(_SplatOther, otherAddr + otherStride - 2);
uint shOffset = shIndex * shStride;
uint4 shRaw0 = _SplatSH.Load4(shOffset);
uint4 shRaw1 = _SplatSH.Load4(shOffset + 16);
if (shFormat == VECTOR_FMT_32F)
{
uint4 shRaw2 = _SplatSH.Load4(shOffset + 32);
uint4 shRaw3 = _SplatSH.Load4(shOffset + 48);
uint4 shRaw4 = _SplatSH.Load4(shOffset + 64);
uint4 shRaw5 = _SplatSH.Load4(shOffset + 80);
uint4 shRaw6 = _SplatSH.Load4(shOffset + 96);
uint4 shRaw7 = _SplatSH.Load4(shOffset + 112);
uint4 shRaw8 = _SplatSH.Load4(shOffset + 128);
uint4 shRaw9 = _SplatSH.Load4(shOffset + 144);
uint4 shRawA = _SplatSH.Load4(shOffset + 160);
uint shRawB = _SplatSH.Load(shOffset + 176);
s.sh.sh1.r = asfloat(shRaw0.x); s.sh.sh1.g = asfloat(shRaw0.y); s.sh.sh1.b = asfloat(shRaw0.z);
s.sh.sh2.r = asfloat(shRaw0.w); s.sh.sh2.g = asfloat(shRaw1.x); s.sh.sh2.b = asfloat(shRaw1.y);
s.sh.sh3.r = asfloat(shRaw1.z); s.sh.sh3.g = asfloat(shRaw1.w); s.sh.sh3.b = asfloat(shRaw2.x);
s.sh.sh4.r = asfloat(shRaw2.y); s.sh.sh4.g = asfloat(shRaw2.z); s.sh.sh4.b = asfloat(shRaw2.w);
s.sh.sh5.r = asfloat(shRaw3.x); s.sh.sh5.g = asfloat(shRaw3.y); s.sh.sh5.b = asfloat(shRaw3.z);
s.sh.sh6.r = asfloat(shRaw3.w); s.sh.sh6.g = asfloat(shRaw4.x); s.sh.sh6.b = asfloat(shRaw4.y);
s.sh.sh7.r = asfloat(shRaw4.z); s.sh.sh7.g = asfloat(shRaw4.w); s.sh.sh7.b = asfloat(shRaw5.x);
s.sh.sh8.r = asfloat(shRaw5.y); s.sh.sh8.g = asfloat(shRaw5.z); s.sh.sh8.b = asfloat(shRaw5.w);
s.sh.sh9.r = asfloat(shRaw6.x); s.sh.sh9.g = asfloat(shRaw6.y); s.sh.sh9.b = asfloat(shRaw6.z);
s.sh.sh10.r = asfloat(shRaw6.w); s.sh.sh10.g = asfloat(shRaw7.x); s.sh.sh10.b = asfloat(shRaw7.y);
s.sh.sh11.r = asfloat(shRaw7.z); s.sh.sh11.g = asfloat(shRaw7.w); s.sh.sh11.b = asfloat(shRaw8.x);
s.sh.sh12.r = asfloat(shRaw8.y); s.sh.sh12.g = asfloat(shRaw8.z); s.sh.sh12.b = asfloat(shRaw8.w);
s.sh.sh13.r = asfloat(shRaw9.x); s.sh.sh13.g = asfloat(shRaw9.y); s.sh.sh13.b = asfloat(shRaw9.z);
s.sh.sh14.r = asfloat(shRaw9.w); s.sh.sh14.g = asfloat(shRawA.x); s.sh.sh14.b = asfloat(shRawA.y);
s.sh.sh15.r = asfloat(shRawA.z); s.sh.sh15.g = asfloat(shRawA.w); s.sh.sh15.b = asfloat(shRawB);
}
else if (shFormat == VECTOR_FMT_16 || shFormat > VECTOR_FMT_6)
{
uint4 shRaw2 = _SplatSH.Load4(shOffset + 32);
uint4 shRaw3 = _SplatSH.Load4(shOffset + 48);
uint4 shRaw4 = _SplatSH.Load4(shOffset + 64);
uint3 shRaw5 = _SplatSH.Load3(shOffset + 80);
s.sh.sh1.r = f16tof32(shRaw0.x ); s.sh.sh1.g = f16tof32(shRaw0.x >> 16); s.sh.sh1.b = f16tof32(shRaw0.y );
s.sh.sh2.r = f16tof32(shRaw0.y >> 16); s.sh.sh2.g = f16tof32(shRaw0.z ); s.sh.sh2.b = f16tof32(shRaw0.z >> 16);
s.sh.sh3.r = f16tof32(shRaw0.w ); s.sh.sh3.g = f16tof32(shRaw0.w >> 16); s.sh.sh3.b = f16tof32(shRaw1.x );
s.sh.sh4.r = f16tof32(shRaw1.x >> 16); s.sh.sh4.g = f16tof32(shRaw1.y ); s.sh.sh4.b = f16tof32(shRaw1.y >> 16);
s.sh.sh5.r = f16tof32(shRaw1.z ); s.sh.sh5.g = f16tof32(shRaw1.z >> 16); s.sh.sh5.b = f16tof32(shRaw1.w );
s.sh.sh6.r = f16tof32(shRaw1.w >> 16); s.sh.sh6.g = f16tof32(shRaw2.x ); s.sh.sh6.b = f16tof32(shRaw2.x >> 16);
s.sh.sh7.r = f16tof32(shRaw2.y ); s.sh.sh7.g = f16tof32(shRaw2.y >> 16); s.sh.sh7.b = f16tof32(shRaw2.z );
s.sh.sh8.r = f16tof32(shRaw2.z >> 16); s.sh.sh8.g = f16tof32(shRaw2.w ); s.sh.sh8.b = f16tof32(shRaw2.w >> 16);
s.sh.sh9.r = f16tof32(shRaw3.x ); s.sh.sh9.g = f16tof32(shRaw3.x >> 16); s.sh.sh9.b = f16tof32(shRaw3.y );
s.sh.sh10.r = f16tof32(shRaw3.y >> 16); s.sh.sh10.g = f16tof32(shRaw3.z ); s.sh.sh10.b = f16tof32(shRaw3.z >> 16);
s.sh.sh11.r = f16tof32(shRaw3.w ); s.sh.sh11.g = f16tof32(shRaw3.w >> 16); s.sh.sh11.b = f16tof32(shRaw4.x );
s.sh.sh12.r = f16tof32(shRaw4.x >> 16); s.sh.sh12.g = f16tof32(shRaw4.y ); s.sh.sh12.b = f16tof32(shRaw4.y >> 16);
s.sh.sh13.r = f16tof32(shRaw4.z ); s.sh.sh13.g = f16tof32(shRaw4.z >> 16); s.sh.sh13.b = f16tof32(shRaw4.w );
s.sh.sh14.r = f16tof32(shRaw4.w >> 16); s.sh.sh14.g = f16tof32(shRaw5.x ); s.sh.sh14.b = f16tof32(shRaw5.x >> 16);
s.sh.sh15.r = f16tof32(shRaw5.y ); s.sh.sh15.g = f16tof32(shRaw5.y >> 16); s.sh.sh15.b = f16tof32(shRaw5.z );
}
else if (shFormat == VECTOR_FMT_11)
{
uint4 shRaw2 = _SplatSH.Load4(shOffset + 32);
uint3 shRaw3 = _SplatSH.Load3(shOffset + 48);
s.sh.sh1 = DecodePacked_11_10_11(shRaw0.x);
s.sh.sh2 = DecodePacked_11_10_11(shRaw0.y);
s.sh.sh3 = DecodePacked_11_10_11(shRaw0.z);
s.sh.sh4 = DecodePacked_11_10_11(shRaw0.w);
s.sh.sh5 = DecodePacked_11_10_11(shRaw1.x);
s.sh.sh6 = DecodePacked_11_10_11(shRaw1.y);
s.sh.sh7 = DecodePacked_11_10_11(shRaw1.z);
s.sh.sh8 = DecodePacked_11_10_11(shRaw1.w);
s.sh.sh9 = DecodePacked_11_10_11(shRaw2.x);
s.sh.sh10 = DecodePacked_11_10_11(shRaw2.y);
s.sh.sh11 = DecodePacked_11_10_11(shRaw2.z);
s.sh.sh12 = DecodePacked_11_10_11(shRaw2.w);
s.sh.sh13 = DecodePacked_11_10_11(shRaw3.x);
s.sh.sh14 = DecodePacked_11_10_11(shRaw3.y);
s.sh.sh15 = DecodePacked_11_10_11(shRaw3.z);
}
else if (shFormat == VECTOR_FMT_6)
{
s.sh.sh1 = DecodePacked_5_6_5(shRaw0.x);
s.sh.sh2 = DecodePacked_5_6_5(shRaw0.x >> 16);
s.sh.sh3 = DecodePacked_5_6_5(shRaw0.y);
s.sh.sh4 = DecodePacked_5_6_5(shRaw0.y >> 16);
s.sh.sh5 = DecodePacked_5_6_5(shRaw0.z);
s.sh.sh6 = DecodePacked_5_6_5(shRaw0.z >> 16);
s.sh.sh7 = DecodePacked_5_6_5(shRaw0.w);
s.sh.sh8 = DecodePacked_5_6_5(shRaw0.w >> 16);
s.sh.sh9 = DecodePacked_5_6_5(shRaw1.x);
s.sh.sh10 = DecodePacked_5_6_5(shRaw1.x >> 16);
s.sh.sh11 = DecodePacked_5_6_5(shRaw1.y);
s.sh.sh12 = DecodePacked_5_6_5(shRaw1.y >> 16);
s.sh.sh13 = DecodePacked_5_6_5(shRaw1.z);
s.sh.sh14 = DecodePacked_5_6_5(shRaw1.z >> 16);
s.sh.sh15 = DecodePacked_5_6_5(shRaw1.w);
}
// if raw data is chunk-relative, convert to final values by interpolating between chunk min/max
uint chunkIdx = idx / kChunkSize;
if (chunkIdx < _SplatChunkCount)
{
SplatChunkInfo chunk = _SplatChunks[chunkIdx];
float3 posMin = float3(chunk.posX.x, chunk.posY.x, chunk.posZ.x);
float3 posMax = float3(chunk.posX.y, chunk.posY.y, chunk.posZ.y);
half3 sclMin = half3(f16tof32(chunk.sclX ), f16tof32(chunk.sclY ), f16tof32(chunk.sclZ ));
half3 sclMax = half3(f16tof32(chunk.sclX>>16), f16tof32(chunk.sclY>>16), f16tof32(chunk.sclZ>>16));
half4 colMin = half4(f16tof32(chunk.colR ), f16tof32(chunk.colG ), f16tof32(chunk.colB ), f16tof32(chunk.colA ));
half4 colMax = half4(f16tof32(chunk.colR>>16), f16tof32(chunk.colG>>16), f16tof32(chunk.colB>>16), f16tof32(chunk.colA>>16));
half3 shMin = half3(f16tof32(chunk.shR ), f16tof32(chunk.shG ), f16tof32(chunk.shB ));
half3 shMax = half3(f16tof32(chunk.shR>>16), f16tof32(chunk.shG>>16), f16tof32(chunk.shB>>16));
s.pos = lerp(posMin, posMax, s.pos);
s.scale = lerp(sclMin, sclMax, s.scale);
s.scale *= s.scale;
s.scale *= s.scale;
s.scale *= s.scale;
col = lerp(colMin, colMax, col);
col.a = InvSquareCentered01(col.a);
if (shFormat > VECTOR_FMT_32F && shFormat <= VECTOR_FMT_6)
{
s.sh.sh1 = lerp(shMin, shMax, s.sh.sh1 );
s.sh.sh2 = lerp(shMin, shMax, s.sh.sh2 );
s.sh.sh3 = lerp(shMin, shMax, s.sh.sh3 );
s.sh.sh4 = lerp(shMin, shMax, s.sh.sh4 );
s.sh.sh5 = lerp(shMin, shMax, s.sh.sh5 );
s.sh.sh6 = lerp(shMin, shMax, s.sh.sh6 );
s.sh.sh7 = lerp(shMin, shMax, s.sh.sh7 );
s.sh.sh8 = lerp(shMin, shMax, s.sh.sh8 );
s.sh.sh9 = lerp(shMin, shMax, s.sh.sh9 );
s.sh.sh10 = lerp(shMin, shMax, s.sh.sh10);
s.sh.sh11 = lerp(shMin, shMax, s.sh.sh11);
s.sh.sh12 = lerp(shMin, shMax, s.sh.sh12);
s.sh.sh13 = lerp(shMin, shMax, s.sh.sh13);
s.sh.sh14 = lerp(shMin, shMax, s.sh.sh14);
s.sh.sh15 = lerp(shMin, shMax, s.sh.sh15);
}
}
s.opacity = col.a;
s.sh.col = col.rgb;
return s;
}
struct SplatViewData
{
float4 pos;
float2 axis1, axis2;
uint2 color; // 4xFP16
};
#endif // GAUSSIAN_SPLATTING_HLSL

View File

@@ -0,0 +1,7 @@
fileFormatVersion: 2
guid: d4087e54957693c48a7be32de91c99e2
ShaderIncludeImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,111 @@
// SPDX-License-Identifier: MIT
Shader "Gaussian Splatting/Render Splats"
{
SubShader
{
Tags { "RenderType"="Transparent" "Queue"="Transparent" }
Pass
{
ZWrite Off
Blend OneMinusDstAlpha One
Cull Off
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#pragma require compute
#pragma use_dxc
#include "GaussianSplatting.hlsl"
StructuredBuffer<uint> _OrderBuffer;
struct v2f
{
half4 col : COLOR0;
float2 pos : TEXCOORD0;
float4 vertex : SV_POSITION;
};
StructuredBuffer<SplatViewData> _SplatViewData;
ByteAddressBuffer _SplatSelectedBits;
uint _SplatBitsValid;
v2f vert (uint vtxID : SV_VertexID, uint instID : SV_InstanceID)
{
v2f o = (v2f)0;
instID = _OrderBuffer[instID];
SplatViewData view = _SplatViewData[instID];
float4 centerClipPos = view.pos;
bool behindCam = centerClipPos.w <= 0;
if (behindCam)
{
o.vertex = asfloat(0x7fc00000); // NaN discards the primitive
}
else
{
o.col.r = f16tof32(view.color.x >> 16);
o.col.g = f16tof32(view.color.x);
o.col.b = f16tof32(view.color.y >> 16);
o.col.a = f16tof32(view.color.y);
uint idx = vtxID;
float2 quadPos = float2(idx&1, (idx>>1)&1) * 2.0 - 1.0;
quadPos *= 2;
o.pos = quadPos;
float2 deltaScreenPos = (quadPos.x * view.axis1 + quadPos.y * view.axis2) * 2 / _ScreenParams.xy;
o.vertex = centerClipPos;
o.vertex.xy += deltaScreenPos * centerClipPos.w;
// is this splat selected?
if (_SplatBitsValid)
{
uint wordIdx = instID / 32;
uint bitIdx = instID & 31;
uint selVal = _SplatSelectedBits.Load(wordIdx * 4);
if (selVal & (1 << bitIdx))
{
o.col.a = -1;
}
}
}
return o;
}
half4 frag (v2f i) : SV_Target
{
float power = -dot(i.pos, i.pos);
half alpha = exp(power);
if (i.col.a >= 0)
{
alpha = saturate(alpha * i.col.a);
}
else
{
// "selected" splat: magenta outline, increase opacity, magenta tint
half3 selectedColor = half3(1,0,1);
if (alpha > 7.0/255.0)
{
if (alpha < 10.0/255.0)
{
alpha = 1;
i.col.rgb = selectedColor;
}
alpha = saturate(alpha + 0.3);
}
i.col.rgb = lerp(i.col.rgb, selectedColor, 0.5);
}
if (alpha < 1.0/255.0)
discard;
half4 res = half4(i.col.rgb * alpha, alpha);
return res;
}
ENDCG
}
}
}

View File

@@ -0,0 +1,9 @@
fileFormatVersion: 2
guid: ed800126ae8844a67aad1974ddddd59c
ShaderImporter:
externalObjects: {}
defaultTextures: []
nonModifiableTextures: []
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,959 @@
/******************************************************************************
* SortCommon
* Common functions for GPUSorting
*
* SPDX-License-Identifier: MIT
* Copyright Thomas Smith 5/17/2024
* https://github.com/b0nes164/GPUSorting
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
******************************************************************************/
#define KEYS_PER_THREAD 15U
#define D_DIM 256U
#define PART_SIZE 3840U
#define D_TOTAL_SMEM 4096U
#define RADIX 256U //Number of digit bins
#define RADIX_MASK 255U //Mask of digit bins
#define HALF_RADIX 128U //For smaller waves where bit packing is necessary
#define HALF_MASK 127U // ''
#define RADIX_LOG 8U //log2(RADIX)
#define RADIX_PASSES 4U //(Key width) / RADIX_LOG
cbuffer cbGpuSorting : register(b0)
{
uint e_numKeys;
uint e_radixShift;
uint e_threadBlocks;
uint padding;
};
#if defined(KEY_UINT)
RWStructuredBuffer<uint> b_sort;
RWStructuredBuffer<uint> b_alt;
#elif defined(KEY_INT)
RWStructuredBuffer<int> b_sort;
RWStructuredBuffer<int> b_alt;
#elif defined(KEY_FLOAT)
RWStructuredBuffer<float> b_sort;
RWStructuredBuffer<float> b_alt;
#endif
#if defined(PAYLOAD_UINT)
RWStructuredBuffer<uint> b_sortPayload;
RWStructuredBuffer<uint> b_altPayload;
#elif defined(PAYLOAD_INT)
RWStructuredBuffer<int> b_sortPayload;
RWStructuredBuffer<int> b_altPayload;
#elif defined(PAYLOAD_FLOAT)
RWStructuredBuffer<float> b_sortPayload;
RWStructuredBuffer<float> b_altPayload;
#endif
groupshared uint g_d[D_TOTAL_SMEM]; //Shared memory for DigitBinningPass and DownSweep kernels
struct KeyStruct
{
uint k[KEYS_PER_THREAD];
};
struct OffsetStruct
{
#if defined(ENABLE_16_BIT)
uint16_t o[KEYS_PER_THREAD];
#else
uint o[KEYS_PER_THREAD];
#endif
};
struct DigitStruct
{
#if defined(ENABLE_16_BIT)
uint16_t d[KEYS_PER_THREAD];
#else
uint d[KEYS_PER_THREAD];
#endif
};
//*****************************************************************************
//HELPER FUNCTIONS
//*****************************************************************************
//Due to a bug with SPIRV pre 1.6, we cannot use WaveGetLaneCount() to get the currently active wavesize
inline uint getWaveSize()
{
#if defined(VULKAN)
GroupMemoryBarrierWithGroupSync(); //Make absolutely sure the wave is not diverged here
return dot(countbits(WaveActiveBallot(true)), uint4(1, 1, 1, 1));
#else
return WaveGetLaneCount();
#endif
}
inline uint getWaveIndex(uint gtid, uint waveSize)
{
return gtid / waveSize;
}
//Radix Tricks by Michael Herf
//http://stereopsis.com/radix.html
inline uint FloatToUint(float f)
{
uint mask = -((int) (asuint(f) >> 31)) | 0x80000000;
return asuint(f) ^ mask;
}
inline float UintToFloat(uint u)
{
uint mask = ((u >> 31) - 1) | 0x80000000;
return asfloat(u ^ mask);
}
inline uint IntToUint(int i)
{
return asuint(i ^ 0x80000000);
}
inline int UintToInt(uint u)
{
return asint(u ^ 0x80000000);
}
inline uint getWaveCountPass(uint waveSize)
{
return D_DIM / waveSize;
}
inline uint ExtractDigit(uint key)
{
return key >> e_radixShift & RADIX_MASK;
}
inline uint ExtractDigit(uint key, uint shift)
{
return key >> shift & RADIX_MASK;
}
inline uint ExtractPackedIndex(uint key)
{
return key >> (e_radixShift + 1) & HALF_MASK;
}
inline uint ExtractPackedShift(uint key)
{
return (key >> e_radixShift & 1) ? 16 : 0;
}
inline uint ExtractPackedValue(uint packed, uint key)
{
return packed >> ExtractPackedShift(key) & 0xffff;
}
inline uint SubPartSizeWGE16(uint waveSize)
{
return KEYS_PER_THREAD * waveSize;
}
inline uint SharedOffsetWGE16(uint gtid, uint waveSize)
{
return WaveGetLaneIndex() + getWaveIndex(gtid, waveSize) * SubPartSizeWGE16(waveSize);
}
inline uint SubPartSizeWLT16(uint waveSize, uint _serialIterations)
{
return KEYS_PER_THREAD * waveSize * _serialIterations;
}
inline uint SharedOffsetWLT16(uint gtid, uint waveSize, uint _serialIterations)
{
return WaveGetLaneIndex() +
(getWaveIndex(gtid, waveSize) / _serialIterations * SubPartSizeWLT16(waveSize, _serialIterations)) +
(getWaveIndex(gtid, waveSize) % _serialIterations * waveSize);
}
inline uint DeviceOffsetWGE16(uint gtid, uint waveSize, uint partIndex)
{
return SharedOffsetWGE16(gtid, waveSize) + partIndex * PART_SIZE;
}
inline uint DeviceOffsetWLT16(uint gtid, uint waveSize, uint partIndex, uint serialIterations)
{
return SharedOffsetWLT16(gtid, waveSize, serialIterations) + partIndex * PART_SIZE;
}
inline uint GlobalHistOffset()
{
return e_radixShift << 5;
}
inline uint WaveHistsSizeWGE16(uint waveSize)
{
return D_DIM / waveSize * RADIX;
}
inline uint WaveHistsSizeWLT16()
{
return D_TOTAL_SMEM;
}
//*****************************************************************************
//FUNCTIONS COMMON TO THE DOWNSWEEP / DIGIT BINNING PASS
//*****************************************************************************
//If the size of a wave is too small, we do not have enough space in
//shared memory to assign a histogram to each wave, so instead,
//some operations are peformed serially.
inline uint SerialIterations(uint waveSize)
{
return (D_DIM / waveSize + 31) >> 5;
}
inline void ClearWaveHists(uint gtid, uint waveSize)
{
const uint histsEnd = waveSize >= 16 ?
WaveHistsSizeWGE16(waveSize) : WaveHistsSizeWLT16();
for (uint i = gtid; i < histsEnd; i += D_DIM)
g_d[i] = 0;
}
inline void LoadKey(inout uint key, uint index)
{
#if defined(KEY_UINT)
key = b_sort[index];
#elif defined(KEY_INT)
key = UintToInt(b_sort[index]);
#elif defined(KEY_FLOAT)
key = FloatToUint(b_sort[index]);
#endif
}
inline void LoadDummyKey(inout uint key)
{
key = 0xffffffff;
}
inline KeyStruct LoadKeysWGE16(uint gtid, uint waveSize, uint partIndex)
{
KeyStruct keys;
[unroll]
for (uint i = 0, t = DeviceOffsetWGE16(gtid, waveSize, partIndex);
i < KEYS_PER_THREAD;
++i, t += waveSize)
{
LoadKey(keys.k[i], t);
}
return keys;
}
inline KeyStruct LoadKeysWLT16(uint gtid, uint waveSize, uint partIndex, uint serialIterations)
{
KeyStruct keys;
[unroll]
for (uint i = 0, t = DeviceOffsetWLT16(gtid, waveSize, partIndex, serialIterations);
i < KEYS_PER_THREAD;
++i, t += waveSize * serialIterations)
{
LoadKey(keys.k[i], t);
}
return keys;
}
inline KeyStruct LoadKeysPartialWGE16(uint gtid, uint waveSize, uint partIndex)
{
KeyStruct keys;
[unroll]
for (uint i = 0, t = DeviceOffsetWGE16(gtid, waveSize, partIndex);
i < KEYS_PER_THREAD;
++i, t += waveSize)
{
if (t < e_numKeys)
LoadKey(keys.k[i], t);
else
LoadDummyKey(keys.k[i]);
}
return keys;
}
inline KeyStruct LoadKeysPartialWLT16(uint gtid, uint waveSize, uint partIndex, uint serialIterations)
{
KeyStruct keys;
[unroll]
for (uint i = 0, t = DeviceOffsetWLT16(gtid, waveSize, partIndex, serialIterations);
i < KEYS_PER_THREAD;
++i, t += waveSize * serialIterations)
{
if (t < e_numKeys)
LoadKey(keys.k[i], t);
else
LoadDummyKey(keys.k[i]);
}
return keys;
}
inline uint WaveFlagsWGE16(uint waveSize)
{
return (waveSize & 31) ? (1U << waveSize) - 1 : 0xffffffff;
}
inline uint WaveFlagsWLT16(uint waveSize)
{
return (1U << waveSize) - 1;;
}
inline void WarpLevelMultiSplitWGE16(uint key, inout uint4 waveFlags)
{
[unroll]
for (uint k = 0; k < RADIX_LOG; ++k)
{
const uint currentBit = 1 << k + e_radixShift;
const bool t = (key & currentBit) != 0;
GroupMemoryBarrierWithGroupSync(); //Play on the safe side, throw in a barrier for convergence
const uint4 ballot = WaveActiveBallot(t);
if(t)
waveFlags &= ballot;
else
waveFlags &= (~ballot);
}
}
inline uint2 CountBitsWGE16(uint waveSize, uint ltMask, uint4 waveFlags)
{
uint2 count = uint2(0, 0);
for(uint wavePart = 0; wavePart < waveSize; wavePart += 32)
{
uint t = countbits(waveFlags[wavePart >> 5]);
if (WaveGetLaneIndex() >= wavePart)
{
if (WaveGetLaneIndex() >= wavePart + 32)
count.x += t;
else
count.x += countbits(waveFlags[wavePart >> 5] & ltMask);
}
count.y += t;
}
return count;
}
inline void WarpLevelMultiSplitWLT16(uint key, inout uint waveFlags)
{
[unroll]
for (uint k = 0; k < RADIX_LOG; ++k)
{
const bool t = key >> (k + e_radixShift) & 1;
waveFlags &= (t ? 0 : 0xffffffff) ^ (uint) WaveActiveBallot(t);
}
}
inline OffsetStruct RankKeysWGE16(
uint waveSize,
uint waveOffset,
KeyStruct keys)
{
OffsetStruct offsets;
const uint initialFlags = WaveFlagsWGE16(waveSize);
const uint ltMask = (1U << (WaveGetLaneIndex() & 31)) - 1;
[unroll]
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
{
uint4 waveFlags = initialFlags;
WarpLevelMultiSplitWGE16(keys.k[i], waveFlags);
const uint index = ExtractDigit(keys.k[i]) + waveOffset;
const uint2 bitCount = CountBitsWGE16(waveSize, ltMask, waveFlags);
offsets.o[i] = g_d[index] + bitCount.x;
GroupMemoryBarrierWithGroupSync();
if (bitCount.x == 0)
g_d[index] += bitCount.y;
GroupMemoryBarrierWithGroupSync();
}
return offsets;
}
inline OffsetStruct RankKeysWLT16(uint waveSize, uint waveIndex, KeyStruct keys, uint serialIterations)
{
OffsetStruct offsets;
const uint ltMask = (1U << WaveGetLaneIndex()) - 1;
const uint initialFlags = WaveFlagsWLT16(waveSize);
[unroll]
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
{
uint waveFlags = initialFlags;
WarpLevelMultiSplitWLT16(keys.k[i], waveFlags);
const uint index = ExtractPackedIndex(keys.k[i]) +
(waveIndex / serialIterations * HALF_RADIX);
const uint peerBits = countbits(waveFlags & ltMask);
for (uint k = 0; k < serialIterations; ++k)
{
if (waveIndex % serialIterations == k)
offsets.o[i] = ExtractPackedValue(g_d[index], keys.k[i]) + peerBits;
GroupMemoryBarrierWithGroupSync();
if (waveIndex % serialIterations == k && peerBits == 0)
{
InterlockedAdd(g_d[index],
countbits(waveFlags) << ExtractPackedShift(keys.k[i]));
}
GroupMemoryBarrierWithGroupSync();
}
}
return offsets;
}
inline uint WaveHistInclusiveScanCircularShiftWGE16(uint gtid, uint waveSize)
{
uint histReduction = g_d[gtid];
for (uint i = gtid + RADIX; i < WaveHistsSizeWGE16(waveSize); i += RADIX)
{
histReduction += g_d[i];
g_d[i] = histReduction - g_d[i];
}
return histReduction;
}
inline uint WaveHistInclusiveScanCircularShiftWLT16(uint gtid)
{
uint histReduction = g_d[gtid];
for (uint i = gtid + HALF_RADIX; i < WaveHistsSizeWLT16(); i += HALF_RADIX)
{
histReduction += g_d[i];
g_d[i] = histReduction - g_d[i];
}
return histReduction;
}
inline void WaveHistReductionExclusiveScanWGE16(uint gtid, uint waveSize, uint histReduction)
{
if (gtid < RADIX)
{
const uint laneMask = waveSize - 1;
g_d[((WaveGetLaneIndex() + 1) & laneMask) + (gtid & ~laneMask)] = histReduction;
}
GroupMemoryBarrierWithGroupSync();
if (gtid < RADIX / waveSize)
{
g_d[gtid * waveSize] =
WavePrefixSum(g_d[gtid * waveSize]);
}
GroupMemoryBarrierWithGroupSync();
uint t = WaveReadLaneAt(g_d[gtid], 0);
if (gtid < RADIX && WaveGetLaneIndex())
g_d[gtid] += t;
}
//inclusive/exclusive prefix sum up the histograms,
//use a blelloch scan for in place packed exclusive
inline void WaveHistReductionExclusiveScanWLT16(uint gtid)
{
uint shift = 1;
for (uint j = RADIX >> 2; j > 0; j >>= 1)
{
GroupMemoryBarrierWithGroupSync();
if (gtid < j)
{
g_d[((((gtid << 1) + 2) << shift) - 1) >> 1] +=
g_d[((((gtid << 1) + 1) << shift) - 1) >> 1] & 0xffff0000;
}
shift++;
}
GroupMemoryBarrierWithGroupSync();
if (gtid == 0)
g_d[HALF_RADIX - 1] &= 0xffff;
for (uint j = 1; j < RADIX >> 1; j <<= 1)
{
--shift;
GroupMemoryBarrierWithGroupSync();
if (gtid < j)
{
const uint t = ((((gtid << 1) + 1) << shift) - 1) >> 1;
const uint t2 = ((((gtid << 1) + 2) << shift) - 1) >> 1;
const uint t3 = g_d[t];
g_d[t] = (g_d[t] & 0xffff) | (g_d[t2] & 0xffff0000);
g_d[t2] += t3 & 0xffff0000;
}
}
GroupMemoryBarrierWithGroupSync();
if (gtid < HALF_RADIX)
{
const uint t = g_d[gtid];
g_d[gtid] = (t >> 16) + (t << 16) + (t & 0xffff0000);
}
}
inline void UpdateOffsetsWGE16(
uint gtid,
uint waveSize,
inout OffsetStruct offsets,
KeyStruct keys)
{
if (gtid >= waveSize)
{
const uint t = getWaveIndex(gtid, waveSize) * RADIX;
[unroll]
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
{
const uint t2 = ExtractDigit(keys.k[i]);
offsets.o[i] += g_d[t2 + t] + g_d[t2];
}
}
else
{
[unroll]
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
offsets.o[i] += g_d[ExtractDigit(keys.k[i])];
}
}
inline void UpdateOffsetsWLT16(
uint gtid,
uint waveSize,
uint serialIterations,
inout OffsetStruct offsets,
KeyStruct keys)
{
if (gtid >= waveSize * serialIterations)
{
const uint t = getWaveIndex(gtid, waveSize) / serialIterations * HALF_RADIX;
[unroll]
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
{
const uint t2 = ExtractPackedIndex(keys.k[i]);
offsets.o[i] += ExtractPackedValue(g_d[t2 + t] + g_d[t2], keys.k[i]);
}
}
else
{
[unroll]
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
offsets.o[i] += ExtractPackedValue(g_d[ExtractPackedIndex(keys.k[i])], keys.k[i]);
}
}
inline void ScatterKeysShared(OffsetStruct offsets, KeyStruct keys)
{
[unroll]
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
g_d[offsets.o[i]] = keys.k[i];
}
inline uint DescendingIndex(uint deviceIndex)
{
return e_numKeys - deviceIndex - 1;
}
inline void WriteKey(uint deviceIndex, uint groupSharedIndex)
{
#if defined(KEY_UINT)
b_alt[deviceIndex] = g_d[groupSharedIndex];
#elif defined(KEY_INT)
b_alt[deviceIndex] = UintToInt(g_d[groupSharedIndex]);
#elif defined(KEY_FLOAT)
b_alt[deviceIndex] = UintToFloat(g_d[groupSharedIndex]);
#endif
}
inline void LoadPayload(inout uint payload, uint deviceIndex)
{
#if defined(PAYLOAD_UINT)
payload = b_sortPayload[deviceIndex];
#elif defined(PAYLOAD_INT) || defined(PAYLOAD_FLOAT)
payload = asuint(b_sortPayload[deviceIndex]);
#endif
}
inline void ScatterPayloadsShared(OffsetStruct offsets, KeyStruct payloads)
{
ScatterKeysShared(offsets, payloads);
}
inline void WritePayload(uint deviceIndex, uint groupSharedIndex)
{
#if defined(PAYLOAD_UINT)
b_altPayload[deviceIndex] = g_d[groupSharedIndex];
#elif defined(PAYLOAD_INT)
b_altPayload[deviceIndex] = asint(g_d[groupSharedIndex]);
#elif defined(PAYLOAD_FLOAT)
b_altPayload[deviceIndex] = asfloat(g_d[groupSharedIndex]);
#endif
}
//*****************************************************************************
//SCATTERING: FULL PARTITIONS
//*****************************************************************************
//KEYS ONLY
inline void ScatterKeysOnlyDeviceAscending(uint gtid)
{
for (uint i = gtid; i < PART_SIZE; i += D_DIM)
WriteKey(g_d[ExtractDigit(g_d[i]) + PART_SIZE] + i, i);
}
inline void ScatterKeysOnlyDeviceDescending(uint gtid)
{
if (e_radixShift == 24)
{
for (uint i = gtid; i < PART_SIZE; i += D_DIM)
WriteKey(DescendingIndex(g_d[ExtractDigit(g_d[i]) + PART_SIZE] + i), i);
}
else
{
ScatterKeysOnlyDeviceAscending(gtid);
}
}
inline void ScatterKeysOnlyDevice(uint gtid)
{
#if defined(SHOULD_ASCEND)
ScatterKeysOnlyDeviceAscending(gtid);
#else
ScatterKeysOnlyDeviceDescending(gtid);
#endif
}
//KEY VALUE PAIRS
inline void ScatterPairsKeyPhaseAscending(
uint gtid,
inout DigitStruct digits)
{
[unroll]
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
{
digits.d[i] = ExtractDigit(g_d[t]);
WriteKey(g_d[digits.d[i] + PART_SIZE] + t, t);
}
}
inline void ScatterPairsKeyPhaseDescending(
uint gtid,
inout DigitStruct digits)
{
if (e_radixShift == 24)
{
[unroll]
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
{
digits.d[i] = ExtractDigit(g_d[t]);
WriteKey(DescendingIndex(g_d[digits.d[i] + PART_SIZE] + t), t);
}
}
else
{
ScatterPairsKeyPhaseAscending(gtid, digits);
}
}
inline void LoadPayloadsWGE16(
uint gtid,
uint waveSize,
uint partIndex,
inout KeyStruct payloads)
{
[unroll]
for (uint i = 0, t = DeviceOffsetWGE16(gtid, waveSize, partIndex);
i < KEYS_PER_THREAD;
++i, t += waveSize)
{
LoadPayload(payloads.k[i], t);
}
}
inline void LoadPayloadsWLT16(
uint gtid,
uint waveSize,
uint partIndex,
uint serialIterations,
inout KeyStruct payloads)
{
[unroll]
for (uint i = 0, t = DeviceOffsetWLT16(gtid, waveSize, partIndex, serialIterations);
i < KEYS_PER_THREAD;
++i, t += waveSize * serialIterations)
{
LoadPayload(payloads.k[i], t);
}
}
inline void ScatterPayloadsAscending(uint gtid, DigitStruct digits)
{
[unroll]
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
WritePayload(g_d[digits.d[i] + PART_SIZE] + t, t);
}
inline void ScatterPayloadsDescending(uint gtid, DigitStruct digits)
{
if (e_radixShift == 24)
{
[unroll]
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
WritePayload(DescendingIndex(g_d[digits.d[i] + PART_SIZE] + t), t);
}
else
{
ScatterPayloadsAscending(gtid, digits);
}
}
inline void ScatterPairsDevice(
uint gtid,
uint waveSize,
uint partIndex,
OffsetStruct offsets)
{
DigitStruct digits;
#if defined(SHOULD_ASCEND)
ScatterPairsKeyPhaseAscending(gtid, digits);
#else
ScatterPairsKeyPhaseDescending(gtid, digits);
#endif
GroupMemoryBarrierWithGroupSync();
KeyStruct payloads;
if (waveSize >= 16)
LoadPayloadsWGE16(gtid, waveSize, partIndex, payloads);
else
LoadPayloadsWLT16(gtid, waveSize, partIndex, SerialIterations(waveSize), payloads);
ScatterPayloadsShared(offsets, payloads);
GroupMemoryBarrierWithGroupSync();
#if defined(SHOULD_ASCEND)
ScatterPayloadsAscending(gtid, digits);
#else
ScatterPayloadsDescending(gtid, digits);
#endif
}
inline void ScatterDevice(
uint gtid,
uint waveSize,
uint partIndex,
OffsetStruct offsets)
{
#if defined(SORT_PAIRS)
ScatterPairsDevice(
gtid,
waveSize,
partIndex,
offsets);
#else
ScatterKeysOnlyDevice(gtid);
#endif
}
//*****************************************************************************
//SCATTERING: PARTIAL PARTITIONS
//*****************************************************************************
//KEYS ONLY
inline void ScatterKeysOnlyDevicePartialAscending(uint gtid, uint finalPartSize)
{
for (uint i = gtid; i < PART_SIZE; i += D_DIM)
{
if (i < finalPartSize)
WriteKey(g_d[ExtractDigit(g_d[i]) + PART_SIZE] + i, i);
}
}
inline void ScatterKeysOnlyDevicePartialDescending(uint gtid, uint finalPartSize)
{
if (e_radixShift == 24)
{
for (uint i = gtid; i < PART_SIZE; i += D_DIM)
{
if (i < finalPartSize)
WriteKey(DescendingIndex(g_d[ExtractDigit(g_d[i]) + PART_SIZE] + i), i);
}
}
else
{
ScatterKeysOnlyDevicePartialAscending(gtid, finalPartSize);
}
}
inline void ScatterKeysOnlyDevicePartial(uint gtid, uint partIndex)
{
const uint finalPartSize = e_numKeys - partIndex * PART_SIZE;
#if defined(SHOULD_ASCEND)
ScatterKeysOnlyDevicePartialAscending(gtid, finalPartSize);
#else
ScatterKeysOnlyDevicePartialDescending(gtid, finalPartSize);
#endif
}
//KEY VALUE PAIRS
inline void ScatterPairsKeyPhaseAscendingPartial(
uint gtid,
uint finalPartSize,
inout DigitStruct digits)
{
[unroll]
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
{
if (t < finalPartSize)
{
digits.d[i] = ExtractDigit(g_d[t]);
WriteKey(g_d[digits.d[i] + PART_SIZE] + t, t);
}
}
}
inline void ScatterPairsKeyPhaseDescendingPartial(
uint gtid,
uint finalPartSize,
inout DigitStruct digits)
{
if (e_radixShift == 24)
{
[unroll]
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
{
if (t < finalPartSize)
{
digits.d[i] = ExtractDigit(g_d[t]);
WriteKey(DescendingIndex(g_d[digits.d[i] + PART_SIZE] + t), t);
}
}
}
else
{
ScatterPairsKeyPhaseAscendingPartial(gtid, finalPartSize, digits);
}
}
inline void LoadPayloadsPartialWGE16(
uint gtid,
uint waveSize,
uint partIndex,
inout KeyStruct payloads)
{
[unroll]
for (uint i = 0, t = DeviceOffsetWGE16(gtid, waveSize, partIndex);
i < KEYS_PER_THREAD;
++i, t += waveSize)
{
if (t < e_numKeys)
LoadPayload(payloads.k[i], t);
}
}
inline void LoadPayloadsPartialWLT16(
uint gtid,
uint waveSize,
uint partIndex,
uint serialIterations,
inout KeyStruct payloads)
{
[unroll]
for (uint i = 0, t = DeviceOffsetWLT16(gtid, waveSize, partIndex, serialIterations);
i < KEYS_PER_THREAD;
++i, t += waveSize * serialIterations)
{
if (t < e_numKeys)
LoadPayload(payloads.k[i], t);
}
}
inline void ScatterPayloadsAscendingPartial(
uint gtid,
uint finalPartSize,
DigitStruct digits)
{
[unroll]
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
{
if (t < finalPartSize)
WritePayload(g_d[digits.d[i] + PART_SIZE] + t, t);
}
}
inline void ScatterPayloadsDescendingPartial(
uint gtid,
uint finalPartSize,
DigitStruct digits)
{
if (e_radixShift == 24)
{
[unroll]
for (uint i = 0, t = gtid; i < KEYS_PER_THREAD; ++i, t += D_DIM)
{
if (t < finalPartSize)
WritePayload(DescendingIndex(g_d[digits.d[i] + PART_SIZE] + t), t);
}
}
else
{
ScatterPayloadsAscendingPartial(gtid, finalPartSize, digits);
}
}
inline void ScatterPairsDevicePartial(
uint gtid,
uint waveSize,
uint partIndex,
OffsetStruct offsets)
{
DigitStruct digits;
const uint finalPartSize = e_numKeys - partIndex * PART_SIZE;
#if defined(SHOULD_ASCEND)
ScatterPairsKeyPhaseAscendingPartial(gtid, finalPartSize, digits);
#else
ScatterPairsKeyPhaseDescendingPartial(gtid, finalPartSize, digits);
#endif
GroupMemoryBarrierWithGroupSync();
KeyStruct payloads;
if (waveSize >= 16)
LoadPayloadsPartialWGE16(gtid, waveSize, partIndex, payloads);
else
LoadPayloadsPartialWLT16(gtid, waveSize, partIndex, SerialIterations(waveSize), payloads);
ScatterPayloadsShared(offsets, payloads);
GroupMemoryBarrierWithGroupSync();
#if defined(SHOULD_ASCEND)
ScatterPayloadsAscendingPartial(gtid, finalPartSize, digits);
#else
ScatterPayloadsDescendingPartial(gtid, finalPartSize, digits);
#endif
}
inline void ScatterDevicePartial(
uint gtid,
uint waveSize,
uint partIndex,
OffsetStruct offsets)
{
#if defined(SORT_PAIRS)
ScatterPairsDevicePartial(
gtid,
waveSize,
partIndex,
offsets);
#else
ScatterKeysOnlyDevicePartial(gtid, partIndex);
#endif
}

View File

@@ -0,0 +1,7 @@
fileFormatVersion: 2
guid: 268e5936ab6d79f4b8aeef8f5d14e7ee
ShaderIncludeImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,212 @@
// SPDX-License-Identifier: MIT
#ifndef SPHERICAL_HARMONICS_HLSL
#define SPHERICAL_HARMONICS_HLSL
// SH rotation based on https://github.com/andrewwillmott/sh-lib (Unlicense / public domain)
#define SH_MAX_ORDER 4
#define SH_MAX_COEFFS_COUNT (SH_MAX_ORDER*SH_MAX_ORDER)
float3 Dot3(int vidx, float3 v[SH_MAX_COEFFS_COUNT], float f[3])
{
return v[vidx+0] * f[0] + v[vidx+1] * f[1] + v[vidx+2] * f[2];
}
float3 Dot5(int vidx, float3 v[SH_MAX_COEFFS_COUNT], float f[5])
{
return v[vidx+0] * f[0] + v[vidx+1] * f[1] + v[vidx+2] * f[2] + v[vidx+3] * f[3] + v[vidx+4] * f[4];
}
float3 Dot7(int vidx, float3 v[SH_MAX_COEFFS_COUNT], float f[7])
{
return v[vidx+0] * f[0] + v[vidx+1] * f[1] + v[vidx+2] * f[2] + v[vidx+3] * f[3] + v[vidx+4] * f[4] + v[vidx+5] * f[5] + v[vidx+6] * f[6];
}
void RotateSH(float3x3 orient, int n, float3 coeffsIn[SH_MAX_COEFFS_COUNT], out float3 coeffs[SH_MAX_COEFFS_COUNT])
{
const float kSqrt03_02 = sqrt( 3.0 / 2.0);
const float kSqrt01_03 = sqrt( 1.0 / 3.0);
const float kSqrt02_03 = sqrt( 2.0 / 3.0);
const float kSqrt04_03 = sqrt( 4.0 / 3.0);
const float kSqrt01_04 = sqrt( 1.0 / 4.0);
const float kSqrt03_04 = sqrt( 3.0 / 4.0);
const float kSqrt01_05 = sqrt( 1.0 / 5.0);
const float kSqrt03_05 = sqrt( 3.0 / 5.0);
const float kSqrt06_05 = sqrt( 6.0 / 5.0);
const float kSqrt08_05 = sqrt( 8.0 / 5.0);
const float kSqrt09_05 = sqrt( 9.0 / 5.0);
const float kSqrt05_06 = sqrt( 5.0 / 6.0);
const float kSqrt01_06 = sqrt( 1.0 / 6.0);
const float kSqrt03_08 = sqrt( 3.0 / 8.0);
const float kSqrt05_08 = sqrt( 5.0 / 8.0);
const float kSqrt07_08 = sqrt( 7.0 / 8.0);
const float kSqrt09_08 = sqrt( 9.0 / 8.0);
const float kSqrt05_09 = sqrt( 5.0 / 9.0);
const float kSqrt08_09 = sqrt( 8.0 / 9.0);
const float kSqrt01_10 = sqrt( 1.0 / 10.0);
const float kSqrt03_10 = sqrt( 3.0 / 10.0);
const float kSqrt01_12 = sqrt( 1.0 / 12.0);
const float kSqrt04_15 = sqrt( 4.0 / 15.0);
const float kSqrt01_16 = sqrt( 1.0 / 16.0);
const float kSqrt07_16 = sqrt( 7.0 / 16.0);
const float kSqrt15_16 = sqrt(15.0 / 16.0);
const float kSqrt01_18 = sqrt( 1.0 / 18.0);
const float kSqrt03_25 = sqrt( 3.0 / 25.0);
const float kSqrt14_25 = sqrt(14.0 / 25.0);
const float kSqrt15_25 = sqrt(15.0 / 25.0);
const float kSqrt18_25 = sqrt(18.0 / 25.0);
const float kSqrt01_32 = sqrt( 1.0 / 32.0);
const float kSqrt03_32 = sqrt( 3.0 / 32.0);
const float kSqrt15_32 = sqrt(15.0 / 32.0);
const float kSqrt21_32 = sqrt(21.0 / 32.0);
const float kSqrt01_50 = sqrt( 1.0 / 50.0);
const float kSqrt03_50 = sqrt( 3.0 / 50.0);
const float kSqrt21_50 = sqrt(21.0 / 50.0);
int srcIdx = 0;
int dstIdx = 0;
// band 0
coeffs[dstIdx++] = coeffsIn[0];
if (n < 2)
return;
// band 1
srcIdx += 1;
float sh1[3][3] =
{
// NOTE: change from upstream code at https://github.com/andrewwillmott/sh-lib, some of the
// values need to have "-" in front of them.
orient._22, -orient._23, orient._21,
-orient._32, orient._33, -orient._31,
orient._12, -orient._13, orient._11
};
coeffs[dstIdx++] = Dot3(srcIdx, coeffsIn, sh1[0]);
coeffs[dstIdx++] = Dot3(srcIdx, coeffsIn, sh1[1]);
coeffs[dstIdx++] = Dot3(srcIdx, coeffsIn, sh1[2]);
if (n < 3)
return;
// band 2
srcIdx += 3;
float sh2[5][5];
sh2[0][0] = kSqrt01_04 * ((sh1[2][2] * sh1[0][0] + sh1[2][0] * sh1[0][2]) + (sh1[0][2] * sh1[2][0] + sh1[0][0] * sh1[2][2]));
sh2[0][1] = (sh1[2][1] * sh1[0][0] + sh1[0][1] * sh1[2][0]);
sh2[0][2] = kSqrt03_04 * (sh1[2][1] * sh1[0][1] + sh1[0][1] * sh1[2][1]);
sh2[0][3] = (sh1[2][1] * sh1[0][2] + sh1[0][1] * sh1[2][2]);
sh2[0][4] = kSqrt01_04 * ((sh1[2][2] * sh1[0][2] - sh1[2][0] * sh1[0][0]) + (sh1[0][2] * sh1[2][2] - sh1[0][0] * sh1[2][0]));
coeffs[dstIdx++] = Dot5(srcIdx, coeffsIn, sh2[0]);
sh2[1][0] = kSqrt01_04 * ((sh1[1][2] * sh1[0][0] + sh1[1][0] * sh1[0][2]) + (sh1[0][2] * sh1[1][0] + sh1[0][0] * sh1[1][2]));
sh2[1][1] = sh1[1][1] * sh1[0][0] + sh1[0][1] * sh1[1][0];
sh2[1][2] = kSqrt03_04 * (sh1[1][1] * sh1[0][1] + sh1[0][1] * sh1[1][1]);
sh2[1][3] = sh1[1][1] * sh1[0][2] + sh1[0][1] * sh1[1][2];
sh2[1][4] = kSqrt01_04 * ((sh1[1][2] * sh1[0][2] - sh1[1][0] * sh1[0][0]) + (sh1[0][2] * sh1[1][2] - sh1[0][0] * sh1[1][0]));
coeffs[dstIdx++] = Dot5(srcIdx, coeffsIn, sh2[1]);
sh2[2][0] = kSqrt01_03 * (sh1[1][2] * sh1[1][0] + sh1[1][0] * sh1[1][2]) + -kSqrt01_12 * ((sh1[2][2] * sh1[2][0] + sh1[2][0] * sh1[2][2]) + (sh1[0][2] * sh1[0][0] + sh1[0][0] * sh1[0][2]));
sh2[2][1] = kSqrt04_03 * sh1[1][1] * sh1[1][0] + -kSqrt01_03 * (sh1[2][1] * sh1[2][0] + sh1[0][1] * sh1[0][0]);
sh2[2][2] = sh1[1][1] * sh1[1][1] + -kSqrt01_04 * (sh1[2][1] * sh1[2][1] + sh1[0][1] * sh1[0][1]);
sh2[2][3] = kSqrt04_03 * sh1[1][1] * sh1[1][2] + -kSqrt01_03 * (sh1[2][1] * sh1[2][2] + sh1[0][1] * sh1[0][2]);
sh2[2][4] = kSqrt01_03 * (sh1[1][2] * sh1[1][2] - sh1[1][0] * sh1[1][0]) + -kSqrt01_12 * ((sh1[2][2] * sh1[2][2] - sh1[2][0] * sh1[2][0]) + (sh1[0][2] * sh1[0][2] - sh1[0][0] * sh1[0][0]));
coeffs[dstIdx++] = Dot5(srcIdx, coeffsIn, sh2[2]);
sh2[3][0] = kSqrt01_04 * ((sh1[1][2] * sh1[2][0] + sh1[1][0] * sh1[2][2]) + (sh1[2][2] * sh1[1][0] + sh1[2][0] * sh1[1][2]));
sh2[3][1] = sh1[1][1] * sh1[2][0] + sh1[2][1] * sh1[1][0];
sh2[3][2] = kSqrt03_04 * (sh1[1][1] * sh1[2][1] + sh1[2][1] * sh1[1][1]);
sh2[3][3] = sh1[1][1] * sh1[2][2] + sh1[2][1] * sh1[1][2];
sh2[3][4] = kSqrt01_04 * ((sh1[1][2] * sh1[2][2] - sh1[1][0] * sh1[2][0]) + (sh1[2][2] * sh1[1][2] - sh1[2][0] * sh1[1][0]));
coeffs[dstIdx++] = Dot5(srcIdx, coeffsIn, sh2[3]);
sh2[4][0] = kSqrt01_04 * ((sh1[2][2] * sh1[2][0] + sh1[2][0] * sh1[2][2]) - (sh1[0][2] * sh1[0][0] + sh1[0][0] * sh1[0][2]));
sh2[4][1] = (sh1[2][1] * sh1[2][0] - sh1[0][1] * sh1[0][0]);
sh2[4][2] = kSqrt03_04 * (sh1[2][1] * sh1[2][1] - sh1[0][1] * sh1[0][1]);
sh2[4][3] = (sh1[2][1] * sh1[2][2] - sh1[0][1] * sh1[0][2]);
sh2[4][4] = kSqrt01_04 * ((sh1[2][2] * sh1[2][2] - sh1[2][0] * sh1[2][0]) - (sh1[0][2] * sh1[0][2] - sh1[0][0] * sh1[0][0]));
coeffs[dstIdx++] = Dot5(srcIdx, coeffsIn, sh2[4]);
if (n < 4)
return;
// band 3
srcIdx += 5;
float sh3[7][7];
sh3[0][0] = kSqrt01_04 * ((sh1[2][2] * sh2[0][0] + sh1[2][0] * sh2[0][4]) + (sh1[0][2] * sh2[4][0] + sh1[0][0] * sh2[4][4]));
sh3[0][1] = kSqrt03_02 * (sh1[2][1] * sh2[0][0] + sh1[0][1] * sh2[4][0]);
sh3[0][2] = kSqrt15_16 * (sh1[2][1] * sh2[0][1] + sh1[0][1] * sh2[4][1]);
sh3[0][3] = kSqrt05_06 * (sh1[2][1] * sh2[0][2] + sh1[0][1] * sh2[4][2]);
sh3[0][4] = kSqrt15_16 * (sh1[2][1] * sh2[0][3] + sh1[0][1] * sh2[4][3]);
sh3[0][5] = kSqrt03_02 * (sh1[2][1] * sh2[0][4] + sh1[0][1] * sh2[4][4]);
sh3[0][6] = kSqrt01_04 * ((sh1[2][2] * sh2[0][4] - sh1[2][0] * sh2[0][0]) + (sh1[0][2] * sh2[4][4] - sh1[0][0] * sh2[4][0]));
coeffs[dstIdx++] = Dot7(srcIdx, coeffsIn, sh3[0]);
sh3[1][0] = kSqrt01_06 * (sh1[1][2] * sh2[0][0] + sh1[1][0] * sh2[0][4]) + kSqrt01_06 * ((sh1[2][2] * sh2[1][0] + sh1[2][0] * sh2[1][4]) + (sh1[0][2] * sh2[3][0] + sh1[0][0] * sh2[3][4]));
sh3[1][1] = sh1[1][1] * sh2[0][0] + (sh1[2][1] * sh2[1][0] + sh1[0][1] * sh2[3][0]);
sh3[1][2] = kSqrt05_08 * sh1[1][1] * sh2[0][1] + kSqrt05_08 * (sh1[2][1] * sh2[1][1] + sh1[0][1] * sh2[3][1]);
sh3[1][3] = kSqrt05_09 * sh1[1][1] * sh2[0][2] + kSqrt05_09 * (sh1[2][1] * sh2[1][2] + sh1[0][1] * sh2[3][2]);
sh3[1][4] = kSqrt05_08 * sh1[1][1] * sh2[0][3] + kSqrt05_08 * (sh1[2][1] * sh2[1][3] + sh1[0][1] * sh2[3][3]);
sh3[1][5] = sh1[1][1] * sh2[0][4] + (sh1[2][1] * sh2[1][4] + sh1[0][1] * sh2[3][4]);
sh3[1][6] = kSqrt01_06 * (sh1[1][2] * sh2[0][4] - sh1[1][0] * sh2[0][0]) + kSqrt01_06 * ((sh1[2][2] * sh2[1][4] - sh1[2][0] * sh2[1][0]) + (sh1[0][2] * sh2[3][4] - sh1[0][0] * sh2[3][0]));
coeffs[dstIdx++] = Dot7(srcIdx, coeffsIn, sh3[1]);
sh3[2][0] = kSqrt04_15 * (sh1[1][2] * sh2[1][0] + sh1[1][0] * sh2[1][4]) + kSqrt01_05 * (sh1[0][2] * sh2[2][0] + sh1[0][0] * sh2[2][4]) + -sqrt(1.0 / 60.0) * ((sh1[2][2] * sh2[0][0] + sh1[2][0] * sh2[0][4]) - (sh1[0][2] * sh2[4][0] + sh1[0][0] * sh2[4][4]));
sh3[2][1] = kSqrt08_05 * sh1[1][1] * sh2[1][0] + kSqrt06_05 * sh1[0][1] * sh2[2][0] + -kSqrt01_10 * (sh1[2][1] * sh2[0][0] - sh1[0][1] * sh2[4][0]);
sh3[2][2] = sh1[1][1] * sh2[1][1] + kSqrt03_04 * sh1[0][1] * sh2[2][1] + -kSqrt01_16 * (sh1[2][1] * sh2[0][1] - sh1[0][1] * sh2[4][1]);
sh3[2][3] = kSqrt08_09 * sh1[1][1] * sh2[1][2] + kSqrt02_03 * sh1[0][1] * sh2[2][2] + -kSqrt01_18 * (sh1[2][1] * sh2[0][2] - sh1[0][1] * sh2[4][2]);
sh3[2][4] = sh1[1][1] * sh2[1][3] + kSqrt03_04 * sh1[0][1] * sh2[2][3] + -kSqrt01_16 * (sh1[2][1] * sh2[0][3] - sh1[0][1] * sh2[4][3]);
sh3[2][5] = kSqrt08_05 * sh1[1][1] * sh2[1][4] + kSqrt06_05 * sh1[0][1] * sh2[2][4] + -kSqrt01_10 * (sh1[2][1] * sh2[0][4] - sh1[0][1] * sh2[4][4]);
sh3[2][6] = kSqrt04_15 * (sh1[1][2] * sh2[1][4] - sh1[1][0] * sh2[1][0]) + kSqrt01_05 * (sh1[0][2] * sh2[2][4] - sh1[0][0] * sh2[2][0]) + -sqrt(1.0 / 60.0) * ((sh1[2][2] * sh2[0][4] - sh1[2][0] * sh2[0][0]) - (sh1[0][2] * sh2[4][4] - sh1[0][0] * sh2[4][0]));
coeffs[dstIdx++] = Dot7(srcIdx, coeffsIn, sh3[2]);
sh3[3][0] = kSqrt03_10 * (sh1[1][2] * sh2[2][0] + sh1[1][0] * sh2[2][4]) + -kSqrt01_10 * ((sh1[2][2] * sh2[3][0] + sh1[2][0] * sh2[3][4]) + (sh1[0][2] * sh2[1][0] + sh1[0][0] * sh2[1][4]));
sh3[3][1] = kSqrt09_05 * sh1[1][1] * sh2[2][0] + -kSqrt03_05 * (sh1[2][1] * sh2[3][0] + sh1[0][1] * sh2[1][0]);
sh3[3][2] = kSqrt09_08 * sh1[1][1] * sh2[2][1] + -kSqrt03_08 * (sh1[2][1] * sh2[3][1] + sh1[0][1] * sh2[1][1]);
sh3[3][3] = sh1[1][1] * sh2[2][2] + -kSqrt01_03 * (sh1[2][1] * sh2[3][2] + sh1[0][1] * sh2[1][2]);
sh3[3][4] = kSqrt09_08 * sh1[1][1] * sh2[2][3] + -kSqrt03_08 * (sh1[2][1] * sh2[3][3] + sh1[0][1] * sh2[1][3]);
sh3[3][5] = kSqrt09_05 * sh1[1][1] * sh2[2][4] + -kSqrt03_05 * (sh1[2][1] * sh2[3][4] + sh1[0][1] * sh2[1][4]);
sh3[3][6] = kSqrt03_10 * (sh1[1][2] * sh2[2][4] - sh1[1][0] * sh2[2][0]) + -kSqrt01_10 * ((sh1[2][2] * sh2[3][4] - sh1[2][0] * sh2[3][0]) + (sh1[0][2] * sh2[1][4] - sh1[0][0] * sh2[1][0]));
coeffs[dstIdx++] = Dot7(srcIdx, coeffsIn, sh3[3]);
sh3[4][0] = kSqrt04_15 * (sh1[1][2] * sh2[3][0] + sh1[1][0] * sh2[3][4]) + kSqrt01_05 * (sh1[2][2] * sh2[2][0] + sh1[2][0] * sh2[2][4]) + -sqrt(1.0 / 60.0) * ((sh1[2][2] * sh2[4][0] + sh1[2][0] * sh2[4][4]) + (sh1[0][2] * sh2[0][0] + sh1[0][0] * sh2[0][4]));
sh3[4][1] = kSqrt08_05 * sh1[1][1] * sh2[3][0] + kSqrt06_05 * sh1[2][1] * sh2[2][0] + -kSqrt01_10 * (sh1[2][1] * sh2[4][0] + sh1[0][1] * sh2[0][0]);
sh3[4][2] = sh1[1][1] * sh2[3][1] + kSqrt03_04 * sh1[2][1] * sh2[2][1] + -kSqrt01_16 * (sh1[2][1] * sh2[4][1] + sh1[0][1] * sh2[0][1]);
sh3[4][3] = kSqrt08_09 * sh1[1][1] * sh2[3][2] + kSqrt02_03 * sh1[2][1] * sh2[2][2] + -kSqrt01_18 * (sh1[2][1] * sh2[4][2] + sh1[0][1] * sh2[0][2]);
sh3[4][4] = sh1[1][1] * sh2[3][3] + kSqrt03_04 * sh1[2][1] * sh2[2][3] + -kSqrt01_16 * (sh1[2][1] * sh2[4][3] + sh1[0][1] * sh2[0][3]);
sh3[4][5] = kSqrt08_05 * sh1[1][1] * sh2[3][4] + kSqrt06_05 * sh1[2][1] * sh2[2][4] + -kSqrt01_10 * (sh1[2][1] * sh2[4][4] + sh1[0][1] * sh2[0][4]);
sh3[4][6] = kSqrt04_15 * (sh1[1][2] * sh2[3][4] - sh1[1][0] * sh2[3][0]) + kSqrt01_05 * (sh1[2][2] * sh2[2][4] - sh1[2][0] * sh2[2][0]) + -sqrt(1.0 / 60.0) * ((sh1[2][2] * sh2[4][4] - sh1[2][0] * sh2[4][0]) + (sh1[0][2] * sh2[0][4] - sh1[0][0] * sh2[0][0]));
coeffs[dstIdx++] = Dot7(srcIdx, coeffsIn, sh3[4]);
sh3[5][0] = kSqrt01_06 * (sh1[1][2] * sh2[4][0] + sh1[1][0] * sh2[4][4]) + kSqrt01_06 * ((sh1[2][2] * sh2[3][0] + sh1[2][0] * sh2[3][4]) - (sh1[0][2] * sh2[1][0] + sh1[0][0] * sh2[1][4]));
sh3[5][1] = sh1[1][1] * sh2[4][0] + (sh1[2][1] * sh2[3][0] - sh1[0][1] * sh2[1][0]);
sh3[5][2] = kSqrt05_08 * sh1[1][1] * sh2[4][1] + kSqrt05_08 * (sh1[2][1] * sh2[3][1] - sh1[0][1] * sh2[1][1]);
sh3[5][3] = kSqrt05_09 * sh1[1][1] * sh2[4][2] + kSqrt05_09 * (sh1[2][1] * sh2[3][2] - sh1[0][1] * sh2[1][2]);
sh3[5][4] = kSqrt05_08 * sh1[1][1] * sh2[4][3] + kSqrt05_08 * (sh1[2][1] * sh2[3][3] - sh1[0][1] * sh2[1][3]);
sh3[5][5] = sh1[1][1] * sh2[4][4] + (sh1[2][1] * sh2[3][4] - sh1[0][1] * sh2[1][4]);
sh3[5][6] = kSqrt01_06 * (sh1[1][2] * sh2[4][4] - sh1[1][0] * sh2[4][0]) + kSqrt01_06 * ((sh1[2][2] * sh2[3][4] - sh1[2][0] * sh2[3][0]) - (sh1[0][2] * sh2[1][4] - sh1[0][0] * sh2[1][0]));
coeffs[dstIdx++] = Dot7(srcIdx, coeffsIn, sh3[5]);
sh3[6][0] = kSqrt01_04 * ((sh1[2][2] * sh2[4][0] + sh1[2][0] * sh2[4][4]) - (sh1[0][2] * sh2[0][0] + sh1[0][0] * sh2[0][4]));
sh3[6][1] = kSqrt03_02 * (sh1[2][1] * sh2[4][0] - sh1[0][1] * sh2[0][0]);
sh3[6][2] = kSqrt15_16 * (sh1[2][1] * sh2[4][1] - sh1[0][1] * sh2[0][1]);
sh3[6][3] = kSqrt05_06 * (sh1[2][1] * sh2[4][2] - sh1[0][1] * sh2[0][2]);
sh3[6][4] = kSqrt15_16 * (sh1[2][1] * sh2[4][3] - sh1[0][1] * sh2[0][3]);
sh3[6][5] = kSqrt03_02 * (sh1[2][1] * sh2[4][4] - sh1[0][1] * sh2[0][4]);
sh3[6][6] = kSqrt01_04 * ((sh1[2][2] * sh2[4][4] - sh1[2][0] * sh2[4][0]) - (sh1[0][2] * sh2[0][4] - sh1[0][0] * sh2[0][0]));
coeffs[dstIdx++] = Dot7(srcIdx, coeffsIn, sh3[6]);
}
#endif // SPHERICAL_HARMONICS_HLSL

View File

@@ -0,0 +1,7 @@
fileFormatVersion: 2
guid: 0e45617a7c5ba4b4eb55e897761dcb31
ShaderIncludeImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,757 @@
// SPDX-License-Identifier: MIT
#define GROUP_SIZE 1024
#pragma kernel CSSetIndices
#pragma kernel CSCalcDistances
#pragma kernel CSCalcViewData
#pragma kernel CSUpdateEditData
#pragma kernel CSInitEditData
#pragma kernel CSClearBuffer
#pragma kernel CSInvertSelection
#pragma kernel CSSelectAll
#pragma kernel CSOrBuffers
#pragma kernel CSSelectionUpdate
#pragma kernel CSTranslateSelection
#pragma kernel CSRotateSelection
#pragma kernel CSScaleSelection
#pragma kernel CSExportData
#pragma kernel CSCopySplats
// DeviceRadixSort
#pragma multi_compile __ KEY_UINT KEY_INT KEY_FLOAT
#pragma multi_compile __ PAYLOAD_UINT PAYLOAD_INT PAYLOAD_FLOAT
#pragma multi_compile __ SHOULD_ASCEND
#pragma multi_compile __ SORT_PAIRS
#pragma multi_compile __ VULKAN
#pragma kernel InitDeviceRadixSort
#pragma kernel Upsweep
#pragma kernel Scan
#pragma kernel Downsweep
// GPU sorting needs wave ops
#pragma require wavebasic
#pragma require waveballot
#pragma use_dxc
#include "DeviceRadixSort.hlsl"
#include "GaussianSplatting.hlsl"
#include "UnityCG.cginc"
float4x4 _MatrixObjectToWorld;
float4x4 _MatrixWorldToObject;
float4x4 _MatrixMV;
float4 _VecScreenParams;
float4 _VecWorldSpaceCameraPos;
int _SelectionMode;
RWStructuredBuffer<uint> _SplatSortDistances;
RWStructuredBuffer<uint> _SplatSortKeys;
uint _SplatCount;
// radix sort etc. friendly, see http://stereopsis.com/radix.html
uint FloatToSortableUint(float f)
{
uint fu = asuint(f);
uint mask = -((int)(fu >> 31)) | 0x80000000;
return fu ^ mask;
}
[numthreads(GROUP_SIZE,1,1)]
void CSSetIndices (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _SplatCount)
return;
_SplatSortKeys[idx] = idx;
}
[numthreads(GROUP_SIZE,1,1)]
void CSCalcDistances (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _SplatCount)
return;
uint origIdx = _SplatSortKeys[idx];
float3 pos = LoadSplatPos(origIdx);
pos = mul(_MatrixMV, float4(pos.xyz, 1)).xyz;
_SplatSortDistances[idx] = FloatToSortableUint(pos.z);
}
RWStructuredBuffer<SplatViewData> _SplatViewData;
float _SplatScale;
float _SplatOpacityScale;
uint _SHOrder;
uint _SHOnly;
uint _SplatCutoutsCount;
#define SPLAT_CUTOUT_TYPE_ELLIPSOID 0
#define SPLAT_CUTOUT_TYPE_BOX 1
struct GaussianCutoutShaderData // match GaussianCutout.ShaderData in C#
{
float4x4 mat;
uint typeAndFlags;
};
StructuredBuffer<GaussianCutoutShaderData> _SplatCutouts;
RWByteAddressBuffer _SplatSelectedBits;
ByteAddressBuffer _SplatDeletedBits;
uint _SplatBitsValid;
void DecomposeCovariance(float3 cov2d, out float2 v1, out float2 v2)
{
#if 0 // does not quite give the correct results?
// https://jsfiddle.net/mattrossman/ehxmtgw6/
// References:
// - https://www.youtube.com/watch?v=e50Bj7jn9IQ
// - https://en.wikipedia.org/wiki/Eigenvalue_algorithm#2%C3%972_matrices
// - https://people.math.harvard.edu/~knill/teaching/math21b2004/exhibits/2dmatrices/index.html
float a = cov2d.x;
float b = cov2d.y;
float d = cov2d.z;
float det = a * d - b * b; // matrix is symmetric, so "c" is same as "b"
float trace = a + d;
float mean = 0.5 * trace;
float dist = sqrt(mean * mean - det);
float lambda1 = mean + dist; // 1st eigenvalue
float lambda2 = mean - dist; // 2nd eigenvalue
if (b == 0) {
// https://twitter.com/the_ross_man/status/1706342719776551360
if (a > d) v1 = float2(1, 0);
else v1 = float2(0, 1);
} else
v1 = normalize(float2(b, d - lambda2));
v1.y = -v1.y;
// The 2nd eigenvector is just a 90 degree rotation of the first since Gaussian axes are orthogonal
v2 = float2(v1.y, -v1.x);
// scaling components
v1 *= sqrt(lambda1);
v2 *= sqrt(lambda2);
float radius = 1.5;
v1 *= radius;
v2 *= radius;
#else
// same as in antimatter15/splat
float diag1 = cov2d.x, diag2 = cov2d.z, offDiag = cov2d.y;
float mid = 0.5f * (diag1 + diag2);
float radius = length(float2((diag1 - diag2) / 2.0, offDiag));
float lambda1 = mid + radius;
float lambda2 = max(mid - radius, 0.1);
float2 diagVec = normalize(float2(offDiag, lambda1 - diag1));
diagVec.y = -diagVec.y;
float maxSize = 4096.0;
v1 = min(sqrt(2.0 * lambda1), maxSize) * diagVec;
v2 = min(sqrt(2.0 * lambda2), maxSize) * float2(diagVec.y, -diagVec.x);
#endif
}
bool IsSplatCut(float3 pos)
{
bool finalCut = false;
for (uint i = 0; i < _SplatCutoutsCount; ++i)
{
GaussianCutoutShaderData cutData = _SplatCutouts[i];
uint type = cutData.typeAndFlags & 0xFF;
if (type == 0xFF) // invalid/null cutout, ignore
continue;
bool invert = (cutData.typeAndFlags & 0xFF00) != 0;
float3 cutoutPos = mul(cutData.mat, float4(pos, 1)).xyz;
if (type == SPLAT_CUTOUT_TYPE_ELLIPSOID)
{
if (dot(cutoutPos, cutoutPos) <= 1) return invert;
}
if (type == SPLAT_CUTOUT_TYPE_BOX)
{
if (all(abs(cutoutPos) <= 1)) return invert;
}
finalCut |= !invert;
}
return finalCut;
}
[numthreads(GROUP_SIZE,1,1)]
void CSCalcViewData (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _SplatCount)
return;
SplatData splat = LoadSplatData(idx);
SplatViewData view = (SplatViewData)0;
float3 centerWorldPos = mul(_MatrixObjectToWorld, float4(splat.pos,1)).xyz;
float4 centerClipPos = mul(UNITY_MATRIX_VP, float4(centerWorldPos, 1));
half opacityScale = _SplatOpacityScale;
float splatScale = _SplatScale;
// deleted?
if (_SplatBitsValid)
{
uint wordIdx = idx / 32;
uint bitIdx = idx & 31;
uint wordVal = _SplatDeletedBits.Load(wordIdx * 4);
if (wordVal & (1 << bitIdx))
{
centerClipPos.w = 0;
}
}
// cutouts
if (IsSplatCut(splat.pos))
{
centerClipPos.w = 0;
}
view.pos = centerClipPos;
bool behindCam = centerClipPos.w <= 0;
if (!behindCam)
{
float4 boxRot = splat.rot;
float3 boxSize = splat.scale;
float3x3 splatRotScaleMat = CalcMatrixFromRotationScale(boxRot, boxSize);
float3 cov3d0, cov3d1;
CalcCovariance3D(splatRotScaleMat, cov3d0, cov3d1);
float splatScale2 = splatScale * splatScale;
cov3d0 *= splatScale2;
cov3d1 *= splatScale2;
float3 cov2d = CalcCovariance2D(splat.pos, cov3d0, cov3d1, _MatrixMV, UNITY_MATRIX_P, _VecScreenParams);
DecomposeCovariance(cov2d, view.axis1, view.axis2);
float3 worldViewDir = _VecWorldSpaceCameraPos.xyz - centerWorldPos;
float3 objViewDir = mul((float3x3)_MatrixWorldToObject, worldViewDir);
objViewDir = normalize(objViewDir);
half4 col;
col.rgb = ShadeSH(splat.sh, objViewDir, _SHOrder, _SHOnly != 0);
col.a = min(splat.opacity * opacityScale, 65000);
view.color.x = (f32tof16(col.r) << 16) | f32tof16(col.g);
view.color.y = (f32tof16(col.b) << 16) | f32tof16(col.a);
}
_SplatViewData[idx] = view;
}
RWByteAddressBuffer _DstBuffer;
ByteAddressBuffer _SrcBuffer;
uint _BufferSize;
uint2 GetSplatIndicesFromWord(uint idx)
{
uint idxStart = idx * 32;
uint idxEnd = min(idxStart + 32, _SplatCount);
return uint2(idxStart, idxEnd);
}
[numthreads(GROUP_SIZE,1,1)]
void CSUpdateEditData (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _BufferSize)
return;
uint valSel = _SplatSelectedBits.Load(idx * 4);
uint valDel = _SplatDeletedBits.Load(idx * 4);
valSel &= ~valDel; // don't count deleted splats as selected
uint2 splatIndices = GetSplatIndicesFromWord(idx);
// update selection bounds
float3 bmin = 1.0e38;
float3 bmax = -1.0e38;
uint mask = 1;
uint valCut = 0;
for (uint sidx = splatIndices.x; sidx < splatIndices.y; ++sidx, mask <<= 1)
{
float3 spos = LoadSplatPos(sidx);
// don't count cut splats as selected
if (IsSplatCut(spos))
{
valSel &= ~mask;
valCut |= mask;
}
if (valSel & mask)
{
bmin = min(bmin, spos);
bmax = max(bmax, spos);
}
}
valCut &= ~valDel; // don't count deleted splats as cut
if (valSel != 0)
{
_DstBuffer.InterlockedMin(12, FloatToSortableUint(bmin.x));
_DstBuffer.InterlockedMin(16, FloatToSortableUint(bmin.y));
_DstBuffer.InterlockedMin(20, FloatToSortableUint(bmin.z));
_DstBuffer.InterlockedMax(24, FloatToSortableUint(bmax.x));
_DstBuffer.InterlockedMax(28, FloatToSortableUint(bmax.y));
_DstBuffer.InterlockedMax(32, FloatToSortableUint(bmax.z));
}
uint sumSel = countbits(valSel);
uint sumDel = countbits(valDel);
uint sumCut = countbits(valCut);
_DstBuffer.InterlockedAdd(0, sumSel);
_DstBuffer.InterlockedAdd(4, sumDel);
_DstBuffer.InterlockedAdd(8, sumCut);
}
[numthreads(1,1,1)]
void CSInitEditData (uint3 id : SV_DispatchThreadID)
{
_DstBuffer.Store3(0, uint3(0,0,0)); // selected, deleted, cut counts
uint initMin = FloatToSortableUint(1.0e38);
uint initMax = FloatToSortableUint(-1.0e38);
_DstBuffer.Store3(12, uint3(initMin, initMin, initMin));
_DstBuffer.Store3(24, uint3(initMax, initMax, initMax));
}
[numthreads(GROUP_SIZE,1,1)]
void CSClearBuffer (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _BufferSize)
return;
_DstBuffer.Store(idx * 4, 0);
}
[numthreads(GROUP_SIZE,1,1)]
void CSInvertSelection (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _BufferSize)
return;
uint v = _DstBuffer.Load(idx * 4);
v = ~v;
// do not select splats that are cut
uint2 splatIndices = GetSplatIndicesFromWord(idx);
uint mask = 1;
for (uint sidx = splatIndices.x; sidx < splatIndices.y; ++sidx, mask <<= 1)
{
float3 spos = LoadSplatPos(sidx);
if (IsSplatCut(spos))
v &= ~mask;
}
_DstBuffer.Store(idx * 4, v);
}
[numthreads(GROUP_SIZE,1,1)]
void CSSelectAll (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _BufferSize)
return;
uint v = ~0;
// do not select splats that are cut
uint2 splatIndices = GetSplatIndicesFromWord(idx);
uint mask = 1;
for (uint sidx = splatIndices.x; sidx < splatIndices.y; ++sidx, mask <<= 1)
{
float3 spos = LoadSplatPos(sidx);
if (IsSplatCut(spos))
v &= ~mask;
}
_DstBuffer.Store(idx * 4, v);
}
[numthreads(GROUP_SIZE,1,1)]
void CSOrBuffers (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _BufferSize)
return;
uint a = _SrcBuffer.Load(idx * 4);
uint b = _DstBuffer.Load(idx * 4);
_DstBuffer.Store(idx * 4, a | b);
}
float4 _SelectionRect;
[numthreads(GROUP_SIZE,1,1)]
void CSSelectionUpdate (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _SplatCount)
return;
float3 pos = LoadSplatPos(idx);
if (IsSplatCut(pos))
return;
float3 centerWorldPos = mul(_MatrixObjectToWorld, float4(pos,1)).xyz;
float4 centerClipPos = mul(UNITY_MATRIX_VP, float4(centerWorldPos, 1));
bool behindCam = centerClipPos.w <= 0;
if (behindCam)
return;
float2 pixelPos = (centerClipPos.xy / centerClipPos.w * float2(0.5, -0.5) + 0.5) * _VecScreenParams.xy;
if (pixelPos.x < _SelectionRect.x || pixelPos.x > _SelectionRect.z ||
pixelPos.y < _SelectionRect.y || pixelPos.y > _SelectionRect.w)
{
return;
}
uint wordIdx = idx / 32;
uint bitIdx = idx & 31;
if (_SelectionMode)
_SplatSelectedBits.InterlockedOr(wordIdx * 4, 1u << bitIdx); // +
else
_SplatSelectedBits.InterlockedAnd(wordIdx * 4, ~(1u << bitIdx)); // -
}
float3 _SelectionDelta;
bool IsSplatSelected(uint idx)
{
uint wordIdx = idx / 32;
uint bitIdx = idx & 31;
uint selVal = _SplatSelectedBits.Load(wordIdx * 4);
return (selVal & (1 << bitIdx)) != 0;
}
[numthreads(GROUP_SIZE,1,1)]
void CSTranslateSelection (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _SplatCount)
return;
if (!IsSplatSelected(idx))
return;
uint fmt = _SplatFormat & 0xFF;
if (_SplatChunkCount == 0 && fmt == VECTOR_FMT_32F)
{
uint stride = 12;
float3 pos = asfloat(_SplatPos.Load3(idx * stride));
pos += _SelectionDelta;
_SplatPos.Store3(idx * stride, asuint(pos));
}
}
float3 _SelectionCenter;
float4 _SelectionDeltaRot;
ByteAddressBuffer _SplatPosMouseDown;
ByteAddressBuffer _SplatOtherMouseDown;
[numthreads(GROUP_SIZE,1,1)]
void CSRotateSelection (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _SplatCount)
return;
if (!IsSplatSelected(idx))
return;
uint posFmt = _SplatFormat & 0xFF;
if (_SplatChunkCount == 0 && posFmt == VECTOR_FMT_32F)
{
uint posStride = 12;
float3 pos = asfloat(_SplatPosMouseDown.Load3(idx * posStride));
pos -= _SelectionCenter;
pos = mul(_MatrixObjectToWorld, float4(pos,1)).xyz;
pos = QuatRotateVector(pos, _SelectionDeltaRot);
pos = mul(_MatrixWorldToObject, float4(pos,1)).xyz;
pos += _SelectionCenter;
_SplatPos.Store3(idx * posStride, asuint(pos));
}
uint scaleFmt = (_SplatFormat >> 8) & 0xFF;
uint shFormat = (_SplatFormat >> 16) & 0xFF;
if (_SplatChunkCount == 0 && scaleFmt == VECTOR_FMT_32F && shFormat == VECTOR_FMT_32F)
{
uint otherStride = 4 + 12;
uint rotVal = _SplatOtherMouseDown.Load(idx * otherStride);
float4 rot = DecodeRotation(DecodePacked_10_10_10_2(rotVal));
//@TODO: correct rotation
rot = QuatMul(rot, _SelectionDeltaRot);
rotVal = EncodeQuatToNorm10(PackSmallest3Rotation(rot));
_SplatOther.Store(idx * otherStride, rotVal);
}
//@TODO: rotate SHs
}
//@TODO: maybe scale the splat scale itself too?
[numthreads(GROUP_SIZE,1,1)]
void CSScaleSelection (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _SplatCount)
return;
if (!IsSplatSelected(idx))
return;
uint fmt = _SplatFormat & 0xFF;
if (_SplatChunkCount == 0 && fmt == VECTOR_FMT_32F)
{
uint stride = 12;
float3 pos = asfloat(_SplatPosMouseDown.Load3(idx * stride));
pos -= _SelectionCenter;
pos = mul(_MatrixObjectToWorld, float4(pos,1)).xyz;
pos *= _SelectionDelta;
pos = mul(_MatrixWorldToObject, float4(pos,1)).xyz;
pos += _SelectionCenter;
_SplatPos.Store3(idx * stride, asuint(pos));
}
}
struct ExportSplatData
{
float3 pos;
float3 nor;
float3 dc0;
float4 shR14; float4 shR58; float4 shR9C; float3 shRDF;
float4 shG14; float4 shG58; float4 shG9C; float3 shGDF;
float4 shB14; float4 shB58; float4 shB9C; float3 shBDF;
float opacity;
float3 scale;
float4 rot;
};
RWStructuredBuffer<ExportSplatData> _ExportBuffer;
float3 ColorToSH0(float3 col)
{
return (col - 0.5) / 0.2820948;
}
float InvSigmoid(float v)
{
return log(v / max(1 - v, 1.0e-6));
}
// SH rotation
#include "SphericalHarmonics.hlsl"
void RotateSH(inout SplatSHData sh, float3x3 rot)
{
float3 shin[16];
float3 shout[16];
shin[0] = sh.col;
shin[1] = sh.sh1;
shin[2] = sh.sh2;
shin[3] = sh.sh3;
shin[4] = sh.sh4;
shin[5] = sh.sh5;
shin[6] = sh.sh6;
shin[7] = sh.sh7;
shin[8] = sh.sh8;
shin[9] = sh.sh9;
shin[10] = sh.sh10;
shin[11] = sh.sh11;
shin[12] = sh.sh12;
shin[13] = sh.sh13;
shin[14] = sh.sh14;
shin[15] = sh.sh15;
RotateSH(rot, 4, shin, shout);
sh.col = shout[0];
sh.sh1 = shout[1];
sh.sh2 = shout[2];
sh.sh3 = shout[3];
sh.sh4 = shout[4];
sh.sh5 = shout[5];
sh.sh6 = shout[6];
sh.sh7 = shout[7];
sh.sh8 = shout[8];
sh.sh9 = shout[9];
sh.sh10 = shout[10];
sh.sh11 = shout[11];
sh.sh12 = shout[12];
sh.sh13 = shout[13];
sh.sh14 = shout[14];
sh.sh15 = shout[15];
}
float3x3 CalcSHRotMatrix(float4x4 objToWorld)
{
float3x3 m = (float3x3)objToWorld;
float sx = length(float3(m[0][0], m[0][1], m[0][2]));
float sy = length(float3(m[1][0], m[1][1], m[1][2]));
float sz = length(float3(m[2][0], m[2][1], m[2][2]));
float invSX = 1.0 / sx;
float invSY = 1.0 / sy;
float invSZ = 1.0 / sz;
m[0][0] *= invSX;
m[0][1] *= invSX;
m[0][2] *= invSX;
m[1][0] *= invSY;
m[1][1] *= invSY;
m[1][2] *= invSY;
m[2][0] *= invSZ;
m[2][1] *= invSZ;
m[2][2] *= invSZ;
return m;
}
float4 _ExportTransformRotation;
float3 _ExportTransformScale;
uint _ExportTransformFlags;
[numthreads(GROUP_SIZE,1,1)]
void CSExportData (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _SplatCount)
return;
SplatData src = LoadSplatData(idx);
bool isCut = IsSplatCut(src.pos);
// transform splat by matrix, if needed
if (_ExportTransformFlags != 0)
{
src.pos = mul(_MatrixObjectToWorld, float4(src.pos,1)).xyz;
// note: this only handles axis flips from scale, not any arbitrary scaling
if (_ExportTransformScale.x < 0)
src.rot.yz = -src.rot.yz;
if (_ExportTransformScale.y < 0)
src.rot.xz = -src.rot.xz;
if (_ExportTransformScale.z < 0)
src.rot.xy = -src.rot.xy;
src.rot = QuatMul(_ExportTransformRotation, src.rot);
src.scale *= abs(_ExportTransformScale);
float3x3 shRot = CalcSHRotMatrix(_MatrixObjectToWorld);
RotateSH(src.sh, shRot);
}
ExportSplatData dst;
dst.pos = src.pos;
dst.nor = 0;
dst.dc0 = ColorToSH0(src.sh.col);
dst.shR14 = float4(src.sh.sh1.r, src.sh.sh2.r, src.sh.sh3.r, src.sh.sh4.r);
dst.shR58 = float4(src.sh.sh5.r, src.sh.sh6.r, src.sh.sh7.r, src.sh.sh8.r);
dst.shR9C = float4(src.sh.sh9.r, src.sh.sh10.r, src.sh.sh11.r, src.sh.sh12.r);
dst.shRDF = float3(src.sh.sh13.r, src.sh.sh14.r, src.sh.sh15.r);
dst.shG14 = float4(src.sh.sh1.g, src.sh.sh2.g, src.sh.sh3.g, src.sh.sh4.g);
dst.shG58 = float4(src.sh.sh5.g, src.sh.sh6.g, src.sh.sh7.g, src.sh.sh8.g);
dst.shG9C = float4(src.sh.sh9.g, src.sh.sh10.g, src.sh.sh11.g, src.sh.sh12.g);
dst.shGDF = float3(src.sh.sh13.g, src.sh.sh14.g, src.sh.sh15.g);
dst.shB14 = float4(src.sh.sh1.b, src.sh.sh2.b, src.sh.sh3.b, src.sh.sh4.b);
dst.shB58 = float4(src.sh.sh5.b, src.sh.sh6.b, src.sh.sh7.b, src.sh.sh8.b);
dst.shB9C = float4(src.sh.sh9.b, src.sh.sh10.b, src.sh.sh11.b, src.sh.sh12.b);
dst.shBDF = float3(src.sh.sh13.b, src.sh.sh14.b, src.sh.sh15.b);
dst.opacity = InvSigmoid(src.opacity);
dst.scale = log(src.scale);
dst.rot = src.rot.wxyz;
if (isCut)
dst.nor = 1; // mark as skipped for export
_ExportBuffer[idx] = dst;
}
RWByteAddressBuffer _CopyDstPos;
RWByteAddressBuffer _CopyDstOther;
RWByteAddressBuffer _CopyDstSH;
RWByteAddressBuffer _CopyDstEditDeleted;
RWTexture2D<float4> _CopyDstColor;
uint _CopyDstSize, _CopySrcStartIndex, _CopyDstStartIndex, _CopyCount;
float4x4 _CopyTransformMatrix;
float4 _CopyTransformRotation;
float3 _CopyTransformScale;
[numthreads(GROUP_SIZE,1,1)]
void CSCopySplats (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _CopyCount)
return;
uint srcIdx = _CopySrcStartIndex + idx;
uint dstIdx = _CopyDstStartIndex + idx;
if (srcIdx >= _SplatCount || dstIdx >= _CopyDstSize)
return;
SplatData src = LoadSplatData(idx);
// transform the splat
src.pos = mul(_CopyTransformMatrix, float4(src.pos,1)).xyz;
// note: this only handles axis flips from scale, not any arbitrary scaling
if (_CopyTransformScale.x < 0)
src.rot.yz = -src.rot.yz;
if (_CopyTransformScale.y < 0)
src.rot.xz = -src.rot.xz;
if (_CopyTransformScale.z < 0)
src.rot.xy = -src.rot.xy;
src.rot = QuatMul(_CopyTransformRotation, src.rot);
src.scale *= abs(_CopyTransformScale);
float3x3 shRot = CalcSHRotMatrix(_CopyTransformMatrix);
RotateSH(src.sh, shRot);
// output data into destination:
// pos
uint posStride = 12;
_CopyDstPos.Store3(dstIdx * posStride, asuint(src.pos));
// rot + scale
uint otherStride = 4 + 12;
uint rotVal = EncodeQuatToNorm10(PackSmallest3Rotation(src.rot));
_CopyDstOther.Store4(dstIdx * otherStride, uint4(
rotVal,
asuint(src.scale.x),
asuint(src.scale.y),
asuint(src.scale.z)));
// color
uint3 pixelIndex = SplatIndexToPixelIndex(dstIdx);
_CopyDstColor[pixelIndex.xy] = float4(src.sh.col, src.opacity);
// SH
uint shStride = 192; // 15*3 fp32, rounded up to multiple of 16
uint shOffset = dstIdx * shStride;
_CopyDstSH.Store3(shOffset + 12 * 0, asuint(src.sh.sh1));
_CopyDstSH.Store3(shOffset + 12 * 1, asuint(src.sh.sh2));
_CopyDstSH.Store3(shOffset + 12 * 2, asuint(src.sh.sh3));
_CopyDstSH.Store3(shOffset + 12 * 3, asuint(src.sh.sh4));
_CopyDstSH.Store3(shOffset + 12 * 4, asuint(src.sh.sh5));
_CopyDstSH.Store3(shOffset + 12 * 5, asuint(src.sh.sh6));
_CopyDstSH.Store3(shOffset + 12 * 6, asuint(src.sh.sh7));
_CopyDstSH.Store3(shOffset + 12 * 7, asuint(src.sh.sh8));
_CopyDstSH.Store3(shOffset + 12 * 8, asuint(src.sh.sh9));
_CopyDstSH.Store3(shOffset + 12 * 9, asuint(src.sh.sh10));
_CopyDstSH.Store3(shOffset + 12 * 10, asuint(src.sh.sh11));
_CopyDstSH.Store3(shOffset + 12 * 11, asuint(src.sh.sh12));
_CopyDstSH.Store3(shOffset + 12 * 12, asuint(src.sh.sh13));
_CopyDstSH.Store3(shOffset + 12 * 13, asuint(src.sh.sh14));
_CopyDstSH.Store3(shOffset + 12 * 14, asuint(src.sh.sh15));
// deleted bits
uint srcWordIdx = srcIdx / 32;
uint srcBitIdx = srcIdx & 31;
if (_SplatDeletedBits.Load(srcWordIdx * 4) & (1u << srcBitIdx))
{
uint dstWordIdx = dstIdx / 32;
uint dstBitIdx = dstIdx & 31;
_CopyDstEditDeleted.InterlockedOr(dstWordIdx * 4, 1u << dstBitIdx);
}
}

View File

@@ -0,0 +1,7 @@
fileFormatVersion: 2
guid: ec84f78b836bd4f96a105d6b804f08bd
ComputeShaderImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant: