Files
XCEngine/MVS/3DGS-D3D12/shaders/PrepareGaussiansCS.hlsl

273 lines
8.7 KiB
HLSL
Raw Normal View History

#include "PreparedSplatView.hlsli"
#define GROUP_SIZE 64
cbuffer FrameConstants : register(b0)
{
float4x4 gViewProjection;
float4x4 gView;
float4x4 gProjection;
float4 gCameraWorldPos;
float4 gScreenParams;
float4 gSettings;
};
ByteAddressBuffer gPositions : register(t0);
ByteAddressBuffer gOther : register(t1);
Texture2D<float4> gColor : register(t2);
ByteAddressBuffer gSh : register(t3);
RWStructuredBuffer<PreparedSplatView> gPreparedViews : register(u0);
static const float SH_C1 = 0.4886025;
static const float SH_C2[] = { 1.0925484, -1.0925484, 0.3153916, -1.0925484, 0.5462742 };
static const float SH_C3[] = { -0.5900436, 2.8906114, -0.4570458, 0.3731763, -0.4570458, 1.4453057, -0.5900436 };
static const uint kColorTextureWidth = 2048;
static const uint kOtherStride = 16;
static const uint kShStride = 192;
struct SplatSHData
{
float3 col;
float3 sh[15];
};
float3 LoadFloat3(ByteAddressBuffer buffer, uint byteOffset)
{
return asfloat(buffer.Load3(byteOffset));
}
uint EncodeMorton2D_16x16(uint2 c)
{
uint t = ((c.y & 0xF) << 8) | (c.x & 0xF);
t = (t ^ (t << 2)) & 0x3333;
t = (t ^ (t << 1)) & 0x5555;
return (t | (t >> 7)) & 0xFF;
}
uint2 DecodeMorton2D_16x16(uint t)
{
t = (t & 0xFF) | ((t & 0xFE) << 7);
t &= 0x5555;
t = (t ^ (t >> 1)) & 0x3333;
t = (t ^ (t >> 2)) & 0x0F0F;
return uint2(t & 0xF, t >> 8);
}
uint3 SplatIndexToPixelIndex(uint index)
{
uint2 xy = DecodeMorton2D_16x16(index);
uint tileWidth = kColorTextureWidth / 16;
index >>= 8;
uint3 result;
result.x = (index % tileWidth) * 16 + xy.x;
result.y = (index / tileWidth) * 16 + xy.y;
result.z = 0;
return result;
}
float4 DecodePacked_10_10_10_2(uint encoded)
{
return float4(
(encoded & 1023) / 1023.0,
((encoded >> 10) & 1023) / 1023.0,
((encoded >> 20) & 1023) / 1023.0,
((encoded >> 30) & 3) / 3.0);
}
float4 DecodeRotation(float4 packedRotation)
{
uint droppedIndex = (uint)round(packedRotation.w * 3.0);
float4 rotation;
rotation.xyz = packedRotation.xyz * sqrt(2.0) - (1.0 / sqrt(2.0));
rotation.w = sqrt(1.0 - saturate(dot(rotation.xyz, rotation.xyz)));
if (droppedIndex == 0)
{
rotation = rotation.wxyz;
}
if (droppedIndex == 1)
{
rotation = rotation.xwyz;
}
if (droppedIndex == 2)
{
rotation = rotation.xywz;
}
return rotation;
}
float3x3 CalcMatrixFromRotationScale(float4 rotation, float3 scale)
{
float3x3 scaleMatrix = float3x3(
scale.x, 0, 0,
0, scale.y, 0,
0, 0, scale.z);
float x = rotation.x;
float y = rotation.y;
float z = rotation.z;
float w = rotation.w;
float3x3 rotationMatrix = float3x3(
1 - 2 * (y * y + z * z), 2 * (x * y - w * z), 2 * (x * z + w * y),
2 * (x * y + w * z), 1 - 2 * (x * x + z * z), 2 * (y * z - w * x),
2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x * x + y * y));
return mul(rotationMatrix, scaleMatrix);
}
void CalcCovariance3D(float3x3 rotationScaleMatrix, out float3 sigma0, out float3 sigma1)
{
float3x3 sigma = mul(rotationScaleMatrix, transpose(rotationScaleMatrix));
sigma0 = float3(sigma._m00, sigma._m01, sigma._m02);
sigma1 = float3(sigma._m11, sigma._m12, sigma._m22);
}
float3 CalcCovariance2D(float3 worldPosition, float3 covariance0, float3 covariance1)
{
float3 viewPosition = mul(float4(worldPosition, 1.0), gView).xyz;
float aspect = gProjection._m00 / gProjection._m11;
float tanFovX = rcp(gProjection._m00);
float tanFovY = rcp(gProjection._m11 * aspect);
float clampX = 1.3 * tanFovX;
float clampY = 1.3 * tanFovY;
viewPosition.x = clamp(viewPosition.x / viewPosition.z, -clampX, clampX) * viewPosition.z;
viewPosition.y = clamp(viewPosition.y / viewPosition.z, -clampY, clampY) * viewPosition.z;
float focal = gScreenParams.x * gProjection._m00 * 0.5;
float3x3 jacobian = float3x3(
focal / viewPosition.z, 0, -(focal * viewPosition.x) / (viewPosition.z * viewPosition.z),
0, focal / viewPosition.z, -(focal * viewPosition.y) / (viewPosition.z * viewPosition.z),
0, 0, 0);
float3x3 worldToView = transpose((float3x3)gView);
float3x3 transform = mul(jacobian, worldToView);
float3x3 covariance = float3x3(
covariance0.x, covariance0.y, covariance0.z,
covariance0.y, covariance1.x, covariance1.y,
covariance0.z, covariance1.y, covariance1.z);
float3x3 projected = mul(transform, mul(covariance, transpose(transform)));
projected._m00 += 0.3;
projected._m11 += 0.3;
return float3(projected._m00, projected._m01, projected._m11);
}
void DecomposeCovariance(float3 covariance2D, out float2 axis1, out float2 axis2)
{
float diagonal0 = covariance2D.x;
float diagonal1 = covariance2D.z;
float offDiagonal = covariance2D.y;
float mid = 0.5 * (diagonal0 + diagonal1);
float radius = length(float2((diagonal0 - diagonal1) * 0.5, offDiagonal));
float lambda0 = mid + radius;
float lambda1 = max(mid - radius, 0.1);
float2 diagonalVector = normalize(float2(offDiagonal, lambda0 - diagonal0));
diagonalVector.y = -diagonalVector.y;
const float maxSize = 4096.0;
axis1 = min(sqrt(2.0 * lambda0), maxSize) * diagonalVector;
axis2 = min(sqrt(2.0 * lambda1), maxSize) * float2(diagonalVector.y, -diagonalVector.x);
}
SplatSHData LoadSplatSH(uint index)
{
SplatSHData sh;
const uint shBaseOffset = index * kShStride;
sh.col = gColor.Load(int3(SplatIndexToPixelIndex(index).xy, 0)).rgb;
[unroll]
for (uint coefficientIndex = 0; coefficientIndex < 15; ++coefficientIndex)
{
sh.sh[coefficientIndex] = LoadFloat3(gSh, shBaseOffset + coefficientIndex * 12);
}
return sh;
}
float3 ShadeSH(SplatSHData sh, float3 direction, int shOrder)
{
direction *= -1.0;
float x = direction.x;
float y = direction.y;
float z = direction.z;
float3 result = sh.col;
if (shOrder >= 1)
{
result += SH_C1 * (-sh.sh[0] * y + sh.sh[1] * z - sh.sh[2] * x);
if (shOrder >= 2)
{
float xx = x * x;
float yy = y * y;
float zz = z * z;
float xy = x * y;
float yz = y * z;
float xz = x * z;
result +=
(SH_C2[0] * xy) * sh.sh[3] +
(SH_C2[1] * yz) * sh.sh[4] +
(SH_C2[2] * (2 * zz - xx - yy)) * sh.sh[5] +
(SH_C2[3] * xz) * sh.sh[6] +
(SH_C2[4] * (xx - yy)) * sh.sh[7];
if (shOrder >= 3)
{
result +=
(SH_C3[0] * y * (3 * xx - yy)) * sh.sh[8] +
(SH_C3[1] * xy * z) * sh.sh[9] +
(SH_C3[2] * y * (4 * zz - xx - yy)) * sh.sh[10] +
(SH_C3[3] * z * (2 * zz - 3 * xx - 3 * yy)) * sh.sh[11] +
(SH_C3[4] * x * (4 * zz - xx - yy)) * sh.sh[12] +
(SH_C3[5] * z * (xx - yy)) * sh.sh[13] +
(SH_C3[6] * x * (xx - 3 * yy)) * sh.sh[14];
}
}
}
return max(result, 0.0);
}
[numthreads(GROUP_SIZE, 1, 1)]
void MainCS(uint3 dispatchThreadId : SV_DispatchThreadID)
{
uint index = dispatchThreadId.x;
uint splatCount = (uint)gSettings.x;
if (index >= splatCount)
{
return;
}
PreparedSplatView view = (PreparedSplatView)0;
float3 position = LoadFloat3(gPositions, index * 12);
uint packedRotation = gOther.Load(index * kOtherStride);
float4 rotation = DecodeRotation(DecodePacked_10_10_10_2(packedRotation));
float3 scale = LoadFloat3(gOther, index * kOtherStride + 4);
float4 colorOpacity = gColor.Load(int3(SplatIndexToPixelIndex(index).xy, 0));
view.clipPosition = mul(float4(position, 1.0), gViewProjection);
if (view.clipPosition.w > 0.0)
{
float3x3 rotationScale = CalcMatrixFromRotationScale(rotation, scale);
float3 covariance0;
float3 covariance1;
CalcCovariance3D(rotationScale, covariance0, covariance1);
float splatScaleSquared = gSettings.w * gSettings.w;
covariance0 *= splatScaleSquared;
covariance1 *= splatScaleSquared;
float3 covariance2D = CalcCovariance2D(position, covariance0, covariance1);
DecomposeCovariance(covariance2D, view.axis1, view.axis2);
SplatSHData sh = LoadSplatSH(index);
float3 viewDirection = normalize(gCameraWorldPos.xyz - position);
float3 shadedColor = ShadeSH(sh, viewDirection, (int)gSettings.z);
float opacity = saturate(colorOpacity.a * gSettings.y);
view.packedColor.x = (f32tof16(shadedColor.r) << 16) | f32tof16(shadedColor.g);
view.packedColor.y = (f32tof16(shadedColor.b) << 16) | f32tof16(opacity);
}
gPreparedViews[index] = view;
}