Add formal compute pipeline creation API

This commit is contained in:
2026-04-11 02:27:33 +08:00
parent d9bc0f1457
commit 5191bb1149
17 changed files with 324 additions and 71 deletions

View File

@@ -72,6 +72,7 @@ public:
RHICommandQueue* CreateCommandQueue(const CommandQueueDesc& desc) override; RHICommandQueue* CreateCommandQueue(const CommandQueueDesc& desc) override;
RHIShader* CreateShader(const ShaderCompileDesc& desc) override; RHIShader* CreateShader(const ShaderCompileDesc& desc) override;
RHIPipelineState* CreatePipelineState(const GraphicsPipelineDesc& desc) override; RHIPipelineState* CreatePipelineState(const GraphicsPipelineDesc& desc) override;
RHIPipelineState* CreateComputePipelineState(const ComputePipelineDesc& desc) override;
RHIPipelineLayout* CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) override; RHIPipelineLayout* CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) override;
RHIFence* CreateFence(const FenceDesc& desc) override; RHIFence* CreateFence(const FenceDesc& desc) override;
RHISampler* CreateSampler(const SamplerDesc& desc) override; RHISampler* CreateSampler(const SamplerDesc& desc) override;

View File

@@ -3,6 +3,7 @@
#include <d3d12.h> #include <d3d12.h>
#include <dxgi1_4.h> #include <dxgi1_4.h>
#include <wrl/client.h> #include <wrl/client.h>
#include <memory>
#include <vector> #include <vector>
#include "../RHIPipelineState.h" #include "../RHIPipelineState.h"
@@ -31,7 +32,9 @@ public:
void SetTopology(uint32_t topologyType) override; void SetTopology(uint32_t topologyType) override;
void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override; void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override;
void SetSampleCount(uint32_t count) override; void SetSampleCount(uint32_t count) override;
void SetSampleQuality(uint32_t quality) override;
void SetComputeShader(RHIShader* shader) override; void SetComputeShader(RHIShader* shader) override;
void SetOwnedComputeShader(std::unique_ptr<class D3D12Shader> shader);
void SetRootSignature(ID3D12RootSignature* rootSignature); void SetRootSignature(ID3D12RootSignature* rootSignature);
// State query // State query
@@ -95,6 +98,7 @@ private:
uint32_t m_renderTargetFormats[8] = { 0 }; uint32_t m_renderTargetFormats[8] = { 0 };
uint32_t m_depthStencilFormat = 0; uint32_t m_depthStencilFormat = 0;
uint32_t m_sampleCount = 1; uint32_t m_sampleCount = 1;
uint32_t m_sampleQuality = 0;
// Shader bytecodes (set externally) // Shader bytecodes (set externally)
D3D12_SHADER_BYTECODE m_vsBytecode = {}; D3D12_SHADER_BYTECODE m_vsBytecode = {};
@@ -102,6 +106,7 @@ private:
D3D12_SHADER_BYTECODE m_gsBytecode = {}; D3D12_SHADER_BYTECODE m_gsBytecode = {};
D3D12_SHADER_BYTECODE m_csBytecode = {}; D3D12_SHADER_BYTECODE m_csBytecode = {};
class RHIShader* m_computeShader = nullptr; class RHIShader* m_computeShader = nullptr;
std::unique_ptr<class D3D12Shader> m_ownedComputeShader;
ComPtr<ID3D12RootSignature> m_rootSignature; ComPtr<ID3D12RootSignature> m_rootSignature;
// D3D12 resources // D3D12 resources

View File

@@ -42,6 +42,7 @@ public:
RHICommandQueue* CreateCommandQueue(const CommandQueueDesc& desc) override; RHICommandQueue* CreateCommandQueue(const CommandQueueDesc& desc) override;
RHIShader* CreateShader(const ShaderCompileDesc& desc) override; RHIShader* CreateShader(const ShaderCompileDesc& desc) override;
RHIPipelineState* CreatePipelineState(const GraphicsPipelineDesc& desc) override; RHIPipelineState* CreatePipelineState(const GraphicsPipelineDesc& desc) override;
RHIPipelineState* CreateComputePipelineState(const ComputePipelineDesc& desc) override;
RHIPipelineLayout* CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) override; RHIPipelineLayout* CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) override;
RHIFence* CreateFence(const FenceDesc& desc) override; RHIFence* CreateFence(const FenceDesc& desc) override;
RHISampler* CreateSampler(const SamplerDesc& desc) override; RHISampler* CreateSampler(const SamplerDesc& desc) override;

View File

@@ -84,7 +84,9 @@ public:
void SetTopology(uint32_t topologyType) override; void SetTopology(uint32_t topologyType) override;
void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override; void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override;
void SetSampleCount(uint32_t count) override; void SetSampleCount(uint32_t count) override;
void SetSampleQuality(uint32_t quality) override;
void SetComputeShader(RHIShader* shader) override; void SetComputeShader(RHIShader* shader) override;
void SetOwnedComputeShader(std::unique_ptr<class OpenGLShader> shader);
// State query // State query
const RasterizerDesc& GetRasterizerState() const override { return m_rasterizerDesc; } const RasterizerDesc& GetRasterizerState() const override { return m_rasterizerDesc; }
@@ -142,6 +144,7 @@ private:
unsigned int m_computeProgram = 0; unsigned int m_computeProgram = 0;
class RHIShader* m_computeShader = nullptr; class RHIShader* m_computeShader = nullptr;
std::unique_ptr<class OpenGLShader> m_graphicsShader; std::unique_ptr<class OpenGLShader> m_graphicsShader;
std::unique_ptr<class OpenGLShader> m_ownedComputeShader;
bool m_programAttached = false; bool m_programAttached = false;
// OpenGL specific state // OpenGL specific state

View File

