chore: sync workspace state
This commit is contained in:
757
MVS/3DGS-Unity/Shaders/SplatUtilities.compute
Normal file
757
MVS/3DGS-Unity/Shaders/SplatUtilities.compute
Normal file
@@ -0,0 +1,757 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
#define GROUP_SIZE 1024
|
||||
|
||||
#pragma kernel CSSetIndices
|
||||
#pragma kernel CSCalcDistances
|
||||
#pragma kernel CSCalcViewData
|
||||
#pragma kernel CSUpdateEditData
|
||||
#pragma kernel CSInitEditData
|
||||
#pragma kernel CSClearBuffer
|
||||
#pragma kernel CSInvertSelection
|
||||
#pragma kernel CSSelectAll
|
||||
#pragma kernel CSOrBuffers
|
||||
#pragma kernel CSSelectionUpdate
|
||||
#pragma kernel CSTranslateSelection
|
||||
#pragma kernel CSRotateSelection
|
||||
#pragma kernel CSScaleSelection
|
||||
#pragma kernel CSExportData
|
||||
#pragma kernel CSCopySplats
|
||||
|
||||
// DeviceRadixSort
|
||||
#pragma multi_compile __ KEY_UINT KEY_INT KEY_FLOAT
|
||||
#pragma multi_compile __ PAYLOAD_UINT PAYLOAD_INT PAYLOAD_FLOAT
|
||||
#pragma multi_compile __ SHOULD_ASCEND
|
||||
#pragma multi_compile __ SORT_PAIRS
|
||||
#pragma multi_compile __ VULKAN
|
||||
#pragma kernel InitDeviceRadixSort
|
||||
#pragma kernel Upsweep
|
||||
#pragma kernel Scan
|
||||
#pragma kernel Downsweep
|
||||
|
||||
// GPU sorting needs wave ops
|
||||
#pragma require wavebasic
|
||||
#pragma require waveballot
|
||||
#pragma use_dxc
|
||||
|
||||
#include "DeviceRadixSort.hlsl"
|
||||
#include "GaussianSplatting.hlsl"
|
||||
#include "UnityCG.cginc"
|
||||
|
||||
float4x4 _MatrixObjectToWorld;
|
||||
float4x4 _MatrixWorldToObject;
|
||||
float4x4 _MatrixMV;
|
||||
float4 _VecScreenParams;
|
||||
float4 _VecWorldSpaceCameraPos;
|
||||
int _SelectionMode;
|
||||
|
||||
RWStructuredBuffer<uint> _SplatSortDistances;
|
||||
RWStructuredBuffer<uint> _SplatSortKeys;
|
||||
uint _SplatCount;
|
||||
|
||||
// radix sort etc. friendly, see http://stereopsis.com/radix.html
|
||||
uint FloatToSortableUint(float f)
|
||||
{
|
||||
uint fu = asuint(f);
|
||||
uint mask = -((int)(fu >> 31)) | 0x80000000;
|
||||
return fu ^ mask;
|
||||
}
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSSetIndices (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _SplatCount)
|
||||
return;
|
||||
|
||||
_SplatSortKeys[idx] = idx;
|
||||
}
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSCalcDistances (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _SplatCount)
|
||||
return;
|
||||
|
||||
uint origIdx = _SplatSortKeys[idx];
|
||||
|
||||
float3 pos = LoadSplatPos(origIdx);
|
||||
pos = mul(_MatrixMV, float4(pos.xyz, 1)).xyz;
|
||||
|
||||
_SplatSortDistances[idx] = FloatToSortableUint(pos.z);
|
||||
}
|
||||
|
||||
RWStructuredBuffer<SplatViewData> _SplatViewData;
|
||||
|
||||
float _SplatScale;
|
||||
float _SplatOpacityScale;
|
||||
uint _SHOrder;
|
||||
uint _SHOnly;
|
||||
|
||||
uint _SplatCutoutsCount;
|
||||
|
||||
#define SPLAT_CUTOUT_TYPE_ELLIPSOID 0
|
||||
#define SPLAT_CUTOUT_TYPE_BOX 1
|
||||
|
||||
struct GaussianCutoutShaderData // match GaussianCutout.ShaderData in C#
|
||||
{
|
||||
float4x4 mat;
|
||||
uint typeAndFlags;
|
||||
};
|
||||
StructuredBuffer<GaussianCutoutShaderData> _SplatCutouts;
|
||||
|
||||
RWByteAddressBuffer _SplatSelectedBits;
|
||||
ByteAddressBuffer _SplatDeletedBits;
|
||||
uint _SplatBitsValid;
|
||||
|
||||
void DecomposeCovariance(float3 cov2d, out float2 v1, out float2 v2)
|
||||
{
|
||||
#if 0 // does not quite give the correct results?
|
||||
|
||||
// https://jsfiddle.net/mattrossman/ehxmtgw6/
|
||||
// References:
|
||||
// - https://www.youtube.com/watch?v=e50Bj7jn9IQ
|
||||
// - https://en.wikipedia.org/wiki/Eigenvalue_algorithm#2%C3%972_matrices
|
||||
// - https://people.math.harvard.edu/~knill/teaching/math21b2004/exhibits/2dmatrices/index.html
|
||||
float a = cov2d.x;
|
||||
float b = cov2d.y;
|
||||
float d = cov2d.z;
|
||||
float det = a * d - b * b; // matrix is symmetric, so "c" is same as "b"
|
||||
float trace = a + d;
|
||||
|
||||
float mean = 0.5 * trace;
|
||||
float dist = sqrt(mean * mean - det);
|
||||
|
||||
float lambda1 = mean + dist; // 1st eigenvalue
|
||||
float lambda2 = mean - dist; // 2nd eigenvalue
|
||||
|
||||
if (b == 0) {
|
||||
// https://twitter.com/the_ross_man/status/1706342719776551360
|
||||
if (a > d) v1 = float2(1, 0);
|
||||
else v1 = float2(0, 1);
|
||||
} else
|
||||
v1 = normalize(float2(b, d - lambda2));
|
||||
|
||||
v1.y = -v1.y;
|
||||
// The 2nd eigenvector is just a 90 degree rotation of the first since Gaussian axes are orthogonal
|
||||
v2 = float2(v1.y, -v1.x);
|
||||
|
||||
// scaling components
|
||||
v1 *= sqrt(lambda1);
|
||||
v2 *= sqrt(lambda2);
|
||||
|
||||
float radius = 1.5;
|
||||
v1 *= radius;
|
||||
v2 *= radius;
|
||||
|
||||
#else
|
||||
|
||||
// same as in antimatter15/splat
|
||||
float diag1 = cov2d.x, diag2 = cov2d.z, offDiag = cov2d.y;
|
||||
float mid = 0.5f * (diag1 + diag2);
|
||||
float radius = length(float2((diag1 - diag2) / 2.0, offDiag));
|
||||
float lambda1 = mid + radius;
|
||||
float lambda2 = max(mid - radius, 0.1);
|
||||
float2 diagVec = normalize(float2(offDiag, lambda1 - diag1));
|
||||
diagVec.y = -diagVec.y;
|
||||
float maxSize = 4096.0;
|
||||
v1 = min(sqrt(2.0 * lambda1), maxSize) * diagVec;
|
||||
v2 = min(sqrt(2.0 * lambda2), maxSize) * float2(diagVec.y, -diagVec.x);
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
bool IsSplatCut(float3 pos)
|
||||
{
|
||||
bool finalCut = false;
|
||||
for (uint i = 0; i < _SplatCutoutsCount; ++i)
|
||||
{
|
||||
GaussianCutoutShaderData cutData = _SplatCutouts[i];
|
||||
uint type = cutData.typeAndFlags & 0xFF;
|
||||
if (type == 0xFF) // invalid/null cutout, ignore
|
||||
continue;
|
||||
bool invert = (cutData.typeAndFlags & 0xFF00) != 0;
|
||||
|
||||
float3 cutoutPos = mul(cutData.mat, float4(pos, 1)).xyz;
|
||||
if (type == SPLAT_CUTOUT_TYPE_ELLIPSOID)
|
||||
{
|
||||
if (dot(cutoutPos, cutoutPos) <= 1) return invert;
|
||||
}
|
||||
if (type == SPLAT_CUTOUT_TYPE_BOX)
|
||||
{
|
||||
if (all(abs(cutoutPos) <= 1)) return invert;
|
||||
}
|
||||
finalCut |= !invert;
|
||||
}
|
||||
return finalCut;
|
||||
}
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSCalcViewData (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _SplatCount)
|
||||
return;
|
||||
|
||||
SplatData splat = LoadSplatData(idx);
|
||||
SplatViewData view = (SplatViewData)0;
|
||||
|
||||
float3 centerWorldPos = mul(_MatrixObjectToWorld, float4(splat.pos,1)).xyz;
|
||||
float4 centerClipPos = mul(UNITY_MATRIX_VP, float4(centerWorldPos, 1));
|
||||
half opacityScale = _SplatOpacityScale;
|
||||
float splatScale = _SplatScale;
|
||||
|
||||
// deleted?
|
||||
if (_SplatBitsValid)
|
||||
{
|
||||
uint wordIdx = idx / 32;
|
||||
uint bitIdx = idx & 31;
|
||||
uint wordVal = _SplatDeletedBits.Load(wordIdx * 4);
|
||||
if (wordVal & (1 << bitIdx))
|
||||
{
|
||||
centerClipPos.w = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// cutouts
|
||||
if (IsSplatCut(splat.pos))
|
||||
{
|
||||
centerClipPos.w = 0;
|
||||
}
|
||||
|
||||
view.pos = centerClipPos;
|
||||
bool behindCam = centerClipPos.w <= 0;
|
||||
if (!behindCam)
|
||||
{
|
||||
float4 boxRot = splat.rot;
|
||||
float3 boxSize = splat.scale;
|
||||
|
||||
float3x3 splatRotScaleMat = CalcMatrixFromRotationScale(boxRot, boxSize);
|
||||
|
||||
float3 cov3d0, cov3d1;
|
||||
CalcCovariance3D(splatRotScaleMat, cov3d0, cov3d1);
|
||||
float splatScale2 = splatScale * splatScale;
|
||||
cov3d0 *= splatScale2;
|
||||
cov3d1 *= splatScale2;
|
||||
float3 cov2d = CalcCovariance2D(splat.pos, cov3d0, cov3d1, _MatrixMV, UNITY_MATRIX_P, _VecScreenParams);
|
||||
|
||||
DecomposeCovariance(cov2d, view.axis1, view.axis2);
|
||||
|
||||
float3 worldViewDir = _VecWorldSpaceCameraPos.xyz - centerWorldPos;
|
||||
float3 objViewDir = mul((float3x3)_MatrixWorldToObject, worldViewDir);
|
||||
objViewDir = normalize(objViewDir);
|
||||
|
||||
half4 col;
|
||||
col.rgb = ShadeSH(splat.sh, objViewDir, _SHOrder, _SHOnly != 0);
|
||||
col.a = min(splat.opacity * opacityScale, 65000);
|
||||
view.color.x = (f32tof16(col.r) << 16) | f32tof16(col.g);
|
||||
view.color.y = (f32tof16(col.b) << 16) | f32tof16(col.a);
|
||||
}
|
||||
|
||||
_SplatViewData[idx] = view;
|
||||
}
|
||||
|
||||
|
||||
RWByteAddressBuffer _DstBuffer;
|
||||
ByteAddressBuffer _SrcBuffer;
|
||||
uint _BufferSize;
|
||||
|
||||
uint2 GetSplatIndicesFromWord(uint idx)
|
||||
{
|
||||
uint idxStart = idx * 32;
|
||||
uint idxEnd = min(idxStart + 32, _SplatCount);
|
||||
return uint2(idxStart, idxEnd);
|
||||
}
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSUpdateEditData (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _BufferSize)
|
||||
return;
|
||||
|
||||
uint valSel = _SplatSelectedBits.Load(idx * 4);
|
||||
uint valDel = _SplatDeletedBits.Load(idx * 4);
|
||||
valSel &= ~valDel; // don't count deleted splats as selected
|
||||
uint2 splatIndices = GetSplatIndicesFromWord(idx);
|
||||
|
||||
// update selection bounds
|
||||
float3 bmin = 1.0e38;
|
||||
float3 bmax = -1.0e38;
|
||||
uint mask = 1;
|
||||
uint valCut = 0;
|
||||
for (uint sidx = splatIndices.x; sidx < splatIndices.y; ++sidx, mask <<= 1)
|
||||
{
|
||||
float3 spos = LoadSplatPos(sidx);
|
||||
// don't count cut splats as selected
|
||||
if (IsSplatCut(spos))
|
||||
{
|
||||
valSel &= ~mask;
|
||||
valCut |= mask;
|
||||
}
|
||||
if (valSel & mask)
|
||||
{
|
||||
bmin = min(bmin, spos);
|
||||
bmax = max(bmax, spos);
|
||||
}
|
||||
}
|
||||
valCut &= ~valDel; // don't count deleted splats as cut
|
||||
|
||||
if (valSel != 0)
|
||||
{
|
||||
_DstBuffer.InterlockedMin(12, FloatToSortableUint(bmin.x));
|
||||
_DstBuffer.InterlockedMin(16, FloatToSortableUint(bmin.y));
|
||||
_DstBuffer.InterlockedMin(20, FloatToSortableUint(bmin.z));
|
||||
_DstBuffer.InterlockedMax(24, FloatToSortableUint(bmax.x));
|
||||
_DstBuffer.InterlockedMax(28, FloatToSortableUint(bmax.y));
|
||||
_DstBuffer.InterlockedMax(32, FloatToSortableUint(bmax.z));
|
||||
}
|
||||
uint sumSel = countbits(valSel);
|
||||
uint sumDel = countbits(valDel);
|
||||
uint sumCut = countbits(valCut);
|
||||
_DstBuffer.InterlockedAdd(0, sumSel);
|
||||
_DstBuffer.InterlockedAdd(4, sumDel);
|
||||
_DstBuffer.InterlockedAdd(8, sumCut);
|
||||
}
|
||||
|
||||
[numthreads(1,1,1)]
|
||||
void CSInitEditData (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
_DstBuffer.Store3(0, uint3(0,0,0)); // selected, deleted, cut counts
|
||||
uint initMin = FloatToSortableUint(1.0e38);
|
||||
uint initMax = FloatToSortableUint(-1.0e38);
|
||||
_DstBuffer.Store3(12, uint3(initMin, initMin, initMin));
|
||||
_DstBuffer.Store3(24, uint3(initMax, initMax, initMax));
|
||||
}
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSClearBuffer (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _BufferSize)
|
||||
return;
|
||||
_DstBuffer.Store(idx * 4, 0);
|
||||
}
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSInvertSelection (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _BufferSize)
|
||||
return;
|
||||
uint v = _DstBuffer.Load(idx * 4);
|
||||
v = ~v;
|
||||
|
||||
// do not select splats that are cut
|
||||
uint2 splatIndices = GetSplatIndicesFromWord(idx);
|
||||
uint mask = 1;
|
||||
for (uint sidx = splatIndices.x; sidx < splatIndices.y; ++sidx, mask <<= 1)
|
||||
{
|
||||
float3 spos = LoadSplatPos(sidx);
|
||||
if (IsSplatCut(spos))
|
||||
v &= ~mask;
|
||||
}
|
||||
|
||||
_DstBuffer.Store(idx * 4, v);
|
||||
}
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSSelectAll (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _BufferSize)
|
||||
return;
|
||||
uint v = ~0;
|
||||
|
||||
// do not select splats that are cut
|
||||
uint2 splatIndices = GetSplatIndicesFromWord(idx);
|
||||
uint mask = 1;
|
||||
for (uint sidx = splatIndices.x; sidx < splatIndices.y; ++sidx, mask <<= 1)
|
||||
{
|
||||
float3 spos = LoadSplatPos(sidx);
|
||||
if (IsSplatCut(spos))
|
||||
v &= ~mask;
|
||||
}
|
||||
|
||||
_DstBuffer.Store(idx * 4, v);
|
||||
}
|
||||
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSOrBuffers (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _BufferSize)
|
||||
return;
|
||||
uint a = _SrcBuffer.Load(idx * 4);
|
||||
uint b = _DstBuffer.Load(idx * 4);
|
||||
_DstBuffer.Store(idx * 4, a | b);
|
||||
}
|
||||
|
||||
float4 _SelectionRect;
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSSelectionUpdate (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _SplatCount)
|
||||
return;
|
||||
|
||||
float3 pos = LoadSplatPos(idx);
|
||||
if (IsSplatCut(pos))
|
||||
return;
|
||||
|
||||
float3 centerWorldPos = mul(_MatrixObjectToWorld, float4(pos,1)).xyz;
|
||||
float4 centerClipPos = mul(UNITY_MATRIX_VP, float4(centerWorldPos, 1));
|
||||
bool behindCam = centerClipPos.w <= 0;
|
||||
if (behindCam)
|
||||
return;
|
||||
|
||||
float2 pixelPos = (centerClipPos.xy / centerClipPos.w * float2(0.5, -0.5) + 0.5) * _VecScreenParams.xy;
|
||||
if (pixelPos.x < _SelectionRect.x || pixelPos.x > _SelectionRect.z ||
|
||||
pixelPos.y < _SelectionRect.y || pixelPos.y > _SelectionRect.w)
|
||||
{
|
||||
return;
|
||||
}
|
||||
uint wordIdx = idx / 32;
|
||||
uint bitIdx = idx & 31;
|
||||
if (_SelectionMode)
|
||||
_SplatSelectedBits.InterlockedOr(wordIdx * 4, 1u << bitIdx); // +
|
||||
else
|
||||
_SplatSelectedBits.InterlockedAnd(wordIdx * 4, ~(1u << bitIdx)); // -
|
||||
}
|
||||
|
||||
float3 _SelectionDelta;
|
||||
|
||||
bool IsSplatSelected(uint idx)
|
||||
{
|
||||
uint wordIdx = idx / 32;
|
||||
uint bitIdx = idx & 31;
|
||||
uint selVal = _SplatSelectedBits.Load(wordIdx * 4);
|
||||
return (selVal & (1 << bitIdx)) != 0;
|
||||
}
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSTranslateSelection (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _SplatCount)
|
||||
return;
|
||||
if (!IsSplatSelected(idx))
|
||||
return;
|
||||
|
||||
uint fmt = _SplatFormat & 0xFF;
|
||||
if (_SplatChunkCount == 0 && fmt == VECTOR_FMT_32F)
|
||||
{
|
||||
uint stride = 12;
|
||||
float3 pos = asfloat(_SplatPos.Load3(idx * stride));
|
||||
pos += _SelectionDelta;
|
||||
_SplatPos.Store3(idx * stride, asuint(pos));
|
||||
}
|
||||
}
|
||||
|
||||
float3 _SelectionCenter;
|
||||
float4 _SelectionDeltaRot;
|
||||
ByteAddressBuffer _SplatPosMouseDown;
|
||||
ByteAddressBuffer _SplatOtherMouseDown;
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSRotateSelection (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _SplatCount)
|
||||
return;
|
||||
if (!IsSplatSelected(idx))
|
||||
return;
|
||||
|
||||
uint posFmt = _SplatFormat & 0xFF;
|
||||
if (_SplatChunkCount == 0 && posFmt == VECTOR_FMT_32F)
|
||||
{
|
||||
uint posStride = 12;
|
||||
float3 pos = asfloat(_SplatPosMouseDown.Load3(idx * posStride));
|
||||
pos -= _SelectionCenter;
|
||||
pos = mul(_MatrixObjectToWorld, float4(pos,1)).xyz;
|
||||
pos = QuatRotateVector(pos, _SelectionDeltaRot);
|
||||
pos = mul(_MatrixWorldToObject, float4(pos,1)).xyz;
|
||||
pos += _SelectionCenter;
|
||||
_SplatPos.Store3(idx * posStride, asuint(pos));
|
||||
}
|
||||
|
||||
uint scaleFmt = (_SplatFormat >> 8) & 0xFF;
|
||||
uint shFormat = (_SplatFormat >> 16) & 0xFF;
|
||||
if (_SplatChunkCount == 0 && scaleFmt == VECTOR_FMT_32F && shFormat == VECTOR_FMT_32F)
|
||||
{
|
||||
uint otherStride = 4 + 12;
|
||||
uint rotVal = _SplatOtherMouseDown.Load(idx * otherStride);
|
||||
float4 rot = DecodeRotation(DecodePacked_10_10_10_2(rotVal));
|
||||
|
||||
//@TODO: correct rotation
|
||||
rot = QuatMul(rot, _SelectionDeltaRot);
|
||||
|
||||
rotVal = EncodeQuatToNorm10(PackSmallest3Rotation(rot));
|
||||
_SplatOther.Store(idx * otherStride, rotVal);
|
||||
}
|
||||
|
||||
//@TODO: rotate SHs
|
||||
}
|
||||
|
||||
//@TODO: maybe scale the splat scale itself too?
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSScaleSelection (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _SplatCount)
|
||||
return;
|
||||
if (!IsSplatSelected(idx))
|
||||
return;
|
||||
|
||||
uint fmt = _SplatFormat & 0xFF;
|
||||
if (_SplatChunkCount == 0 && fmt == VECTOR_FMT_32F)
|
||||
{
|
||||
uint stride = 12;
|
||||
float3 pos = asfloat(_SplatPosMouseDown.Load3(idx * stride));
|
||||
pos -= _SelectionCenter;
|
||||
pos = mul(_MatrixObjectToWorld, float4(pos,1)).xyz;
|
||||
pos *= _SelectionDelta;
|
||||
pos = mul(_MatrixWorldToObject, float4(pos,1)).xyz;
|
||||
pos += _SelectionCenter;
|
||||
_SplatPos.Store3(idx * stride, asuint(pos));
|
||||
}
|
||||
}
|
||||
|
||||
struct ExportSplatData
|
||||
{
|
||||
float3 pos;
|
||||
float3 nor;
|
||||
float3 dc0;
|
||||
float4 shR14; float4 shR58; float4 shR9C; float3 shRDF;
|
||||
float4 shG14; float4 shG58; float4 shG9C; float3 shGDF;
|
||||
float4 shB14; float4 shB58; float4 shB9C; float3 shBDF;
|
||||
float opacity;
|
||||
float3 scale;
|
||||
float4 rot;
|
||||
};
|
||||
RWStructuredBuffer<ExportSplatData> _ExportBuffer;
|
||||
|
||||
float3 ColorToSH0(float3 col)
|
||||
{
|
||||
return (col - 0.5) / 0.2820948;
|
||||
}
|
||||
float InvSigmoid(float v)
|
||||
{
|
||||
return log(v / max(1 - v, 1.0e-6));
|
||||
}
|
||||
|
||||
// SH rotation
|
||||
#include "SphericalHarmonics.hlsl"
|
||||
|
||||
void RotateSH(inout SplatSHData sh, float3x3 rot)
|
||||
{
|
||||
float3 shin[16];
|
||||
float3 shout[16];
|
||||
shin[0] = sh.col;
|
||||
shin[1] = sh.sh1;
|
||||
shin[2] = sh.sh2;
|
||||
shin[3] = sh.sh3;
|
||||
shin[4] = sh.sh4;
|
||||
shin[5] = sh.sh5;
|
||||
shin[6] = sh.sh6;
|
||||
shin[7] = sh.sh7;
|
||||
shin[8] = sh.sh8;
|
||||
shin[9] = sh.sh9;
|
||||
shin[10] = sh.sh10;
|
||||
shin[11] = sh.sh11;
|
||||
shin[12] = sh.sh12;
|
||||
shin[13] = sh.sh13;
|
||||
shin[14] = sh.sh14;
|
||||
shin[15] = sh.sh15;
|
||||
RotateSH(rot, 4, shin, shout);
|
||||
sh.col = shout[0];
|
||||
sh.sh1 = shout[1];
|
||||
sh.sh2 = shout[2];
|
||||
sh.sh3 = shout[3];
|
||||
sh.sh4 = shout[4];
|
||||
sh.sh5 = shout[5];
|
||||
sh.sh6 = shout[6];
|
||||
sh.sh7 = shout[7];
|
||||
sh.sh8 = shout[8];
|
||||
sh.sh9 = shout[9];
|
||||
sh.sh10 = shout[10];
|
||||
sh.sh11 = shout[11];
|
||||
sh.sh12 = shout[12];
|
||||
sh.sh13 = shout[13];
|
||||
sh.sh14 = shout[14];
|
||||
sh.sh15 = shout[15];
|
||||
}
|
||||
|
||||
float3x3 CalcSHRotMatrix(float4x4 objToWorld)
|
||||
{
|
||||
float3x3 m = (float3x3)objToWorld;
|
||||
float sx = length(float3(m[0][0], m[0][1], m[0][2]));
|
||||
float sy = length(float3(m[1][0], m[1][1], m[1][2]));
|
||||
float sz = length(float3(m[2][0], m[2][1], m[2][2]));
|
||||
|
||||
float invSX = 1.0 / sx;
|
||||
float invSY = 1.0 / sy;
|
||||
float invSZ = 1.0 / sz;
|
||||
|
||||
m[0][0] *= invSX;
|
||||
m[0][1] *= invSX;
|
||||
m[0][2] *= invSX;
|
||||
m[1][0] *= invSY;
|
||||
m[1][1] *= invSY;
|
||||
m[1][2] *= invSY;
|
||||
m[2][0] *= invSZ;
|
||||
m[2][1] *= invSZ;
|
||||
m[2][2] *= invSZ;
|
||||
return m;
|
||||
}
|
||||
|
||||
|
||||
float4 _ExportTransformRotation;
|
||||
float3 _ExportTransformScale;
|
||||
uint _ExportTransformFlags;
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSExportData (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _SplatCount)
|
||||
return;
|
||||
SplatData src = LoadSplatData(idx);
|
||||
|
||||
bool isCut = IsSplatCut(src.pos);
|
||||
|
||||
// transform splat by matrix, if needed
|
||||
if (_ExportTransformFlags != 0)
|
||||
{
|
||||
src.pos = mul(_MatrixObjectToWorld, float4(src.pos,1)).xyz;
|
||||
|
||||
// note: this only handles axis flips from scale, not any arbitrary scaling
|
||||
if (_ExportTransformScale.x < 0)
|
||||
src.rot.yz = -src.rot.yz;
|
||||
if (_ExportTransformScale.y < 0)
|
||||
src.rot.xz = -src.rot.xz;
|
||||
if (_ExportTransformScale.z < 0)
|
||||
src.rot.xy = -src.rot.xy;
|
||||
src.rot = QuatMul(_ExportTransformRotation, src.rot);
|
||||
src.scale *= abs(_ExportTransformScale);
|
||||
|
||||
float3x3 shRot = CalcSHRotMatrix(_MatrixObjectToWorld);
|
||||
RotateSH(src.sh, shRot);
|
||||
}
|
||||
|
||||
ExportSplatData dst;
|
||||
dst.pos = src.pos;
|
||||
dst.nor = 0;
|
||||
dst.dc0 = ColorToSH0(src.sh.col);
|
||||
|
||||
dst.shR14 = float4(src.sh.sh1.r, src.sh.sh2.r, src.sh.sh3.r, src.sh.sh4.r);
|
||||
dst.shR58 = float4(src.sh.sh5.r, src.sh.sh6.r, src.sh.sh7.r, src.sh.sh8.r);
|
||||
dst.shR9C = float4(src.sh.sh9.r, src.sh.sh10.r, src.sh.sh11.r, src.sh.sh12.r);
|
||||
dst.shRDF = float3(src.sh.sh13.r, src.sh.sh14.r, src.sh.sh15.r);
|
||||
|
||||
dst.shG14 = float4(src.sh.sh1.g, src.sh.sh2.g, src.sh.sh3.g, src.sh.sh4.g);
|
||||
dst.shG58 = float4(src.sh.sh5.g, src.sh.sh6.g, src.sh.sh7.g, src.sh.sh8.g);
|
||||
dst.shG9C = float4(src.sh.sh9.g, src.sh.sh10.g, src.sh.sh11.g, src.sh.sh12.g);
|
||||
dst.shGDF = float3(src.sh.sh13.g, src.sh.sh14.g, src.sh.sh15.g);
|
||||
|
||||
dst.shB14 = float4(src.sh.sh1.b, src.sh.sh2.b, src.sh.sh3.b, src.sh.sh4.b);
|
||||
dst.shB58 = float4(src.sh.sh5.b, src.sh.sh6.b, src.sh.sh7.b, src.sh.sh8.b);
|
||||
dst.shB9C = float4(src.sh.sh9.b, src.sh.sh10.b, src.sh.sh11.b, src.sh.sh12.b);
|
||||
dst.shBDF = float3(src.sh.sh13.b, src.sh.sh14.b, src.sh.sh15.b);
|
||||
|
||||
dst.opacity = InvSigmoid(src.opacity);
|
||||
dst.scale = log(src.scale);
|
||||
dst.rot = src.rot.wxyz;
|
||||
|
||||
if (isCut)
|
||||
dst.nor = 1; // mark as skipped for export
|
||||
|
||||
_ExportBuffer[idx] = dst;
|
||||
}
|
||||
|
||||
RWByteAddressBuffer _CopyDstPos;
|
||||
RWByteAddressBuffer _CopyDstOther;
|
||||
RWByteAddressBuffer _CopyDstSH;
|
||||
RWByteAddressBuffer _CopyDstEditDeleted;
|
||||
RWTexture2D<float4> _CopyDstColor;
|
||||
uint _CopyDstSize, _CopySrcStartIndex, _CopyDstStartIndex, _CopyCount;
|
||||
|
||||
float4x4 _CopyTransformMatrix;
|
||||
float4 _CopyTransformRotation;
|
||||
float3 _CopyTransformScale;
|
||||
|
||||
[numthreads(GROUP_SIZE,1,1)]
|
||||
void CSCopySplats (uint3 id : SV_DispatchThreadID)
|
||||
{
|
||||
uint idx = id.x;
|
||||
if (idx >= _CopyCount)
|
||||
return;
|
||||
uint srcIdx = _CopySrcStartIndex + idx;
|
||||
uint dstIdx = _CopyDstStartIndex + idx;
|
||||
if (srcIdx >= _SplatCount || dstIdx >= _CopyDstSize)
|
||||
return;
|
||||
|
||||
SplatData src = LoadSplatData(idx);
|
||||
|
||||
// transform the splat
|
||||
src.pos = mul(_CopyTransformMatrix, float4(src.pos,1)).xyz;
|
||||
// note: this only handles axis flips from scale, not any arbitrary scaling
|
||||
if (_CopyTransformScale.x < 0)
|
||||
src.rot.yz = -src.rot.yz;
|
||||
if (_CopyTransformScale.y < 0)
|
||||
src.rot.xz = -src.rot.xz;
|
||||
if (_CopyTransformScale.z < 0)
|
||||
src.rot.xy = -src.rot.xy;
|
||||
src.rot = QuatMul(_CopyTransformRotation, src.rot);
|
||||
src.scale *= abs(_CopyTransformScale);
|
||||
|
||||
float3x3 shRot = CalcSHRotMatrix(_CopyTransformMatrix);
|
||||
RotateSH(src.sh, shRot);
|
||||
|
||||
// output data into destination:
|
||||
// pos
|
||||
uint posStride = 12;
|
||||
_CopyDstPos.Store3(dstIdx * posStride, asuint(src.pos));
|
||||
// rot + scale
|
||||
uint otherStride = 4 + 12;
|
||||
uint rotVal = EncodeQuatToNorm10(PackSmallest3Rotation(src.rot));
|
||||
_CopyDstOther.Store4(dstIdx * otherStride, uint4(
|
||||
rotVal,
|
||||
asuint(src.scale.x),
|
||||
asuint(src.scale.y),
|
||||
asuint(src.scale.z)));
|
||||
// color
|
||||
uint3 pixelIndex = SplatIndexToPixelIndex(dstIdx);
|
||||
_CopyDstColor[pixelIndex.xy] = float4(src.sh.col, src.opacity);
|
||||
|
||||
// SH
|
||||
uint shStride = 192; // 15*3 fp32, rounded up to multiple of 16
|
||||
uint shOffset = dstIdx * shStride;
|
||||
_CopyDstSH.Store3(shOffset + 12 * 0, asuint(src.sh.sh1));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 1, asuint(src.sh.sh2));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 2, asuint(src.sh.sh3));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 3, asuint(src.sh.sh4));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 4, asuint(src.sh.sh5));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 5, asuint(src.sh.sh6));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 6, asuint(src.sh.sh7));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 7, asuint(src.sh.sh8));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 8, asuint(src.sh.sh9));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 9, asuint(src.sh.sh10));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 10, asuint(src.sh.sh11));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 11, asuint(src.sh.sh12));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 12, asuint(src.sh.sh13));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 13, asuint(src.sh.sh14));
|
||||
_CopyDstSH.Store3(shOffset + 12 * 14, asuint(src.sh.sh15));
|
||||
|
||||
// deleted bits
|
||||
uint srcWordIdx = srcIdx / 32;
|
||||
uint srcBitIdx = srcIdx & 31;
|
||||
if (_SplatDeletedBits.Load(srcWordIdx * 4) & (1u << srcBitIdx))
|
||||
{
|
||||
uint dstWordIdx = dstIdx / 32;
|
||||
uint dstBitIdx = dstIdx & 31;
|
||||
_CopyDstEditDeleted.InterlockedOr(dstWordIdx * 4, 1u << dstBitIdx);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user