refactor(RHI): 完成 Shader uniform 设置迁移到 CommandList

- 删除 RHIShader 的 OpenGL 风格 SetMat4/SetVec3/SetInt 等方法
- 添加 UniformInfo 结构体和 GetUniformInfos/GetUniformInfo 接口
- D3D12Shader 和 OpenGLShader 实现 CacheUniformInfos
- RHICommandList 添加 SetUniform*/SetGlobal* 统一接口
- D3D12 实现 D3D12PipelineLayout 管理 root signature 映射
- 修复 D3D12CommandList::SetPipelineStateInternal 在 Reset 后未重新应用 root signature 的问题
- 更新 OpenGL 集成测试使用新的 SetUniform* API
- 所有单元测试和集成测试通过 (8/8 integration tests)
This commit is contained in:
2026-03-24 19:47:22 +08:00
parent 135fe9145b
commit 0f5d018c1a
21 changed files with 578 additions and 54 deletions

View File

@@ -1,6 +1,8 @@
#include "XCEngine/RHI/D3D12/D3D12CommandList.h"
#include "XCEngine/RHI/D3D12/D3D12ResourceView.h"
#include "XCEngine/RHI/D3D12/D3D12PipelineState.h"
#include "XCEngine/RHI/D3D12/D3D12Shader.h"
#include "XCEngine/RHI/D3D12/D3D12PipelineLayout.h"
namespace XCEngine {
namespace RHI {
@@ -10,7 +12,9 @@ D3D12CommandList::D3D12CommandList()
, m_currentTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST)
, m_currentPipelineState(nullptr)
, m_currentRootSignature(nullptr)
, m_currentDescriptorHeap(nullptr) {
, m_currentDescriptorHeap(nullptr)
, m_currentShader(nullptr)
, m_currentPipelineLayout(nullptr) {
}
D3D12CommandList::~D3D12CommandList() {
@@ -55,6 +59,13 @@ void D3D12CommandList::Shutdown() {
m_rtvHeap.Reset();
m_resourceStateMap.clear();
m_trackedResources.clear();
m_currentShader = nullptr;
m_globalIntCache.clear();
m_globalFloatCache.clear();
m_globalVec3Cache.clear();
m_globalVec4Cache.clear();
m_globalMat4Cache.clear();
m_globalTextureCache.clear();
}
void D3D12CommandList::Reset() {
@@ -64,6 +75,7 @@ void D3D12CommandList::Reset() {
m_currentTopology = D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST;
m_currentPipelineState = nullptr;
m_currentRootSignature = nullptr;
m_currentPipelineLayout = nullptr;
m_currentDescriptorHeap = nullptr;
m_resourceStateMap.clear();
m_trackedResources.clear();
@@ -75,6 +87,103 @@ void D3D12CommandList::Close() {
m_commandList->Close();
}
void D3D12CommandList::SetShader(RHIShader* shader) {
m_currentShader = static_cast<D3D12Shader*>(shader);
}
void D3D12CommandList::SetPipelineLayout(D3D12PipelineLayout* layout) {
m_currentPipelineLayout = layout;
if (layout) {
SetRootSignature(layout->GetRootSignature());
}
}
void D3D12CommandList::SetUniformInt(const char* name, int value) {
if (m_currentShader && m_currentPipelineLayout) {
const RHIShader::UniformInfo* info = m_currentShader->GetUniformInfo(name);
if (info) {
uint32_t rootIndex = m_currentPipelineLayout->GetRootParameterIndex(info->bindPoint);
if (rootIndex != UINT32_MAX) {
m_commandList->SetGraphicsRoot32BitConstants(rootIndex, 1, &value, 0);
}
}
}
}
void D3D12CommandList::SetUniformFloat(const char* name, float value) {
if (m_currentShader && m_currentPipelineLayout) {
const RHIShader::UniformInfo* info = m_currentShader->GetUniformInfo(name);
if (info) {
uint32_t rootIndex = m_currentPipelineLayout->GetRootParameterIndex(info->bindPoint);
if (rootIndex != UINT32_MAX) {
m_commandList->SetGraphicsRoot32BitConstants(rootIndex, 1, &value, 0);
}
}
}
}
void D3D12CommandList::SetUniformVec3(const char* name, float x, float y, float z) {
if (m_currentShader && m_currentPipelineLayout) {
const RHIShader::UniformInfo* info = m_currentShader->GetUniformInfo(name);
if (info) {
uint32_t rootIndex = m_currentPipelineLayout->GetRootParameterIndex(info->bindPoint);
if (rootIndex != UINT32_MAX) {
float values[3] = { x, y, z };
m_commandList->SetGraphicsRoot32BitConstants(rootIndex, 3, values, 0);
}
}
}
}
void D3D12CommandList::SetUniformVec4(const char* name, float x, float y, float z, float w) {
if (m_currentShader && m_currentPipelineLayout) {
const RHIShader::UniformInfo* info = m_currentShader->GetUniformInfo(name);
if (info) {
uint32_t rootIndex = m_currentPipelineLayout->GetRootParameterIndex(info->bindPoint);
if (rootIndex != UINT32_MAX) {
float values[4] = { x, y, z, w };
m_commandList->SetGraphicsRoot32BitConstants(rootIndex, 4, values, 0);
}
}
}
}
void D3D12CommandList::SetUniformMat4(const char* name, const float* value) {
if (m_currentShader && m_currentPipelineLayout) {
const RHIShader::UniformInfo* info = m_currentShader->GetUniformInfo(name);
if (info) {
uint32_t rootIndex = m_currentPipelineLayout->GetRootParameterIndex(info->bindPoint);
if (rootIndex != UINT32_MAX) {
m_commandList->SetGraphicsRoot32BitConstants(rootIndex, 16, value, 0);
}
}
}
}
void D3D12CommandList::SetGlobalInt(const char* name, int value) {
m_globalIntCache[name] = value;
}
void D3D12CommandList::SetGlobalFloat(const char* name, float value) {
m_globalFloatCache[name] = value;
}
void D3D12CommandList::SetGlobalVec3(const char* name, float x, float y, float z) {
m_globalVec3Cache[name] = { x, y, z };
}
void D3D12CommandList::SetGlobalVec4(const char* name, float x, float y, float z, float w) {
m_globalVec4Cache[name] = { x, y, z, w };
}
void D3D12CommandList::SetGlobalMat4(const char* name, const float* value) {
m_globalMat4Cache[name] = std::vector<float>(value, value + 16);
}
void D3D12CommandList::SetGlobalTexture(const char* name, RHIResourceView* texture) {
m_globalTextureCache[name] = texture;
}
void D3D12CommandList::TransitionBarrier(RHIResourceView* resource, ResourceStates stateBefore, ResourceStates stateAfter) {
if (!resource || !resource->IsValid()) return;
D3D12ResourceView* d3d12View = static_cast<D3D12ResourceView*>(resource);
@@ -139,6 +248,9 @@ void D3D12CommandList::AliasBarrierInternal(ID3D12Resource* beforeResource, ID3D
void D3D12CommandList::SetPipelineStateInternal(ID3D12PipelineState* pso) {
m_commandList->SetPipelineState(pso);
m_currentPipelineState = pso;
if (m_currentRootSignature) {
m_commandList->SetGraphicsRootSignature(m_currentRootSignature);
}
}
void D3D12CommandList::SetRootSignature(ID3D12RootSignature* signature) {

View File

@@ -6,6 +6,7 @@
#include "XCEngine/RHI/D3D12/D3D12DescriptorHeap.h"
#include "XCEngine/RHI/D3D12/D3D12QueryHeap.h"
#include "XCEngine/RHI/D3D12/D3D12RootSignature.h"
#include "XCEngine/RHI/D3D12/D3D12PipelineLayout.h"
#include "XCEngine/RHI/D3D12/D3D12PipelineState.h"
#include "XCEngine/RHI/D3D12/D3D12Sampler.h"
#include "XCEngine/RHI/D3D12/D3D12Texture.h"
@@ -380,6 +381,15 @@ RHIPipelineState* D3D12Device::CreatePipelineState(const GraphicsPipelineDesc& d
return pso;
}
RHIPipelineLayout* D3D12Device::CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) {
auto* pipelineLayout = new D3D12PipelineLayout();
if (!pipelineLayout->InitializeWithDevice(this, desc)) {
delete pipelineLayout;
return nullptr;
}
return pipelineLayout;
}
RHIResourceView* D3D12Device::CreateRenderTargetView(RHITexture* texture, const ResourceViewDesc& desc) {
auto* view = new D3D12ResourceView();
auto* d3d12Texture = static_cast<D3D12Texture*>(texture);

View File

@@ -0,0 +1,119 @@
#include "XCEngine/RHI/D3D12/D3D12PipelineLayout.h"
#include "XCEngine/RHI/D3D12/D3D12Device.h"
namespace XCEngine {
namespace RHI {
D3D12PipelineLayout::D3D12PipelineLayout()
: m_device(nullptr) {
}
D3D12PipelineLayout::~D3D12PipelineLayout() {
Shutdown();
}
bool D3D12PipelineLayout::InitializeWithDevice(D3D12Device* device, const RHIPipelineLayoutDesc& desc) {
return InitializeInternal(device, desc);
}
bool D3D12PipelineLayout::InitializeInternal(D3D12Device* device, const RHIPipelineLayoutDesc& desc) {
if (!device) {
return false;
}
m_device = device;
m_rootParameters.clear();
m_registerToRootIndex.clear();
uint32_t rootIndex = 0;
if (desc.constantBufferCount > 0) {
D3D12_ROOT_PARAMETER param = D3D12RootSignature::Create32BitConstants(
0, desc.constantBufferCount * 16, ShaderVisibility::All, 0);
m_rootParameters.push_back(param);
for (uint32_t i = 0; i < desc.constantBufferCount; ++i) {
m_registerToRootIndex[i] = rootIndex;
}
rootIndex++;
}
if (desc.textureCount > 0) {
D3D12_DESCRIPTOR_RANGE range = D3D12RootSignature::CreateDescriptorRange(
D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 0, desc.textureCount, 0);
D3D12_ROOT_PARAMETER param = D3D12RootSignature::CreateDescriptorTable(1, &range, ShaderVisibility::All);
m_rootParameters.push_back(param);
for (uint32_t i = 0; i < desc.textureCount; ++i) {
m_registerToRootIndex[100 + i] = rootIndex;
}
rootIndex++;
}
if (desc.samplerCount > 0) {
D3D12_DESCRIPTOR_RANGE range = D3D12RootSignature::CreateDescriptorRange(
D3D12_DESCRIPTOR_RANGE_TYPE_SAMPLER, 0, desc.samplerCount, 0);
D3D12_ROOT_PARAMETER param = D3D12RootSignature::CreateDescriptorTable(1, &range, ShaderVisibility::All);
m_rootParameters.push_back(param);
for (uint32_t i = 0; i < desc.samplerCount; ++i) {
m_registerToRootIndex[200 + i] = rootIndex;
}
rootIndex++;
}
if (m_rootParameters.empty()) {
return false;
}
D3D12_ROOT_SIGNATURE_DESC rootSigDesc = D3D12RootSignature::CreateDesc(
m_rootParameters.data(),
static_cast<uint32_t>(m_rootParameters.size()),
nullptr, 0,
D3D12_ROOT_SIGNATURE_FLAG_ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT);
ID3D12Device* d3d12Device = device->GetDevice();
if (!d3d12Device) {
return false;
}
ID3DBlob* signature = nullptr;
ID3DBlob* error = nullptr;
HRESULT hr = D3D12SerializeRootSignature(&rootSigDesc, D3D_ROOT_SIGNATURE_VERSION_1, &signature, &error);
if (FAILED(hr)) {
if (error) {
error->Release();
}
return false;
}
hr = d3d12Device->CreateRootSignature(0, signature->GetBufferPointer(), signature->GetBufferSize(), IID_PPV_ARGS(&m_rootSignature));
signature->Release();
if (FAILED(hr)) {
return false;
}
return true;
}
void D3D12PipelineLayout::Shutdown() {
m_rootSignature.Reset();
m_rootParameters.clear();
m_registerToRootIndex.clear();
m_device = nullptr;
}
uint32_t D3D12PipelineLayout::GetRootParameterIndex(uint32_t shaderRegister) const {
auto it = m_registerToRootIndex.find(shaderRegister);
if (it != m_registerToRootIndex.end()) {
return it->second;
}
return UINT32_MAX;
}
bool D3D12PipelineLayout::HasRootParameter(uint32_t shaderRegister) const {
return m_registerToRootIndex.find(shaderRegister) != m_registerToRootIndex.end();
}
} // namespace RHI
} // namespace XCEngine

View File

@@ -1,11 +1,15 @@
#include "XCEngine/RHI/D3D12/D3D12Shader.h"
#include <d3dcompiler.h>
static const IID IID_ID3D12ShaderReflection = {
0x8e5c5a69, 0x5c6a, 0x427d, {0xb0, 0xdc, 0x27, 0x63, 0xae, 0xac, 0xe3, 0x75}
};
namespace XCEngine {
namespace RHI {
D3D12Shader::D3D12Shader()
: m_type(ShaderType::Vertex) {
: m_type(ShaderType::Vertex), m_uniformsCached(false) {
}
D3D12Shader::~D3D12Shader() {
@@ -35,6 +39,7 @@ bool D3D12Shader::CompileFromFile(const wchar_t* filePath, const char* entryPoin
m_type = ShaderType::Compute;
}
m_uniformsCached = false;
return true;
}
@@ -50,12 +55,87 @@ bool D3D12Shader::Compile(const void* sourceData, size_t sourceSize, const char*
return false;
}
m_uniformsCached = false;
return true;
}
void D3D12Shader::Shutdown() {
m_bytecode.Reset();
m_error.Reset();
m_uniformInfos.clear();
m_uniformsCached = false;
}
void D3D12Shader::CacheUniformInfos() const {
if (m_uniformsCached || !m_bytecode) {
return;
}
m_uniformInfos.clear();
ComPtr<ID3D12ShaderReflection> pReflection;
HRESULT hr = D3DReflect(m_bytecode->GetBufferPointer(), m_bytecode->GetBufferSize(),
IID_ID3D12ShaderReflection, (void**)&pReflection);
if (FAILED(hr)) {
return;
}
D3D12_SHADER_DESC shaderDesc;
pReflection->GetDesc(&shaderDesc);
for (UINT i = 0; i < shaderDesc.BoundResources; ++i) {
D3D12_SHADER_INPUT_BIND_DESC bindDesc;
pReflection->GetResourceBindingDesc(i, &bindDesc);
UniformInfo info;
info.name = bindDesc.Name;
info.bindPoint = bindDesc.BindPoint;
info.arraySize = bindDesc.NumSamples;
switch (bindDesc.Type) {
case D3D_SIT_CBUFFER:
info.type = static_cast<uint32_t>(D3D_SIT_CBUFFER);
break;
case D3D_SIT_TEXTURE:
info.type = static_cast<uint32_t>(D3D_SIT_TEXTURE);
break;
case D3D_SIT_SAMPLER:
info.type = static_cast<uint32_t>(D3D_SIT_SAMPLER);
break;
default:
info.type = static_cast<uint32_t>(bindDesc.Type);
break;
}
D3D12_SHADER_BUFFER_DESC bufferDesc;
if (bindDesc.Type == D3D_SIT_CBUFFER) {
ID3D12ShaderReflectionConstantBuffer* pCB = pReflection->GetConstantBufferByIndex(bindDesc.BindPoint);
pCB->GetDesc(&bufferDesc);
info.size = bufferDesc.Size;
} else {
info.size = 0;
}
m_uniformInfos.push_back(info);
}
m_uniformsCached = true;
}
const std::vector<RHIShader::UniformInfo>& D3D12Shader::GetUniformInfos() const {
CacheUniformInfos();
return m_uniformInfos;
}
const RHIShader::UniformInfo* D3D12Shader::GetUniformInfo(const char* name) const {
CacheUniformInfos();
for (const auto& info : m_uniformInfos) {
if (info.name == name) {
return &info;
}
}
return nullptr;
}
const D3D12_SHADER_BYTECODE D3D12Shader::GetD3D12Bytecode() const {
@@ -90,4 +170,4 @@ ShaderType D3D12Shader::GetType() const {
}
} // namespace RHI
} // namespace XCEngine
} // namespace XCEngine

View File

@@ -1,6 +1,7 @@
#include "XCEngine/RHI/OpenGL/OpenGLCommandList.h"
#include "XCEngine/RHI/OpenGL/OpenGLResourceView.h"
#include "XCEngine/RHI/OpenGL/OpenGLPipelineState.h"
#include "XCEngine/RHI/OpenGL/OpenGLShader.h"
#include <glad/glad.h>
namespace XCEngine {
@@ -24,7 +25,8 @@ static unsigned int ToGLPrimitiveType(PrimitiveType type) {
OpenGLCommandList::OpenGLCommandList()
: m_primitiveType(GL_TRIANGLES)
, m_currentVAO(0)
, m_currentProgram(0) {
, m_currentProgram(0)
, m_currentShader(nullptr) {
}
OpenGLCommandList::~OpenGLCommandList() {
@@ -423,6 +425,13 @@ void OpenGLCommandList::PopDebugGroup() {
}
void OpenGLCommandList::Shutdown() {
m_currentShader = nullptr;
m_globalIntCache.clear();
m_globalFloatCache.clear();
m_globalVec3Cache.clear();
m_globalVec4Cache.clear();
m_globalMat4Cache.clear();
m_globalTextureCache.clear();
}
void OpenGLCommandList::Reset() {
@@ -431,6 +440,67 @@ void OpenGLCommandList::Reset() {
void OpenGLCommandList::Close() {
}
void OpenGLCommandList::SetShader(RHIShader* shader) {
m_currentShader = static_cast<OpenGLShader*>(shader);
if (m_currentShader) {
UseShader(m_currentShader->GetID());
}
}
void OpenGLCommandList::SetUniformInt(const char* name, int value) {
if (m_currentShader) {
glUniform1i(glGetUniformLocation(m_currentShader->GetID(), name), value);
}
}
void OpenGLCommandList::SetUniformFloat(const char* name, float value) {
if (m_currentShader) {
glUniform1f(glGetUniformLocation(m_currentShader->GetID(), name), value);
}
}
void OpenGLCommandList::SetUniformVec3(const char* name, float x, float y, float z) {
if (m_currentShader) {
glUniform3f(glGetUniformLocation(m_currentShader->GetID(), name), x, y, z);
}
}
void OpenGLCommandList::SetUniformVec4(const char* name, float x, float y, float z, float w) {
if (m_currentShader) {
glUniform4f(glGetUniformLocation(m_currentShader->GetID(), name), x, y, z, w);
}
}
void OpenGLCommandList::SetUniformMat4(const char* name, const float* value) {
if (m_currentShader) {
glUniformMatrix4fv(glGetUniformLocation(m_currentShader->GetID(), name), 1, GL_FALSE, value);
}
}
void OpenGLCommandList::SetGlobalInt(const char* name, int value) {
m_globalIntCache[name] = value;
}
void OpenGLCommandList::SetGlobalFloat(const char* name, float value) {
m_globalFloatCache[name] = value;
}
void OpenGLCommandList::SetGlobalVec3(const char* name, float x, float y, float z) {
m_globalVec3Cache[name] = {x, y, z};
}
void OpenGLCommandList::SetGlobalVec4(const char* name, float x, float y, float z, float w) {
m_globalVec4Cache[name] = {x, y, z, w};
}
void OpenGLCommandList::SetGlobalMat4(const char* name, const float* value) {
m_globalMat4Cache[name] = std::vector<float>(value, value + 16);
}
void OpenGLCommandList::SetGlobalTexture(const char* name, RHIResourceView* texture) {
m_globalTextureCache[name] = texture;
}
void OpenGLCommandList::TransitionBarrier(RHIResourceView* resource, ResourceStates stateBefore, ResourceStates stateAfter) {
(void)resource;
(void)stateBefore;

View File

@@ -456,6 +456,10 @@ RHIPipelineState* OpenGLDevice::CreatePipelineState(const GraphicsPipelineDesc&
return pso;
}
RHIPipelineLayout* OpenGLDevice::CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) {
return nullptr;
}
RHIFence* OpenGLDevice::CreateFence(const FenceDesc& desc) {
auto* fence = new OpenGLFence();
fence->Initialize(desc.initialValue > 0);

View File

@@ -260,8 +260,7 @@ void OpenGLShader::CacheUniformInfos() const {
info.offset = static_cast<uint32_t>(values[2]);
info.arraySize = static_cast<uint32_t>(values[3]);
GLint size = 0;
glGetActiveUniformsiv(m_program, 1, &i, GL_SIZE, &size);
GLint size = values[3];
switch (values[1]) {
case GL_FLOAT: info.size = sizeof(GLfloat) * size; break;
case GL_FLOAT_VEC2: info.size = sizeof(GLfloat) * 2 * size; break;