Files
XCEngine/engine/src/RHI/D3D12/D3D12Shader.cpp
ssdfasd 720dd422d5 RHI: Add Compute/Dispatch unit tests (P1-7) and fix shader type bugs
Bug fixes:
- D3D12Shader::Compile: Set m_type based on target string (cs_/vs_/ps_/gs_)
- OpenGLShader::Compile: Parse target parameter to determine shader type
- OpenGLShader::CompileCompute: Set m_type = ShaderType::Compute
- D3D12CommandList::SetPipelineState: Use correct PSO handle for Compute

New tests (test_compute.cpp, 8 tests):
- ComputeShader_Compile_ValidShader
- ComputeShader_GetType_ReturnsCompute
- ComputeShader_Shutdown_Invalidates
- PipelineState_SetComputeShader
- PipelineState_HasComputeShader_ReturnsTrue
- PipelineState_GetType_Compute
- PipelineState_EnsureValid_Compute
- CommandList_Dispatch_Basic

Test results: 232/232 passed (D3D12: 116, OpenGL: 116)
2026-03-25 13:52:11 +08:00

183 lines
5.1 KiB
C++

#include "XCEngine/RHI/D3D12/D3D12Shader.h"
#include <d3dcompiler.h>
static const IID IID_ID3D12ShaderReflection = {
0x8e5c5a69, 0x5c6a, 0x427d, {0xb0, 0xdc, 0x27, 0x63, 0xae, 0xac, 0xe3, 0x75}
};
namespace XCEngine {
namespace RHI {
D3D12Shader::D3D12Shader()
: m_type(ShaderType::Vertex), m_uniformsCached(false) {
}
D3D12Shader::~D3D12Shader() {
Shutdown();
}
bool D3D12Shader::CompileFromFile(const wchar_t* filePath, const char* entryPoint, const char* target) {
HRESULT hResult = D3DCompileFromFile(filePath, nullptr, nullptr, entryPoint, target,
D3DCOMPILE_DEBUG | D3DCOMPILE_SKIP_OPTIMIZATION, 0, &m_bytecode, &m_error);
if (FAILED(hResult)) {
if (m_error) {
const char* errorMsg = static_cast<const char*>(m_error->GetBufferPointer());
OutputDebugStringA(errorMsg);
fprintf(stderr, "[SHADER ERROR] %s\n", errorMsg);
}
return false;
}
if (strstr(target, "vs_")) {
m_type = ShaderType::Vertex;
} else if (strstr(target, "ps_")) {
m_type = ShaderType::Fragment;
} else if (strstr(target, "gs_")) {
m_type = ShaderType::Geometry;
} else if (strstr(target, "cs_")) {
m_type = ShaderType::Compute;
}
m_uniformsCached = false;
return true;
}
bool D3D12Shader::Compile(const void* sourceData, size_t sourceSize, const char* entryPoint, const char* target) {
HRESULT hResult = D3DCompile(sourceData, sourceSize, nullptr, nullptr, nullptr, entryPoint, target,
D3DCOMPILE_DEBUG | D3DCOMPILE_SKIP_OPTIMIZATION, 0, &m_bytecode, &m_error);
if (FAILED(hResult)) {
if (m_error) {
const char* errorMsg = static_cast<const char*>(m_error->GetBufferPointer());
OutputDebugStringA(errorMsg);
}
return false;
}
if (strstr(target, "vs_")) {
m_type = ShaderType::Vertex;
} else if (strstr(target, "ps_")) {
m_type = ShaderType::Fragment;
} else if (strstr(target, "gs_")) {
m_type = ShaderType::Geometry;
} else if (strstr(target, "cs_")) {
m_type = ShaderType::Compute;
}
m_uniformsCached = false;
return true;
}
void D3D12Shader::Shutdown() {
m_bytecode.Reset();
m_error.Reset();
m_uniformInfos.clear();
m_uniformsCached = false;
}
void D3D12Shader::CacheUniformInfos() const {
if (m_uniformsCached || !m_bytecode) {
return;
}
m_uniformInfos.clear();
ComPtr<ID3D12ShaderReflection> pReflection;
HRESULT hr = D3DReflect(m_bytecode->GetBufferPointer(), m_bytecode->GetBufferSize(),
IID_ID3D12ShaderReflection, (void**)&pReflection);
if (FAILED(hr)) {
return;
}
D3D12_SHADER_DESC shaderDesc;
pReflection->GetDesc(&shaderDesc);
for (UINT i = 0; i < shaderDesc.BoundResources; ++i) {
D3D12_SHADER_INPUT_BIND_DESC bindDesc;
pReflection->GetResourceBindingDesc(i, &bindDesc);
UniformInfo info;
info.name = bindDesc.Name;
info.bindPoint = bindDesc.BindPoint;
info.arraySize = bindDesc.NumSamples;
switch (bindDesc.Type) {
case D3D_SIT_CBUFFER:
info.type = static_cast<uint32_t>(D3D_SIT_CBUFFER);
break;
case D3D_SIT_TEXTURE:
info.type = static_cast<uint32_t>(D3D_SIT_TEXTURE);
break;
case D3D_SIT_SAMPLER:
info.type = static_cast<uint32_t>(D3D_SIT_SAMPLER);
break;
default:
info.type = static_cast<uint32_t>(bindDesc.Type);
break;
}
D3D12_SHADER_BUFFER_DESC bufferDesc;
if (bindDesc.Type == D3D_SIT_CBUFFER) {
ID3D12ShaderReflectionConstantBuffer* pCB = pReflection->GetConstantBufferByIndex(bindDesc.BindPoint);
pCB->GetDesc(&bufferDesc);
info.size = bufferDesc.Size;
} else {
info.size = 0;
}
m_uniformInfos.push_back(info);
}
m_uniformsCached = true;
}
const std::vector<RHIShader::UniformInfo>& D3D12Shader::GetUniformInfos() const {
CacheUniformInfos();
return m_uniformInfos;
}
const RHIShader::UniformInfo* D3D12Shader::GetUniformInfo(const char* name) const {
CacheUniformInfos();
for (const auto& info : m_uniformInfos) {
if (info.name == name) {
return &info;
}
}
return nullptr;
}
const D3D12_SHADER_BYTECODE D3D12Shader::GetD3D12Bytecode() const {
D3D12_SHADER_BYTECODE bytecode = {};
if (m_bytecode) {
bytecode.pShaderBytecode = m_bytecode->GetBufferPointer();
bytecode.BytecodeLength = m_bytecode->GetBufferSize();
}
return bytecode;
}
const void* D3D12Shader::GetBytecode() const {
if (m_bytecode) {
return m_bytecode->GetBufferPointer();
}
return nullptr;
}
size_t D3D12Shader::GetBytecodeSize() const {
if (m_bytecode) {
return m_bytecode->GetBufferSize();
}
return 0;
}
const InputLayoutDesc& D3D12Shader::GetInputLayout() const {
return m_inputLayout;
}
ShaderType D3D12Shader::GetType() const {
return m_type;
}
} // namespace RHI
} // namespace XCEngine