Files
XCEngine/MVS/3DGS-D3D12/src/GaussianPlyLoader.cpp

618 lines
21 KiB
C++

#include "XC3DGSD3D12/GaussianPlyLoader.h"
#include <algorithm>
#include <array>
#include <cmath>
#include <cstring>
#include <fstream>
#include <limits>
#include <sstream>
#include <string_view>
#include <unordered_map>
#include <vector>
namespace XC3DGSD3D12 {
namespace {
constexpr float kSHC0 = 0.2820948f;
enum class PlyPropertyType {
None,
Float32,
Float64,
UInt8,
};
struct PlyProperty {
std::string name;
PlyPropertyType type = PlyPropertyType::None;
uint32_t offset = 0;
uint32_t size = 0;
};
struct PlyHeader {
uint32_t vertexCount = 0;
uint32_t vertexStride = 0;
std::vector<PlyProperty> properties;
};
struct Float4 {
float x = 0.0f;
float y = 0.0f;
float z = 0.0f;
float w = 0.0f;
};
struct RawGaussianSplat {
Float3 position = {};
Float3 dc0 = {};
std::array<Float3, GaussianSplatRuntimeData::kShCoefficientCount> sh = {};
float opacity = 0.0f;
Float3 scale = {};
Float4 rotation = {};
};
struct GaussianPlyPropertyLayout {
const PlyProperty* position[3] = {};
const PlyProperty* dc0[3] = {};
const PlyProperty* opacity = nullptr;
const PlyProperty* scale[3] = {};
const PlyProperty* rotation[4] = {};
std::array<const PlyProperty*, GaussianSplatRuntimeData::kShCoefficientCount * 3> sh = {};
};
std::string TrimTrailingCarriageReturn(std::string line) {
if (!line.empty() && line.back() == '\r') {
line.pop_back();
}
return line;
}
uint32_t PropertyTypeSize(PlyPropertyType type) {
switch (type) {
case PlyPropertyType::Float32:
return 4;
case PlyPropertyType::Float64:
return 8;
case PlyPropertyType::UInt8:
return 1;
default:
return 0;
}
}
bool ParsePropertyType(const std::string& token, PlyPropertyType& outType) {
if (token == "float") {
outType = PlyPropertyType::Float32;
return true;
}
if (token == "double") {
outType = PlyPropertyType::Float64;
return true;
}
if (token == "uchar") {
outType = PlyPropertyType::UInt8;
return true;
}
outType = PlyPropertyType::None;
return false;
}
bool ParsePlyHeader(std::ifstream& input, PlyHeader& outHeader, std::string& outErrorMessage) {
std::string line;
if (!std::getline(input, line)) {
outErrorMessage = "Failed to read PLY magic line.";
return false;
}
if (TrimTrailingCarriageReturn(line) != "ply") {
outErrorMessage = "Input file is not a valid PLY file.";
return false;
}
bool sawFormat = false;
std::string currentElement;
while (std::getline(input, line)) {
line = TrimTrailingCarriageReturn(line);
if (line == "end_header") {
break;
}
if (line.empty()) {
continue;
}
std::istringstream stream(line);
std::string token;
stream >> token;
if (token == "comment") {
continue;
}
if (token == "format") {
std::string formatName;
std::string version;
stream >> formatName >> version;
if (formatName != "binary_little_endian") {
outErrorMessage = "Only binary_little_endian PLY files are supported.";
return false;
}
sawFormat = true;
continue;
}
if (token == "element") {
stream >> currentElement;
if (currentElement == "vertex") {
stream >> outHeader.vertexCount;
}
continue;
}
if (token == "property" && currentElement == "vertex") {
std::string typeToken;
std::string name;
stream >> typeToken >> name;
PlyPropertyType propertyType = PlyPropertyType::None;
if (!ParsePropertyType(typeToken, propertyType)) {
outErrorMessage = "Unsupported PLY vertex property type: " + typeToken;
return false;
}
PlyProperty property;
property.name = name;
property.type = propertyType;
property.offset = outHeader.vertexStride;
property.size = PropertyTypeSize(propertyType);
outHeader.vertexStride += property.size;
outHeader.properties.push_back(property);
}
}
if (!sawFormat) {
outErrorMessage = "PLY header is missing a valid format declaration.";
return false;
}
if (outHeader.vertexCount == 0) {
outErrorMessage = "PLY file does not contain any vertex data.";
return false;
}
if (outHeader.vertexStride == 0 || outHeader.properties.empty()) {
outErrorMessage = "PLY vertex layout is empty.";
return false;
}
return true;
}
bool ReadPropertyAsFloat(
const std::byte* vertexBytes,
const PlyProperty& property,
float& outValue) {
const std::byte* propertyPtr = vertexBytes + property.offset;
switch (property.type) {
case PlyPropertyType::Float32: {
std::memcpy(&outValue, propertyPtr, sizeof(float));
return true;
}
case PlyPropertyType::Float64: {
double value = 0.0;
std::memcpy(&value, propertyPtr, sizeof(double));
outValue = static_cast<float>(value);
return true;
}
case PlyPropertyType::UInt8: {
uint8_t value = 0;
std::memcpy(&value, propertyPtr, sizeof(uint8_t));
outValue = static_cast<float>(value);
return true;
}
default:
return false;
}
}
bool BuildPropertyMap(
const PlyHeader& header,
std::unordered_map<std::string_view, const PlyProperty*>& outMap,
std::string& outErrorMessage) {
outMap.clear();
outMap.reserve(header.properties.size());
for (const PlyProperty& property : header.properties) {
const auto [it, inserted] = outMap.emplace(property.name, &property);
if (!inserted) {
outErrorMessage = "Duplicate PLY vertex property found: " + property.name;
return false;
}
}
return true;
}
bool RequireProperty(
const std::unordered_map<std::string_view, const PlyProperty*>& propertyMap,
std::string_view name,
const PlyProperty*& outProperty,
std::string& outErrorMessage) {
const auto iterator = propertyMap.find(name);
if (iterator == propertyMap.end()) {
outErrorMessage = "Missing required PLY property: " + std::string(name);
return false;
}
outProperty = iterator->second;
return true;
}
bool BuildGaussianPlyPropertyLayout(
const std::unordered_map<std::string_view, const PlyProperty*>& propertyMap,
GaussianPlyPropertyLayout& outLayout,
std::string& outErrorMessage) {
outLayout = {};
if (!RequireProperty(propertyMap, "x", outLayout.position[0], outErrorMessage) ||
!RequireProperty(propertyMap, "y", outLayout.position[1], outErrorMessage) ||
!RequireProperty(propertyMap, "z", outLayout.position[2], outErrorMessage) ||
!RequireProperty(propertyMap, "f_dc_0", outLayout.dc0[0], outErrorMessage) ||
!RequireProperty(propertyMap, "f_dc_1", outLayout.dc0[1], outErrorMessage) ||
!RequireProperty(propertyMap, "f_dc_2", outLayout.dc0[2], outErrorMessage) ||
!RequireProperty(propertyMap, "opacity", outLayout.opacity, outErrorMessage) ||
!RequireProperty(propertyMap, "scale_0", outLayout.scale[0], outErrorMessage) ||
!RequireProperty(propertyMap, "scale_1", outLayout.scale[1], outErrorMessage) ||
!RequireProperty(propertyMap, "scale_2", outLayout.scale[2], outErrorMessage) ||
!RequireProperty(propertyMap, "rot_0", outLayout.rotation[0], outErrorMessage) ||
!RequireProperty(propertyMap, "rot_1", outLayout.rotation[1], outErrorMessage) ||
!RequireProperty(propertyMap, "rot_2", outLayout.rotation[2], outErrorMessage) ||
!RequireProperty(propertyMap, "rot_3", outLayout.rotation[3], outErrorMessage)) {
return false;
}
for (uint32_t index = 0; index < outLayout.sh.size(); ++index) {
const std::string propertyName = "f_rest_" + std::to_string(index);
if (!RequireProperty(propertyMap, propertyName, outLayout.sh[index], outErrorMessage)) {
return false;
}
}
return true;
}
Float3 Min(const Float3& a, const Float3& b) {
return {
std::min(a.x, b.x),
std::min(a.y, b.y),
std::min(a.z, b.z),
};
}
Float3 Max(const Float3& a, const Float3& b) {
return {
std::max(a.x, b.x),
std::max(a.y, b.y),
std::max(a.z, b.z),
};
}
float Dot(const Float4& a, const Float4& b) {
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
}
Float4 NormalizeSwizzleRotation(const Float4& wxyz) {
const float lengthSquared = Dot(wxyz, wxyz);
if (lengthSquared <= std::numeric_limits<float>::epsilon()) {
return { 0.0f, 0.0f, 0.0f, 1.0f };
}
const float inverseLength = 1.0f / std::sqrt(lengthSquared);
return {
wxyz.y * inverseLength,
wxyz.z * inverseLength,
wxyz.w * inverseLength,
wxyz.x * inverseLength,
};
}
Float4 PackSmallest3Rotation(Float4 rotation) {
const Float4 absoluteRotation = {
std::fabs(rotation.x),
std::fabs(rotation.y),
std::fabs(rotation.z),
std::fabs(rotation.w),
};
int largestIndex = 0;
float largestValue = absoluteRotation.x;
if (absoluteRotation.y > largestValue) {
largestIndex = 1;
largestValue = absoluteRotation.y;
}
if (absoluteRotation.z > largestValue) {
largestIndex = 2;
largestValue = absoluteRotation.z;
}
if (absoluteRotation.w > largestValue) {
largestIndex = 3;
largestValue = absoluteRotation.w;
}
if (largestIndex == 0) {
rotation = { rotation.y, rotation.z, rotation.w, rotation.x };
} else if (largestIndex == 1) {
rotation = { rotation.x, rotation.z, rotation.w, rotation.y };
} else if (largestIndex == 2) {
rotation = { rotation.x, rotation.y, rotation.w, rotation.z };
}
const float sign = rotation.w >= 0.0f ? 1.0f : -1.0f;
const float invSqrt2 = std::sqrt(2.0f) * 0.5f;
const Float3 encoded = {
(rotation.x * sign * std::sqrt(2.0f)) * 0.5f + 0.5f,
(rotation.y * sign * std::sqrt(2.0f)) * 0.5f + 0.5f,
(rotation.z * sign * std::sqrt(2.0f)) * 0.5f + 0.5f,
};
(void)invSqrt2;
return { encoded.x, encoded.y, encoded.z, static_cast<float>(largestIndex) / 3.0f };
}
uint32_t EncodeQuatToNorm10(const Float4& packedRotation) {
const auto saturate = [](float value) {
return std::clamp(value, 0.0f, 1.0f);
};
const uint32_t x = static_cast<uint32_t>(saturate(packedRotation.x) * 1023.5f);
const uint32_t y = static_cast<uint32_t>(saturate(packedRotation.y) * 1023.5f);
const uint32_t z = static_cast<uint32_t>(saturate(packedRotation.z) * 1023.5f);
const uint32_t w = static_cast<uint32_t>(saturate(packedRotation.w) * 3.5f);
return x | (y << 10) | (z << 20) | (w << 30);
}
Float3 LinearScale(const Float3& logarithmicScale) {
return {
std::fabs(std::exp(logarithmicScale.x)),
std::fabs(std::exp(logarithmicScale.y)),
std::fabs(std::exp(logarithmicScale.z)),
};
}
Float3 SH0ToColor(const Float3& dc0) {
return {
dc0.x * kSHC0 + 0.5f,
dc0.y * kSHC0 + 0.5f,
dc0.z * kSHC0 + 0.5f,
};
}
float Sigmoid(float value) {
return 1.0f / (1.0f + std::exp(-value));
}
std::array<uint32_t, 2> DecodeMorton2D16x16(uint32_t value) {
value = (value & 0xFFu) | ((value & 0xFEu) << 7u);
value &= 0x5555u;
value = (value ^ (value >> 1u)) & 0x3333u;
value = (value ^ (value >> 2u)) & 0x0F0Fu;
return { value & 0xFu, value >> 8u };
}
uint32_t SplatIndexToTextureIndex(uint32_t index) {
const std::array<uint32_t, 2> morton = DecodeMorton2D16x16(index);
const uint32_t widthInBlocks = GaussianSplatRuntimeData::kColorTextureWidth / 16u;
index >>= 8u;
const uint32_t x = (index % widthInBlocks) * 16u + morton[0];
const uint32_t y = (index / widthInBlocks) * 16u + morton[1];
return y * GaussianSplatRuntimeData::kColorTextureWidth + x;
}
template <typename T>
void WriteValue(std::vector<std::byte>& bytes, size_t offset, const T& value) {
std::memcpy(bytes.data() + offset, &value, sizeof(T));
}
void WriteFloat3(std::vector<std::byte>& bytes, size_t offset, const Float3& value) {
WriteValue(bytes, offset + 0, value.x);
WriteValue(bytes, offset + 4, value.y);
WriteValue(bytes, offset + 8, value.z);
}
void WriteFloat4(std::vector<std::byte>& bytes, size_t offset, float x, float y, float z, float w) {
WriteValue(bytes, offset + 0, x);
WriteValue(bytes, offset + 4, y);
WriteValue(bytes, offset + 8, z);
WriteValue(bytes, offset + 12, w);
}
bool ReadGaussianSplat(
const std::byte* vertexBytes,
const GaussianPlyPropertyLayout& propertyLayout,
RawGaussianSplat& outSplat,
std::string& outErrorMessage) {
auto readFloat = [&](const PlyProperty* property, float& outValue) -> bool {
if (property == nullptr) {
outErrorMessage = "Gaussian PLY property layout is incomplete.";
return false;
}
return ReadPropertyAsFloat(vertexBytes, *property, outValue);
};
if (!readFloat(propertyLayout.position[0], outSplat.position.x) ||
!readFloat(propertyLayout.position[1], outSplat.position.y) ||
!readFloat(propertyLayout.position[2], outSplat.position.z) ||
!readFloat(propertyLayout.dc0[0], outSplat.dc0.x) ||
!readFloat(propertyLayout.dc0[1], outSplat.dc0.y) ||
!readFloat(propertyLayout.dc0[2], outSplat.dc0.z) ||
!readFloat(propertyLayout.opacity, outSplat.opacity) ||
!readFloat(propertyLayout.scale[0], outSplat.scale.x) ||
!readFloat(propertyLayout.scale[1], outSplat.scale.y) ||
!readFloat(propertyLayout.scale[2], outSplat.scale.z) ||
!readFloat(propertyLayout.rotation[0], outSplat.rotation.x) ||
!readFloat(propertyLayout.rotation[1], outSplat.rotation.y) ||
!readFloat(propertyLayout.rotation[2], outSplat.rotation.z) ||
!readFloat(propertyLayout.rotation[3], outSplat.rotation.w)) {
if (outErrorMessage.empty()) {
outErrorMessage = "Failed to read required Gaussian splat PLY properties.";
}
return false;
}
std::array<float, GaussianSplatRuntimeData::kShCoefficientCount * 3> shRaw = {};
for (uint32_t index = 0; index < shRaw.size(); ++index) {
if (!readFloat(propertyLayout.sh[index], shRaw[index])) {
if (outErrorMessage.empty()) {
outErrorMessage = "Failed to read SH rest coefficients from PLY.";
}
return false;
}
}
for (uint32_t coefficientIndex = 0; coefficientIndex < GaussianSplatRuntimeData::kShCoefficientCount; ++coefficientIndex) {
outSplat.sh[coefficientIndex] = {
shRaw[coefficientIndex + 0],
shRaw[coefficientIndex + GaussianSplatRuntimeData::kShCoefficientCount],
shRaw[coefficientIndex + GaussianSplatRuntimeData::kShCoefficientCount * 2],
};
}
return true;
}
void LinearizeGaussianSplat(RawGaussianSplat& splat) {
const Float4 normalizedQuaternion = NormalizeSwizzleRotation(splat.rotation);
const Float4 packedQuaternion = PackSmallest3Rotation(normalizedQuaternion);
splat.rotation = packedQuaternion;
splat.scale = LinearScale(splat.scale);
splat.dc0 = SH0ToColor(splat.dc0);
splat.opacity = Sigmoid(splat.opacity);
}
} // namespace
bool LoadGaussianSceneFromPly(
const std::filesystem::path& filePath,
GaussianSplatRuntimeData& outData,
std::string& outErrorMessage) {
outData = {};
outErrorMessage.clear();
std::ifstream input(filePath, std::ios::binary);
if (!input.is_open()) {
outErrorMessage = "Failed to open PLY file: " + filePath.string();
return false;
}
PlyHeader header;
if (!ParsePlyHeader(input, header, outErrorMessage)) {
return false;
}
std::unordered_map<std::string_view, const PlyProperty*> propertyMap;
if (!BuildPropertyMap(header, propertyMap, outErrorMessage)) {
return false;
}
GaussianPlyPropertyLayout propertyLayout;
if (!BuildGaussianPlyPropertyLayout(propertyMap, propertyLayout, outErrorMessage)) {
return false;
}
outData.splatCount = header.vertexCount;
outData.colorTextureWidth = GaussianSplatRuntimeData::kColorTextureWidth;
outData.colorTextureHeight =
std::max<uint32_t>(1u, (header.vertexCount + outData.colorTextureWidth - 1u) / outData.colorTextureWidth);
outData.colorTextureHeight = (outData.colorTextureHeight + 15u) / 16u * 16u;
outData.positionData.resize(static_cast<size_t>(header.vertexCount) * GaussianSplatRuntimeData::kPositionStride);
outData.otherData.resize(static_cast<size_t>(header.vertexCount) * GaussianSplatRuntimeData::kOtherStride);
outData.colorData.resize(
static_cast<size_t>(outData.colorTextureWidth) *
static_cast<size_t>(outData.colorTextureHeight) *
GaussianSplatRuntimeData::kColorStride);
outData.shData.resize(static_cast<size_t>(header.vertexCount) * GaussianSplatRuntimeData::kShStride);
outData.boundsMin = {
std::numeric_limits<float>::infinity(),
std::numeric_limits<float>::infinity(),
std::numeric_limits<float>::infinity(),
};
outData.boundsMax = {
-std::numeric_limits<float>::infinity(),
-std::numeric_limits<float>::infinity(),
-std::numeric_limits<float>::infinity(),
};
std::vector<std::byte> vertexBytes(header.vertexStride);
for (uint32_t splatIndex = 0; splatIndex < header.vertexCount; ++splatIndex) {
input.read(reinterpret_cast<char*>(vertexBytes.data()), static_cast<std::streamsize>(vertexBytes.size()));
if (input.gcount() != static_cast<std::streamsize>(vertexBytes.size())) {
outErrorMessage =
"Unexpected end of file while reading Gaussian splat vertex " + std::to_string(splatIndex) + ".";
return false;
}
RawGaussianSplat splat;
if (!ReadGaussianSplat(vertexBytes.data(), propertyLayout, splat, outErrorMessage)) {
return false;
}
LinearizeGaussianSplat(splat);
outData.boundsMin = Min(outData.boundsMin, splat.position);
outData.boundsMax = Max(outData.boundsMax, splat.position);
const size_t positionOffset = static_cast<size_t>(splatIndex) * GaussianSplatRuntimeData::kPositionStride;
WriteFloat3(outData.positionData, positionOffset, splat.position);
const size_t otherOffset = static_cast<size_t>(splatIndex) * GaussianSplatRuntimeData::kOtherStride;
const uint32_t packedRotation = EncodeQuatToNorm10(splat.rotation);
WriteValue(outData.otherData, otherOffset, packedRotation);
WriteFloat3(outData.otherData, otherOffset + sizeof(uint32_t), splat.scale);
const size_t shOffset = static_cast<size_t>(splatIndex) * GaussianSplatRuntimeData::kShStride;
for (uint32_t coefficientIndex = 0; coefficientIndex < GaussianSplatRuntimeData::kShCoefficientCount; ++coefficientIndex) {
const size_t coefficientOffset = shOffset + static_cast<size_t>(coefficientIndex) * sizeof(float) * 3u;
WriteFloat3(outData.shData, coefficientOffset, splat.sh[coefficientIndex]);
}
const uint32_t textureIndex = SplatIndexToTextureIndex(splatIndex);
const size_t colorOffset = static_cast<size_t>(textureIndex) * GaussianSplatRuntimeData::kColorStride;
WriteFloat4(outData.colorData, colorOffset, splat.dc0.x, splat.dc0.y, splat.dc0.z, splat.opacity);
}
return true;
}
bool WriteGaussianSceneSummary(
const std::filesystem::path& filePath,
const GaussianSplatRuntimeData& data,
std::string& outErrorMessage) {
outErrorMessage.clear();
std::ofstream output(filePath, std::ios::binary | std::ios::trunc);
if (!output.is_open()) {
outErrorMessage = "Failed to open summary output file: " + filePath.string();
return false;
}
output << "splat_count=" << data.splatCount << '\n';
output << "color_texture_width=" << data.colorTextureWidth << '\n';
output << "color_texture_height=" << data.colorTextureHeight << '\n';
output << "bounds_min=" << data.boundsMin.x << "," << data.boundsMin.y << "," << data.boundsMin.z << '\n';
output << "bounds_max=" << data.boundsMax.x << "," << data.boundsMax.y << "," << data.boundsMax.z << '\n';
output << "position_bytes=" << data.positionData.size() << '\n';
output << "other_bytes=" << data.otherData.size() << '\n';
output << "color_bytes=" << data.colorData.size() << '\n';
output << "sh_bytes=" << data.shData.size() << '\n';
return output.good();
}
} // namespace XC3DGSD3D12