refactor(rhi): untangle d3d12 descriptor bindings

This commit is contained in:
2026-03-26 12:21:49 +08:00
parent d8e14df78a
commit 36d2f479cd
8 changed files with 511 additions and 120 deletions

View File

@@ -4,6 +4,9 @@
#include "XCEngine/RHI/D3D12/D3D12ResourceView.h"
#include "XCEngine/RHI/D3D12/D3D12Sampler.h"
#include <algorithm>
#include <cstring>
namespace XCEngine {
namespace RHI {
@@ -14,6 +17,18 @@ uint64_t AlignConstantBufferSize(size_t size) {
return (minSize + 255ull) & ~255ull;
}
bool UsesDescriptorHeap(DescriptorHeapType heapType, DescriptorType bindingType) {
if (heapType == DescriptorHeapType::CBV_SRV_UAV) {
return bindingType == DescriptorType::SRV || bindingType == DescriptorType::UAV;
}
if (heapType == DescriptorHeapType::Sampler) {
return bindingType == DescriptorType::Sampler;
}
return false;
}
} // namespace
D3D12DescriptorSet::D3D12DescriptorSet()
@@ -21,7 +36,7 @@ D3D12DescriptorSet::D3D12DescriptorSet()
, m_offset(0)
, m_count(0)
, m_bindingCount(0)
, m_bindings(nullptr) {
, m_bindings() {
}
D3D12DescriptorSet::~D3D12DescriptorSet() {
@@ -33,14 +48,57 @@ bool D3D12DescriptorSet::Initialize(D3D12DescriptorHeap* heap, uint32_t offset,
m_offset = offset;
m_count = count;
m_bindingCount = layout.bindingCount;
m_bindings.clear();
m_bindingRecords.clear();
m_bindingToRecordIndex.clear();
if (layout.bindingCount > 0 && layout.bindings != nullptr) {
m_bindings = new DescriptorSetLayoutBinding[layout.bindingCount];
m_bindings.reserve(layout.bindingCount);
m_bindingRecords.reserve(layout.bindingCount);
for (uint32_t i = 0; i < layout.bindingCount; ++i) {
m_bindings[i] = layout.bindings[i];
m_bindings.push_back(layout.bindings[i]);
BindingRecord record = {};
record.layout = layout.bindings[i];
m_bindingRecords.push_back(std::move(record));
m_bindingToRecordIndex[layout.bindings[i].binding] = i;
}
uint32_t nextDescriptorIndex = 0;
const DescriptorType orderedTypes[] = {
DescriptorType::SRV,
DescriptorType::UAV,
DescriptorType::Sampler,
};
for (DescriptorType type : orderedTypes) {
if (!UsesDescriptorHeap(m_heap->GetType(), type)) {
continue;
}
std::vector<uint32_t> matchingIndices;
matchingIndices.reserve(m_bindingRecords.size());
for (uint32_t i = 0; i < m_bindingRecords.size(); ++i) {
if (static_cast<DescriptorType>(m_bindingRecords[i].layout.type) == type) {
matchingIndices.push_back(i);
}
}
std::sort(
matchingIndices.begin(),
matchingIndices.end(),
[this](uint32_t left, uint32_t right) {
return m_bindingRecords[left].layout.binding < m_bindingRecords[right].layout.binding;
});
for (uint32_t recordIndex : matchingIndices) {
m_bindingRecords[recordIndex].descriptorIndex = nextDescriptorIndex;
nextDescriptorIndex += m_bindingRecords[recordIndex].layout.count;
}
}
}
return true;
}
@@ -49,12 +107,9 @@ void D3D12DescriptorSet::Shutdown() {
m_offset = 0;
m_count = 0;
m_bindingCount = 0;
m_constantBuffer.reset();
m_constantBufferCapacity = 0;
if (m_bindings != nullptr) {
delete[] m_bindings;
m_bindings = nullptr;
}
m_bindings.clear();
m_bindingRecords.clear();
m_bindingToRecordIndex.clear();
}
void D3D12DescriptorSet::Bind() {
@@ -68,17 +123,8 @@ void D3D12DescriptorSet::Update(uint32_t offset, RHIResourceView* view) {
return;
}
uint32_t descriptorOffset = 0;
bool foundBinding = false;
for (uint32_t i = 0; i < m_bindingCount; ++i) {
if (m_bindings[i].binding == offset) {
foundBinding = true;
break;
}
descriptorOffset += m_bindings[i].count;
}
if (!foundBinding) {
const uint32_t descriptorOffset = GetDescriptorIndexForBinding(offset);
if (descriptorOffset == UINT32_MAX) {
return;
}
@@ -101,17 +147,8 @@ void D3D12DescriptorSet::UpdateSampler(uint32_t offset, RHISampler* sampler) {
return;
}
uint32_t descriptorOffset = 0;
bool foundBinding = false;
for (uint32_t i = 0; i < m_bindingCount; ++i) {
if (m_bindings[i].binding == offset) {
foundBinding = true;
break;
}
descriptorOffset += m_bindings[i].count;
}
if (!foundBinding) {
const uint32_t descriptorOffset = GetDescriptorIndexForBinding(offset);
if (descriptorOffset == UINT32_MAX) {
return;
}
@@ -122,26 +159,75 @@ void D3D12DescriptorSet::UpdateSampler(uint32_t offset, RHISampler* sampler) {
}
D3D12_GPU_DESCRIPTOR_HANDLE D3D12DescriptorSet::GetGPUHandle(uint32_t index) const {
if (m_heap == nullptr) {
if (m_heap == nullptr || index >= m_count) {
return D3D12_GPU_DESCRIPTOR_HANDLE{0};
}
GPUDescriptorHandle handle = m_heap->GetGPUDescriptorHandle(m_offset + index);
return D3D12_GPU_DESCRIPTOR_HANDLE{ handle.ptr };
}
void D3D12DescriptorSet::WriteConstant(uint32_t binding, const void* data, size_t size, size_t offset) {
(void)binding;
size_t requiredSize = offset + size;
if (m_constantBufferData.size() < requiredSize) {
m_constantBufferData.resize(requiredSize);
D3D12_GPU_DESCRIPTOR_HANDLE D3D12DescriptorSet::GetGPUHandleForBinding(uint32_t binding) const {
const uint32_t descriptorIndex = GetDescriptorIndexForBinding(binding);
if (descriptorIndex == UINT32_MAX) {
return D3D12_GPU_DESCRIPTOR_HANDLE{0};
}
return GetGPUHandle(descriptorIndex);
}
void D3D12DescriptorSet::WriteConstant(uint32_t binding, const void* data, size_t size, size_t offset) {
BindingRecord* bindingRecord = FindBindingRecord(binding);
if (bindingRecord == nullptr ||
static_cast<DescriptorType>(bindingRecord->layout.type) != DescriptorType::CBV ||
data == nullptr ||
size == 0) {
return;
}
size_t requiredSize = offset + size;
if (bindingRecord->constantBufferData.size() < requiredSize) {
bindingRecord->constantBufferData.resize(requiredSize);
}
memcpy(bindingRecord->constantBufferData.data() + offset, data, size);
bindingRecord->constantBufferDirty = true;
}
void* D3D12DescriptorSet::GetConstantBufferData() {
BindingRecord* bindingRecord = FindFirstBindingRecordOfType(DescriptorType::CBV);
if (bindingRecord == nullptr || bindingRecord->constantBufferData.empty()) {
return nullptr;
}
return bindingRecord->constantBufferData.data();
}
size_t D3D12DescriptorSet::GetConstantBufferSize() const {
const BindingRecord* bindingRecord = FindFirstBindingRecordOfType(DescriptorType::CBV);
return bindingRecord != nullptr ? bindingRecord->constantBufferData.size() : 0;
}
bool D3D12DescriptorSet::IsConstantDirty() const {
for (const BindingRecord& bindingRecord : m_bindingRecords) {
if (static_cast<DescriptorType>(bindingRecord.layout.type) == DescriptorType::CBV &&
bindingRecord.constantBufferDirty) {
return true;
}
}
return false;
}
void D3D12DescriptorSet::MarkConstantClean() {
for (BindingRecord& bindingRecord : m_bindingRecords) {
if (static_cast<DescriptorType>(bindingRecord.layout.type) == DescriptorType::CBV) {
bindingRecord.constantBufferDirty = false;
}
}
memcpy(m_constantBufferData.data() + offset, data, size);
m_constantBufferDirty = true;
}
bool D3D12DescriptorSet::HasBindingType(DescriptorType type) const {
for (uint32_t i = 0; i < m_bindingCount; ++i) {
if (static_cast<DescriptorType>(m_bindings[i].type) == type) {
for (const BindingRecord& bindingRecord : m_bindingRecords) {
if (static_cast<DescriptorType>(bindingRecord.layout.type) == type) {
return true;
}
}
@@ -149,21 +235,42 @@ bool D3D12DescriptorSet::HasBindingType(DescriptorType type) const {
}
uint32_t D3D12DescriptorSet::GetFirstBindingOfType(DescriptorType type) const {
for (uint32_t i = 0; i < m_bindingCount; ++i) {
if (static_cast<DescriptorType>(m_bindings[i].type) == type) {
return m_bindings[i].binding;
uint32_t firstBinding = UINT32_MAX;
for (const BindingRecord& bindingRecord : m_bindingRecords) {
if (static_cast<DescriptorType>(bindingRecord.layout.type) == type) {
if (bindingRecord.layout.binding < firstBinding) {
firstBinding = bindingRecord.layout.binding;
}
}
}
return UINT32_MAX;
return firstBinding;
}
bool D3D12DescriptorSet::UploadConstantBuffer() {
if (!HasBindingType(DescriptorType::CBV) || m_heap == nullptr || m_heap->GetDevice() == nullptr) {
uint32_t D3D12DescriptorSet::GetDescriptorIndexForBinding(uint32_t binding) const {
const BindingRecord* bindingRecord = FindBindingRecord(binding);
if (bindingRecord == nullptr || m_heap == nullptr) {
return UINT32_MAX;
}
if (!UsesDescriptorHeap(m_heap->GetType(), static_cast<DescriptorType>(bindingRecord->layout.type))) {
return UINT32_MAX;
}
return bindingRecord->descriptorIndex;
}
bool D3D12DescriptorSet::UploadConstantBuffer(uint32_t binding) {
BindingRecord* bindingRecord = FindBindingRecord(binding);
if (bindingRecord == nullptr ||
static_cast<DescriptorType>(bindingRecord->layout.type) != DescriptorType::CBV ||
m_heap == nullptr ||
m_heap->GetDevice() == nullptr) {
return false;
}
const uint64_t alignedSize = AlignConstantBufferSize(m_constantBufferData.size());
if (!m_constantBuffer || m_constantBufferCapacity < alignedSize) {
const uint64_t alignedSize = AlignConstantBufferSize(bindingRecord->constantBufferData.size());
if (!bindingRecord->constantBuffer || bindingRecord->constantBufferCapacity < alignedSize) {
auto constantBuffer = std::make_unique<D3D12Buffer>();
if (!constantBuffer->Initialize(
m_heap->GetDevice(),
@@ -175,21 +282,66 @@ bool D3D12DescriptorSet::UploadConstantBuffer() {
constantBuffer->SetBufferType(BufferType::Constant);
constantBuffer->SetStride(static_cast<uint32_t>(alignedSize));
m_constantBuffer = std::move(constantBuffer);
m_constantBufferCapacity = alignedSize;
m_constantBufferDirty = true;
bindingRecord->constantBuffer = std::move(constantBuffer);
bindingRecord->constantBufferCapacity = alignedSize;
bindingRecord->constantBufferDirty = true;
}
if (m_constantBufferDirty && !m_constantBufferData.empty()) {
m_constantBuffer->SetData(m_constantBufferData.data(), m_constantBufferData.size());
m_constantBufferDirty = false;
if (bindingRecord->constantBufferDirty && !bindingRecord->constantBufferData.empty()) {
bindingRecord->constantBuffer->SetData(
bindingRecord->constantBufferData.data(),
bindingRecord->constantBufferData.size());
bindingRecord->constantBufferDirty = false;
}
return m_constantBuffer != nullptr;
return bindingRecord->constantBuffer != nullptr;
}
D3D12_GPU_VIRTUAL_ADDRESS D3D12DescriptorSet::GetConstantBufferGPUAddress() const {
return m_constantBuffer ? m_constantBuffer->GetGPUVirtualAddress() : 0;
D3D12_GPU_VIRTUAL_ADDRESS D3D12DescriptorSet::GetConstantBufferGPUAddress(uint32_t binding) const {
const BindingRecord* bindingRecord = FindBindingRecord(binding);
if (bindingRecord == nullptr || bindingRecord->constantBuffer == nullptr) {
return 0;
}
return bindingRecord->constantBuffer->GetGPUVirtualAddress();
}
D3D12DescriptorSet::BindingRecord* D3D12DescriptorSet::FindBindingRecord(uint32_t binding) {
auto it = m_bindingToRecordIndex.find(binding);
if (it == m_bindingToRecordIndex.end()) {
return nullptr;
}
return &m_bindingRecords[it->second];
}
const D3D12DescriptorSet::BindingRecord* D3D12DescriptorSet::FindBindingRecord(uint32_t binding) const {
auto it = m_bindingToRecordIndex.find(binding);
if (it == m_bindingToRecordIndex.end()) {
return nullptr;
}
return &m_bindingRecords[it->second];
}
D3D12DescriptorSet::BindingRecord* D3D12DescriptorSet::FindFirstBindingRecordOfType(DescriptorType type) {
for (BindingRecord& bindingRecord : m_bindingRecords) {
if (static_cast<DescriptorType>(bindingRecord.layout.type) == type) {
return &bindingRecord;
}
}
return nullptr;
}
const D3D12DescriptorSet::BindingRecord* D3D12DescriptorSet::FindFirstBindingRecordOfType(DescriptorType type) const {
for (const BindingRecord& bindingRecord : m_bindingRecords) {
if (static_cast<DescriptorType>(bindingRecord.layout.type) == type) {
return &bindingRecord;
}
}
return nullptr;
}
} // namespace RHI