#include "XCEngine/RHI/D3D12/D3D12Shader.h" #include static const IID IID_ID3D12ShaderReflection = { 0x8e5c5a69, 0x5c6a, 0x427d, {0xb0, 0xdc, 0x27, 0x63, 0xae, 0xac, 0xe3, 0x75} }; namespace XCEngine { namespace RHI { D3D12Shader::D3D12Shader() : m_type(ShaderType::Vertex), m_uniformsCached(false) { } D3D12Shader::~D3D12Shader() { Shutdown(); } bool D3D12Shader::CompileFromFile(const wchar_t* filePath, const char* entryPoint, const char* target) { HRESULT hResult = D3DCompileFromFile(filePath, nullptr, nullptr, entryPoint, target, D3DCOMPILE_DEBUG | D3DCOMPILE_SKIP_OPTIMIZATION, 0, &m_bytecode, &m_error); if (FAILED(hResult)) { if (m_error) { const char* errorMsg = static_cast(m_error->GetBufferPointer()); OutputDebugStringA(errorMsg); fprintf(stderr, "[SHADER ERROR] %s\n", errorMsg); } return false; } if (strstr(target, "vs_")) { m_type = ShaderType::Vertex; } else if (strstr(target, "ps_")) { m_type = ShaderType::Fragment; } else if (strstr(target, "gs_")) { m_type = ShaderType::Geometry; } else if (strstr(target, "cs_")) { m_type = ShaderType::Compute; } m_uniformsCached = false; return true; } bool D3D12Shader::Compile(const void* sourceData, size_t sourceSize, const char* entryPoint, const char* target) { HRESULT hResult = D3DCompile(sourceData, sourceSize, nullptr, nullptr, nullptr, entryPoint, target, D3DCOMPILE_DEBUG | D3DCOMPILE_SKIP_OPTIMIZATION, 0, &m_bytecode, &m_error); if (FAILED(hResult)) { if (m_error) { const char* errorMsg = static_cast(m_error->GetBufferPointer()); OutputDebugStringA(errorMsg); } return false; } if (strstr(target, "vs_")) { m_type = ShaderType::Vertex; } else if (strstr(target, "ps_")) { m_type = ShaderType::Fragment; } else if (strstr(target, "gs_")) { m_type = ShaderType::Geometry; } else if (strstr(target, "cs_")) { m_type = ShaderType::Compute; } m_uniformsCached = false; return true; } void D3D12Shader::Shutdown() { m_bytecode.Reset(); m_error.Reset(); m_uniformInfos.clear(); m_uniformsCached = false; } void D3D12Shader::CacheUniformInfos() const { if (m_uniformsCached || !m_bytecode) { return; } m_uniformInfos.clear(); ComPtr pReflection; HRESULT hr = D3DReflect(m_bytecode->GetBufferPointer(), m_bytecode->GetBufferSize(), IID_ID3D12ShaderReflection, (void**)&pReflection); if (FAILED(hr)) { return; } D3D12_SHADER_DESC shaderDesc; pReflection->GetDesc(&shaderDesc); for (UINT i = 0; i < shaderDesc.BoundResources; ++i) { D3D12_SHADER_INPUT_BIND_DESC bindDesc; pReflection->GetResourceBindingDesc(i, &bindDesc); UniformInfo info; info.name = bindDesc.Name; info.bindPoint = bindDesc.BindPoint; info.arraySize = bindDesc.NumSamples; switch (bindDesc.Type) { case D3D_SIT_CBUFFER: info.type = static_cast(D3D_SIT_CBUFFER); break; case D3D_SIT_TEXTURE: info.type = static_cast(D3D_SIT_TEXTURE); break; case D3D_SIT_SAMPLER: info.type = static_cast(D3D_SIT_SAMPLER); break; default: info.type = static_cast(bindDesc.Type); break; } D3D12_SHADER_BUFFER_DESC bufferDesc; if (bindDesc.Type == D3D_SIT_CBUFFER) { ID3D12ShaderReflectionConstantBuffer* pCB = pReflection->GetConstantBufferByIndex(bindDesc.BindPoint); pCB->GetDesc(&bufferDesc); info.size = bufferDesc.Size; } else { info.size = 0; } m_uniformInfos.push_back(info); } m_uniformsCached = true; } const std::vector& D3D12Shader::GetUniformInfos() const { CacheUniformInfos(); return m_uniformInfos; } const RHIShader::UniformInfo* D3D12Shader::GetUniformInfo(const char* name) const { CacheUniformInfos(); for (const auto& info : m_uniformInfos) { if (info.name == name) { return &info; } } return nullptr; } const D3D12_SHADER_BYTECODE D3D12Shader::GetD3D12Bytecode() const { D3D12_SHADER_BYTECODE bytecode = {}; if (m_bytecode) { bytecode.pShaderBytecode = m_bytecode->GetBufferPointer(); bytecode.BytecodeLength = m_bytecode->GetBufferSize(); } return bytecode; } const void* D3D12Shader::GetBytecode() const { if (m_bytecode) { return m_bytecode->GetBufferPointer(); } return nullptr; } size_t D3D12Shader::GetBytecodeSize() const { if (m_bytecode) { return m_bytecode->GetBufferSize(); } return 0; } const InputLayoutDesc& D3D12Shader::GetInputLayout() const { return m_inputLayout; } ShaderType D3D12Shader::GetType() const { return m_type; } } // namespace RHI } // namespace XCEngine