Files
XCEngine/MVS/3DGS-Unity/Shaders/SplatUtilities.compute

758 lines
22 KiB
Plaintext
Raw Normal View History

2026-03-29 01:36:53 +08:00
// SPDX-License-Identifier: MIT
#define GROUP_SIZE 1024
#pragma kernel CSSetIndices
#pragma kernel CSCalcDistances
#pragma kernel CSCalcViewData
#pragma kernel CSUpdateEditData
#pragma kernel CSInitEditData
#pragma kernel CSClearBuffer
#pragma kernel CSInvertSelection
#pragma kernel CSSelectAll
#pragma kernel CSOrBuffers
#pragma kernel CSSelectionUpdate
#pragma kernel CSTranslateSelection
#pragma kernel CSRotateSelection
#pragma kernel CSScaleSelection
#pragma kernel CSExportData
#pragma kernel CSCopySplats
// DeviceRadixSort
#pragma multi_compile __ KEY_UINT KEY_INT KEY_FLOAT
#pragma multi_compile __ PAYLOAD_UINT PAYLOAD_INT PAYLOAD_FLOAT
#pragma multi_compile __ SHOULD_ASCEND
#pragma multi_compile __ SORT_PAIRS
#pragma multi_compile __ VULKAN
#pragma kernel InitDeviceRadixSort
#pragma kernel Upsweep
#pragma kernel Scan
#pragma kernel Downsweep
// GPU sorting needs wave ops
#pragma require wavebasic
#pragma require waveballot
#pragma use_dxc
#include "DeviceRadixSort.hlsl"
#include "GaussianSplatting.hlsl"
#include "UnityCG.cginc"
float4x4 _MatrixObjectToWorld;
float4x4 _MatrixWorldToObject;
float4x4 _MatrixMV;
float4 _VecScreenParams;
float4 _VecWorldSpaceCameraPos;
int _SelectionMode;
RWStructuredBuffer<uint> _SplatSortDistances;
RWStructuredBuffer<uint> _SplatSortKeys;
uint _SplatCount;
// radix sort etc. friendly, see http://stereopsis.com/radix.html
uint FloatToSortableUint(float f)
{
uint fu = asuint(f);
uint mask = -((int)(fu >> 31)) | 0x80000000;
return fu ^ mask;
}
[numthreads(GROUP_SIZE,1,1)]
void CSSetIndices (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _SplatCount)
return;
_SplatSortKeys[idx] = idx;
}
[numthreads(GROUP_SIZE,1,1)]
void CSCalcDistances (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _SplatCount)
return;
uint origIdx = _SplatSortKeys[idx];
float3 pos = LoadSplatPos(origIdx);
pos = mul(_MatrixMV, float4(pos.xyz, 1)).xyz;
_SplatSortDistances[idx] = FloatToSortableUint(pos.z);
}
RWStructuredBuffer<SplatViewData> _SplatViewData;
float _SplatScale;
float _SplatOpacityScale;
uint _SHOrder;
uint _SHOnly;
uint _SplatCutoutsCount;
#define SPLAT_CUTOUT_TYPE_ELLIPSOID 0
#define SPLAT_CUTOUT_TYPE_BOX 1
struct GaussianCutoutShaderData // match GaussianCutout.ShaderData in C#
{
float4x4 mat;
uint typeAndFlags;
};
StructuredBuffer<GaussianCutoutShaderData> _SplatCutouts;
RWByteAddressBuffer _SplatSelectedBits;
ByteAddressBuffer _SplatDeletedBits;
uint _SplatBitsValid;
void DecomposeCovariance(float3 cov2d, out float2 v1, out float2 v2)
{
#if 0 // does not quite give the correct results?
// https://jsfiddle.net/mattrossman/ehxmtgw6/
// References:
// - https://www.youtube.com/watch?v=e50Bj7jn9IQ
// - https://en.wikipedia.org/wiki/Eigenvalue_algorithm#2%C3%972_matrices
// - https://people.math.harvard.edu/~knill/teaching/math21b2004/exhibits/2dmatrices/index.html
float a = cov2d.x;
float b = cov2d.y;
float d = cov2d.z;
float det = a * d - b * b; // matrix is symmetric, so "c" is same as "b"
float trace = a + d;
float mean = 0.5 * trace;
float dist = sqrt(mean * mean - det);
float lambda1 = mean + dist; // 1st eigenvalue
float lambda2 = mean - dist; // 2nd eigenvalue
if (b == 0) {
// https://twitter.com/the_ross_man/status/1706342719776551360
if (a > d) v1 = float2(1, 0);
else v1 = float2(0, 1);
} else
v1 = normalize(float2(b, d - lambda2));
v1.y = -v1.y;
// The 2nd eigenvector is just a 90 degree rotation of the first since Gaussian axes are orthogonal
v2 = float2(v1.y, -v1.x);
// scaling components
v1 *= sqrt(lambda1);
v2 *= sqrt(lambda2);
float radius = 1.5;
v1 *= radius;
v2 *= radius;
#else
// same as in antimatter15/splat
float diag1 = cov2d.x, diag2 = cov2d.z, offDiag = cov2d.y;
float mid = 0.5f * (diag1 + diag2);
float radius = length(float2((diag1 - diag2) / 2.0, offDiag));
float lambda1 = mid + radius;
float lambda2 = max(mid - radius, 0.1);
float2 diagVec = normalize(float2(offDiag, lambda1 - diag1));
diagVec.y = -diagVec.y;
float maxSize = 4096.0;
v1 = min(sqrt(2.0 * lambda1), maxSize) * diagVec;
v2 = min(sqrt(2.0 * lambda2), maxSize) * float2(diagVec.y, -diagVec.x);
#endif
}
bool IsSplatCut(float3 pos)
{
bool finalCut = false;
for (uint i = 0; i < _SplatCutoutsCount; ++i)
{
GaussianCutoutShaderData cutData = _SplatCutouts[i];
uint type = cutData.typeAndFlags & 0xFF;
if (type == 0xFF) // invalid/null cutout, ignore
continue;
bool invert = (cutData.typeAndFlags & 0xFF00) != 0;
float3 cutoutPos = mul(cutData.mat, float4(pos, 1)).xyz;
if (type == SPLAT_CUTOUT_TYPE_ELLIPSOID)
{
if (dot(cutoutPos, cutoutPos) <= 1) return invert;
}
if (type == SPLAT_CUTOUT_TYPE_BOX)
{
if (all(abs(cutoutPos) <= 1)) return invert;
}
finalCut |= !invert;
}
return finalCut;
}
[numthreads(GROUP_SIZE,1,1)]
void CSCalcViewData (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _SplatCount)
return;
SplatData splat = LoadSplatData(idx);
SplatViewData view = (SplatViewData)0;
float3 centerWorldPos = mul(_MatrixObjectToWorld, float4(splat.pos,1)).xyz;
float4 centerClipPos = mul(UNITY_MATRIX_VP, float4(centerWorldPos, 1));
half opacityScale = _SplatOpacityScale;
float splatScale = _SplatScale;
// deleted?
if (_SplatBitsValid)
{
uint wordIdx = idx / 32;
uint bitIdx = idx & 31;
uint wordVal = _SplatDeletedBits.Load(wordIdx * 4);
if (wordVal & (1 << bitIdx))
{
centerClipPos.w = 0;
}
}
// cutouts
if (IsSplatCut(splat.pos))
{
centerClipPos.w = 0;
}
view.pos = centerClipPos;
bool behindCam = centerClipPos.w <= 0;
if (!behindCam)
{
float4 boxRot = splat.rot;
float3 boxSize = splat.scale;
float3x3 splatRotScaleMat = CalcMatrixFromRotationScale(boxRot, boxSize);
float3 cov3d0, cov3d1;
CalcCovariance3D(splatRotScaleMat, cov3d0, cov3d1);
float splatScale2 = splatScale * splatScale;
cov3d0 *= splatScale2;
cov3d1 *= splatScale2;
float3 cov2d = CalcCovariance2D(splat.pos, cov3d0, cov3d1, _MatrixMV, UNITY_MATRIX_P, _VecScreenParams);
DecomposeCovariance(cov2d, view.axis1, view.axis2);
float3 worldViewDir = _VecWorldSpaceCameraPos.xyz - centerWorldPos;
float3 objViewDir = mul((float3x3)_MatrixWorldToObject, worldViewDir);
objViewDir = normalize(objViewDir);
half4 col;
col.rgb = ShadeSH(splat.sh, objViewDir, _SHOrder, _SHOnly != 0);
col.a = min(splat.opacity * opacityScale, 65000);
view.color.x = (f32tof16(col.r) << 16) | f32tof16(col.g);
view.color.y = (f32tof16(col.b) << 16) | f32tof16(col.a);
}
_SplatViewData[idx] = view;
}
RWByteAddressBuffer _DstBuffer;
ByteAddressBuffer _SrcBuffer;
uint _BufferSize;
uint2 GetSplatIndicesFromWord(uint idx)
{
uint idxStart = idx * 32;
uint idxEnd = min(idxStart + 32, _SplatCount);
return uint2(idxStart, idxEnd);
}
[numthreads(GROUP_SIZE,1,1)]
void CSUpdateEditData (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _BufferSize)
return;
uint valSel = _SplatSelectedBits.Load(idx * 4);
uint valDel = _SplatDeletedBits.Load(idx * 4);
valSel &= ~valDel; // don't count deleted splats as selected
uint2 splatIndices = GetSplatIndicesFromWord(idx);
// update selection bounds
float3 bmin = 1.0e38;
float3 bmax = -1.0e38;
uint mask = 1;
uint valCut = 0;
for (uint sidx = splatIndices.x; sidx < splatIndices.y; ++sidx, mask <<= 1)
{
float3 spos = LoadSplatPos(sidx);
// don't count cut splats as selected
if (IsSplatCut(spos))
{
valSel &= ~mask;
valCut |= mask;
}
if (valSel & mask)
{
bmin = min(bmin, spos);
bmax = max(bmax, spos);
}
}
valCut &= ~valDel; // don't count deleted splats as cut
if (valSel != 0)
{
_DstBuffer.InterlockedMin(12, FloatToSortableUint(bmin.x));
_DstBuffer.InterlockedMin(16, FloatToSortableUint(bmin.y));
_DstBuffer.InterlockedMin(20, FloatToSortableUint(bmin.z));
_DstBuffer.InterlockedMax(24, FloatToSortableUint(bmax.x));
_DstBuffer.InterlockedMax(28, FloatToSortableUint(bmax.y));
_DstBuffer.InterlockedMax(32, FloatToSortableUint(bmax.z));
}
uint sumSel = countbits(valSel);
uint sumDel = countbits(valDel);
uint sumCut = countbits(valCut);
_DstBuffer.InterlockedAdd(0, sumSel);
_DstBuffer.InterlockedAdd(4, sumDel);
_DstBuffer.InterlockedAdd(8, sumCut);
}
[numthreads(1,1,1)]
void CSInitEditData (uint3 id : SV_DispatchThreadID)
{
_DstBuffer.Store3(0, uint3(0,0,0)); // selected, deleted, cut counts
uint initMin = FloatToSortableUint(1.0e38);
uint initMax = FloatToSortableUint(-1.0e38);
_DstBuffer.Store3(12, uint3(initMin, initMin, initMin));
_DstBuffer.Store3(24, uint3(initMax, initMax, initMax));
}
[numthreads(GROUP_SIZE,1,1)]
void CSClearBuffer (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _BufferSize)
return;
_DstBuffer.Store(idx * 4, 0);
}
[numthreads(GROUP_SIZE,1,1)]
void CSInvertSelection (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _BufferSize)
return;
uint v = _DstBuffer.Load(idx * 4);
v = ~v;
// do not select splats that are cut
uint2 splatIndices = GetSplatIndicesFromWord(idx);
uint mask = 1;
for (uint sidx = splatIndices.x; sidx < splatIndices.y; ++sidx, mask <<= 1)
{
float3 spos = LoadSplatPos(sidx);
if (IsSplatCut(spos))
v &= ~mask;
}
_DstBuffer.Store(idx * 4, v);
}
[numthreads(GROUP_SIZE,1,1)]
void CSSelectAll (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _BufferSize)
return;
uint v = ~0;
// do not select splats that are cut
uint2 splatIndices = GetSplatIndicesFromWord(idx);
uint mask = 1;
for (uint sidx = splatIndices.x; sidx < splatIndices.y; ++sidx, mask <<= 1)
{
float3 spos = LoadSplatPos(sidx);
if (IsSplatCut(spos))
v &= ~mask;
}
_DstBuffer.Store(idx * 4, v);
}
[numthreads(GROUP_SIZE,1,1)]
void CSOrBuffers (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _BufferSize)
return;
uint a = _SrcBuffer.Load(idx * 4);
uint b = _DstBuffer.Load(idx * 4);
_DstBuffer.Store(idx * 4, a | b);
}
float4 _SelectionRect;
[numthreads(GROUP_SIZE,1,1)]
void CSSelectionUpdate (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _SplatCount)
return;
float3 pos = LoadSplatPos(idx);
if (IsSplatCut(pos))
return;
float3 centerWorldPos = mul(_MatrixObjectToWorld, float4(pos,1)).xyz;
float4 centerClipPos = mul(UNITY_MATRIX_VP, float4(centerWorldPos, 1));
bool behindCam = centerClipPos.w <= 0;
if (behindCam)
return;
float2 pixelPos = (centerClipPos.xy / centerClipPos.w * float2(0.5, -0.5) + 0.5) * _VecScreenParams.xy;
if (pixelPos.x < _SelectionRect.x || pixelPos.x > _SelectionRect.z ||
pixelPos.y < _SelectionRect.y || pixelPos.y > _SelectionRect.w)
{
return;
}
uint wordIdx = idx / 32;
uint bitIdx = idx & 31;
if (_SelectionMode)
_SplatSelectedBits.InterlockedOr(wordIdx * 4, 1u << bitIdx); // +
else
_SplatSelectedBits.InterlockedAnd(wordIdx * 4, ~(1u << bitIdx)); // -
}
float3 _SelectionDelta;
bool IsSplatSelected(uint idx)
{
uint wordIdx = idx / 32;
uint bitIdx = idx & 31;
uint selVal = _SplatSelectedBits.Load(wordIdx * 4);
return (selVal & (1 << bitIdx)) != 0;
}
[numthreads(GROUP_SIZE,1,1)]
void CSTranslateSelection (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _SplatCount)
return;
if (!IsSplatSelected(idx))
return;
uint fmt = _SplatFormat & 0xFF;
if (_SplatChunkCount == 0 && fmt == VECTOR_FMT_32F)
{
uint stride = 12;
float3 pos = asfloat(_SplatPos.Load3(idx * stride));
pos += _SelectionDelta;
_SplatPos.Store3(idx * stride, asuint(pos));
}
}
float3 _SelectionCenter;
float4 _SelectionDeltaRot;
ByteAddressBuffer _SplatPosMouseDown;
ByteAddressBuffer _SplatOtherMouseDown;
[numthreads(GROUP_SIZE,1,1)]
void CSRotateSelection (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _SplatCount)
return;
if (!IsSplatSelected(idx))
return;
uint posFmt = _SplatFormat & 0xFF;
if (_SplatChunkCount == 0 && posFmt == VECTOR_FMT_32F)
{
uint posStride = 12;
float3 pos = asfloat(_SplatPosMouseDown.Load3(idx * posStride));
pos -= _SelectionCenter;
pos = mul(_MatrixObjectToWorld, float4(pos,1)).xyz;
pos = QuatRotateVector(pos, _SelectionDeltaRot);
pos = mul(_MatrixWorldToObject, float4(pos,1)).xyz;
pos += _SelectionCenter;
_SplatPos.Store3(idx * posStride, asuint(pos));
}
uint scaleFmt = (_SplatFormat >> 8) & 0xFF;
uint shFormat = (_SplatFormat >> 16) & 0xFF;
if (_SplatChunkCount == 0 && scaleFmt == VECTOR_FMT_32F && shFormat == VECTOR_FMT_32F)
{
uint otherStride = 4 + 12;
uint rotVal = _SplatOtherMouseDown.Load(idx * otherStride);
float4 rot = DecodeRotation(DecodePacked_10_10_10_2(rotVal));
//@TODO: correct rotation
rot = QuatMul(rot, _SelectionDeltaRot);
rotVal = EncodeQuatToNorm10(PackSmallest3Rotation(rot));
_SplatOther.Store(idx * otherStride, rotVal);
}
//@TODO: rotate SHs
}
//@TODO: maybe scale the splat scale itself too?
[numthreads(GROUP_SIZE,1,1)]
void CSScaleSelection (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _SplatCount)
return;
if (!IsSplatSelected(idx))
return;
uint fmt = _SplatFormat & 0xFF;
if (_SplatChunkCount == 0 && fmt == VECTOR_FMT_32F)
{
uint stride = 12;
float3 pos = asfloat(_SplatPosMouseDown.Load3(idx * stride));
pos -= _SelectionCenter;
pos = mul(_MatrixObjectToWorld, float4(pos,1)).xyz;
pos *= _SelectionDelta;
pos = mul(_MatrixWorldToObject, float4(pos,1)).xyz;
pos += _SelectionCenter;
_SplatPos.Store3(idx * stride, asuint(pos));
}
}
struct ExportSplatData
{
float3 pos;
float3 nor;
float3 dc0;
float4 shR14; float4 shR58; float4 shR9C; float3 shRDF;
float4 shG14; float4 shG58; float4 shG9C; float3 shGDF;
float4 shB14; float4 shB58; float4 shB9C; float3 shBDF;
float opacity;
float3 scale;
float4 rot;
};
RWStructuredBuffer<ExportSplatData> _ExportBuffer;
float3 ColorToSH0(float3 col)
{
return (col - 0.5) / 0.2820948;
}
float InvSigmoid(float v)
{
return log(v / max(1 - v, 1.0e-6));
}
// SH rotation
#include "SphericalHarmonics.hlsl"
void RotateSH(inout SplatSHData sh, float3x3 rot)
{
float3 shin[16];
float3 shout[16];
shin[0] = sh.col;
shin[1] = sh.sh1;
shin[2] = sh.sh2;
shin[3] = sh.sh3;
shin[4] = sh.sh4;
shin[5] = sh.sh5;
shin[6] = sh.sh6;
shin[7] = sh.sh7;
shin[8] = sh.sh8;
shin[9] = sh.sh9;
shin[10] = sh.sh10;
shin[11] = sh.sh11;
shin[12] = sh.sh12;
shin[13] = sh.sh13;
shin[14] = sh.sh14;
shin[15] = sh.sh15;
RotateSH(rot, 4, shin, shout);
sh.col = shout[0];
sh.sh1 = shout[1];
sh.sh2 = shout[2];
sh.sh3 = shout[3];
sh.sh4 = shout[4];
sh.sh5 = shout[5];
sh.sh6 = shout[6];
sh.sh7 = shout[7];
sh.sh8 = shout[8];
sh.sh9 = shout[9];
sh.sh10 = shout[10];
sh.sh11 = shout[11];
sh.sh12 = shout[12];
sh.sh13 = shout[13];
sh.sh14 = shout[14];
sh.sh15 = shout[15];
}
float3x3 CalcSHRotMatrix(float4x4 objToWorld)
{
float3x3 m = (float3x3)objToWorld;
float sx = length(float3(m[0][0], m[0][1], m[0][2]));
float sy = length(float3(m[1][0], m[1][1], m[1][2]));
float sz = length(float3(m[2][0], m[2][1], m[2][2]));
float invSX = 1.0 / sx;
float invSY = 1.0 / sy;
float invSZ = 1.0 / sz;
m[0][0] *= invSX;
m[0][1] *= invSX;
m[0][2] *= invSX;
m[1][0] *= invSY;
m[1][1] *= invSY;
m[1][2] *= invSY;
m[2][0] *= invSZ;
m[2][1] *= invSZ;
m[2][2] *= invSZ;
return m;
}
float4 _ExportTransformRotation;
float3 _ExportTransformScale;
uint _ExportTransformFlags;
[numthreads(GROUP_SIZE,1,1)]
void CSExportData (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _SplatCount)
return;
SplatData src = LoadSplatData(idx);
bool isCut = IsSplatCut(src.pos);
// transform splat by matrix, if needed
if (_ExportTransformFlags != 0)
{
src.pos = mul(_MatrixObjectToWorld, float4(src.pos,1)).xyz;
// note: this only handles axis flips from scale, not any arbitrary scaling
if (_ExportTransformScale.x < 0)
src.rot.yz = -src.rot.yz;
if (_ExportTransformScale.y < 0)
src.rot.xz = -src.rot.xz;
if (_ExportTransformScale.z < 0)
src.rot.xy = -src.rot.xy;
src.rot = QuatMul(_ExportTransformRotation, src.rot);
src.scale *= abs(_ExportTransformScale);
float3x3 shRot = CalcSHRotMatrix(_MatrixObjectToWorld);
RotateSH(src.sh, shRot);
}
ExportSplatData dst;
dst.pos = src.pos;
dst.nor = 0;
dst.dc0 = ColorToSH0(src.sh.col);
dst.shR14 = float4(src.sh.sh1.r, src.sh.sh2.r, src.sh.sh3.r, src.sh.sh4.r);
dst.shR58 = float4(src.sh.sh5.r, src.sh.sh6.r, src.sh.sh7.r, src.sh.sh8.r);
dst.shR9C = float4(src.sh.sh9.r, src.sh.sh10.r, src.sh.sh11.r, src.sh.sh12.r);
dst.shRDF = float3(src.sh.sh13.r, src.sh.sh14.r, src.sh.sh15.r);
dst.shG14 = float4(src.sh.sh1.g, src.sh.sh2.g, src.sh.sh3.g, src.sh.sh4.g);
dst.shG58 = float4(src.sh.sh5.g, src.sh.sh6.g, src.sh.sh7.g, src.sh.sh8.g);
dst.shG9C = float4(src.sh.sh9.g, src.sh.sh10.g, src.sh.sh11.g, src.sh.sh12.g);
dst.shGDF = float3(src.sh.sh13.g, src.sh.sh14.g, src.sh.sh15.g);
dst.shB14 = float4(src.sh.sh1.b, src.sh.sh2.b, src.sh.sh3.b, src.sh.sh4.b);
dst.shB58 = float4(src.sh.sh5.b, src.sh.sh6.b, src.sh.sh7.b, src.sh.sh8.b);
dst.shB9C = float4(src.sh.sh9.b, src.sh.sh10.b, src.sh.sh11.b, src.sh.sh12.b);
dst.shBDF = float3(src.sh.sh13.b, src.sh.sh14.b, src.sh.sh15.b);
dst.opacity = InvSigmoid(src.opacity);
dst.scale = log(src.scale);
dst.rot = src.rot.wxyz;
if (isCut)
dst.nor = 1; // mark as skipped for export
_ExportBuffer[idx] = dst;
}
RWByteAddressBuffer _CopyDstPos;
RWByteAddressBuffer _CopyDstOther;
RWByteAddressBuffer _CopyDstSH;
RWByteAddressBuffer _CopyDstEditDeleted;
RWTexture2D<float4> _CopyDstColor;
uint _CopyDstSize, _CopySrcStartIndex, _CopyDstStartIndex, _CopyCount;
float4x4 _CopyTransformMatrix;
float4 _CopyTransformRotation;
float3 _CopyTransformScale;
[numthreads(GROUP_SIZE,1,1)]
void CSCopySplats (uint3 id : SV_DispatchThreadID)
{
uint idx = id.x;
if (idx >= _CopyCount)
return;
uint srcIdx = _CopySrcStartIndex + idx;
uint dstIdx = _CopyDstStartIndex + idx;
if (srcIdx >= _SplatCount || dstIdx >= _CopyDstSize)
return;
SplatData src = LoadSplatData(idx);
// transform the splat
src.pos = mul(_CopyTransformMatrix, float4(src.pos,1)).xyz;
// note: this only handles axis flips from scale, not any arbitrary scaling
if (_CopyTransformScale.x < 0)
src.rot.yz = -src.rot.yz;
if (_CopyTransformScale.y < 0)
src.rot.xz = -src.rot.xz;
if (_CopyTransformScale.z < 0)
src.rot.xy = -src.rot.xy;
src.rot = QuatMul(_CopyTransformRotation, src.rot);
src.scale *= abs(_CopyTransformScale);
float3x3 shRot = CalcSHRotMatrix(_CopyTransformMatrix);
RotateSH(src.sh, shRot);
// output data into destination:
// pos
uint posStride = 12;
_CopyDstPos.Store3(dstIdx * posStride, asuint(src.pos));
// rot + scale
uint otherStride = 4 + 12;
uint rotVal = EncodeQuatToNorm10(PackSmallest3Rotation(src.rot));
_CopyDstOther.Store4(dstIdx * otherStride, uint4(
rotVal,
asuint(src.scale.x),
asuint(src.scale.y),
asuint(src.scale.z)));
// color
uint3 pixelIndex = SplatIndexToPixelIndex(dstIdx);
_CopyDstColor[pixelIndex.xy] = float4(src.sh.col, src.opacity);
// SH
uint shStride = 192; // 15*3 fp32, rounded up to multiple of 16
uint shOffset = dstIdx * shStride;
_CopyDstSH.Store3(shOffset + 12 * 0, asuint(src.sh.sh1));
_CopyDstSH.Store3(shOffset + 12 * 1, asuint(src.sh.sh2));
_CopyDstSH.Store3(shOffset + 12 * 2, asuint(src.sh.sh3));
_CopyDstSH.Store3(shOffset + 12 * 3, asuint(src.sh.sh4));
_CopyDstSH.Store3(shOffset + 12 * 4, asuint(src.sh.sh5));
_CopyDstSH.Store3(shOffset + 12 * 5, asuint(src.sh.sh6));
_CopyDstSH.Store3(shOffset + 12 * 6, asuint(src.sh.sh7));
_CopyDstSH.Store3(shOffset + 12 * 7, asuint(src.sh.sh8));
_CopyDstSH.Store3(shOffset + 12 * 8, asuint(src.sh.sh9));
_CopyDstSH.Store3(shOffset + 12 * 9, asuint(src.sh.sh10));
_CopyDstSH.Store3(shOffset + 12 * 10, asuint(src.sh.sh11));
_CopyDstSH.Store3(shOffset + 12 * 11, asuint(src.sh.sh12));
_CopyDstSH.Store3(shOffset + 12 * 12, asuint(src.sh.sh13));
_CopyDstSH.Store3(shOffset + 12 * 13, asuint(src.sh.sh14));
_CopyDstSH.Store3(shOffset + 12 * 14, asuint(src.sh.sh15));
// deleted bits
uint srcWordIdx = srcIdx / 32;
uint srcBitIdx = srcIdx & 31;
if (_SplatDeletedBits.Load(srcWordIdx * 4) & (1u << srcBitIdx))
{
uint dstWordIdx = dstIdx / 32;
uint dstBitIdx = dstIdx & 31;
_CopyDstEditDeleted.InterlockedOr(dstWordIdx * 4, 1u << dstBitIdx);
}
}