758 lines
22 KiB
Plaintext
758 lines
22 KiB
Plaintext
|
|
// SPDX-License-Identifier: MIT
|
||
|
|
#define GROUP_SIZE 1024
|
||
|
|
|
||
|
|
#pragma kernel CSSetIndices
|
||
|
|
#pragma kernel CSCalcDistances
|
||
|
|
#pragma kernel CSCalcViewData
|
||
|
|
#pragma kernel CSUpdateEditData
|
||
|
|
#pragma kernel CSInitEditData
|
||
|
|
#pragma kernel CSClearBuffer
|
||
|
|
#pragma kernel CSInvertSelection
|
||
|
|
#pragma kernel CSSelectAll
|
||
|
|
#pragma kernel CSOrBuffers
|
||
|
|
#pragma kernel CSSelectionUpdate
|
||
|
|
#pragma kernel CSTranslateSelection
|
||
|
|
#pragma kernel CSRotateSelection
|
||
|
|
#pragma kernel CSScaleSelection
|
||
|
|
#pragma kernel CSExportData
|
||
|
|
#pragma kernel CSCopySplats
|
||
|
|
|
||
|
|
// DeviceRadixSort
|
||
|
|
#pragma multi_compile __ KEY_UINT KEY_INT KEY_FLOAT
|
||
|
|
#pragma multi_compile __ PAYLOAD_UINT PAYLOAD_INT PAYLOAD_FLOAT
|
||
|
|
#pragma multi_compile __ SHOULD_ASCEND
|
||
|
|
#pragma multi_compile __ SORT_PAIRS
|
||
|
|
#pragma multi_compile __ VULKAN
|
||
|
|
#pragma kernel InitDeviceRadixSort
|
||
|
|
#pragma kernel Upsweep
|
||
|
|
#pragma kernel Scan
|
||
|
|
#pragma kernel Downsweep
|
||
|
|
|
||
|
|
// GPU sorting needs wave ops
|
||
|
|
#pragma require wavebasic
|
||
|
|
#pragma require waveballot
|
||
|
|
#pragma use_dxc
|
||
|
|
|
||
|
|
#include "DeviceRadixSort.hlsl"
|
||
|
|
#include "GaussianSplatting.hlsl"
|
||
|
|
#include "UnityCG.cginc"
|
||
|
|
|
||
|
|
float4x4 _MatrixObjectToWorld;
|
||
|
|
float4x4 _MatrixWorldToObject;
|
||
|
|
float4x4 _MatrixMV;
|
||
|
|
float4 _VecScreenParams;
|
||
|
|
float4 _VecWorldSpaceCameraPos;
|
||
|
|
int _SelectionMode;
|
||
|
|
|
||
|
|
RWStructuredBuffer<uint> _SplatSortDistances;
|
||
|
|
RWStructuredBuffer<uint> _SplatSortKeys;
|
||
|
|
uint _SplatCount;
|
||
|
|
|
||
|
|
// radix sort etc. friendly, see http://stereopsis.com/radix.html
|
||
|
|
uint FloatToSortableUint(float f)
|
||
|
|
{
|
||
|
|
uint fu = asuint(f);
|
||
|
|
uint mask = -((int)(fu >> 31)) | 0x80000000;
|
||
|
|
return fu ^ mask;
|
||
|
|
}
|
||
|
|
|
||
|
|
[numthreads(GROUP_SIZE,1,1)]
|
||
|
|
void CSSetIndices (uint3 id : SV_DispatchThreadID)
|
||
|
|
{
|
||
|
|
uint idx = id.x;
|
||
|
|
if (idx >= _SplatCount)
|
||
|
|
return;
|
||
|
|
|
||
|
|
_SplatSortKeys[idx] = idx;
|
||
|
|
}
|
||
|
|
|
||
|
|
[numthreads(GROUP_SIZE,1,1)]
|
||
|
|
void CSCalcDistances (uint3 id : SV_DispatchThreadID)
|
||
|
|
{
|
||
|
|
uint idx = id.x;
|
||
|
|
if (idx >= _SplatCount)
|
||
|
|
return;
|
||
|
|
|
||
|
|
uint origIdx = _SplatSortKeys[idx];
|
||
|
|
|
||
|
|
float3 pos = LoadSplatPos(origIdx);
|
||
|
|
pos = mul(_MatrixMV, float4(pos.xyz, 1)).xyz;
|
||
|
|
|
||
|
|
_SplatSortDistances[idx] = FloatToSortableUint(pos.z);
|
||
|
|
}
|
||
|
|
|
||
|
|
RWStructuredBuffer<SplatViewData> _SplatViewData;
|
||
|
|
|
||
|
|
float _SplatScale;
|
||
|
|
float _SplatOpacityScale;
|
||
|
|
uint _SHOrder;
|
||
|
|
uint _SHOnly;
|
||
|
|
|
||
|
|
uint _SplatCutoutsCount;
|
||
|
|
|
||
|
|
#define SPLAT_CUTOUT_TYPE_ELLIPSOID 0
|
||
|
|
#define SPLAT_CUTOUT_TYPE_BOX 1
|
||
|
|
|
||
|
|
struct GaussianCutoutShaderData // match GaussianCutout.ShaderData in C#
|
||
|
|
{
|
||
|
|
float4x4 mat;
|
||
|
|
uint typeAndFlags;
|
||
|
|
};
|
||
|
|
StructuredBuffer<GaussianCutoutShaderData> _SplatCutouts;
|
||
|
|
|
||
|
|
RWByteAddressBuffer _SplatSelectedBits;
|
||
|
|
ByteAddressBuffer _SplatDeletedBits;
|
||
|
|
uint _SplatBitsValid;
|
||
|
|
|
||
|
|
void DecomposeCovariance(float3 cov2d, out float2 v1, out float2 v2)
|
||
|
|
{
|
||
|
|
#if 0 // does not quite give the correct results?
|
||
|
|
|
||
|
|
// https://jsfiddle.net/mattrossman/ehxmtgw6/
|
||
|
|
// References:
|
||
|
|
// - https://www.youtube.com/watch?v=e50Bj7jn9IQ
|
||
|
|
// - https://en.wikipedia.org/wiki/Eigenvalue_algorithm#2%C3%972_matrices
|
||
|
|
// - https://people.math.harvard.edu/~knill/teaching/math21b2004/exhibits/2dmatrices/index.html
|
||
|
|
float a = cov2d.x;
|
||
|
|
float b = cov2d.y;
|
||
|
|
float d = cov2d.z;
|
||
|
|
float det = a * d - b * b; // matrix is symmetric, so "c" is same as "b"
|
||
|
|
float trace = a + d;
|
||
|
|
|
||
|
|
float mean = 0.5 * trace;
|
||
|
|
float dist = sqrt(mean * mean - det);
|
||
|
|
|
||
|
|
float lambda1 = mean + dist; // 1st eigenvalue
|
||
|
|
float lambda2 = mean - dist; // 2nd eigenvalue
|
||
|
|
|
||
|
|
if (b == 0) {
|
||
|
|
// https://twitter.com/the_ross_man/status/1706342719776551360
|
||
|
|
if (a > d) v1 = float2(1, 0);
|
||
|
|
else v1 = float2(0, 1);
|
||
|
|
} else
|
||
|
|
v1 = normalize(float2(b, d - lambda2));
|
||
|
|
|
||
|
|
v1.y = -v1.y;
|
||
|
|
// The 2nd eigenvector is just a 90 degree rotation of the first since Gaussian axes are orthogonal
|
||
|
|
v2 = float2(v1.y, -v1.x);
|
||
|
|
|
||
|
|
// scaling components
|
||
|
|
v1 *= sqrt(lambda1);
|
||
|
|
v2 *= sqrt(lambda2);
|
||
|
|
|
||
|
|
float radius = 1.5;
|
||
|
|
v1 *= radius;
|
||
|
|
v2 *= radius;
|
||
|
|
|
||
|
|
#else
|
||
|
|
|
||
|
|
// same as in antimatter15/splat
|
||
|
|
float diag1 = cov2d.x, diag2 = cov2d.z, offDiag = cov2d.y;
|
||
|
|
float mid = 0.5f * (diag1 + diag2);
|
||
|
|
float radius = length(float2((diag1 - diag2) / 2.0, offDiag));
|
||
|
|
float lambda1 = mid + radius;
|
||
|
|
float lambda2 = max(mid - radius, 0.1);
|
||
|
|
float2 diagVec = normalize(float2(offDiag, lambda1 - diag1));
|
||
|
|
diagVec.y = -diagVec.y;
|
||
|
|
float maxSize = 4096.0;
|
||
|
|
v1 = min(sqrt(2.0 * lambda1), maxSize) * diagVec;
|
||
|
|
v2 = min(sqrt(2.0 * lambda2), maxSize) * float2(diagVec.y, -diagVec.x);
|
||
|
|
|
||
|
|
#endif
|
||
|
|
}
|
||
|
|
|
||
|
|
bool IsSplatCut(float3 pos)
|
||
|
|
{
|
||
|
|
bool finalCut = false;
|
||
|
|
for (uint i = 0; i < _SplatCutoutsCount; ++i)
|
||
|
|
{
|
||
|
|
GaussianCutoutShaderData cutData = _SplatCutouts[i];
|
||
|
|
uint type = cutData.typeAndFlags & 0xFF;
|
||
|
|
if (type == 0xFF) // invalid/null cutout, ignore
|
||
|
|
continue;
|
||
|
|
bool invert = (cutData.typeAndFlags & 0xFF00) != 0;
|
||
|
|
|
||
|
|
float3 cutoutPos = mul(cutData.mat, float4(pos, 1)).xyz;
|
||
|
|
if (type == SPLAT_CUTOUT_TYPE_ELLIPSOID)
|
||
|
|
{
|
||
|
|
if (dot(cutoutPos, cutoutPos) <= 1) return invert;
|
||
|
|
}
|
||
|
|
if (type == SPLAT_CUTOUT_TYPE_BOX)
|
||
|
|
{
|
||
|
|
if (all(abs(cutoutPos) <= 1)) return invert;
|
||
|
|
}
|
||
|
|
finalCut |= !invert;
|
||
|
|
}
|
||
|
|
return finalCut;
|
||
|
|
}
|
||
|
|
|
||
|
|
[numthreads(GROUP_SIZE,1,1)]
|
||
|
|
void CSCalcViewData (uint3 id : SV_DispatchThreadID)
|
||
|
|
{
|
||
|
|
uint idx = id.x;
|
||
|
|
if (idx >= _SplatCount)
|
||
|
|
return;
|
||
|
|
|
||
|
|
SplatData splat = LoadSplatData(idx);
|
||
|
|
SplatViewData view = (SplatViewData)0;
|
||
|
|
|
||
|
|
float3 centerWorldPos = mul(_MatrixObjectToWorld, float4(splat.pos,1)).xyz;
|
||
|
|
float4 centerClipPos = mul(UNITY_MATRIX_VP, float4(centerWorldPos, 1));
|
||
|
|
half opacityScale = _SplatOpacityScale;
|
||
|
|
float splatScale = _SplatScale;
|
||
|
|
|
||
|
|
// deleted?
|
||
|
|
if (_SplatBitsValid)
|
||
|
|
{
|
||
|
|
uint wordIdx = idx / 32;
|
||
|
|
uint bitIdx = idx & 31;
|
||
|
|
uint wordVal = _SplatDeletedBits.Load(wordIdx * 4);
|
||
|
|
if (wordVal & (1 << bitIdx))
|
||
|
|
{
|
||
|
|
centerClipPos.w = 0;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// cutouts
|
||
|
|
if (IsSplatCut(splat.pos))
|
||
|
|
{
|
||
|
|
centerClipPos.w = 0;
|
||
|
|
}
|
||
|
|
|
||
|
|
view.pos = centerClipPos;
|
||
|
|
bool behindCam = centerClipPos.w <= 0;
|
||
|
|
if (!behindCam)
|
||
|
|
{
|
||
|
|
float4 boxRot = splat.rot;
|
||
|
|
float3 boxSize = splat.scale;
|
||
|
|
|
||
|
|
float3x3 splatRotScaleMat = CalcMatrixFromRotationScale(boxRot, boxSize);
|
||
|
|
|
||
|
|
float3 cov3d0, cov3d1;
|
||
|
|
CalcCovariance3D(splatRotScaleMat, cov3d0, cov3d1);
|
||
|
|
float splatScale2 = splatScale * splatScale;
|
||
|
|
cov3d0 *= splatScale2;
|
||
|
|
cov3d1 *= splatScale2;
|
||
|
|
float3 cov2d = CalcCovariance2D(splat.pos, cov3d0, cov3d1, _MatrixMV, UNITY_MATRIX_P, _VecScreenParams);
|
||
|
|
|
||
|
|
DecomposeCovariance(cov2d, view.axis1, view.axis2);
|
||
|
|
|
||
|
|
float3 worldViewDir = _VecWorldSpaceCameraPos.xyz - centerWorldPos;
|
||
|
|
float3 objViewDir = mul((float3x3)_MatrixWorldToObject, worldViewDir);
|
||
|
|
objViewDir = normalize(objViewDir);
|
||
|
|
|
||
|
|
half4 col;
|
||
|
|
col.rgb = ShadeSH(splat.sh, objViewDir, _SHOrder, _SHOnly != 0);
|
||
|
|
col.a = min(splat.opacity * opacityScale, 65000);
|
||
|
|
view.color.x = (f32tof16(col.r) << 16) | f32tof16(col.g);
|
||
|
|
view.color.y = (f32tof16(col.b) << 16) | f32tof16(col.a);
|
||
|
|
}
|
||
|
|
|
||
|
|
_SplatViewData[idx] = view;
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
RWByteAddressBuffer _DstBuffer;
|
||
|
|
ByteAddressBuffer _SrcBuffer;
|
||
|
|
uint _BufferSize;
|
||
|
|
|
||
|
|
uint2 GetSplatIndicesFromWord(uint idx)
|
||
|
|
{
|
||
|
|
uint idxStart = idx * 32;
|
||
|
|
uint idxEnd = min(idxStart + 32, _SplatCount);
|
||
|
|
return uint2(idxStart, idxEnd);
|
||
|
|
}
|
||
|
|
|
||
|
|
[numthreads(GROUP_SIZE,1,1)]
|
||
|
|
void CSUpdateEditData (uint3 id : SV_DispatchThreadID)
|
||
|
|
{
|
||
|
|
uint idx = id.x;
|
||
|
|
if (idx >= _BufferSize)
|
||
|
|
return;
|
||
|
|
|
||
|
|
uint valSel = _SplatSelectedBits.Load(idx * 4);
|
||
|
|
uint valDel = _SplatDeletedBits.Load(idx * 4);
|
||
|
|
valSel &= ~valDel; // don't count deleted splats as selected
|
||
|
|
uint2 splatIndices = GetSplatIndicesFromWord(idx);
|
||
|
|
|
||
|
|
// update selection bounds
|
||
|
|
float3 bmin = 1.0e38;
|
||
|
|
float3 bmax = -1.0e38;
|
||
|
|
uint mask = 1;
|
||
|
|
uint valCut = 0;
|
||
|
|
for (uint sidx = splatIndices.x; sidx < splatIndices.y; ++sidx, mask <<= 1)
|
||
|
|
{
|
||
|
|
float3 spos = LoadSplatPos(sidx);
|
||
|
|
// don't count cut splats as selected
|
||
|
|
if (IsSplatCut(spos))
|
||
|
|
{
|
||
|
|
valSel &= ~mask;
|
||
|
|
valCut |= mask;
|
||
|
|
}
|
||
|
|
if (valSel & mask)
|
||
|
|
{
|
||
|
|
bmin = min(bmin, spos);
|
||
|
|
bmax = max(bmax, spos);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
valCut &= ~valDel; // don't count deleted splats as cut
|
||
|
|
|
||
|
|
if (valSel != 0)
|
||
|
|
{
|
||
|
|
_DstBuffer.InterlockedMin(12, FloatToSortableUint(bmin.x));
|
||
|
|
_DstBuffer.InterlockedMin(16, FloatToSortableUint(bmin.y));
|
||
|
|
_DstBuffer.InterlockedMin(20, FloatToSortableUint(bmin.z));
|
||
|
|
_DstBuffer.InterlockedMax(24, FloatToSortableUint(bmax.x));
|
||
|
|
_DstBuffer.InterlockedMax(28, FloatToSortableUint(bmax.y));
|
||
|
|
_DstBuffer.InterlockedMax(32, FloatToSortableUint(bmax.z));
|
||
|
|
}
|
||
|
|
uint sumSel = countbits(valSel);
|
||
|
|
uint sumDel = countbits(valDel);
|
||
|
|
uint sumCut = countbits(valCut);
|
||
|
|
_DstBuffer.InterlockedAdd(0, sumSel);
|
||
|
|
_DstBuffer.InterlockedAdd(4, sumDel);
|
||
|
|
_DstBuffer.InterlockedAdd(8, sumCut);
|
||
|
|
}
|
||
|
|
|
||
|
|
[numthreads(1,1,1)]
|
||
|
|
void CSInitEditData (uint3 id : SV_DispatchThreadID)
|
||
|
|
{
|
||
|
|
_DstBuffer.Store3(0, uint3(0,0,0)); // selected, deleted, cut counts
|
||
|
|
uint initMin = FloatToSortableUint(1.0e38);
|
||
|
|
uint initMax = FloatToSortableUint(-1.0e38);
|
||
|
|
_DstBuffer.Store3(12, uint3(initMin, initMin, initMin));
|
||
|
|
_DstBuffer.Store3(24, uint3(initMax, initMax, initMax));
|
||
|
|
}
|
||
|
|
|
||
|
|
[numthreads(GROUP_SIZE,1,1)]
|
||
|
|
void CSClearBuffer (uint3 id : SV_DispatchThreadID)
|
||
|
|
{
|
||
|
|
uint idx = id.x;
|
||
|
|
if (idx >= _BufferSize)
|
||
|
|
return;
|
||
|
|
_DstBuffer.Store(idx * 4, 0);
|
||
|
|
}
|
||
|
|
|
||
|
|
[numthreads(GROUP_SIZE,1,1)]
|
||
|
|
void CSInvertSelection (uint3 id : SV_DispatchThreadID)
|
||
|
|
{
|
||
|
|
uint idx = id.x;
|
||
|
|
if (idx >= _BufferSize)
|
||
|
|
return;
|
||
|
|
uint v = _DstBuffer.Load(idx * 4);
|
||
|
|
v = ~v;
|
||
|
|
|
||
|
|
// do not select splats that are cut
|
||
|
|
uint2 splatIndices = GetSplatIndicesFromWord(idx);
|
||
|
|
uint mask = 1;
|
||
|
|
for (uint sidx = splatIndices.x; sidx < splatIndices.y; ++sidx, mask <<= 1)
|
||
|
|
{
|
||
|
|
float3 spos = LoadSplatPos(sidx);
|
||
|
|
if (IsSplatCut(spos))
|
||
|
|
v &= ~mask;
|
||
|
|
}
|
||
|
|
|
||
|
|
_DstBuffer.Store(idx * 4, v);
|
||
|
|
}
|
||
|
|
|
||
|
|
[numthreads(GROUP_SIZE,1,1)]
|
||
|
|
void CSSelectAll (uint3 id : SV_DispatchThreadID)
|
||
|
|
{
|
||
|
|
uint idx = id.x;
|
||
|
|
if (idx >= _BufferSize)
|
||
|
|
return;
|
||
|
|
uint v = ~0;
|
||
|
|
|
||
|
|
// do not select splats that are cut
|
||
|
|
uint2 splatIndices = GetSplatIndicesFromWord(idx);
|
||
|
|
uint mask = 1;
|
||
|
|
for (uint sidx = splatIndices.x; sidx < splatIndices.y; ++sidx, mask <<= 1)
|
||
|
|
{
|
||
|
|
float3 spos = LoadSplatPos(sidx);
|
||
|
|
if (IsSplatCut(spos))
|
||
|
|
v &= ~mask;
|
||
|
|
}
|
||
|
|
|
||
|
|
_DstBuffer.Store(idx * 4, v);
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
[numthreads(GROUP_SIZE,1,1)]
|
||
|
|
void CSOrBuffers (uint3 id : SV_DispatchThreadID)
|
||
|
|
{
|
||
|
|
uint idx = id.x;
|
||
|
|
if (idx >= _BufferSize)
|
||
|
|
return;
|
||
|
|
uint a = _SrcBuffer.Load(idx * 4);
|
||
|
|
uint b = _DstBuffer.Load(idx * 4);
|
||
|
|
_DstBuffer.Store(idx * 4, a | b);
|
||
|
|
}
|
||
|
|
|
||
|
|
float4 _SelectionRect;
|
||
|
|
|
||
|
|
[numthreads(GROUP_SIZE,1,1)]
|
||
|
|
void CSSelectionUpdate (uint3 id : SV_DispatchThreadID)
|
||
|
|
{
|
||
|
|
uint idx = id.x;
|
||
|
|
if (idx >= _SplatCount)
|
||
|
|
return;
|
||
|
|
|
||
|
|
float3 pos = LoadSplatPos(idx);
|
||
|
|
if (IsSplatCut(pos))
|
||
|
|
return;
|
||
|
|
|
||
|
|
float3 centerWorldPos = mul(_MatrixObjectToWorld, float4(pos,1)).xyz;
|
||
|
|
float4 centerClipPos = mul(UNITY_MATRIX_VP, float4(centerWorldPos, 1));
|
||
|
|
bool behindCam = centerClipPos.w <= 0;
|
||
|
|
if (behindCam)
|
||
|
|
return;
|
||
|
|
|
||
|
|
float2 pixelPos = (centerClipPos.xy / centerClipPos.w * float2(0.5, -0.5) + 0.5) * _VecScreenParams.xy;
|
||
|
|
if (pixelPos.x < _SelectionRect.x || pixelPos.x > _SelectionRect.z ||
|
||
|
|
pixelPos.y < _SelectionRect.y || pixelPos.y > _SelectionRect.w)
|
||
|
|
{
|
||
|
|
return;
|
||
|
|
}
|
||
|
|
uint wordIdx = idx / 32;
|
||
|
|
uint bitIdx = idx & 31;
|
||
|
|
if (_SelectionMode)
|
||
|
|
_SplatSelectedBits.InterlockedOr(wordIdx * 4, 1u << bitIdx); // +
|
||
|
|
else
|
||
|
|
_SplatSelectedBits.InterlockedAnd(wordIdx * 4, ~(1u << bitIdx)); // -
|
||
|
|
}
|
||
|
|
|
||
|
|
float3 _SelectionDelta;
|
||
|
|
|
||
|
|
bool IsSplatSelected(uint idx)
|
||
|
|
{
|
||
|
|
uint wordIdx = idx / 32;
|
||
|
|
uint bitIdx = idx & 31;
|
||
|
|
uint selVal = _SplatSelectedBits.Load(wordIdx * 4);
|
||
|
|
return (selVal & (1 << bitIdx)) != 0;
|
||
|
|
}
|
||
|
|
|
||
|
|
[numthreads(GROUP_SIZE,1,1)]
|
||
|
|
void CSTranslateSelection (uint3 id : SV_DispatchThreadID)
|
||
|
|
{
|
||
|
|
uint idx = id.x;
|
||
|
|
if (idx >= _SplatCount)
|
||
|
|
return;
|
||
|
|
if (!IsSplatSelected(idx))
|
||
|
|
return;
|
||
|
|
|
||
|
|
uint fmt = _SplatFormat & 0xFF;
|
||
|
|
if (_SplatChunkCount == 0 && fmt == VECTOR_FMT_32F)
|
||
|
|
{
|
||
|
|
uint stride = 12;
|
||
|
|
float3 pos = asfloat(_SplatPos.Load3(idx * stride));
|
||
|
|
pos += _SelectionDelta;
|
||
|
|
_SplatPos.Store3(idx * stride, asuint(pos));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
float3 _SelectionCenter;
|
||
|
|
float4 _SelectionDeltaRot;
|
||
|
|
ByteAddressBuffer _SplatPosMouseDown;
|
||
|
|
ByteAddressBuffer _SplatOtherMouseDown;
|
||
|
|
|
||
|
|
[numthreads(GROUP_SIZE,1,1)]
|
||
|
|
void CSRotateSelection (uint3 id : SV_DispatchThreadID)
|
||
|
|
{
|
||
|
|
uint idx = id.x;
|
||
|
|
if (idx >= _SplatCount)
|
||
|
|
return;
|
||
|
|
if (!IsSplatSelected(idx))
|
||
|
|
return;
|
||
|
|
|
||
|
|
uint posFmt = _SplatFormat & 0xFF;
|
||
|
|
if (_SplatChunkCount == 0 && posFmt == VECTOR_FMT_32F)
|
||
|
|
{
|
||
|
|
uint posStride = 12;
|
||
|
|
float3 pos = asfloat(_SplatPosMouseDown.Load3(idx * posStride));
|
||
|
|
pos -= _SelectionCenter;
|
||
|
|
pos = mul(_MatrixObjectToWorld, float4(pos,1)).xyz;
|
||
|
|
pos = QuatRotateVector(pos, _SelectionDeltaRot);
|
||
|
|
pos = mul(_MatrixWorldToObject, float4(pos,1)).xyz;
|
||
|
|
pos += _SelectionCenter;
|
||
|
|
_SplatPos.Store3(idx * posStride, asuint(pos));
|
||
|
|
}
|
||
|
|
|
||
|
|
uint scaleFmt = (_SplatFormat >> 8) & 0xFF;
|
||
|
|
uint shFormat = (_SplatFormat >> 16) & 0xFF;
|
||
|
|
if (_SplatChunkCount == 0 && scaleFmt == VECTOR_FMT_32F && shFormat == VECTOR_FMT_32F)
|
||
|
|
{
|
||
|
|
uint otherStride = 4 + 12;
|
||
|
|
uint rotVal = _SplatOtherMouseDown.Load(idx * otherStride);
|
||
|
|
float4 rot = DecodeRotation(DecodePacked_10_10_10_2(rotVal));
|
||
|
|
|
||
|
|
//@TODO: correct rotation
|
||
|
|
rot = QuatMul(rot, _SelectionDeltaRot);
|
||
|
|
|
||
|
|
rotVal = EncodeQuatToNorm10(PackSmallest3Rotation(rot));
|
||
|
|
_SplatOther.Store(idx * otherStride, rotVal);
|
||
|
|
}
|
||
|
|
|
||
|
|
//@TODO: rotate SHs
|
||
|
|
}
|
||
|
|
|
||
|
|
//@TODO: maybe scale the splat scale itself too?
|
||
|
|
[numthreads(GROUP_SIZE,1,1)]
|
||
|
|
void CSScaleSelection (uint3 id : SV_DispatchThreadID)
|
||
|
|
{
|
||
|
|
uint idx = id.x;
|
||
|
|
if (idx >= _SplatCount)
|
||
|
|
return;
|
||
|
|
if (!IsSplatSelected(idx))
|
||
|
|
return;
|
||
|
|
|
||
|
|
uint fmt = _SplatFormat & 0xFF;
|
||
|
|
if (_SplatChunkCount == 0 && fmt == VECTOR_FMT_32F)
|
||
|
|
{
|
||
|
|
uint stride = 12;
|
||
|
|
float3 pos = asfloat(_SplatPosMouseDown.Load3(idx * stride));
|
||
|
|
pos -= _SelectionCenter;
|
||
|
|
pos = mul(_MatrixObjectToWorld, float4(pos,1)).xyz;
|
||
|
|
pos *= _SelectionDelta;
|
||
|
|
pos = mul(_MatrixWorldToObject, float4(pos,1)).xyz;
|
||
|
|
pos += _SelectionCenter;
|
||
|
|
_SplatPos.Store3(idx * stride, asuint(pos));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
struct ExportSplatData
|
||
|
|
{
|
||
|
|
float3 pos;
|
||
|
|
float3 nor;
|
||
|
|
float3 dc0;
|
||
|
|
float4 shR14; float4 shR58; float4 shR9C; float3 shRDF;
|
||
|
|
float4 shG14; float4 shG58; float4 shG9C; float3 shGDF;
|
||
|
|
float4 shB14; float4 shB58; float4 shB9C; float3 shBDF;
|
||
|
|
float opacity;
|
||
|
|
float3 scale;
|
||
|
|
float4 rot;
|
||
|
|
};
|
||
|
|
RWStructuredBuffer<ExportSplatData> _ExportBuffer;
|
||
|
|
|
||
|
|
float3 ColorToSH0(float3 col)
|
||
|
|
{
|
||
|
|
return (col - 0.5) / 0.2820948;
|
||
|
|
}
|
||
|
|
float InvSigmoid(float v)
|
||
|
|
{
|
||
|
|
return log(v / max(1 - v, 1.0e-6));
|
||
|
|
}
|
||
|
|
|
||
|
|
// SH rotation
|
||
|
|
#include "SphericalHarmonics.hlsl"
|
||
|
|
|
||
|
|
void RotateSH(inout SplatSHData sh, float3x3 rot)
|
||
|
|
{
|
||
|
|
float3 shin[16];
|
||
|
|
float3 shout[16];
|
||
|
|
shin[0] = sh.col;
|
||
|
|
shin[1] = sh.sh1;
|
||
|
|
shin[2] = sh.sh2;
|
||
|
|
shin[3] = sh.sh3;
|
||
|
|
shin[4] = sh.sh4;
|
||
|
|
shin[5] = sh.sh5;
|
||
|
|
shin[6] = sh.sh6;
|
||
|
|
shin[7] = sh.sh7;
|
||
|
|
shin[8] = sh.sh8;
|
||
|
|
shin[9] = sh.sh9;
|
||
|
|
shin[10] = sh.sh10;
|
||
|
|
shin[11] = sh.sh11;
|
||
|
|
shin[12] = sh.sh12;
|
||
|
|
shin[13] = sh.sh13;
|
||
|
|
shin[14] = sh.sh14;
|
||
|
|
shin[15] = sh.sh15;
|
||
|
|
RotateSH(rot, 4, shin, shout);
|
||
|
|
sh.col = shout[0];
|
||
|
|
sh.sh1 = shout[1];
|
||
|
|
sh.sh2 = shout[2];
|
||
|
|
sh.sh3 = shout[3];
|
||
|
|
sh.sh4 = shout[4];
|
||
|
|
sh.sh5 = shout[5];
|
||
|
|
sh.sh6 = shout[6];
|
||
|
|
sh.sh7 = shout[7];
|
||
|
|
sh.sh8 = shout[8];
|
||
|
|
sh.sh9 = shout[9];
|
||
|
|
sh.sh10 = shout[10];
|
||
|
|
sh.sh11 = shout[11];
|
||
|
|
sh.sh12 = shout[12];
|
||
|
|
sh.sh13 = shout[13];
|
||
|
|
sh.sh14 = shout[14];
|
||
|
|
sh.sh15 = shout[15];
|
||
|
|
}
|
||
|
|
|
||
|
|
float3x3 CalcSHRotMatrix(float4x4 objToWorld)
|
||
|
|
{
|
||
|
|
float3x3 m = (float3x3)objToWorld;
|
||
|
|
float sx = length(float3(m[0][0], m[0][1], m[0][2]));
|
||
|
|
float sy = length(float3(m[1][0], m[1][1], m[1][2]));
|
||
|
|
float sz = length(float3(m[2][0], m[2][1], m[2][2]));
|
||
|
|
|
||
|
|
float invSX = 1.0 / sx;
|
||
|
|
float invSY = 1.0 / sy;
|
||
|
|
float invSZ = 1.0 / sz;
|
||
|
|
|
||
|
|
m[0][0] *= invSX;
|
||
|
|
m[0][1] *= invSX;
|
||
|
|
m[0][2] *= invSX;
|
||
|
|
m[1][0] *= invSY;
|
||
|
|
m[1][1] *= invSY;
|
||
|
|
m[1][2] *= invSY;
|
||
|
|
m[2][0] *= invSZ;
|
||
|
|
m[2][1] *= invSZ;
|
||
|
|
m[2][2] *= invSZ;
|
||
|
|
return m;
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
float4 _ExportTransformRotation;
|
||
|
|
float3 _ExportTransformScale;
|
||
|
|
uint _ExportTransformFlags;
|
||
|
|
|
||
|
|
[numthreads(GROUP_SIZE,1,1)]
|
||
|
|
void CSExportData (uint3 id : SV_DispatchThreadID)
|
||
|
|
{
|
||
|
|
uint idx = id.x;
|
||
|
|
if (idx >= _SplatCount)
|
||
|
|
return;
|
||
|
|
SplatData src = LoadSplatData(idx);
|
||
|
|
|
||
|
|
bool isCut = IsSplatCut(src.pos);
|
||
|
|
|
||
|
|
// transform splat by matrix, if needed
|
||
|
|
if (_ExportTransformFlags != 0)
|
||
|
|
{
|
||
|
|
src.pos = mul(_MatrixObjectToWorld, float4(src.pos,1)).xyz;
|
||
|
|
|
||
|
|
// note: this only handles axis flips from scale, not any arbitrary scaling
|
||
|
|
if (_ExportTransformScale.x < 0)
|
||
|
|
src.rot.yz = -src.rot.yz;
|
||
|
|
if (_ExportTransformScale.y < 0)
|
||
|
|
src.rot.xz = -src.rot.xz;
|
||
|
|
if (_ExportTransformScale.z < 0)
|
||
|
|
src.rot.xy = -src.rot.xy;
|
||
|
|
src.rot = QuatMul(_ExportTransformRotation, src.rot);
|
||
|
|
src.scale *= abs(_ExportTransformScale);
|
||
|
|
|
||
|
|
float3x3 shRot = CalcSHRotMatrix(_MatrixObjectToWorld);
|
||
|
|
RotateSH(src.sh, shRot);
|
||
|
|
}
|
||
|
|
|
||
|
|
ExportSplatData dst;
|
||
|
|
dst.pos = src.pos;
|
||
|
|
dst.nor = 0;
|
||
|
|
dst.dc0 = ColorToSH0(src.sh.col);
|
||
|
|
|
||
|
|
dst.shR14 = float4(src.sh.sh1.r, src.sh.sh2.r, src.sh.sh3.r, src.sh.sh4.r);
|
||
|
|
dst.shR58 = float4(src.sh.sh5.r, src.sh.sh6.r, src.sh.sh7.r, src.sh.sh8.r);
|
||
|
|
dst.shR9C = float4(src.sh.sh9.r, src.sh.sh10.r, src.sh.sh11.r, src.sh.sh12.r);
|
||
|
|
dst.shRDF = float3(src.sh.sh13.r, src.sh.sh14.r, src.sh.sh15.r);
|
||
|
|
|
||
|
|
dst.shG14 = float4(src.sh.sh1.g, src.sh.sh2.g, src.sh.sh3.g, src.sh.sh4.g);
|
||
|
|
dst.shG58 = float4(src.sh.sh5.g, src.sh.sh6.g, src.sh.sh7.g, src.sh.sh8.g);
|
||
|
|
dst.shG9C = float4(src.sh.sh9.g, src.sh.sh10.g, src.sh.sh11.g, src.sh.sh12.g);
|
||
|
|
dst.shGDF = float3(src.sh.sh13.g, src.sh.sh14.g, src.sh.sh15.g);
|
||
|
|
|
||
|
|
dst.shB14 = float4(src.sh.sh1.b, src.sh.sh2.b, src.sh.sh3.b, src.sh.sh4.b);
|
||
|
|
dst.shB58 = float4(src.sh.sh5.b, src.sh.sh6.b, src.sh.sh7.b, src.sh.sh8.b);
|
||
|
|
dst.shB9C = float4(src.sh.sh9.b, src.sh.sh10.b, src.sh.sh11.b, src.sh.sh12.b);
|
||
|
|
dst.shBDF = float3(src.sh.sh13.b, src.sh.sh14.b, src.sh.sh15.b);
|
||
|
|
|
||
|
|
dst.opacity = InvSigmoid(src.opacity);
|
||
|
|
dst.scale = log(src.scale);
|
||
|
|
dst.rot = src.rot.wxyz;
|
||
|
|
|
||
|
|
if (isCut)
|
||
|
|
dst.nor = 1; // mark as skipped for export
|
||
|
|
|
||
|
|
_ExportBuffer[idx] = dst;
|
||
|
|
}
|
||
|
|
|
||
|
|
RWByteAddressBuffer _CopyDstPos;
|
||
|
|
RWByteAddressBuffer _CopyDstOther;
|
||
|
|
RWByteAddressBuffer _CopyDstSH;
|
||
|
|
RWByteAddressBuffer _CopyDstEditDeleted;
|
||
|
|
RWTexture2D<float4> _CopyDstColor;
|
||
|
|
uint _CopyDstSize, _CopySrcStartIndex, _CopyDstStartIndex, _CopyCount;
|
||
|
|
|
||
|
|
float4x4 _CopyTransformMatrix;
|
||
|
|
float4 _CopyTransformRotation;
|
||
|
|
float3 _CopyTransformScale;
|
||
|
|
|
||
|
|
[numthreads(GROUP_SIZE,1,1)]
|
||
|
|
void CSCopySplats (uint3 id : SV_DispatchThreadID)
|
||
|
|
{
|
||
|
|
uint idx = id.x;
|
||
|
|
if (idx >= _CopyCount)
|
||
|
|
return;
|
||
|
|
uint srcIdx = _CopySrcStartIndex + idx;
|
||
|
|
uint dstIdx = _CopyDstStartIndex + idx;
|
||
|
|
if (srcIdx >= _SplatCount || dstIdx >= _CopyDstSize)
|
||
|
|
return;
|
||
|
|
|
||
|
|
SplatData src = LoadSplatData(idx);
|
||
|
|
|
||
|
|
// transform the splat
|
||
|
|
src.pos = mul(_CopyTransformMatrix, float4(src.pos,1)).xyz;
|
||
|
|
// note: this only handles axis flips from scale, not any arbitrary scaling
|
||
|
|
if (_CopyTransformScale.x < 0)
|
||
|
|
src.rot.yz = -src.rot.yz;
|
||
|
|
if (_CopyTransformScale.y < 0)
|
||
|
|
src.rot.xz = -src.rot.xz;
|
||
|
|
if (_CopyTransformScale.z < 0)
|
||
|
|
src.rot.xy = -src.rot.xy;
|
||
|
|
src.rot = QuatMul(_CopyTransformRotation, src.rot);
|
||
|
|
src.scale *= abs(_CopyTransformScale);
|
||
|
|
|
||
|
|
float3x3 shRot = CalcSHRotMatrix(_CopyTransformMatrix);
|
||
|
|
RotateSH(src.sh, shRot);
|
||
|
|
|
||
|
|
// output data into destination:
|
||
|
|
// pos
|
||
|
|
uint posStride = 12;
|
||
|
|
_CopyDstPos.Store3(dstIdx * posStride, asuint(src.pos));
|
||
|
|
// rot + scale
|
||
|
|
uint otherStride = 4 + 12;
|
||
|
|
uint rotVal = EncodeQuatToNorm10(PackSmallest3Rotation(src.rot));
|
||
|
|
_CopyDstOther.Store4(dstIdx * otherStride, uint4(
|
||
|
|
rotVal,
|
||
|
|
asuint(src.scale.x),
|
||
|
|
asuint(src.scale.y),
|
||
|
|
asuint(src.scale.z)));
|
||
|
|
// color
|
||
|
|
uint3 pixelIndex = SplatIndexToPixelIndex(dstIdx);
|
||
|
|
_CopyDstColor[pixelIndex.xy] = float4(src.sh.col, src.opacity);
|
||
|
|
|
||
|
|
// SH
|
||
|
|
uint shStride = 192; // 15*3 fp32, rounded up to multiple of 16
|
||
|
|
uint shOffset = dstIdx * shStride;
|
||
|
|
_CopyDstSH.Store3(shOffset + 12 * 0, asuint(src.sh.sh1));
|
||
|
|
_CopyDstSH.Store3(shOffset + 12 * 1, asuint(src.sh.sh2));
|
||
|
|
_CopyDstSH.Store3(shOffset + 12 * 2, asuint(src.sh.sh3));
|
||
|
|
_CopyDstSH.Store3(shOffset + 12 * 3, asuint(src.sh.sh4));
|
||
|
|
_CopyDstSH.Store3(shOffset + 12 * 4, asuint(src.sh.sh5));
|
||
|
|
_CopyDstSH.Store3(shOffset + 12 * 5, asuint(src.sh.sh6));
|
||
|
|
_CopyDstSH.Store3(shOffset + 12 * 6, asuint(src.sh.sh7));
|
||
|
|
_CopyDstSH.Store3(shOffset + 12 * 7, asuint(src.sh.sh8));
|
||
|
|
_CopyDstSH.Store3(shOffset + 12 * 8, asuint(src.sh.sh9));
|
||
|
|
_CopyDstSH.Store3(shOffset + 12 * 9, asuint(src.sh.sh10));
|
||
|
|
_CopyDstSH.Store3(shOffset + 12 * 10, asuint(src.sh.sh11));
|
||
|
|
_CopyDstSH.Store3(shOffset + 12 * 11, asuint(src.sh.sh12));
|
||
|
|
_CopyDstSH.Store3(shOffset + 12 * 12, asuint(src.sh.sh13));
|
||
|
|
_CopyDstSH.Store3(shOffset + 12 * 13, asuint(src.sh.sh14));
|
||
|
|
_CopyDstSH.Store3(shOffset + 12 * 14, asuint(src.sh.sh15));
|
||
|
|
|
||
|
|
// deleted bits
|
||
|
|
uint srcWordIdx = srcIdx / 32;
|
||
|
|
uint srcBitIdx = srcIdx & 31;
|
||
|
|
if (_SplatDeletedBits.Load(srcWordIdx * 4) & (1u << srcBitIdx))
|
||
|
|
{
|
||
|
|
uint dstWordIdx = dstIdx / 32;
|
||
|
|
uint dstBitIdx = dstIdx & 31;
|
||
|
|
_CopyDstEditDeleted.InterlockedOr(dstWordIdx * 4, 1u << dstBitIdx);
|
||
|
|
}
|
||
|
|
}
|