Files
XCEngine/engine/src/RHI/D3D12/D3D12DescriptorSet.cpp

368 lines
12 KiB
C++

#include "XCEngine/RHI/D3D12/D3D12DescriptorSet.h"
#include "XCEngine/RHI/D3D12/D3D12Buffer.h"
#include "XCEngine/RHI/D3D12/D3D12DescriptorHeap.h"
#include "XCEngine/RHI/D3D12/D3D12ResourceView.h"
#include "XCEngine/RHI/D3D12/D3D12Sampler.h"
#include <algorithm>
#include <cstring>
namespace XCEngine {
namespace RHI {
namespace {
uint64_t AlignConstantBufferSize(size_t size) {
const uint64_t minSize = size > 0 ? static_cast<uint64_t>(size) : 1ull;
return (minSize + 255ull) & ~255ull;
}
bool UsesDescriptorHeap(DescriptorHeapType heapType, DescriptorType bindingType) {
if (heapType == DescriptorHeapType::CBV_SRV_UAV) {
return bindingType == DescriptorType::SRV || bindingType == DescriptorType::UAV;
}
if (heapType == DescriptorHeapType::Sampler) {
return bindingType == DescriptorType::Sampler;
}
return false;
}
} // namespace
D3D12DescriptorSet::D3D12DescriptorSet()
: m_heap(nullptr)
, m_offset(0)
, m_count(0)
, m_bindingCount(0)
, m_bindings() {
}
D3D12DescriptorSet::~D3D12DescriptorSet() {
Shutdown();
}
bool D3D12DescriptorSet::Initialize(D3D12DescriptorHeap* heap, uint32_t offset, uint32_t count, const DescriptorSetLayoutDesc& layout) {
m_heap = heap;
m_offset = offset;
m_count = count;
m_bindingCount = layout.bindingCount;
m_bindings.clear();
m_bindingRecords.clear();
m_bindingToRecordIndex.clear();
if (layout.bindingCount > 0 && layout.bindings != nullptr) {
m_bindings.reserve(layout.bindingCount);
m_bindingRecords.reserve(layout.bindingCount);
for (uint32_t i = 0; i < layout.bindingCount; ++i) {
m_bindings.push_back(layout.bindings[i]);
BindingRecord record = {};
record.layout = layout.bindings[i];
m_bindingRecords.push_back(std::move(record));
m_bindingToRecordIndex[layout.bindings[i].binding] = i;
}
uint32_t nextDescriptorIndex = 0;
const DescriptorType orderedTypes[] = {
DescriptorType::SRV,
DescriptorType::UAV,
DescriptorType::Sampler,
};
for (DescriptorType type : orderedTypes) {
if (!UsesDescriptorHeap(m_heap->GetType(), type)) {
continue;
}
std::vector<uint32_t> matchingIndices;
matchingIndices.reserve(m_bindingRecords.size());
for (uint32_t i = 0; i < m_bindingRecords.size(); ++i) {
if (static_cast<DescriptorType>(m_bindingRecords[i].layout.type) == type) {
matchingIndices.push_back(i);
}
}
std::sort(
matchingIndices.begin(),
matchingIndices.end(),
[this](uint32_t left, uint32_t right) {
return m_bindingRecords[left].layout.binding < m_bindingRecords[right].layout.binding;
});
for (uint32_t recordIndex : matchingIndices) {
m_bindingRecords[recordIndex].descriptorIndex = nextDescriptorIndex;
nextDescriptorIndex += m_bindingRecords[recordIndex].layout.count;
}
}
}
return true;
}
void D3D12DescriptorSet::Shutdown() {
m_heap = nullptr;
m_offset = 0;
m_count = 0;
m_bindingCount = 0;
m_bindings.clear();
m_bindingRecords.clear();
m_bindingToRecordIndex.clear();
}
void D3D12DescriptorSet::Bind() {
}
void D3D12DescriptorSet::Unbind() {
}
void D3D12DescriptorSet::Update(uint32_t offset, RHIResourceView* view) {
if (m_heap == nullptr || view == nullptr || m_heap->GetType() != DescriptorHeapType::CBV_SRV_UAV) {
return;
}
const uint32_t descriptorOffset = GetDescriptorIndexForBinding(offset);
if (descriptorOffset == UINT32_MAX) {
return;
}
D3D12ResourceView* d3d12View = static_cast<D3D12ResourceView*>(view);
if (!d3d12View->IsValid()) {
return;
}
CPUDescriptorHandle dstHandle = m_heap->GetCPUDescriptorHandle(m_offset + descriptorOffset);
D3D12_CPU_DESCRIPTOR_HANDLE dst = { dstHandle.ptr };
m_heap->GetDevice()->CopyDescriptorsSimple(
1,
dst,
d3d12View->GetCPUHandle(),
D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
}
void D3D12DescriptorSet::UpdateSampler(uint32_t offset, RHISampler* sampler) {
if (m_heap == nullptr || sampler == nullptr || m_heap->GetType() != DescriptorHeapType::Sampler) {
return;
}
const uint32_t descriptorOffset = GetDescriptorIndexForBinding(offset);
if (descriptorOffset == UINT32_MAX) {
return;
}
D3D12Sampler* d3d12Sampler = static_cast<D3D12Sampler*>(sampler);
CPUDescriptorHandle dstHandle = m_heap->GetCPUDescriptorHandle(m_offset + descriptorOffset);
D3D12_CPU_DESCRIPTOR_HANDLE dst = { dstHandle.ptr };
m_heap->GetDevice()->CreateSampler(&d3d12Sampler->GetDesc(), dst);
}
D3D12_GPU_DESCRIPTOR_HANDLE D3D12DescriptorSet::GetGPUHandle(uint32_t index) const {
if (m_heap == nullptr || index >= m_count) {
return D3D12_GPU_DESCRIPTOR_HANDLE{0};
}
GPUDescriptorHandle handle = m_heap->GetGPUDescriptorHandle(m_offset + index);
return D3D12_GPU_DESCRIPTOR_HANDLE{ handle.ptr };
}
D3D12_GPU_DESCRIPTOR_HANDLE D3D12DescriptorSet::GetGPUHandleForBinding(uint32_t binding) const {
const uint32_t descriptorIndex = GetDescriptorIndexForBinding(binding);
if (descriptorIndex == UINT32_MAX) {
return D3D12_GPU_DESCRIPTOR_HANDLE{0};
}
return GetGPUHandle(descriptorIndex);
}
D3D12_CPU_DESCRIPTOR_HANDLE D3D12DescriptorSet::GetCPUHandleForBinding(uint32_t binding) const {
const uint32_t descriptorIndex = GetDescriptorIndexForBinding(binding);
if (descriptorIndex == UINT32_MAX || m_heap == nullptr) {
return D3D12_CPU_DESCRIPTOR_HANDLE{0};
}
const CPUDescriptorHandle handle = m_heap->GetCPUDescriptorHandle(m_offset + descriptorIndex);
return D3D12_CPU_DESCRIPTOR_HANDLE{handle.ptr};
}
uint32_t D3D12DescriptorSet::GetDescriptorCountForBinding(uint32_t binding) const {
const BindingRecord* bindingRecord = FindBindingRecord(binding);
if (bindingRecord == nullptr) {
return 0;
}
return bindingRecord->layout.count > 0 ? bindingRecord->layout.count : 1u;
}
void D3D12DescriptorSet::WriteConstant(uint32_t binding, const void* data, size_t size, size_t offset) {
BindingRecord* bindingRecord = FindBindingRecord(binding);
if (bindingRecord == nullptr ||
static_cast<DescriptorType>(bindingRecord->layout.type) != DescriptorType::CBV ||
data == nullptr ||
size == 0) {
return;
}
size_t requiredSize = offset + size;
if (bindingRecord->constantBufferData.size() < requiredSize) {
bindingRecord->constantBufferData.resize(requiredSize);
}
memcpy(bindingRecord->constantBufferData.data() + offset, data, size);
bindingRecord->constantBufferDirty = true;
}
void* D3D12DescriptorSet::GetConstantBufferData() {
BindingRecord* bindingRecord = FindFirstBindingRecordOfType(DescriptorType::CBV);
if (bindingRecord == nullptr || bindingRecord->constantBufferData.empty()) {
return nullptr;
}
return bindingRecord->constantBufferData.data();
}
size_t D3D12DescriptorSet::GetConstantBufferSize() const {
const BindingRecord* bindingRecord = FindFirstBindingRecordOfType(DescriptorType::CBV);
return bindingRecord != nullptr ? bindingRecord->constantBufferData.size() : 0;
}
bool D3D12DescriptorSet::IsConstantDirty() const {
for (const BindingRecord& bindingRecord : m_bindingRecords) {
if (static_cast<DescriptorType>(bindingRecord.layout.type) == DescriptorType::CBV &&
bindingRecord.constantBufferDirty) {
return true;
}
}
return false;
}
void D3D12DescriptorSet::MarkConstantClean() {
for (BindingRecord& bindingRecord : m_bindingRecords) {
if (static_cast<DescriptorType>(bindingRecord.layout.type) == DescriptorType::CBV) {
bindingRecord.constantBufferDirty = false;
}
}
}
bool D3D12DescriptorSet::HasBindingType(DescriptorType type) const {
for (const BindingRecord& bindingRecord : m_bindingRecords) {
if (static_cast<DescriptorType>(bindingRecord.layout.type) == type) {
return true;
}
}
return false;
}
uint32_t D3D12DescriptorSet::GetFirstBindingOfType(DescriptorType type) const {
uint32_t firstBinding = UINT32_MAX;
for (const BindingRecord& bindingRecord : m_bindingRecords) {
if (static_cast<DescriptorType>(bindingRecord.layout.type) == type) {
if (bindingRecord.layout.binding < firstBinding) {
firstBinding = bindingRecord.layout.binding;
}
}
}
return firstBinding;
}
uint32_t D3D12DescriptorSet::GetDescriptorIndexForBinding(uint32_t binding) const {
const BindingRecord* bindingRecord = FindBindingRecord(binding);
if (bindingRecord == nullptr || m_heap == nullptr) {
return UINT32_MAX;
}
if (!UsesDescriptorHeap(m_heap->GetType(), static_cast<DescriptorType>(bindingRecord->layout.type))) {
return UINT32_MAX;
}
return bindingRecord->descriptorIndex;
}
bool D3D12DescriptorSet::UploadConstantBuffer(uint32_t binding) {
BindingRecord* bindingRecord = FindBindingRecord(binding);
if (bindingRecord == nullptr ||
static_cast<DescriptorType>(bindingRecord->layout.type) != DescriptorType::CBV ||
m_heap == nullptr ||
m_heap->GetDevice() == nullptr) {
return false;
}
const uint64_t alignedSize = AlignConstantBufferSize(bindingRecord->constantBufferData.size());
if (!bindingRecord->constantBuffer || bindingRecord->constantBufferCapacity < alignedSize) {
auto constantBuffer = std::make_unique<D3D12Buffer>();
if (!constantBuffer->Initialize(
m_heap->GetDevice(),
alignedSize,
D3D12_RESOURCE_STATE_GENERIC_READ,
D3D12_HEAP_TYPE_UPLOAD)) {
return false;
}
constantBuffer->SetBufferType(BufferType::Constant);
constantBuffer->SetStride(static_cast<uint32_t>(alignedSize));
bindingRecord->constantBuffer = std::move(constantBuffer);
bindingRecord->constantBufferCapacity = alignedSize;
bindingRecord->constantBufferDirty = true;
}
if (bindingRecord->constantBufferDirty && !bindingRecord->constantBufferData.empty()) {
bindingRecord->constantBuffer->SetData(
bindingRecord->constantBufferData.data(),
bindingRecord->constantBufferData.size());
bindingRecord->constantBufferDirty = false;
}
return bindingRecord->constantBuffer != nullptr;
}
D3D12_GPU_VIRTUAL_ADDRESS D3D12DescriptorSet::GetConstantBufferGPUAddress(uint32_t binding) const {
const BindingRecord* bindingRecord = FindBindingRecord(binding);
if (bindingRecord == nullptr || bindingRecord->constantBuffer == nullptr) {
return 0;
}
return bindingRecord->constantBuffer->GetGPUVirtualAddress();
}
D3D12DescriptorSet::BindingRecord* D3D12DescriptorSet::FindBindingRecord(uint32_t binding) {
auto it = m_bindingToRecordIndex.find(binding);
if (it == m_bindingToRecordIndex.end()) {
return nullptr;
}
return &m_bindingRecords[it->second];
}
const D3D12DescriptorSet::BindingRecord* D3D12DescriptorSet::FindBindingRecord(uint32_t binding) const {
auto it = m_bindingToRecordIndex.find(binding);
if (it == m_bindingToRecordIndex.end()) {
return nullptr;
}
return &m_bindingRecords[it->second];
}
D3D12DescriptorSet::BindingRecord* D3D12DescriptorSet::FindFirstBindingRecordOfType(DescriptorType type) {
for (BindingRecord& bindingRecord : m_bindingRecords) {
if (static_cast<DescriptorType>(bindingRecord.layout.type) == type) {
return &bindingRecord;
}
}
return nullptr;
}
const D3D12DescriptorSet::BindingRecord* D3D12DescriptorSet::FindFirstBindingRecordOfType(DescriptorType type) const {
for (const BindingRecord& bindingRecord : m_bindingRecords) {
if (static_cast<DescriptorType>(bindingRecord.layout.type) == type) {
return &bindingRecord;
}
}
return nullptr;
}
} // namespace RHI
} // namespace XCEngine