#include "PreparedSplatView.hlsli" #define GROUP_SIZE 64 cbuffer FrameConstants : register(b0) { float4x4 gViewProjection; float4x4 gView; float4x4 gProjection; float4 gCameraWorldPos; float4 gScreenParams; float4 gSettings; }; ByteAddressBuffer gPositions : register(t0); ByteAddressBuffer gOther : register(t1); Texture2D gColor : register(t2); ByteAddressBuffer gSh : register(t3); RWStructuredBuffer gPreparedViews : register(u0); static const float SH_C1 = 0.4886025; static const float SH_C2[] = { 1.0925484, -1.0925484, 0.3153916, -1.0925484, 0.5462742 }; static const float SH_C3[] = { -0.5900436, 2.8906114, -0.4570458, 0.3731763, -0.4570458, 1.4453057, -0.5900436 }; static const uint kColorTextureWidth = 2048; static const uint kOtherStride = 16; static const uint kShStride = 192; struct SplatSHData { float3 col; float3 sh[15]; }; float3 LoadFloat3(ByteAddressBuffer buffer, uint byteOffset) { return asfloat(buffer.Load3(byteOffset)); } uint EncodeMorton2D_16x16(uint2 c) { uint t = ((c.y & 0xF) << 8) | (c.x & 0xF); t = (t ^ (t << 2)) & 0x3333; t = (t ^ (t << 1)) & 0x5555; return (t | (t >> 7)) & 0xFF; } uint2 DecodeMorton2D_16x16(uint t) { t = (t & 0xFF) | ((t & 0xFE) << 7); t &= 0x5555; t = (t ^ (t >> 1)) & 0x3333; t = (t ^ (t >> 2)) & 0x0F0F; return uint2(t & 0xF, t >> 8); } uint3 SplatIndexToPixelIndex(uint index) { uint2 xy = DecodeMorton2D_16x16(index); uint tileWidth = kColorTextureWidth / 16; index >>= 8; uint3 result; result.x = (index % tileWidth) * 16 + xy.x; result.y = (index / tileWidth) * 16 + xy.y; result.z = 0; return result; } float4 DecodePacked_10_10_10_2(uint encoded) { return float4( (encoded & 1023) / 1023.0, ((encoded >> 10) & 1023) / 1023.0, ((encoded >> 20) & 1023) / 1023.0, ((encoded >> 30) & 3) / 3.0); } float4 DecodeRotation(float4 packedRotation) { uint droppedIndex = (uint)round(packedRotation.w * 3.0); float4 rotation; rotation.xyz = packedRotation.xyz * sqrt(2.0) - (1.0 / sqrt(2.0)); rotation.w = sqrt(1.0 - saturate(dot(rotation.xyz, rotation.xyz))); if (droppedIndex == 0) { rotation = rotation.wxyz; } if (droppedIndex == 1) { rotation = rotation.xwyz; } if (droppedIndex == 2) { rotation = rotation.xywz; } return rotation; } float3x3 CalcMatrixFromRotationScale(float4 rotation, float3 scale) { float3x3 scaleMatrix = float3x3( scale.x, 0, 0, 0, scale.y, 0, 0, 0, scale.z); float x = rotation.x; float y = rotation.y; float z = rotation.z; float w = rotation.w; float3x3 rotationMatrix = float3x3( 1 - 2 * (y * y + z * z), 2 * (x * y - w * z), 2 * (x * z + w * y), 2 * (x * y + w * z), 1 - 2 * (x * x + z * z), 2 * (y * z - w * x), 2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x * x + y * y)); return mul(rotationMatrix, scaleMatrix); } void CalcCovariance3D(float3x3 rotationScaleMatrix, out float3 sigma0, out float3 sigma1) { float3x3 sigma = mul(rotationScaleMatrix, transpose(rotationScaleMatrix)); sigma0 = float3(sigma._m00, sigma._m01, sigma._m02); sigma1 = float3(sigma._m11, sigma._m12, sigma._m22); } float3 CalcCovariance2D(float3 worldPosition, float3 covariance0, float3 covariance1) { float3 viewPosition = mul(float4(worldPosition, 1.0), gView).xyz; float aspect = gProjection._m00 / gProjection._m11; float tanFovX = rcp(gProjection._m00); float tanFovY = rcp(gProjection._m11 * aspect); float clampX = 1.3 * tanFovX; float clampY = 1.3 * tanFovY; viewPosition.x = clamp(viewPosition.x / viewPosition.z, -clampX, clampX) * viewPosition.z; viewPosition.y = clamp(viewPosition.y / viewPosition.z, -clampY, clampY) * viewPosition.z; float focal = gScreenParams.x * gProjection._m00 * 0.5; float3x3 jacobian = float3x3( focal / viewPosition.z, 0, -(focal * viewPosition.x) / (viewPosition.z * viewPosition.z), 0, focal / viewPosition.z, -(focal * viewPosition.y) / (viewPosition.z * viewPosition.z), 0, 0, 0); float3x3 worldToView = transpose((float3x3)gView); float3x3 transform = mul(jacobian, worldToView); float3x3 covariance = float3x3( covariance0.x, covariance0.y, covariance0.z, covariance0.y, covariance1.x, covariance1.y, covariance0.z, covariance1.y, covariance1.z); float3x3 projected = mul(transform, mul(covariance, transpose(transform))); projected._m00 += 0.3; projected._m11 += 0.3; return float3(projected._m00, projected._m01, projected._m11); } void DecomposeCovariance(float3 covariance2D, out float2 axis1, out float2 axis2) { float diagonal0 = covariance2D.x; float diagonal1 = covariance2D.z; float offDiagonal = covariance2D.y; float mid = 0.5 * (diagonal0 + diagonal1); float radius = length(float2((diagonal0 - diagonal1) * 0.5, offDiagonal)); float lambda0 = mid + radius; float lambda1 = max(mid - radius, 0.1); float2 diagonalVector = normalize(float2(offDiagonal, lambda0 - diagonal0)); diagonalVector.y = -diagonalVector.y; const float maxSize = 4096.0; axis1 = min(sqrt(2.0 * lambda0), maxSize) * diagonalVector; axis2 = min(sqrt(2.0 * lambda1), maxSize) * float2(diagonalVector.y, -diagonalVector.x); } SplatSHData LoadSplatSH(uint index) { SplatSHData sh; const uint shBaseOffset = index * kShStride; sh.col = gColor.Load(int3(SplatIndexToPixelIndex(index).xy, 0)).rgb; [unroll] for (uint coefficientIndex = 0; coefficientIndex < 15; ++coefficientIndex) { sh.sh[coefficientIndex] = LoadFloat3(gSh, shBaseOffset + coefficientIndex * 12); } return sh; } float3 ShadeSH(SplatSHData sh, float3 direction, int shOrder) { direction *= -1.0; float x = direction.x; float y = direction.y; float z = direction.z; float3 result = sh.col; if (shOrder >= 1) { result += SH_C1 * (-sh.sh[0] * y + sh.sh[1] * z - sh.sh[2] * x); if (shOrder >= 2) { float xx = x * x; float yy = y * y; float zz = z * z; float xy = x * y; float yz = y * z; float xz = x * z; result += (SH_C2[0] * xy) * sh.sh[3] + (SH_C2[1] * yz) * sh.sh[4] + (SH_C2[2] * (2 * zz - xx - yy)) * sh.sh[5] + (SH_C2[3] * xz) * sh.sh[6] + (SH_C2[4] * (xx - yy)) * sh.sh[7]; if (shOrder >= 3) { result += (SH_C3[0] * y * (3 * xx - yy)) * sh.sh[8] + (SH_C3[1] * xy * z) * sh.sh[9] + (SH_C3[2] * y * (4 * zz - xx - yy)) * sh.sh[10] + (SH_C3[3] * z * (2 * zz - 3 * xx - 3 * yy)) * sh.sh[11] + (SH_C3[4] * x * (4 * zz - xx - yy)) * sh.sh[12] + (SH_C3[5] * z * (xx - yy)) * sh.sh[13] + (SH_C3[6] * x * (xx - 3 * yy)) * sh.sh[14]; } } } return max(result, 0.0); } [numthreads(GROUP_SIZE, 1, 1)] void MainCS(uint3 dispatchThreadId : SV_DispatchThreadID) { uint index = dispatchThreadId.x; uint splatCount = (uint)gSettings.x; if (index >= splatCount) { return; } PreparedSplatView view = (PreparedSplatView)0; float3 position = LoadFloat3(gPositions, index * 12); uint packedRotation = gOther.Load(index * kOtherStride); float4 rotation = DecodeRotation(DecodePacked_10_10_10_2(packedRotation)); float3 scale = LoadFloat3(gOther, index * kOtherStride + 4); float4 colorOpacity = gColor.Load(int3(SplatIndexToPixelIndex(index).xy, 0)); view.clipPosition = mul(float4(position, 1.0), gViewProjection); if (view.clipPosition.w > 0.0) { float3x3 rotationScale = CalcMatrixFromRotationScale(rotation, scale); float3 covariance0; float3 covariance1; CalcCovariance3D(rotationScale, covariance0, covariance1); float splatScaleSquared = gSettings.w * gSettings.w; covariance0 *= splatScaleSquared; covariance1 *= splatScaleSquared; float3 covariance2D = CalcCovariance2D(position, covariance0, covariance1); DecomposeCovariance(covariance2D, view.axis1, view.axis2); SplatSHData sh = LoadSplatSH(index); float3 viewDirection = normalize(gCameraWorldPos.xyz - position); float3 shadedColor = ShadeSH(sh, viewDirection, (int)gSettings.z); float opacity = saturate(colorOpacity.a * gSettings.y); view.packedColor.x = (f32tof16(shadedColor.r) << 16) | f32tof16(shadedColor.g); view.packedColor.y = (f32tof16(shadedColor.b) << 16) | f32tof16(opacity); } gPreparedViews[index] = view; }