@@ -77,6 +77,10 @@ public:
virtual RHICommandQueue* CreateCommandQueue(const CommandQueueDesc& desc) = 0; virtual RHICommandQueue* CreateCommandQueue(const CommandQueueDesc& desc) = 0;
virtual RHIShader* CreateShader(const ShaderCompileDesc& desc) = 0; virtual RHIShader* CreateShader(const ShaderCompileDesc& desc) = 0;
virtual RHIPipelineState* CreatePipelineState(const GraphicsPipelineDesc& desc) = 0; virtual RHIPipelineState* CreatePipelineState(const GraphicsPipelineDesc& desc) = 0;
virtual RHIPipelineState* CreateComputePipelineState(const ComputePipelineDesc& desc) {
(void)desc;
return nullptr;
}
virtual RHIPipelineLayout* CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) = 0; virtual RHIPipelineLayout* CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) = 0;
virtual RHIFence* CreateFence(const FenceDesc& desc) = 0; virtual RHIFence* CreateFence(const FenceDesc& desc) = 0;
virtual RHISampler* CreateSampler(const SamplerDesc& desc) = 0; virtual RHISampler* CreateSampler(const SamplerDesc& desc) = 0;

View File

@@ -46,6 +46,12 @@ struct ShaderCompileMacro {
}; };
enum class ShaderLanguage : uint8_t; enum class ShaderLanguage : uint8_t;
enum class ShaderBinaryBackend : uint8_t {
Unknown = 0,
D3D12,
OpenGL,
Vulkan
};
struct ShaderCompileDesc { struct ShaderCompileDesc {
std::wstring fileName; std::wstring fileName;
@@ -55,6 +61,8 @@ struct ShaderCompileDesc {
std::wstring entryPoint; std::wstring entryPoint;
std::wstring profile; std::wstring profile;
std::vector<ShaderCompileMacro> macros; std::vector<ShaderCompileMacro> macros;
ShaderBinaryBackend compiledBinaryBackend = ShaderBinaryBackend::Unknown;
std::vector<uint8_t> compiledBinary;
}; };
struct InputElementDesc { struct InputElementDesc {
@@ -327,6 +335,11 @@ struct GraphicsPipelineDesc {
uint32_t sampleQuality = 0; uint32_t sampleQuality = 0;
}; };
struct ComputePipelineDesc {
ShaderCompileDesc computeShader;
RHIPipelineLayout* pipelineLayout = nullptr;
};
struct RHIDeviceDesc { struct RHIDeviceDesc {
bool enableDebugLayer = false; bool enableDebugLayer = false;
bool enableGPUValidation = false; bool enableGPUValidation = false;

View File

@@ -25,6 +25,7 @@ public:
RHICommandQueue* CreateCommandQueue(const CommandQueueDesc& desc) override; RHICommandQueue* CreateCommandQueue(const CommandQueueDesc& desc) override;
RHIShader* CreateShader(const ShaderCompileDesc& desc) override; RHIShader* CreateShader(const ShaderCompileDesc& desc) override;
RHIPipelineState* CreatePipelineState(const GraphicsPipelineDesc& desc) override; RHIPipelineState* CreatePipelineState(const GraphicsPipelineDesc& desc) override;
RHIPipelineState* CreateComputePipelineState(const ComputePipelineDesc& desc) override;
RHIPipelineLayout* CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) override; RHIPipelineLayout* CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) override;
RHIFence* CreateFence(const FenceDesc& desc) override; RHIFence* CreateFence(const FenceDesc& desc) override;
RHISampler* CreateSampler(const SamplerDesc& desc) override; RHISampler* CreateSampler(const SamplerDesc& desc) override;

View File

@@ -3,6 +3,8 @@
#include "XCEngine/RHI/RHIPipelineState.h" #include "XCEngine/RHI/RHIPipelineState.h"
#include "XCEngine/RHI/Vulkan/VulkanCommon.h" #include "XCEngine/RHI/Vulkan/VulkanCommon.h"
#include <memory>
namespace XCEngine { namespace XCEngine {
namespace RHI { namespace RHI {
@@ -15,6 +17,10 @@ public:
~VulkanPipelineState() override; ~VulkanPipelineState() override;
bool Initialize(VulkanDevice* device, const GraphicsPipelineDesc& desc); bool Initialize(VulkanDevice* device, const GraphicsPipelineDesc& desc);
bool InitializeCompute(
VulkanDevice* device,
RHIPipelineLayout* pipelineLayout,
std::unique_ptr<VulkanShader> shader);
void SetInputLayout(const InputLayoutDesc& layout) override; void SetInputLayout(const InputLayoutDesc& layout) override;
void SetRasterizerState(const RasterizerDesc& state) override; void SetRasterizerState(const RasterizerDesc& state) override;
@@ -23,7 +29,9 @@ public:
void SetTopology(uint32_t topologyType) override; void SetTopology(uint32_t topologyType) override;
void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override; void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override;
void SetSampleCount(uint32_t count) override; void SetSampleCount(uint32_t count) override;
void SetSampleQuality(uint32_t quality) override;
void SetComputeShader(RHIShader* shader) override; void SetComputeShader(RHIShader* shader) override;
void SetOwnedComputeShader(std::unique_ptr<VulkanShader> shader);
const RasterizerDesc& GetRasterizerState() const override { return m_rasterizerDesc; } const RasterizerDesc& GetRasterizerState() const override { return m_rasterizerDesc; }
const BlendDesc& GetBlendState() const override { return m_blendDesc; } const BlendDesc& GetBlendState() const override { return m_blendDesc; }
@@ -31,7 +39,7 @@ public:
const InputLayoutDesc& GetInputLayout() const override { return m_inputLayoutDesc; } const InputLayoutDesc& GetInputLayout() const override { return m_inputLayoutDesc; }
PipelineStateHash GetHash() const override; PipelineStateHash GetHash() const override;
RHIShader* GetComputeShader() const override { return m_computeShader; } RHIShader* GetComputeShader() const override { return m_computeShader; }
bool HasComputeShader() const override { return m_computeShader != nullptr; } bool HasComputeShader() const override { return m_hasComputeShader; }
bool IsValid() const override { return m_isConfigured; } bool IsValid() const override { return m_isConfigured; }
void EnsureValid() override; void EnsureValid() override;
@@ -52,6 +60,7 @@ public:
private: private:
bool EnsurePipelineLayout(const GraphicsPipelineDesc& desc); bool EnsurePipelineLayout(const GraphicsPipelineDesc& desc);
bool EnsurePipelineLayout(RHIPipelineLayout* pipelineLayout);
bool CreateGraphicsPipeline(const GraphicsPipelineDesc& desc); bool CreateGraphicsPipeline(const GraphicsPipelineDesc& desc);
bool CreateComputePipeline(); bool CreateComputePipeline();
@@ -71,6 +80,8 @@ private:
uint32_t m_depthStencilFormat = 0; uint32_t m_depthStencilFormat = 0;
uint32_t m_sampleCount = 1; uint32_t m_sampleCount = 1;
RHIShader* m_computeShader = nullptr; RHIShader* m_computeShader = nullptr;
std::unique_ptr<VulkanShader> m_ownedComputeShader;
bool m_hasComputeShader = false;
bool m_isConfigured = false; bool m_isConfigured = false;
}; };

View File

@@ -48,7 +48,7 @@ uint64_t GetVolumeTraceSteadyMs();
void LogVolumeTraceRendering(const std::string& message); void LogVolumeTraceRendering(const std::string& message);
bool HasShaderPayload(const ShaderCompileDesc& desc) { bool HasShaderPayload(const ShaderCompileDesc& desc) {
return !desc.source.empty() || !desc.fileName.empty(); return !desc.source.empty() || !desc.fileName.empty() || !desc.compiledBinary.empty();
} }
bool ShouldTraceVolumetricShaderCompile(const ShaderCompileDesc& desc) { bool ShouldTraceVolumetricShaderCompile(const ShaderCompileDesc& desc) {
@@ -94,6 +94,23 @@ bool CompileD3D12Shader(const ShaderCompileDesc& desc, D3D12Shader& shader) {
const char* entryPointPtr = entryPoint.empty() ? nullptr : entryPoint.c_str(); const char* entryPointPtr = entryPoint.empty() ? nullptr : entryPoint.c_str();
const char* profilePtr = profile.empty() ? nullptr : profile.c_str(); const char* profilePtr = profile.empty() ? nullptr : profile.c_str();
if (desc.compiledBinaryBackend == ShaderBinaryBackend::D3D12 &&
!desc.compiledBinary.empty()) {
const bool compiled = shader.InitializeFromBytecode(
desc.compiledBinary.data(),
desc.compiledBinary.size(),
profilePtr);
if (traceShaderCompile) {
const uint64_t compileEndMs = GetVolumeTraceSteadyMs();
LogVolumeTraceRendering(
std::string("D3D12 shader compile ") + (compiled ? "cache-hit" : "cache-hit-failed") +
" steady_ms=" + std::to_string(compileEndMs) +
" total_ms=" + std::to_string(compileEndMs - compileStartMs) + " " +
DescribeShaderCompileDesc(desc));
}
return compiled;
}
if (!desc.source.empty()) { if (!desc.source.empty()) {
std::vector<std::string> macroNames; std::vector<std::string> macroNames;
std::vector<std::string> macroDefinitions; std::vector<std::string> macroDefinitions;
@@ -1228,25 +1245,7 @@ RHITexture* D3D12Device::CreateTexture(const TextureDesc& desc, const void* init
RHIShader* D3D12Device::CreateShader(const ShaderCompileDesc& desc) { RHIShader* D3D12Device::CreateShader(const ShaderCompileDesc& desc) {
auto* shader = new D3D12Shader(); auto* shader = new D3D12Shader();
const std::string entryPoint = NarrowAscii(desc.entryPoint); if (CompileD3D12Shader(desc, *shader)) {
const std::string profile = NarrowAscii(desc.profile);
const char* entryPointPtr = entryPoint.empty() ? nullptr : entryPoint.c_str();
const char* profilePtr = profile.empty() ? nullptr : profile.c_str();
bool success = false;
if (!desc.source.empty()) {
success = shader->Compile(
desc.source.data(),
desc.source.size(),
desc.fileName.empty() ? nullptr : desc.fileName.c_str(),
nullptr,
entryPointPtr,
profilePtr);
} else if (!desc.fileName.empty()) {
success = shader->CompileFromFile(desc.fileName.c_str(), entryPointPtr, profilePtr);
}
if (success) {
return shader; return shader;
} }
delete shader; delete shader;
@@ -1524,6 +1523,33 @@ RHIPipelineState* D3D12Device::CreatePipelineState(const GraphicsPipelineDesc& d
return pso; return pso;
} }
RHIPipelineState* D3D12Device::CreateComputePipelineState(const ComputePipelineDesc& desc) {
if (!HasShaderPayload(desc.computeShader)) {
return nullptr;
}
auto* pso = new D3D12PipelineState(m_device.Get());
if (desc.pipelineLayout != nullptr) {
auto* pipelineLayout = static_cast<D3D12PipelineLayout*>(desc.pipelineLayout);
pso->SetRootSignature(pipelineLayout->GetRootSignature());
}
D3D12Shader computeShader;
if (!CompileD3D12Shader(desc.computeShader, computeShader)) {
delete pso;
return nullptr;
}
pso->SetComputeShaderBytecodes(computeShader.GetD3D12Bytecode());
pso->EnsureValid();
if (!pso->IsValid()) {
delete pso;
return nullptr;
}
return pso;
}
RHIPipelineLayout* D3D12Device::CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) { RHIPipelineLayout* D3D12Device::CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) {
auto* pipelineLayout = new D3D12PipelineLayout(); auto* pipelineLayout = new D3D12PipelineLayout();
if (!pipelineLayout->InitializeWithDevice(this, desc)) { if (!pipelineLayout->InitializeWithDevice(this, desc)) {

View File

@@ -109,6 +109,7 @@ bool D3D12PipelineState::Initialize(ID3D12Device* device, const D3D12_GRAPHICS_P
} }
m_depthStencilFormat = static_cast<uint32_t>(desc.DSVFormat); m_depthStencilFormat = static_cast<uint32_t>(desc.DSVFormat);
m_sampleCount = desc.SampleDesc.Count; m_sampleCount = desc.SampleDesc.Count;
m_sampleQuality = desc.SampleDesc.Quality;
// Set shader bytecodes // Set shader bytecodes
m_vsBytecode = desc.VS; m_vsBytecode = desc.VS;
@@ -173,13 +174,21 @@ void D3D12PipelineState::SetSampleCount(uint32_t count) {
m_sampleCount = count; m_sampleCount = count;
} }
void D3D12PipelineState::SetSampleQuality(uint32_t quality) {
m_sampleQuality = quality;
}
PipelineStateHash D3D12PipelineState::GetHash() const { PipelineStateHash D3D12PipelineState::GetHash() const {
PipelineStateHash hash = {}; PipelineStateHash hash = {};
hash.blendStateHash = std::hash<uint64_t>{}(*reinterpret_cast<const uint64_t*>(&m_blendDesc)); hash.blendStateHash = std::hash<uint64_t>{}(*reinterpret_cast<const uint64_t*>(&m_blendDesc));
hash.depthStateHash = std::hash<uint64_t>{}(*reinterpret_cast<const uint64_t*>(&m_depthStencilDesc)); hash.depthStateHash = std::hash<uint64_t>{}(*reinterpret_cast<const uint64_t*>(&m_depthStencilDesc));
hash.rasterizerStateHash = std::hash<uint64_t>{}(*reinterpret_cast<const uint64_t*>(&m_rasterizerDesc)); hash.rasterizerStateHash = std::hash<uint64_t>{}(*reinterpret_cast<const uint64_t*>(&m_rasterizerDesc));
hash.topologyHash = m_topologyType; hash.topologyHash = m_topologyType;
hash.renderTargetHash = m_renderTargetCount | (m_depthStencilFormat << 8); hash.renderTargetHash =
m_renderTargetCount |
(static_cast<uint64_t>(m_depthStencilFormat) << 8) |
(static_cast<uint64_t>(m_sampleCount) << 32) |
(static_cast<uint64_t>(m_sampleQuality) << 48);
return hash; return hash;
} }
@@ -190,6 +199,9 @@ void D3D12PipelineState::SetShaderBytecodes(const D3D12_SHADER_BYTECODE& vs, con
} }
void D3D12PipelineState::SetComputeShader(RHIShader* shader) { void D3D12PipelineState::SetComputeShader(RHIShader* shader) {
if (shader != m_ownedComputeShader.get()) {
m_ownedComputeShader.reset();
}
m_computeShader = shader; m_computeShader = shader;
m_csBytecode = {}; m_csBytecode = {};
m_computePipelineState.Reset(); m_computePipelineState.Reset();
@@ -204,6 +216,11 @@ void D3D12PipelineState::SetComputeShader(RHIShader* shader) {
} }
} }
void D3D12PipelineState::SetOwnedComputeShader(std::unique_ptr<D3D12Shader> shader) {
m_ownedComputeShader = std::move(shader);
SetComputeShader(m_ownedComputeShader.get());
}
void D3D12PipelineState::SetRootSignature(ID3D12RootSignature* rootSignature) { void D3D12PipelineState::SetRootSignature(ID3D12RootSignature* rootSignature) {
m_rootSignature = rootSignature; m_rootSignature = rootSignature;
} }
@@ -275,7 +292,7 @@ bool D3D12PipelineState::CreateD3D12PSO() {
} }
desc.DSVFormat = ToD3D12(static_cast<Format>(m_depthStencilFormat)); desc.DSVFormat = ToD3D12(static_cast<Format>(m_depthStencilFormat));
desc.SampleDesc.Count = m_sampleCount; desc.SampleDesc.Count = m_sampleCount;
desc.SampleDesc.Quality = 0; desc.SampleDesc.Quality = m_sampleQuality;
desc.SampleMask = 0xffffffff; desc.SampleMask = 0xffffffff;
desc.PrimitiveTopologyType = static_cast<D3D12_PRIMITIVE_TOPOLOGY_TYPE>(m_topologyType); desc.PrimitiveTopologyType = static_cast<D3D12_PRIMITIVE_TOPOLOGY_TYPE>(m_topologyType);
@@ -349,6 +366,7 @@ void D3D12PipelineState::Shutdown() {
m_pipelineState.Reset(); m_pipelineState.Reset();
m_computePipelineState.Reset(); m_computePipelineState.Reset();
m_rootSignature.Reset(); m_rootSignature.Reset();
m_ownedComputeShader.reset();
m_vsBytecode = {}; m_vsBytecode = {};
m_psBytecode = {}; m_psBytecode = {};
m_gsBytecode = {}; m_gsBytecode = {};

View File

@@ -136,7 +136,7 @@ uint64_t QuerySharedSystemMemoryBytes() {
} }
bool HasShaderPayload(const ShaderCompileDesc& desc) { bool HasShaderPayload(const ShaderCompileDesc& desc) {
return !desc.source.empty() || !desc.fileName.empty(); return !desc.source.empty() || !desc.fileName.empty() || !desc.compiledBinary.empty();
} }
std::string SourceToString(const ShaderCompileDesc& desc) { std::string SourceToString(const ShaderCompileDesc& desc) {
@@ -594,8 +594,20 @@ bool CompileOpenGLShaderHandleFromSpirv(const ShaderCompileDesc& desc,
GLuint& outShaderHandle, GLuint& outShaderHandle,
ShaderType& outShaderType, ShaderType& outShaderType,
std::string* errorMessage) { std::string* errorMessage) {
ShaderCompileDesc spirvDesc = desc;
SpirvTargetEnvironment targetEnvironment = SpirvTargetEnvironment::Vulkan;
if (desc.compiledBinaryBackend == ShaderBinaryBackend::OpenGL &&
!desc.compiledBinary.empty()) {
spirvDesc.source = desc.compiledBinary;
spirvDesc.sourceLanguage = ShaderLanguage::SPIRV;
spirvDesc.fileName.clear();
spirvDesc.compiledBinary.clear();
spirvDesc.compiledBinaryBackend = ShaderBinaryBackend::Unknown;
targetEnvironment = SpirvTargetEnvironment::OpenGL;
}
CompiledSpirvShader compiledShader = {}; CompiledSpirvShader compiledShader = {};
if (!CompileSpirvShader(desc, SpirvTargetEnvironment::Vulkan, compiledShader, errorMessage)) { if (!CompileSpirvShader(spirvDesc, targetEnvironment, compiledShader, errorMessage)) {
return false; return false;
} }
@@ -1326,7 +1338,7 @@ RHIShader* OpenGLDevice::CreateShader(const ShaderCompileDesc& desc) {
return nullptr; return nullptr;
} }
if (desc.source.empty() && desc.fileName.empty()) { if (!HasShaderPayload(desc)) {
delete shader; delete shader;
return nullptr; return nullptr;
} }
@@ -1378,6 +1390,7 @@ RHIPipelineState* OpenGLDevice::CreatePipelineState(const GraphicsPipelineDesc&
pso->SetTopology(desc.topologyType); pso->SetTopology(desc.topologyType);
pso->SetRenderTargetFormats(desc.renderTargetCount, desc.renderTargetFormats, desc.depthStencilFormat); pso->SetRenderTargetFormats(desc.renderTargetCount, desc.renderTargetFormats, desc.depthStencilFormat);
pso->SetSampleCount(desc.sampleCount); pso->SetSampleCount(desc.sampleCount);
pso->SetSampleQuality(desc.sampleQuality);
const bool hasVertexShader = HasShaderPayload(desc.vertexShader); const bool hasVertexShader = HasShaderPayload(desc.vertexShader);
const bool hasFragmentShader = HasShaderPayload(desc.fragmentShader); const bool hasFragmentShader = HasShaderPayload(desc.fragmentShader);
@@ -1422,6 +1435,21 @@ RHIPipelineState* OpenGLDevice::CreatePipelineState(const GraphicsPipelineDesc&
return pso; return pso;
} }
RHIPipelineState* OpenGLDevice::CreateComputePipelineState(const ComputePipelineDesc& desc) {
if (!HasShaderPayload(desc.computeShader) || !MakeContextCurrent()) {
return nullptr;
}
auto computeShader = std::unique_ptr<OpenGLShader>(static_cast<OpenGLShader*>(CreateShader(desc.computeShader)));
if (computeShader == nullptr || computeShader->GetType() != ShaderType::Compute) {
return nullptr;
}
auto* pso = new OpenGLPipelineState();
pso->SetOwnedComputeShader(std::move(computeShader));
return pso;
}
RHIPipelineLayout* OpenGLDevice::CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) { RHIPipelineLayout* OpenGLDevice::CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) {
auto* layout = new OpenGLPipelineLayout(); auto* layout = new OpenGLPipelineLayout();
if (!layout->Initialize(desc)) { if (!layout->Initialize(desc)) {

View File

@@ -72,7 +72,13 @@ void OpenGLPipelineState::SetRenderTargetFormats(uint32_t count, const uint32_t*
void OpenGLPipelineState::SetSampleCount(uint32_t count) { void OpenGLPipelineState::SetSampleCount(uint32_t count) {
} }
void OpenGLPipelineState::SetSampleQuality(uint32_t quality) {
}
void OpenGLPipelineState::SetComputeShader(RHIShader* shader) { void OpenGLPipelineState::SetComputeShader(RHIShader* shader) {
if (shader != m_ownedComputeShader.get()) {
m_ownedComputeShader.reset();
}
m_computeShader = shader; m_computeShader = shader;
if (shader) { if (shader) {
OpenGLShader* glShader = static_cast<OpenGLShader*>(shader); OpenGLShader* glShader = static_cast<OpenGLShader*>(shader);
@@ -82,6 +88,12 @@ void OpenGLPipelineState::SetComputeShader(RHIShader* shader) {
} }
} }
void OpenGLPipelineState::SetOwnedComputeShader(std::unique_ptr<OpenGLShader> shader) {
m_ownedComputeShader = std::move(shader);
m_computeShader = m_ownedComputeShader.get();
m_computeProgram = m_ownedComputeShader != nullptr ? m_ownedComputeShader->GetID() : 0;
}
PipelineStateHash OpenGLPipelineState::GetHash() const { PipelineStateHash OpenGLPipelineState::GetHash() const {
PipelineStateHash hash = {}; PipelineStateHash hash = {};
return hash; return hash;
@@ -89,6 +101,7 @@ PipelineStateHash OpenGLPipelineState::GetHash() const {
void OpenGLPipelineState::Shutdown() { void OpenGLPipelineState::Shutdown() {
m_graphicsShader.reset(); m_graphicsShader.reset();
m_ownedComputeShader.reset();
m_program = 0; m_program = 0;
m_computeProgram = 0; m_computeProgram = 0;
m_computeShader = nullptr; m_computeShader = nullptr;

View File

@@ -677,6 +677,21 @@ RHIPipelineState* VulkanDevice::CreatePipelineState(const GraphicsPipelineDesc&
return nullptr; return nullptr;
} }
RHIPipelineState* VulkanDevice::CreateComputePipelineState(const ComputePipelineDesc& desc) {
auto computeShader = std::unique_ptr<VulkanShader>(static_cast<VulkanShader*>(CreateShader(desc.computeShader)));
if (computeShader == nullptr || computeShader->GetType() != ShaderType::Compute) {
return nullptr;
}
auto* pipelineState = new VulkanPipelineState();
if (pipelineState->InitializeCompute(this, desc.pipelineLayout, std::move(computeShader))) {
return pipelineState;
}
delete pipelineState;
return nullptr;
}
RHIPipelineLayout* VulkanDevice::CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) { RHIPipelineLayout* VulkanDevice::CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) {
auto* pipelineLayout = new VulkanPipelineLayout(); auto* pipelineLayout = new VulkanPipelineLayout();
if (pipelineLayout->Initialize(m_device, desc)) { if (pipelineLayout->Initialize(m_device, desc)) {

View File

@@ -17,7 +17,7 @@ namespace RHI {
namespace { namespace {
bool HasShaderPayload(const ShaderCompileDesc& desc) { bool HasShaderPayload(const ShaderCompileDesc& desc) {
return !desc.source.empty() || !desc.fileName.empty(); return !desc.source.empty() || !desc.fileName.empty() || !desc.compiledBinary.empty();
} }
VkShaderModule CreateShaderModule(VkDevice device, const std::vector<uint32_t>& words) { VkShaderModule CreateShaderModule(VkDevice device, const std::vector<uint32_t>& words) {
@@ -91,6 +91,9 @@ bool VulkanPipelineState::Initialize(VulkanDevice* device, const GraphicsPipelin
m_renderTargetCount = desc.renderTargetCount; m_renderTargetCount = desc.renderTargetCount;
m_depthStencilFormat = desc.depthStencilFormat; m_depthStencilFormat = desc.depthStencilFormat;
m_sampleCount = desc.sampleCount > 0 ? desc.sampleCount : 1; m_sampleCount = desc.sampleCount > 0 ? desc.sampleCount : 1;
m_computeShader = nullptr;
m_ownedComputeShader.reset();
m_hasComputeShader = false;
for (uint32_t i = 0; i < 8; ++i) { for (uint32_t i = 0; i < 8; ++i) {
m_renderTargetFormats[i] = desc.renderTargetFormats[i]; m_renderTargetFormats[i] = desc.renderTargetFormats[i];
} }
@@ -125,8 +128,41 @@ bool VulkanPipelineState::Initialize(VulkanDevice* device, const GraphicsPipelin
return true; return true;
} }
bool VulkanPipelineState::InitializeCompute(
VulkanDevice* device,
RHIPipelineLayout* pipelineLayout,
std::unique_ptr<VulkanShader> shader) {
if (device == nullptr ||
device->GetDevice() == VK_NULL_HANDLE ||
shader == nullptr ||
!shader->IsValid() ||
shader->GetType() != ShaderType::Compute) {
return false;
}
Shutdown();
m_deviceOwner = device;
m_device = device->GetDevice();
m_ownedComputeShader = std::move(shader);
m_computeShader = m_ownedComputeShader.get();
m_hasComputeShader = m_computeShader != nullptr;
if (!EnsurePipelineLayout(pipelineLayout) || !CreateComputePipeline()) {
Shutdown();
return false;
}
m_isConfigured = true;
return true;
}
bool VulkanPipelineState::EnsurePipelineLayout(const GraphicsPipelineDesc& desc) { bool VulkanPipelineState::EnsurePipelineLayout(const GraphicsPipelineDesc& desc) {
auto* externalPipelineLayout = static_cast<VulkanPipelineLayout*>(desc.pipelineLayout); return EnsurePipelineLayout(desc.pipelineLayout);
}
bool VulkanPipelineState::EnsurePipelineLayout(RHIPipelineLayout* pipelineLayout) {
auto* externalPipelineLayout = static_cast<VulkanPipelineLayout*>(pipelineLayout);
if (externalPipelineLayout != nullptr) { if (externalPipelineLayout != nullptr) {
m_pipelineLayout = externalPipelineLayout->GetPipelineLayout(); m_pipelineLayout = externalPipelineLayout->GetPipelineLayout();
m_ownsPipelineLayout = false; m_ownsPipelineLayout = false;
@@ -425,15 +461,33 @@ void VulkanPipelineState::SetSampleCount(uint32_t count) {
m_sampleCount = count > 0 ? count : 1; m_sampleCount = count > 0 ? count : 1;
} }
void VulkanPipelineState::SetSampleQuality(uint32_t quality) {
(void)quality;
}
void VulkanPipelineState::SetComputeShader(RHIShader* shader) { void VulkanPipelineState::SetComputeShader(RHIShader* shader) {
if (m_pipeline != VK_NULL_HANDLE && m_device != VK_NULL_HANDLE) { if (m_pipeline != VK_NULL_HANDLE && m_device != VK_NULL_HANDLE) {
vkDestroyPipeline(m_device, m_pipeline, nullptr); vkDestroyPipeline(m_device, m_pipeline, nullptr);
m_pipeline = VK_NULL_HANDLE; m_pipeline = VK_NULL_HANDLE;
} }
m_computeShader = shader; if (shader != m_ownedComputeShader.get()) {
if (m_computeShader != nullptr && m_pipelineLayout != VK_NULL_HANDLE) { m_ownedComputeShader.reset();
m_isConfigured = true;
} }
m_computeShader = shader;
m_hasComputeShader = m_computeShader != nullptr;
m_isConfigured = m_hasComputeShader && m_pipelineLayout != VK_NULL_HANDLE;
}
void VulkanPipelineState::SetOwnedComputeShader(std::unique_ptr<VulkanShader> shader) {
if (m_pipeline != VK_NULL_HANDLE && m_device != VK_NULL_HANDLE) {
vkDestroyPipeline(m_device, m_pipeline, nullptr);
m_pipeline = VK_NULL_HANDLE;
}
m_ownedComputeShader = std::move(shader);
m_computeShader = m_ownedComputeShader.get();
m_hasComputeShader = m_computeShader != nullptr;
m_isConfigured = m_hasComputeShader && m_pipelineLayout != VK_NULL_HANDLE;
} }
PipelineStateHash VulkanPipelineState::GetHash() const { PipelineStateHash VulkanPipelineState::GetHash() const {
@@ -499,7 +553,9 @@ void VulkanPipelineState::Shutdown() {
m_deviceOwner = nullptr; m_deviceOwner = nullptr;
m_device = VK_NULL_HANDLE; m_device = VK_NULL_HANDLE;
m_ownsPipelineLayout = false; m_ownsPipelineLayout = false;
m_ownedComputeShader.reset();
m_computeShader = nullptr; m_computeShader = nullptr;
m_hasComputeShader = false;
m_isConfigured = false; m_isConfigured = false;
} }

View File

@@ -756,10 +756,6 @@ TEST_F(OpenGLTestFixture, CommandList_SetComputeDescriptorSets_UsesSetAwareImage
ASSERT_NE(descriptorSet, nullptr); ASSERT_NE(descriptorSet, nullptr);
descriptorSet->Update(0, uav); descriptorSet->Update(0, uav);
GraphicsPipelineDesc pipelineDesc = {};
RHIPipelineState* pipelineState = GetDevice()->CreatePipelineState(pipelineDesc);
ASSERT_NE(pipelineState, nullptr);
ShaderCompileDesc shaderDesc = {}; ShaderCompileDesc shaderDesc = {};
shaderDesc.sourceLanguage = ShaderLanguage::GLSL; shaderDesc.sourceLanguage = ShaderLanguage::GLSL;
static const char* computeSource = R"( static const char* computeSource = R"(
@@ -772,9 +768,11 @@ TEST_F(OpenGLTestFixture, CommandList_SetComputeDescriptorSets_UsesSetAwareImage
)"; )";
shaderDesc.source.assign(computeSource, computeSource + std::strlen(computeSource)); shaderDesc.source.assign(computeSource, computeSource + std::strlen(computeSource));
shaderDesc.profile = L"cs_4_30"; shaderDesc.profile = L"cs_4_30";
RHIShader* computeShader = GetDevice()->CreateShader(shaderDesc); ComputePipelineDesc pipelineDesc = {};
ASSERT_NE(computeShader, nullptr); pipelineDesc.pipelineLayout = pipelineLayout;
pipelineState->SetComputeShader(computeShader); pipelineDesc.computeShader = shaderDesc;
RHIPipelineState* pipelineState = GetDevice()->CreateComputePipelineState(pipelineDesc);
ASSERT_NE(pipelineState, nullptr);
CommandListDesc cmdDesc = {}; CommandListDesc cmdDesc = {};
cmdDesc.commandListType = static_cast<uint32_t>(CommandQueueType::Direct); cmdDesc.commandListType = static_cast<uint32_t>(CommandQueueType::Direct);
@@ -815,8 +813,6 @@ TEST_F(OpenGLTestFixture, CommandList_SetComputeDescriptorSets_UsesSetAwareImage
cmdList->Shutdown(); cmdList->Shutdown();
delete cmdList; delete cmdList;
computeShader->Shutdown();
delete computeShader;
pipelineState->Shutdown(); pipelineState->Shutdown();
delete pipelineState; delete pipelineState;
descriptorSet->Shutdown(); descriptorSet->Shutdown();

View File

@@ -8,12 +8,59 @@
#include "XCEngine/RHI/RHIResourceView.h" #include "XCEngine/RHI/RHIResourceView.h"
#include "XCEngine/RHI/Vulkan/VulkanTexture.h" #include "XCEngine/RHI/Vulkan/VulkanTexture.h"
#include <cstring>
#include <vector> #include <vector>
using namespace XCEngine::RHI; using namespace XCEngine::RHI;
namespace { namespace {
constexpr uint32_t kWriteRedComputeSpirv[] = {
0x07230203, 0x00010000, 0x0008000B, 0x00000017, 0x00000000, 0x00020011, 0x00000001, 0x0006000B,
0x00000001, 0x4C534C47, 0x6474732E, 0x3035342E, 0x00000000, 0x0003000E, 0x00000000, 0x00000001,
0x0005000F, 0x00000005, 0x00000004, 0x6E69616D, 0x00000000, 0x00060010, 0x00000004, 0x00000011,
0x00000001, 0x00000001, 0x00000001, 0x00030003, 0x00000002, 0x000001C2, 0x00040005, 0x00000004,
0x6E69616D, 0x00000000, 0x00040005, 0x00000009, 0x616D4975, 0x00006567, 0x00030047, 0x00000009,
0x00000019, 0x00040047, 0x00000009, 0x00000021, 0x00000000, 0x00040047, 0x00000009, 0x00000022,
0x00000000, 0x00040047, 0x00000016, 0x0000000B, 0x00000019, 0x00020013, 0x00000002, 0x00030021,
0x00000003, 0x00000002, 0x00030016, 0x00000006, 0x00000020, 0x00090019, 0x00000007, 0x00000006,
0x00000001, 0x00000000, 0x00000000, 0x00000000, 0x00000002, 0x00000004, 0x00040020, 0x00000008,
0x00000000, 0x00000007, 0x0004003B, 0x00000008, 0x00000009, 0x00000000, 0x00040015, 0x0000000B,
0x00000020, 0x00000001, 0x00040017, 0x0000000C, 0x0000000B, 0x00000002, 0x0004002B, 0x0000000B,
0x0000000D, 0x00000000, 0x0005002C, 0x0000000C, 0x0000000E, 0x0000000D, 0x0000000D, 0x00040017,
0x0000000F, 0x00000006, 0x0000000004, 0x0004002B, 0x00000006, 0x00000010, 0x3F800000, 0x0004002B,
0x00000006, 0x00000011, 0x00000000, 0x0007002C, 0x0000000F, 0x00000012, 0x00000010, 0x00000011,
0x00000011, 0x00000010, 0x00040015, 0x00000013, 0x00000020, 0x00000000, 0x00040017, 0x00000014,
0x00000013, 0x00000003, 0x0004002B, 0x00000013, 0x00000015, 0x00000001, 0x0006002C, 0x00000014,
0x00000016, 0x00000015, 0x00000015, 0x00000015, 0x00050036, 0x00000002, 0x00000004, 0x00000000,
0x00000003, 0x000200F8, 0x00000005, 0x0004003D, 0x00000007, 0x0000000A, 0x00000009, 0x00040063,
0x0000000A, 0x0000000E, 0x00000012, 0x000100FD, 0x00010038
};
ShaderCompileDesc MakeWriteRedComputeShaderDesc() {
ShaderCompileDesc shaderDesc = {};
shaderDesc.sourceLanguage = ShaderLanguage::SPIRV;
shaderDesc.profile = L"cs_6_0";
shaderDesc.source.resize(sizeof(kWriteRedComputeSpirv));
std::memcpy(shaderDesc.source.data(), kWriteRedComputeSpirv, sizeof(kWriteRedComputeSpirv));
return shaderDesc;
}
ShaderCompileDesc MakeWriteRedComputeShaderFromGlslDesc() {
static const char* computeSource = R"(#version 450
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0, rgba8) uniform writeonly image2D uImage;
void main() {
imageStore(uImage, ivec2(0, 0), vec4(1.0, 0.0, 0.0, 1.0));
}
)";
ShaderCompileDesc shaderDesc = {};
shaderDesc.sourceLanguage = ShaderLanguage::GLSL;
shaderDesc.source.assign(computeSource, computeSource + std::strlen(computeSource));
return shaderDesc;
}
TEST_F(VulkanGraphicsFixture, CreateUnorderedAccessViewProducesValidView) { TEST_F(VulkanGraphicsFixture, CreateUnorderedAccessViewProducesValidView) {
RHITexture* texture = m_device->CreateTexture(CreateColorTextureDesc(4, 4)); RHITexture* texture = m_device->CreateTexture(CreateColorTextureDesc(4, 4));
ASSERT_NE(texture, nullptr); ASSERT_NE(texture, nullptr);
@@ -70,14 +117,11 @@ TEST_F(VulkanGraphicsFixture, DispatchWritesUavTexture) {
ASSERT_NE(descriptorSet, nullptr); ASSERT_NE(descriptorSet, nullptr);
descriptorSet->Update(0, uav); descriptorSet->Update(0, uav);
GraphicsPipelineDesc pipelineDesc = {}; ComputePipelineDesc pipelineDesc = {};
pipelineDesc.pipelineLayout = pipelineLayout; pipelineDesc.pipelineLayout = pipelineLayout;
RHIPipelineState* pipelineState = m_device->CreatePipelineState(pipelineDesc); pipelineDesc.computeShader = MakeWriteRedComputeShaderDesc();
RHIPipelineState* pipelineState = m_device->CreateComputePipelineState(pipelineDesc);
ASSERT_NE(pipelineState, nullptr); ASSERT_NE(pipelineState, nullptr);
RHIShader* shader = CreateWriteRedComputeShader();
ASSERT_NE(shader, nullptr);
pipelineState->SetComputeShader(shader);
EXPECT_TRUE(pipelineState->HasComputeShader()); EXPECT_TRUE(pipelineState->HasComputeShader());
EXPECT_EQ(pipelineState->GetType(), PipelineType::Compute); EXPECT_EQ(pipelineState->GetType(), PipelineType::Compute);
@@ -101,8 +145,6 @@ TEST_F(VulkanGraphicsFixture, DispatchWritesUavTexture) {
commandList->Shutdown(); commandList->Shutdown();
delete commandList; delete commandList;
shader->Shutdown();
delete shader;
pipelineState->Shutdown(); pipelineState->Shutdown();
delete pipelineState; delete pipelineState;
descriptorSet->Shutdown(); descriptorSet->Shutdown();
@@ -154,14 +196,11 @@ TEST_F(VulkanGraphicsFixture, DispatchWritesUavTextureWithGlslComputeShader) {
ASSERT_NE(descriptorSet, nullptr); ASSERT_NE(descriptorSet, nullptr);
descriptorSet->Update(0, uav); descriptorSet->Update(0, uav);
GraphicsPipelineDesc pipelineDesc = {}; ComputePipelineDesc pipelineDesc = {};
pipelineDesc.pipelineLayout = pipelineLayout; pipelineDesc.pipelineLayout = pipelineLayout;
RHIPipelineState* pipelineState = m_device->CreatePipelineState(pipelineDesc); pipelineDesc.computeShader = MakeWriteRedComputeShaderFromGlslDesc();
RHIPipelineState* pipelineState = m_device->CreateComputePipelineState(pipelineDesc);
ASSERT_NE(pipelineState, nullptr); ASSERT_NE(pipelineState, nullptr);
RHIShader* shader = CreateWriteRedComputeShaderFromGlsl();
ASSERT_NE(shader, nullptr);
pipelineState->SetComputeShader(shader);
EXPECT_TRUE(pipelineState->HasComputeShader()); EXPECT_TRUE(pipelineState->HasComputeShader());
EXPECT_EQ(pipelineState->GetType(), PipelineType::Compute); EXPECT_EQ(pipelineState->GetType(), PipelineType::Compute);
@@ -185,8 +224,6 @@ TEST_F(VulkanGraphicsFixture, DispatchWritesUavTextureWithGlslComputeShader) {
commandList->Shutdown(); commandList->Shutdown();
delete commandList; delete commandList;
shader->Shutdown();
delete shader;
pipelineState->Shutdown(); pipelineState->Shutdown();
delete pipelineState; delete pipelineState;
descriptorSet->Shutdown(); descriptorSet->Shutdown();

View File

@@ -165,28 +165,53 @@ TEST_P(RHITestFixture, PipelineState_EnsureValid_Compute) {
delete pso; delete pso;
} }
TEST_P(RHITestFixture, Device_CreateComputePipelineState_ReturnsValidComputePipeline) {
ComputePipelineDesc desc = {};
desc.computeShader = MakeComputeShaderDesc(GetBackendType());
RHIPipelineState* pso = GetDevice()->CreateComputePipelineState(desc);
ASSERT_NE(pso, nullptr);
EXPECT_TRUE(pso->IsValid());
EXPECT_TRUE(pso->HasComputeShader());
EXPECT_EQ(pso->GetType(), PipelineType::Compute);
pso->Shutdown();
delete pso;
}
TEST_P(RHITestFixture, Device_CreateComputePipelineState_UsesExternalPipelineLayout) {
RHIPipelineLayoutDesc pipelineLayoutDesc = {};
RHIPipelineLayout* pipelineLayout = GetDevice()->CreatePipelineLayout(pipelineLayoutDesc);
ASSERT_NE(pipelineLayout, nullptr);
ComputePipelineDesc desc = {};
desc.computeShader = MakeComputeShaderDesc(GetBackendType());
desc.pipelineLayout = pipelineLayout;
RHIPipelineState* pso = GetDevice()->CreateComputePipelineState(desc);
ASSERT_NE(pso, nullptr);
EXPECT_TRUE(pso->IsValid());
EXPECT_TRUE(pso->HasComputeShader());
EXPECT_EQ(pso->GetType(), PipelineType::Compute);
pso->Shutdown();
delete pso;
pipelineLayout->Shutdown();
delete pipelineLayout;
}
TEST_P(RHITestFixture, CommandList_Dispatch_Basic) { TEST_P(RHITestFixture, CommandList_Dispatch_Basic) {
RHICommandList* cmdList = GetDevice()->CreateCommandList({}); RHICommandList* cmdList = GetDevice()->CreateCommandList({});
ASSERT_NE(cmdList, nullptr); ASSERT_NE(cmdList, nullptr);
cmdList->Reset(); cmdList->Reset();
GraphicsPipelineDesc desc = {}; ComputePipelineDesc desc = {};
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc); desc.computeShader = MakeComputeShaderDesc(GetBackendType());
RHIPipelineState* pso = GetDevice()->CreateComputePipelineState(desc);
if (pso != nullptr) { if (pso != nullptr) {
ShaderCompileDesc shaderDesc = MakeComputeShaderDesc(GetBackendType()); cmdList->SetPipelineState(pso);
cmdList->Dispatch(1, 1, 1);
RHIShader* computeShader = GetDevice()->CreateShader(shaderDesc);
if (computeShader != nullptr) {
pso->SetComputeShader(computeShader);
cmdList->SetPipelineState(pso);
cmdList->Dispatch(1, 1, 1);
computeShader->Shutdown();
delete computeShader;
}
pso->Shutdown(); pso->Shutdown();
delete pso; delete pso;
} }