From 36d2f479cd54a8571f80972e6ac4272cca70bad8 Mon Sep 17 00:00:00 2001 From: ssdfasd <2156608475@qq.com> Date: Thu, 26 Mar 2026 12:21:49 +0800 Subject: [PATCH] refactor(rhi): untangle d3d12 descriptor bindings --- .../XCEngine/RHI/D3D12/D3D12DescriptorSet.h | 39 ++- .../XCEngine/RHI/D3D12/D3D12PipelineLayout.h | 15 +- engine/src/RHI/D3D12/D3D12CommandList.cpp | 95 ++++-- engine/src/RHI/D3D12/D3D12DescriptorHeap.cpp | 20 +- engine/src/RHI/D3D12/D3D12DescriptorSet.cpp | 272 ++++++++++++++---- engine/src/RHI/D3D12/D3D12PipelineLayout.cpp | 66 ++++- tests/RHI/unit/test_descriptor_set.cpp | 99 +++++++ tests/RHI/unit/test_pipeline_layout.cpp | 25 +- 8 files changed, 511 insertions(+), 120 deletions(-) diff --git a/engine/include/XCEngine/RHI/D3D12/D3D12DescriptorSet.h b/engine/include/XCEngine/RHI/D3D12/D3D12DescriptorSet.h index 8cadc62e..b07f7c19 100644 --- a/engine/include/XCEngine/RHI/D3D12/D3D12DescriptorSet.h +++ b/engine/include/XCEngine/RHI/D3D12/D3D12DescriptorSet.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -32,32 +33,46 @@ public: void WriteConstant(uint32_t binding, const void* data, size_t size, size_t offset = 0) override; uint32_t GetBindingCount() const override { return m_bindingCount; } - const DescriptorSetLayoutBinding* GetBindings() const override { return m_bindings; } + const DescriptorSetLayoutBinding* GetBindings() const override { return m_bindings.empty() ? nullptr : m_bindings.data(); } - void* GetConstantBufferData() override { return m_constantBufferData.data(); } - size_t GetConstantBufferSize() const override { return m_constantBufferData.size(); } - bool IsConstantDirty() const override { return m_constantBufferDirty; } - void MarkConstantClean() override { m_constantBufferDirty = false; } + void* GetConstantBufferData() override; + size_t GetConstantBufferSize() const override; + bool IsConstantDirty() const override; + void MarkConstantClean() override; D3D12_GPU_DESCRIPTOR_HANDLE GetGPUHandle(uint32_t index = 0) const; + D3D12_GPU_DESCRIPTOR_HANDLE GetGPUHandleForBinding(uint32_t binding) const; uint32_t GetOffset() const { return m_offset; } uint32_t GetCount() const { return m_count; } D3D12DescriptorHeap* GetHeap() const { return m_heap; } bool HasBindingType(DescriptorType type) const; uint32_t GetFirstBindingOfType(DescriptorType type) const; - bool UploadConstantBuffer(); - D3D12_GPU_VIRTUAL_ADDRESS GetConstantBufferGPUAddress() const; + uint32_t GetDescriptorIndexForBinding(uint32_t binding) const; + bool UploadConstantBuffer(uint32_t binding); + D3D12_GPU_VIRTUAL_ADDRESS GetConstantBufferGPUAddress(uint32_t binding) const; private: + struct BindingRecord { + DescriptorSetLayoutBinding layout = {}; + uint32_t descriptorIndex = UINT32_MAX; + std::vector constantBufferData; + bool constantBufferDirty = false; + std::unique_ptr constantBuffer; + uint64_t constantBufferCapacity = 0; + }; + + BindingRecord* FindBindingRecord(uint32_t binding); + const BindingRecord* FindBindingRecord(uint32_t binding) const; + BindingRecord* FindFirstBindingRecordOfType(DescriptorType type); + const BindingRecord* FindFirstBindingRecordOfType(DescriptorType type) const; + D3D12DescriptorHeap* m_heap; uint32_t m_offset; uint32_t m_count; uint32_t m_bindingCount; - DescriptorSetLayoutBinding* m_bindings; - std::vector m_constantBufferData; - bool m_constantBufferDirty = false; - std::unique_ptr m_constantBuffer; - uint64_t m_constantBufferCapacity = 0; + std::vector m_bindings; + std::vector m_bindingRecords; + std::unordered_map m_bindingToRecordIndex; }; } // namespace RHI diff --git a/engine/include/XCEngine/RHI/D3D12/D3D12PipelineLayout.h b/engine/include/XCEngine/RHI/D3D12/D3D12PipelineLayout.h index 8301082f..802941ab 100644 --- a/engine/include/XCEngine/RHI/D3D12/D3D12PipelineLayout.h +++ b/engine/include/XCEngine/RHI/D3D12/D3D12PipelineLayout.h @@ -27,8 +27,14 @@ public: void* GetNativeHandle() override { return m_rootSignature.Get(); } ID3D12RootSignature* GetRootSignature() const { return m_rootSignature.Get(); } - uint32_t GetRootParameterIndex(uint32_t shaderRegister) const; - bool HasRootParameter(uint32_t shaderRegister) const; + bool HasConstantBufferBinding(uint32_t binding) const; + uint32_t GetConstantBufferRootParameterIndex(uint32_t binding) const; + bool HasShaderResourceTable() const; + uint32_t GetShaderResourceTableRootParameterIndex() const; + bool HasUnorderedAccessTable() const; + uint32_t GetUnorderedAccessTableRootParameterIndex() const; + bool HasSamplerTable() const; + uint32_t GetSamplerTableRootParameterIndex() const; const RHIPipelineLayoutDesc& GetDesc() const { return m_desc; } private: @@ -37,7 +43,10 @@ private: ComPtr m_rootSignature; D3D12Device* m_device; RHIPipelineLayoutDesc m_desc = {}; - std::unordered_map m_registerToRootIndex; + std::unordered_map m_constantBufferRootIndices; + uint32_t m_shaderResourceTableRootIndex = UINT32_MAX; + uint32_t m_unorderedAccessTableRootIndex = UINT32_MAX; + uint32_t m_samplerTableRootIndex = UINT32_MAX; std::vector m_rootParameters; std::vector m_descriptorRanges; }; diff --git a/engine/src/RHI/D3D12/D3D12CommandList.cpp b/engine/src/RHI/D3D12/D3D12CommandList.cpp index fdffcc55..8d4aee1b 100644 --- a/engine/src/RHI/D3D12/D3D12CommandList.cpp +++ b/engine/src/RHI/D3D12/D3D12CommandList.cpp @@ -198,14 +198,15 @@ void D3D12CommandList::SetGraphicsDescriptorSets( } D3D12DescriptorSet* d3d12Set = static_cast(descriptorSets[i]); - if (d3d12Set->HasBindingType(DescriptorType::CBV)) { - const uint32_t cbvBinding = d3d12Set->GetFirstBindingOfType(DescriptorType::CBV); - if (cbvBinding != UINT32_MAX && - d3d12Layout->HasRootParameter(cbvBinding) && - d3d12Set->UploadConstantBuffer()) { + const DescriptorSetLayoutBinding* bindings = d3d12Set->GetBindings(); + for (uint32_t bindingIndex = 0; bindingIndex < d3d12Set->GetBindingCount(); ++bindingIndex) { + const DescriptorSetLayoutBinding& binding = bindings[bindingIndex]; + if (static_cast(binding.type) == DescriptorType::CBV && + d3d12Layout->HasConstantBufferBinding(binding.binding) && + d3d12Set->UploadConstantBuffer(binding.binding)) { SetGraphicsRootConstantBufferView( - d3d12Layout->GetRootParameterIndex(cbvBinding), - d3d12Set->GetConstantBufferGPUAddress()); + d3d12Layout->GetConstantBufferRootParameterIndex(binding.binding), + d3d12Set->GetConstantBufferGPUAddress(binding.binding)); } } @@ -214,11 +215,30 @@ void D3D12CommandList::SetGraphicsDescriptorSets( continue; } - D3D12_GPU_DESCRIPTOR_HANDLE gpuHandle = d3d12Set->GetGPUHandle(); - if (heap->GetType() == DescriptorHeapType::CBV_SRV_UAV && HasDescriptorTableBindings(d3d12Set) && d3d12Layout->HasRootParameter(100)) { - SetGraphicsRootDescriptorTable(d3d12Layout->GetRootParameterIndex(100), gpuHandle); - } else if (heap->GetType() == DescriptorHeapType::Sampler && HasSamplerBindings(d3d12Set) && d3d12Layout->HasRootParameter(200)) { - SetGraphicsRootDescriptorTable(d3d12Layout->GetRootParameterIndex(200), gpuHandle); + if (heap->GetType() == DescriptorHeapType::CBV_SRV_UAV) { + if (d3d12Set->HasBindingType(DescriptorType::SRV) && d3d12Layout->HasShaderResourceTable()) { + const uint32_t srvBinding = d3d12Set->GetFirstBindingOfType(DescriptorType::SRV); + const D3D12_GPU_DESCRIPTOR_HANDLE gpuHandle = d3d12Set->GetGPUHandleForBinding(srvBinding); + if (gpuHandle.ptr != 0) { + SetGraphicsRootDescriptorTable(d3d12Layout->GetShaderResourceTableRootParameterIndex(), gpuHandle); + } + } + + if (d3d12Set->HasBindingType(DescriptorType::UAV) && d3d12Layout->HasUnorderedAccessTable()) { + const uint32_t uavBinding = d3d12Set->GetFirstBindingOfType(DescriptorType::UAV); + const D3D12_GPU_DESCRIPTOR_HANDLE gpuHandle = d3d12Set->GetGPUHandleForBinding(uavBinding); + if (gpuHandle.ptr != 0) { + SetGraphicsRootDescriptorTable(d3d12Layout->GetUnorderedAccessTableRootParameterIndex(), gpuHandle); + } + } + } else if (heap->GetType() == DescriptorHeapType::Sampler && + HasSamplerBindings(d3d12Set) && + d3d12Layout->HasSamplerTable()) { + const uint32_t samplerBinding = d3d12Set->GetFirstBindingOfType(DescriptorType::Sampler); + const D3D12_GPU_DESCRIPTOR_HANDLE gpuHandle = d3d12Set->GetGPUHandleForBinding(samplerBinding); + if (gpuHandle.ptr != 0) { + SetGraphicsRootDescriptorTable(d3d12Layout->GetSamplerTableRootParameterIndex(), gpuHandle); + } } } } @@ -234,6 +254,7 @@ void D3D12CommandList::SetComputeDescriptorSets( D3D12PipelineLayout* d3d12Layout = static_cast(pipelineLayout); SetPipelineLayout(d3d12Layout); + m_commandList->SetComputeRootSignature(d3d12Layout->GetRootSignature()); std::vector descriptorHeaps; descriptorHeaps.reserve(2); @@ -265,14 +286,15 @@ void D3D12CommandList::SetComputeDescriptorSets( } D3D12DescriptorSet* d3d12Set = static_cast(descriptorSets[i]); - if (d3d12Set->HasBindingType(DescriptorType::CBV)) { - const uint32_t cbvBinding = d3d12Set->GetFirstBindingOfType(DescriptorType::CBV); - if (cbvBinding != UINT32_MAX && - d3d12Layout->HasRootParameter(cbvBinding) && - d3d12Set->UploadConstantBuffer()) { + const DescriptorSetLayoutBinding* bindings = d3d12Set->GetBindings(); + for (uint32_t bindingIndex = 0; bindingIndex < d3d12Set->GetBindingCount(); ++bindingIndex) { + const DescriptorSetLayoutBinding& binding = bindings[bindingIndex]; + if (static_cast(binding.type) == DescriptorType::CBV && + d3d12Layout->HasConstantBufferBinding(binding.binding) && + d3d12Set->UploadConstantBuffer(binding.binding)) { m_commandList->SetComputeRootConstantBufferView( - d3d12Layout->GetRootParameterIndex(cbvBinding), - d3d12Set->GetConstantBufferGPUAddress()); + d3d12Layout->GetConstantBufferRootParameterIndex(binding.binding), + d3d12Set->GetConstantBufferGPUAddress(binding.binding)); } } @@ -281,11 +303,36 @@ void D3D12CommandList::SetComputeDescriptorSets( continue; } - D3D12_GPU_DESCRIPTOR_HANDLE gpuHandle = d3d12Set->GetGPUHandle(); - if (heap->GetType() == DescriptorHeapType::CBV_SRV_UAV && HasDescriptorTableBindings(d3d12Set) && d3d12Layout->HasRootParameter(100)) { - m_commandList->SetComputeRootDescriptorTable(d3d12Layout->GetRootParameterIndex(100), gpuHandle); - } else if (heap->GetType() == DescriptorHeapType::Sampler && HasSamplerBindings(d3d12Set) && d3d12Layout->HasRootParameter(200)) { - m_commandList->SetComputeRootDescriptorTable(d3d12Layout->GetRootParameterIndex(200), gpuHandle); + if (heap->GetType() == DescriptorHeapType::CBV_SRV_UAV) { + if (d3d12Set->HasBindingType(DescriptorType::SRV) && d3d12Layout->HasShaderResourceTable()) { + const uint32_t srvBinding = d3d12Set->GetFirstBindingOfType(DescriptorType::SRV); + const D3D12_GPU_DESCRIPTOR_HANDLE gpuHandle = d3d12Set->GetGPUHandleForBinding(srvBinding); + if (gpuHandle.ptr != 0) { + m_commandList->SetComputeRootDescriptorTable( + d3d12Layout->GetShaderResourceTableRootParameterIndex(), + gpuHandle); + } + } + + if (d3d12Set->HasBindingType(DescriptorType::UAV) && d3d12Layout->HasUnorderedAccessTable()) { + const uint32_t uavBinding = d3d12Set->GetFirstBindingOfType(DescriptorType::UAV); + const D3D12_GPU_DESCRIPTOR_HANDLE gpuHandle = d3d12Set->GetGPUHandleForBinding(uavBinding); + if (gpuHandle.ptr != 0) { + m_commandList->SetComputeRootDescriptorTable( + d3d12Layout->GetUnorderedAccessTableRootParameterIndex(), + gpuHandle); + } + } + } else if (heap->GetType() == DescriptorHeapType::Sampler && + HasSamplerBindings(d3d12Set) && + d3d12Layout->HasSamplerTable()) { + const uint32_t samplerBinding = d3d12Set->GetFirstBindingOfType(DescriptorType::Sampler); + const D3D12_GPU_DESCRIPTOR_HANDLE gpuHandle = d3d12Set->GetGPUHandleForBinding(samplerBinding); + if (gpuHandle.ptr != 0) { + m_commandList->SetComputeRootDescriptorTable( + d3d12Layout->GetSamplerTableRootParameterIndex(), + gpuHandle); + } } } } diff --git a/engine/src/RHI/D3D12/D3D12DescriptorHeap.cpp b/engine/src/RHI/D3D12/D3D12DescriptorHeap.cpp index eb68d03b..a44156e5 100644 --- a/engine/src/RHI/D3D12/D3D12DescriptorHeap.cpp +++ b/engine/src/RHI/D3D12/D3D12DescriptorHeap.cpp @@ -4,6 +4,22 @@ namespace XCEngine { namespace RHI { +namespace { + +bool UsesDescriptorHeap(DescriptorHeapType heapType, DescriptorType bindingType) { + if (heapType == DescriptorHeapType::CBV_SRV_UAV) { + return bindingType == DescriptorType::SRV || bindingType == DescriptorType::UAV; + } + + if (heapType == DescriptorHeapType::Sampler) { + return bindingType == DescriptorType::Sampler; + } + + return false; +} + +} // namespace + D3D12DescriptorHeap::D3D12DescriptorHeap() : m_type(DescriptorHeapType::CBV_SRV_UAV) , m_numDescriptors(0) @@ -100,7 +116,9 @@ D3D12_DESCRIPTOR_HEAP_DESC D3D12DescriptorHeap::CreateDesc(DescriptorHeapType ty RHIDescriptorSet* D3D12DescriptorHeap::AllocateSet(const DescriptorSetLayoutDesc& layout) { uint32_t requiredDescriptors = 0; for (uint32_t i = 0; i < layout.bindingCount; ++i) { - requiredDescriptors += layout.bindings[i].count; + if (UsesDescriptorHeap(m_type, static_cast(layout.bindings[i].type))) { + requiredDescriptors += layout.bindings[i].count; + } } if (m_nextFreeOffset + requiredDescriptors > m_numDescriptors) { diff --git a/engine/src/RHI/D3D12/D3D12DescriptorSet.cpp b/engine/src/RHI/D3D12/D3D12DescriptorSet.cpp index 44566e28..9c75d46c 100644 --- a/engine/src/RHI/D3D12/D3D12DescriptorSet.cpp +++ b/engine/src/RHI/D3D12/D3D12DescriptorSet.cpp @@ -4,6 +4,9 @@ #include "XCEngine/RHI/D3D12/D3D12ResourceView.h" #include "XCEngine/RHI/D3D12/D3D12Sampler.h" +#include +#include + namespace XCEngine { namespace RHI { @@ -14,6 +17,18 @@ uint64_t AlignConstantBufferSize(size_t size) { return (minSize + 255ull) & ~255ull; } +bool UsesDescriptorHeap(DescriptorHeapType heapType, DescriptorType bindingType) { + if (heapType == DescriptorHeapType::CBV_SRV_UAV) { + return bindingType == DescriptorType::SRV || bindingType == DescriptorType::UAV; + } + + if (heapType == DescriptorHeapType::Sampler) { + return bindingType == DescriptorType::Sampler; + } + + return false; +} + } // namespace D3D12DescriptorSet::D3D12DescriptorSet() @@ -21,7 +36,7 @@ D3D12DescriptorSet::D3D12DescriptorSet() , m_offset(0) , m_count(0) , m_bindingCount(0) - , m_bindings(nullptr) { + , m_bindings() { } D3D12DescriptorSet::~D3D12DescriptorSet() { @@ -33,14 +48,57 @@ bool D3D12DescriptorSet::Initialize(D3D12DescriptorHeap* heap, uint32_t offset, m_offset = offset; m_count = count; m_bindingCount = layout.bindingCount; - + + m_bindings.clear(); + m_bindingRecords.clear(); + m_bindingToRecordIndex.clear(); + if (layout.bindingCount > 0 && layout.bindings != nullptr) { - m_bindings = new DescriptorSetLayoutBinding[layout.bindingCount]; + m_bindings.reserve(layout.bindingCount); + m_bindingRecords.reserve(layout.bindingCount); for (uint32_t i = 0; i < layout.bindingCount; ++i) { - m_bindings[i] = layout.bindings[i]; + m_bindings.push_back(layout.bindings[i]); + + BindingRecord record = {}; + record.layout = layout.bindings[i]; + m_bindingRecords.push_back(std::move(record)); + m_bindingToRecordIndex[layout.bindings[i].binding] = i; + } + + uint32_t nextDescriptorIndex = 0; + const DescriptorType orderedTypes[] = { + DescriptorType::SRV, + DescriptorType::UAV, + DescriptorType::Sampler, + }; + + for (DescriptorType type : orderedTypes) { + if (!UsesDescriptorHeap(m_heap->GetType(), type)) { + continue; + } + + std::vector matchingIndices; + matchingIndices.reserve(m_bindingRecords.size()); + for (uint32_t i = 0; i < m_bindingRecords.size(); ++i) { + if (static_cast(m_bindingRecords[i].layout.type) == type) { + matchingIndices.push_back(i); + } + } + + std::sort( + matchingIndices.begin(), + matchingIndices.end(), + [this](uint32_t left, uint32_t right) { + return m_bindingRecords[left].layout.binding < m_bindingRecords[right].layout.binding; + }); + + for (uint32_t recordIndex : matchingIndices) { + m_bindingRecords[recordIndex].descriptorIndex = nextDescriptorIndex; + nextDescriptorIndex += m_bindingRecords[recordIndex].layout.count; + } } } - + return true; } @@ -49,12 +107,9 @@ void D3D12DescriptorSet::Shutdown() { m_offset = 0; m_count = 0; m_bindingCount = 0; - m_constantBuffer.reset(); - m_constantBufferCapacity = 0; - if (m_bindings != nullptr) { - delete[] m_bindings; - m_bindings = nullptr; - } + m_bindings.clear(); + m_bindingRecords.clear(); + m_bindingToRecordIndex.clear(); } void D3D12DescriptorSet::Bind() { @@ -68,17 +123,8 @@ void D3D12DescriptorSet::Update(uint32_t offset, RHIResourceView* view) { return; } - uint32_t descriptorOffset = 0; - bool foundBinding = false; - for (uint32_t i = 0; i < m_bindingCount; ++i) { - if (m_bindings[i].binding == offset) { - foundBinding = true; - break; - } - descriptorOffset += m_bindings[i].count; - } - - if (!foundBinding) { + const uint32_t descriptorOffset = GetDescriptorIndexForBinding(offset); + if (descriptorOffset == UINT32_MAX) { return; } @@ -101,17 +147,8 @@ void D3D12DescriptorSet::UpdateSampler(uint32_t offset, RHISampler* sampler) { return; } - uint32_t descriptorOffset = 0; - bool foundBinding = false; - for (uint32_t i = 0; i < m_bindingCount; ++i) { - if (m_bindings[i].binding == offset) { - foundBinding = true; - break; - } - descriptorOffset += m_bindings[i].count; - } - - if (!foundBinding) { + const uint32_t descriptorOffset = GetDescriptorIndexForBinding(offset); + if (descriptorOffset == UINT32_MAX) { return; } @@ -122,26 +159,75 @@ void D3D12DescriptorSet::UpdateSampler(uint32_t offset, RHISampler* sampler) { } D3D12_GPU_DESCRIPTOR_HANDLE D3D12DescriptorSet::GetGPUHandle(uint32_t index) const { - if (m_heap == nullptr) { + if (m_heap == nullptr || index >= m_count) { return D3D12_GPU_DESCRIPTOR_HANDLE{0}; } GPUDescriptorHandle handle = m_heap->GetGPUDescriptorHandle(m_offset + index); return D3D12_GPU_DESCRIPTOR_HANDLE{ handle.ptr }; } -void D3D12DescriptorSet::WriteConstant(uint32_t binding, const void* data, size_t size, size_t offset) { - (void)binding; - size_t requiredSize = offset + size; - if (m_constantBufferData.size() < requiredSize) { - m_constantBufferData.resize(requiredSize); +D3D12_GPU_DESCRIPTOR_HANDLE D3D12DescriptorSet::GetGPUHandleForBinding(uint32_t binding) const { + const uint32_t descriptorIndex = GetDescriptorIndexForBinding(binding); + if (descriptorIndex == UINT32_MAX) { + return D3D12_GPU_DESCRIPTOR_HANDLE{0}; + } + + return GetGPUHandle(descriptorIndex); +} + +void D3D12DescriptorSet::WriteConstant(uint32_t binding, const void* data, size_t size, size_t offset) { + BindingRecord* bindingRecord = FindBindingRecord(binding); + if (bindingRecord == nullptr || + static_cast(bindingRecord->layout.type) != DescriptorType::CBV || + data == nullptr || + size == 0) { + return; + } + + size_t requiredSize = offset + size; + if (bindingRecord->constantBufferData.size() < requiredSize) { + bindingRecord->constantBufferData.resize(requiredSize); + } + memcpy(bindingRecord->constantBufferData.data() + offset, data, size); + bindingRecord->constantBufferDirty = true; +} + +void* D3D12DescriptorSet::GetConstantBufferData() { + BindingRecord* bindingRecord = FindFirstBindingRecordOfType(DescriptorType::CBV); + if (bindingRecord == nullptr || bindingRecord->constantBufferData.empty()) { + return nullptr; + } + + return bindingRecord->constantBufferData.data(); +} + +size_t D3D12DescriptorSet::GetConstantBufferSize() const { + const BindingRecord* bindingRecord = FindFirstBindingRecordOfType(DescriptorType::CBV); + return bindingRecord != nullptr ? bindingRecord->constantBufferData.size() : 0; +} + +bool D3D12DescriptorSet::IsConstantDirty() const { + for (const BindingRecord& bindingRecord : m_bindingRecords) { + if (static_cast(bindingRecord.layout.type) == DescriptorType::CBV && + bindingRecord.constantBufferDirty) { + return true; + } + } + + return false; +} + +void D3D12DescriptorSet::MarkConstantClean() { + for (BindingRecord& bindingRecord : m_bindingRecords) { + if (static_cast(bindingRecord.layout.type) == DescriptorType::CBV) { + bindingRecord.constantBufferDirty = false; + } } - memcpy(m_constantBufferData.data() + offset, data, size); - m_constantBufferDirty = true; } bool D3D12DescriptorSet::HasBindingType(DescriptorType type) const { - for (uint32_t i = 0; i < m_bindingCount; ++i) { - if (static_cast(m_bindings[i].type) == type) { + for (const BindingRecord& bindingRecord : m_bindingRecords) { + if (static_cast(bindingRecord.layout.type) == type) { return true; } } @@ -149,21 +235,42 @@ bool D3D12DescriptorSet::HasBindingType(DescriptorType type) const { } uint32_t D3D12DescriptorSet::GetFirstBindingOfType(DescriptorType type) const { - for (uint32_t i = 0; i < m_bindingCount; ++i) { - if (static_cast(m_bindings[i].type) == type) { - return m_bindings[i].binding; + uint32_t firstBinding = UINT32_MAX; + for (const BindingRecord& bindingRecord : m_bindingRecords) { + if (static_cast(bindingRecord.layout.type) == type) { + if (bindingRecord.layout.binding < firstBinding) { + firstBinding = bindingRecord.layout.binding; + } } } - return UINT32_MAX; + + return firstBinding; } -bool D3D12DescriptorSet::UploadConstantBuffer() { - if (!HasBindingType(DescriptorType::CBV) || m_heap == nullptr || m_heap->GetDevice() == nullptr) { +uint32_t D3D12DescriptorSet::GetDescriptorIndexForBinding(uint32_t binding) const { + const BindingRecord* bindingRecord = FindBindingRecord(binding); + if (bindingRecord == nullptr || m_heap == nullptr) { + return UINT32_MAX; + } + + if (!UsesDescriptorHeap(m_heap->GetType(), static_cast(bindingRecord->layout.type))) { + return UINT32_MAX; + } + + return bindingRecord->descriptorIndex; +} + +bool D3D12DescriptorSet::UploadConstantBuffer(uint32_t binding) { + BindingRecord* bindingRecord = FindBindingRecord(binding); + if (bindingRecord == nullptr || + static_cast(bindingRecord->layout.type) != DescriptorType::CBV || + m_heap == nullptr || + m_heap->GetDevice() == nullptr) { return false; } - const uint64_t alignedSize = AlignConstantBufferSize(m_constantBufferData.size()); - if (!m_constantBuffer || m_constantBufferCapacity < alignedSize) { + const uint64_t alignedSize = AlignConstantBufferSize(bindingRecord->constantBufferData.size()); + if (!bindingRecord->constantBuffer || bindingRecord->constantBufferCapacity < alignedSize) { auto constantBuffer = std::make_unique(); if (!constantBuffer->Initialize( m_heap->GetDevice(), @@ -175,21 +282,66 @@ bool D3D12DescriptorSet::UploadConstantBuffer() { constantBuffer->SetBufferType(BufferType::Constant); constantBuffer->SetStride(static_cast(alignedSize)); - m_constantBuffer = std::move(constantBuffer); - m_constantBufferCapacity = alignedSize; - m_constantBufferDirty = true; + bindingRecord->constantBuffer = std::move(constantBuffer); + bindingRecord->constantBufferCapacity = alignedSize; + bindingRecord->constantBufferDirty = true; } - if (m_constantBufferDirty && !m_constantBufferData.empty()) { - m_constantBuffer->SetData(m_constantBufferData.data(), m_constantBufferData.size()); - m_constantBufferDirty = false; + if (bindingRecord->constantBufferDirty && !bindingRecord->constantBufferData.empty()) { + bindingRecord->constantBuffer->SetData( + bindingRecord->constantBufferData.data(), + bindingRecord->constantBufferData.size()); + bindingRecord->constantBufferDirty = false; } - return m_constantBuffer != nullptr; + return bindingRecord->constantBuffer != nullptr; } -D3D12_GPU_VIRTUAL_ADDRESS D3D12DescriptorSet::GetConstantBufferGPUAddress() const { - return m_constantBuffer ? m_constantBuffer->GetGPUVirtualAddress() : 0; +D3D12_GPU_VIRTUAL_ADDRESS D3D12DescriptorSet::GetConstantBufferGPUAddress(uint32_t binding) const { + const BindingRecord* bindingRecord = FindBindingRecord(binding); + if (bindingRecord == nullptr || bindingRecord->constantBuffer == nullptr) { + return 0; + } + + return bindingRecord->constantBuffer->GetGPUVirtualAddress(); +} + +D3D12DescriptorSet::BindingRecord* D3D12DescriptorSet::FindBindingRecord(uint32_t binding) { + auto it = m_bindingToRecordIndex.find(binding); + if (it == m_bindingToRecordIndex.end()) { + return nullptr; + } + + return &m_bindingRecords[it->second]; +} + +const D3D12DescriptorSet::BindingRecord* D3D12DescriptorSet::FindBindingRecord(uint32_t binding) const { + auto it = m_bindingToRecordIndex.find(binding); + if (it == m_bindingToRecordIndex.end()) { + return nullptr; + } + + return &m_bindingRecords[it->second]; +} + +D3D12DescriptorSet::BindingRecord* D3D12DescriptorSet::FindFirstBindingRecordOfType(DescriptorType type) { + for (BindingRecord& bindingRecord : m_bindingRecords) { + if (static_cast(bindingRecord.layout.type) == type) { + return &bindingRecord; + } + } + + return nullptr; +} + +const D3D12DescriptorSet::BindingRecord* D3D12DescriptorSet::FindFirstBindingRecordOfType(DescriptorType type) const { + for (const BindingRecord& bindingRecord : m_bindingRecords) { + if (static_cast(bindingRecord.layout.type) == type) { + return &bindingRecord; + } + } + + return nullptr; } } // namespace RHI diff --git a/engine/src/RHI/D3D12/D3D12PipelineLayout.cpp b/engine/src/RHI/D3D12/D3D12PipelineLayout.cpp index fc3e9b29..3654dd6d 100644 --- a/engine/src/RHI/D3D12/D3D12PipelineLayout.cpp +++ b/engine/src/RHI/D3D12/D3D12PipelineLayout.cpp @@ -26,14 +26,19 @@ bool D3D12PipelineLayout::InitializeInternal(D3D12Device* device, const RHIPipel m_rootParameters.clear(); m_descriptorRanges.clear(); - m_registerToRootIndex.clear(); + m_constantBufferRootIndices.clear(); + m_shaderResourceTableRootIndex = UINT32_MAX; + m_unorderedAccessTableRootIndex = UINT32_MAX; + m_samplerTableRootIndex = UINT32_MAX; const uint32_t rootParameterCount = desc.constantBufferCount + (desc.textureCount > 0 ? 1u : 0u) + + (desc.uavCount > 0 ? 1u : 0u) + (desc.samplerCount > 0 ? 1u : 0u); const uint32_t descriptorRangeCount = (desc.textureCount > 0 ? 1u : 0u) + + (desc.uavCount > 0 ? 1u : 0u) + (desc.samplerCount > 0 ? 1u : 0u); m_rootParameters.reserve(rootParameterCount); @@ -44,7 +49,7 @@ bool D3D12PipelineLayout::InitializeInternal(D3D12Device* device, const RHIPipel for (uint32_t i = 0; i < desc.constantBufferCount; ++i) { D3D12_ROOT_PARAMETER param = D3D12RootSignature::CreateCBV(i, ShaderVisibility::All, 0); m_rootParameters.push_back(param); - m_registerToRootIndex[i] = rootIndex; + m_constantBufferRootIndices[i] = rootIndex; rootIndex++; } @@ -54,9 +59,17 @@ bool D3D12PipelineLayout::InitializeInternal(D3D12Device* device, const RHIPipel D3D12_ROOT_PARAMETER param = D3D12RootSignature::CreateDescriptorTable( 1, &m_descriptorRanges.back(), ShaderVisibility::All); m_rootParameters.push_back(param); - for (uint32_t i = 0; i < desc.textureCount; ++i) { - m_registerToRootIndex[100 + i] = rootIndex; - } + m_shaderResourceTableRootIndex = rootIndex; + rootIndex++; + } + + if (desc.uavCount > 0) { + m_descriptorRanges.push_back(D3D12RootSignature::CreateDescriptorRange( + D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 0, desc.uavCount, 0)); + D3D12_ROOT_PARAMETER param = D3D12RootSignature::CreateDescriptorTable( + 1, &m_descriptorRanges.back(), ShaderVisibility::All); + m_rootParameters.push_back(param); + m_unorderedAccessTableRootIndex = rootIndex; rootIndex++; } @@ -66,9 +79,7 @@ bool D3D12PipelineLayout::InitializeInternal(D3D12Device* device, const RHIPipel D3D12_ROOT_PARAMETER param = D3D12RootSignature::CreateDescriptorTable( 1, &m_descriptorRanges.back(), ShaderVisibility::All); m_rootParameters.push_back(param); - for (uint32_t i = 0; i < desc.samplerCount; ++i) { - m_registerToRootIndex[200 + i] = rootIndex; - } + m_samplerTableRootIndex = rootIndex; rootIndex++; } @@ -109,20 +120,47 @@ void D3D12PipelineLayout::Shutdown() { m_desc = {}; m_rootParameters.clear(); m_descriptorRanges.clear(); - m_registerToRootIndex.clear(); + m_constantBufferRootIndices.clear(); + m_shaderResourceTableRootIndex = UINT32_MAX; + m_unorderedAccessTableRootIndex = UINT32_MAX; + m_samplerTableRootIndex = UINT32_MAX; m_device = nullptr; } -uint32_t D3D12PipelineLayout::GetRootParameterIndex(uint32_t shaderRegister) const { - auto it = m_registerToRootIndex.find(shaderRegister); - if (it != m_registerToRootIndex.end()) { +bool D3D12PipelineLayout::HasConstantBufferBinding(uint32_t binding) const { + return m_constantBufferRootIndices.find(binding) != m_constantBufferRootIndices.end(); +} + +uint32_t D3D12PipelineLayout::GetConstantBufferRootParameterIndex(uint32_t binding) const { + auto it = m_constantBufferRootIndices.find(binding); + if (it != m_constantBufferRootIndices.end()) { return it->second; } return UINT32_MAX; } -bool D3D12PipelineLayout::HasRootParameter(uint32_t shaderRegister) const { - return m_registerToRootIndex.find(shaderRegister) != m_registerToRootIndex.end(); +bool D3D12PipelineLayout::HasShaderResourceTable() const { + return m_shaderResourceTableRootIndex != UINT32_MAX; +} + +uint32_t D3D12PipelineLayout::GetShaderResourceTableRootParameterIndex() const { + return m_shaderResourceTableRootIndex; +} + +bool D3D12PipelineLayout::HasUnorderedAccessTable() const { + return m_unorderedAccessTableRootIndex != UINT32_MAX; +} + +uint32_t D3D12PipelineLayout::GetUnorderedAccessTableRootParameterIndex() const { + return m_unorderedAccessTableRootIndex; +} + +bool D3D12PipelineLayout::HasSamplerTable() const { + return m_samplerTableRootIndex != UINT32_MAX; +} + +uint32_t D3D12PipelineLayout::GetSamplerTableRootParameterIndex() const { + return m_samplerTableRootIndex; } } // namespace RHI diff --git a/tests/RHI/unit/test_descriptor_set.cpp b/tests/RHI/unit/test_descriptor_set.cpp index a483e884..3e3395dc 100644 --- a/tests/RHI/unit/test_descriptor_set.cpp +++ b/tests/RHI/unit/test_descriptor_set.cpp @@ -407,6 +407,105 @@ TEST_P(RHITestFixture, DescriptorSet_MultipleAllocations_AdvanceDescriptorOffset delete pool; } +TEST_P(RHITestFixture, DescriptorSet_D3D12MixedBindings_AssignDescriptorIndicesByType) { + if (GetBackendType() != RHIType::D3D12) { + GTEST_SKIP() << "D3D12-specific descriptor index verification"; + } + + DescriptorPoolDesc poolDesc = {}; + poolDesc.type = DescriptorHeapType::CBV_SRV_UAV; + poolDesc.descriptorCount = 4; + poolDesc.shaderVisible = true; + + RHIDescriptorPool* pool = GetDevice()->CreateDescriptorPool(poolDesc); + ASSERT_NE(pool, nullptr); + + DescriptorSetLayoutBinding bindings[3] = {}; + bindings[0].binding = 2; + bindings[0].type = static_cast(DescriptorType::UAV); + bindings[0].count = 1; + bindings[1].binding = 0; + bindings[1].type = static_cast(DescriptorType::CBV); + bindings[1].count = 1; + bindings[2].binding = 1; + bindings[2].type = static_cast(DescriptorType::SRV); + bindings[2].count = 1; + + DescriptorSetLayoutDesc layoutDesc = {}; + layoutDesc.bindings = bindings; + layoutDesc.bindingCount = 3; + + RHIDescriptorSet* firstSet = pool->AllocateSet(layoutDesc); + RHIDescriptorSet* secondSet = pool->AllocateSet(layoutDesc); + ASSERT_NE(firstSet, nullptr); + ASSERT_NE(secondSet, nullptr); + + auto* firstD3D12Set = static_cast(firstSet); + auto* secondD3D12Set = static_cast(secondSet); + EXPECT_EQ(firstD3D12Set->GetCount(), 2u); + EXPECT_EQ(firstD3D12Set->GetDescriptorIndexForBinding(1), 0u); + EXPECT_EQ(firstD3D12Set->GetDescriptorIndexForBinding(2), 1u); + EXPECT_EQ(firstD3D12Set->GetDescriptorIndexForBinding(0), UINT32_MAX); + EXPECT_EQ(firstD3D12Set->GetOffset(), 0u); + EXPECT_EQ(secondD3D12Set->GetOffset(), 2u); + + firstSet->Shutdown(); + delete firstSet; + secondSet->Shutdown(); + delete secondSet; + pool->Shutdown(); + delete pool; +} + +TEST_P(RHITestFixture, DescriptorSet_D3D12MultipleConstantBuffersUploadIndependently) { + if (GetBackendType() != RHIType::D3D12) { + GTEST_SKIP() << "D3D12-specific constant buffer verification"; + } + + DescriptorPoolDesc poolDesc = {}; + poolDesc.type = DescriptorHeapType::CBV_SRV_UAV; + poolDesc.descriptorCount = 1; + poolDesc.shaderVisible = false; + + RHIDescriptorPool* pool = GetDevice()->CreateDescriptorPool(poolDesc); + ASSERT_NE(pool, nullptr); + + DescriptorSetLayoutBinding bindings[2] = {}; + bindings[0].binding = 0; + bindings[0].type = static_cast(DescriptorType::CBV); + bindings[0].count = 1; + bindings[1].binding = 1; + bindings[1].type = static_cast(DescriptorType::CBV); + bindings[1].count = 1; + + DescriptorSetLayoutDesc layoutDesc = {}; + layoutDesc.bindings = bindings; + layoutDesc.bindingCount = 2; + + RHIDescriptorSet* set = pool->AllocateSet(layoutDesc); + ASSERT_NE(set, nullptr); + + auto* d3d12Set = static_cast(set); + const float firstData[16] = { 1.0f, 0.0f, 0.0f, 0.0f }; + const float secondData[16] = { 2.0f, 0.0f, 0.0f, 0.0f }; + d3d12Set->WriteConstant(0, firstData, sizeof(firstData)); + d3d12Set->WriteConstant(1, secondData, sizeof(secondData)); + + ASSERT_TRUE(d3d12Set->UploadConstantBuffer(0)); + ASSERT_TRUE(d3d12Set->UploadConstantBuffer(1)); + + const D3D12_GPU_VIRTUAL_ADDRESS firstAddress = d3d12Set->GetConstantBufferGPUAddress(0); + const D3D12_GPU_VIRTUAL_ADDRESS secondAddress = d3d12Set->GetConstantBufferGPUAddress(1); + EXPECT_NE(firstAddress, 0u); + EXPECT_NE(secondAddress, 0u); + EXPECT_NE(firstAddress, secondAddress); + + set->Shutdown(); + delete set; + pool->Shutdown(); + delete pool; +} + TEST_P(RHITestFixture, DescriptorSet_Update_UsesBindingNumberOnOpenGL) { if (GetBackendType() != RHIType::OpenGL) { GTEST_SKIP() << "OpenGL-specific descriptor binding verification"; diff --git a/tests/RHI/unit/test_pipeline_layout.cpp b/tests/RHI/unit/test_pipeline_layout.cpp index b7750f3a..68bfee8f 100644 --- a/tests/RHI/unit/test_pipeline_layout.cpp +++ b/tests/RHI/unit/test_pipeline_layout.cpp @@ -146,7 +146,7 @@ TEST_P(RHITestFixture, PipelineLayout_DescriptorSetAllocation) { delete layout; } -TEST_P(RHITestFixture, PipelineLayout_D3D12ConstantBuffers_MapToDistinctRootParameters) { +TEST_P(RHITestFixture, PipelineLayout_D3D12TracksDistinctBindingClasses) { if (GetBackendType() != RHIType::D3D12) { GTEST_SKIP() << "D3D12-specific root parameter verification"; } @@ -154,17 +154,30 @@ TEST_P(RHITestFixture, PipelineLayout_D3D12ConstantBuffers_MapToDistinctRootPara RHIPipelineLayoutDesc desc = {}; desc.constantBufferCount = 2; desc.textureCount = 1; + desc.uavCount = 1; desc.samplerCount = 1; RHIPipelineLayout* layout = GetDevice()->CreatePipelineLayout(desc); ASSERT_NE(layout, nullptr); auto* d3d12Layout = static_cast(layout); - EXPECT_TRUE(d3d12Layout->HasRootParameter(0)); - EXPECT_TRUE(d3d12Layout->HasRootParameter(1)); - EXPECT_TRUE(d3d12Layout->HasRootParameter(100)); - EXPECT_TRUE(d3d12Layout->HasRootParameter(200)); - EXPECT_NE(d3d12Layout->GetRootParameterIndex(0), d3d12Layout->GetRootParameterIndex(1)); + EXPECT_TRUE(d3d12Layout->HasConstantBufferBinding(0)); + EXPECT_TRUE(d3d12Layout->HasConstantBufferBinding(1)); + EXPECT_FALSE(d3d12Layout->HasConstantBufferBinding(2)); + + EXPECT_TRUE(d3d12Layout->HasShaderResourceTable()); + EXPECT_TRUE(d3d12Layout->HasUnorderedAccessTable()); + EXPECT_TRUE(d3d12Layout->HasSamplerTable()); + + EXPECT_NE( + d3d12Layout->GetConstantBufferRootParameterIndex(0), + d3d12Layout->GetConstantBufferRootParameterIndex(1)); + EXPECT_NE( + d3d12Layout->GetShaderResourceTableRootParameterIndex(), + d3d12Layout->GetUnorderedAccessTableRootParameterIndex()); + EXPECT_NE( + d3d12Layout->GetUnorderedAccessTableRootParameterIndex(), + d3d12Layout->GetSamplerTableRootParameterIndex()); layout->Shutdown(); delete layout;