Add formal compute pipeline creation API
This commit is contained in:
@@ -72,6 +72,7 @@ public:
|
||||
RHICommandQueue* CreateCommandQueue(const CommandQueueDesc& desc) override;
|
||||
RHIShader* CreateShader(const ShaderCompileDesc& desc) override;
|
||||
RHIPipelineState* CreatePipelineState(const GraphicsPipelineDesc& desc) override;
|
||||
RHIPipelineState* CreateComputePipelineState(const ComputePipelineDesc& desc) override;
|
||||
RHIPipelineLayout* CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) override;
|
||||
RHIFence* CreateFence(const FenceDesc& desc) override;
|
||||
RHISampler* CreateSampler(const SamplerDesc& desc) override;
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <d3d12.h>
|
||||
#include <dxgi1_4.h>
|
||||
#include <wrl/client.h>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "../RHIPipelineState.h"
|
||||
@@ -31,7 +32,9 @@ public:
|
||||
void SetTopology(uint32_t topologyType) override;
|
||||
void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override;
|
||||
void SetSampleCount(uint32_t count) override;
|
||||
void SetSampleQuality(uint32_t quality) override;
|
||||
void SetComputeShader(RHIShader* shader) override;
|
||||
void SetOwnedComputeShader(std::unique_ptr<class D3D12Shader> shader);
|
||||
void SetRootSignature(ID3D12RootSignature* rootSignature);
|
||||
|
||||
// State query
|
||||
@@ -95,6 +98,7 @@ private:
|
||||
uint32_t m_renderTargetFormats[8] = { 0 };
|
||||
uint32_t m_depthStencilFormat = 0;
|
||||
uint32_t m_sampleCount = 1;
|
||||
uint32_t m_sampleQuality = 0;
|
||||
|
||||
// Shader bytecodes (set externally)
|
||||
D3D12_SHADER_BYTECODE m_vsBytecode = {};
|
||||
@@ -102,6 +106,7 @@ private:
|
||||
D3D12_SHADER_BYTECODE m_gsBytecode = {};
|
||||
D3D12_SHADER_BYTECODE m_csBytecode = {};
|
||||
class RHIShader* m_computeShader = nullptr;
|
||||
std::unique_ptr<class D3D12Shader> m_ownedComputeShader;
|
||||
ComPtr<ID3D12RootSignature> m_rootSignature;
|
||||
|
||||
// D3D12 resources
|
||||
|
||||
@@ -42,6 +42,7 @@ public:
|
||||
RHICommandQueue* CreateCommandQueue(const CommandQueueDesc& desc) override;
|
||||
RHIShader* CreateShader(const ShaderCompileDesc& desc) override;
|
||||
RHIPipelineState* CreatePipelineState(const GraphicsPipelineDesc& desc) override;
|
||||
RHIPipelineState* CreateComputePipelineState(const ComputePipelineDesc& desc) override;
|
||||
RHIPipelineLayout* CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) override;
|
||||
RHIFence* CreateFence(const FenceDesc& desc) override;
|
||||
RHISampler* CreateSampler(const SamplerDesc& desc) override;
|
||||
|
||||
@@ -84,7 +84,9 @@ public:
|
||||
void SetTopology(uint32_t topologyType) override;
|
||||
void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override;
|
||||
void SetSampleCount(uint32_t count) override;
|
||||
void SetSampleQuality(uint32_t quality) override;
|
||||
void SetComputeShader(RHIShader* shader) override;
|
||||
void SetOwnedComputeShader(std::unique_ptr<class OpenGLShader> shader);
|
||||
|
||||
// State query
|
||||
const RasterizerDesc& GetRasterizerState() const override { return m_rasterizerDesc; }
|
||||
@@ -142,6 +144,7 @@ private:
|
||||
unsigned int m_computeProgram = 0;
|
||||
class RHIShader* m_computeShader = nullptr;
|
||||
std::unique_ptr<class OpenGLShader> m_graphicsShader;
|
||||
std::unique_ptr<class OpenGLShader> m_ownedComputeShader;
|
||||
bool m_programAttached = false;
|
||||
|
||||
// OpenGL specific state
|
||||
|
||||
@@ -77,6 +77,10 @@ public:
|
||||
virtual RHICommandQueue* CreateCommandQueue(const CommandQueueDesc& desc) = 0;
|
||||
virtual RHIShader* CreateShader(const ShaderCompileDesc& desc) = 0;
|
||||
virtual RHIPipelineState* CreatePipelineState(const GraphicsPipelineDesc& desc) = 0;
|
||||
virtual RHIPipelineState* CreateComputePipelineState(const ComputePipelineDesc& desc) {
|
||||
(void)desc;
|
||||
return nullptr;
|
||||
}
|
||||
virtual RHIPipelineLayout* CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) = 0;
|
||||
virtual RHIFence* CreateFence(const FenceDesc& desc) = 0;
|
||||
virtual RHISampler* CreateSampler(const SamplerDesc& desc) = 0;
|
||||
|
||||
@@ -46,6 +46,12 @@ struct ShaderCompileMacro {
|
||||
};
|
||||
|
||||
enum class ShaderLanguage : uint8_t;
|
||||
enum class ShaderBinaryBackend : uint8_t {
|
||||
Unknown = 0,
|
||||
D3D12,
|
||||
OpenGL,
|
||||
Vulkan
|
||||
};
|
||||
|
||||
struct ShaderCompileDesc {
|
||||
std::wstring fileName;
|
||||
@@ -55,6 +61,8 @@ struct ShaderCompileDesc {
|
||||
std::wstring entryPoint;
|
||||
std::wstring profile;
|
||||
std::vector<ShaderCompileMacro> macros;
|
||||
ShaderBinaryBackend compiledBinaryBackend = ShaderBinaryBackend::Unknown;
|
||||
std::vector<uint8_t> compiledBinary;
|
||||
};
|
||||
|
||||
struct InputElementDesc {
|
||||
@@ -327,6 +335,11 @@ struct GraphicsPipelineDesc {
|
||||
uint32_t sampleQuality = 0;
|
||||
};
|
||||
|
||||
struct ComputePipelineDesc {
|
||||
ShaderCompileDesc computeShader;
|
||||
RHIPipelineLayout* pipelineLayout = nullptr;
|
||||
};
|
||||
|
||||
struct RHIDeviceDesc {
|
||||
bool enableDebugLayer = false;
|
||||
bool enableGPUValidation = false;
|
||||
|
||||
@@ -25,6 +25,7 @@ public:
|
||||
RHICommandQueue* CreateCommandQueue(const CommandQueueDesc& desc) override;
|
||||
RHIShader* CreateShader(const ShaderCompileDesc& desc) override;
|
||||
RHIPipelineState* CreatePipelineState(const GraphicsPipelineDesc& desc) override;
|
||||
RHIPipelineState* CreateComputePipelineState(const ComputePipelineDesc& desc) override;
|
||||
RHIPipelineLayout* CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) override;
|
||||
RHIFence* CreateFence(const FenceDesc& desc) override;
|
||||
RHISampler* CreateSampler(const SamplerDesc& desc) override;
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
#include "XCEngine/RHI/RHIPipelineState.h"
|
||||
#include "XCEngine/RHI/Vulkan/VulkanCommon.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace XCEngine {
|
||||
namespace RHI {
|
||||
|
||||
@@ -15,6 +17,10 @@ public:
|
||||
~VulkanPipelineState() override;
|
||||
|
||||
bool Initialize(VulkanDevice* device, const GraphicsPipelineDesc& desc);
|
||||
bool InitializeCompute(
|
||||
VulkanDevice* device,
|
||||
RHIPipelineLayout* pipelineLayout,
|
||||
std::unique_ptr<VulkanShader> shader);
|
||||
|
||||
void SetInputLayout(const InputLayoutDesc& layout) override;
|
||||
void SetRasterizerState(const RasterizerDesc& state) override;
|
||||
@@ -23,7 +29,9 @@ public:
|
||||
void SetTopology(uint32_t topologyType) override;
|
||||
void SetRenderTargetFormats(uint32_t count, const uint32_t* formats, uint32_t depthFormat) override;
|
||||
void SetSampleCount(uint32_t count) override;
|
||||
void SetSampleQuality(uint32_t quality) override;
|
||||
void SetComputeShader(RHIShader* shader) override;
|
||||
void SetOwnedComputeShader(std::unique_ptr<VulkanShader> shader);
|
||||
|
||||
const RasterizerDesc& GetRasterizerState() const override { return m_rasterizerDesc; }
|
||||
const BlendDesc& GetBlendState() const override { return m_blendDesc; }
|
||||
@@ -31,7 +39,7 @@ public:
|
||||
const InputLayoutDesc& GetInputLayout() const override { return m_inputLayoutDesc; }
|
||||
PipelineStateHash GetHash() const override;
|
||||
RHIShader* GetComputeShader() const override { return m_computeShader; }
|
||||
bool HasComputeShader() const override { return m_computeShader != nullptr; }
|
||||
bool HasComputeShader() const override { return m_hasComputeShader; }
|
||||
|
||||
bool IsValid() const override { return m_isConfigured; }
|
||||
void EnsureValid() override;
|
||||
@@ -52,6 +60,7 @@ public:
|
||||
|
||||
private:
|
||||
bool EnsurePipelineLayout(const GraphicsPipelineDesc& desc);
|
||||
bool EnsurePipelineLayout(RHIPipelineLayout* pipelineLayout);
|
||||
bool CreateGraphicsPipeline(const GraphicsPipelineDesc& desc);
|
||||
bool CreateComputePipeline();
|
||||
|
||||
@@ -71,6 +80,8 @@ private:
|
||||
uint32_t m_depthStencilFormat = 0;
|
||||
uint32_t m_sampleCount = 1;
|
||||
RHIShader* m_computeShader = nullptr;
|
||||
std::unique_ptr<VulkanShader> m_ownedComputeShader;
|
||||
bool m_hasComputeShader = false;
|
||||
bool m_isConfigured = false;
|
||||
};
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ uint64_t GetVolumeTraceSteadyMs();
|
||||
void LogVolumeTraceRendering(const std::string& message);
|
||||
|
||||
bool HasShaderPayload(const ShaderCompileDesc& desc) {
|
||||
return !desc.source.empty() || !desc.fileName.empty();
|
||||
return !desc.source.empty() || !desc.fileName.empty() || !desc.compiledBinary.empty();
|
||||
}
|
||||
|
||||
bool ShouldTraceVolumetricShaderCompile(const ShaderCompileDesc& desc) {
|
||||
@@ -94,6 +94,23 @@ bool CompileD3D12Shader(const ShaderCompileDesc& desc, D3D12Shader& shader) {
|
||||
const char* entryPointPtr = entryPoint.empty() ? nullptr : entryPoint.c_str();
|
||||
const char* profilePtr = profile.empty() ? nullptr : profile.c_str();
|
||||
|
||||
if (desc.compiledBinaryBackend == ShaderBinaryBackend::D3D12 &&
|
||||
!desc.compiledBinary.empty()) {
|
||||
const bool compiled = shader.InitializeFromBytecode(
|
||||
desc.compiledBinary.data(),
|
||||
desc.compiledBinary.size(),
|
||||
profilePtr);
|
||||
if (traceShaderCompile) {
|
||||
const uint64_t compileEndMs = GetVolumeTraceSteadyMs();
|
||||
LogVolumeTraceRendering(
|
||||
std::string("D3D12 shader compile ") + (compiled ? "cache-hit" : "cache-hit-failed") +
|
||||
" steady_ms=" + std::to_string(compileEndMs) +
|
||||
" total_ms=" + std::to_string(compileEndMs - compileStartMs) + " " +
|
||||
DescribeShaderCompileDesc(desc));
|
||||
}
|
||||
return compiled;
|
||||
}
|
||||
|
||||
if (!desc.source.empty()) {
|
||||
std::vector<std::string> macroNames;
|
||||
std::vector<std::string> macroDefinitions;
|
||||
@@ -1228,25 +1245,7 @@ RHITexture* D3D12Device::CreateTexture(const TextureDesc& desc, const void* init
|
||||
|
||||
RHIShader* D3D12Device::CreateShader(const ShaderCompileDesc& desc) {
|
||||
auto* shader = new D3D12Shader();
|
||||
const std::string entryPoint = NarrowAscii(desc.entryPoint);
|
||||
const std::string profile = NarrowAscii(desc.profile);
|
||||
const char* entryPointPtr = entryPoint.empty() ? nullptr : entryPoint.c_str();
|
||||
const char* profilePtr = profile.empty() ? nullptr : profile.c_str();
|
||||
|
||||
bool success = false;
|
||||
if (!desc.source.empty()) {
|
||||
success = shader->Compile(
|
||||
desc.source.data(),
|
||||
desc.source.size(),
|
||||
desc.fileName.empty() ? nullptr : desc.fileName.c_str(),
|
||||
nullptr,
|
||||
entryPointPtr,
|
||||
profilePtr);
|
||||
} else if (!desc.fileName.empty()) {
|
||||
success = shader->CompileFromFile(desc.fileName.c_str(), entryPointPtr, profilePtr);
|
||||
}
|
||||
|
||||
if (success) {
|
||||
if (CompileD3D12Shader(desc, *shader)) {
|
||||
return shader;
|
||||
}
|
||||
delete shader;
|
||||
@@ -1524,6 +1523,33 @@ RHIPipelineState* D3D12Device::CreatePipelineState(const GraphicsPipelineDesc& d
|
||||
return pso;
|
||||
}
|
||||
|
||||
RHIPipelineState* D3D12Device::CreateComputePipelineState(const ComputePipelineDesc& desc) {
|
||||
if (!HasShaderPayload(desc.computeShader)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* pso = new D3D12PipelineState(m_device.Get());
|
||||
if (desc.pipelineLayout != nullptr) {
|
||||
auto* pipelineLayout = static_cast<D3D12PipelineLayout*>(desc.pipelineLayout);
|
||||
pso->SetRootSignature(pipelineLayout->GetRootSignature());
|
||||
}
|
||||
|
||||
D3D12Shader computeShader;
|
||||
if (!CompileD3D12Shader(desc.computeShader, computeShader)) {
|
||||
delete pso;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
pso->SetComputeShaderBytecodes(computeShader.GetD3D12Bytecode());
|
||||
pso->EnsureValid();
|
||||
if (!pso->IsValid()) {
|
||||
delete pso;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return pso;
|
||||
}
|
||||
|
||||
RHIPipelineLayout* D3D12Device::CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) {
|
||||
auto* pipelineLayout = new D3D12PipelineLayout();
|
||||
if (!pipelineLayout->InitializeWithDevice(this, desc)) {
|
||||
|
||||
@@ -109,6 +109,7 @@ bool D3D12PipelineState::Initialize(ID3D12Device* device, const D3D12_GRAPHICS_P
|
||||
}
|
||||
m_depthStencilFormat = static_cast<uint32_t>(desc.DSVFormat);
|
||||
m_sampleCount = desc.SampleDesc.Count;
|
||||
m_sampleQuality = desc.SampleDesc.Quality;
|
||||
|
||||
// Set shader bytecodes
|
||||
m_vsBytecode = desc.VS;
|
||||
@@ -173,13 +174,21 @@ void D3D12PipelineState::SetSampleCount(uint32_t count) {
|
||||
m_sampleCount = count;
|
||||
}
|
||||
|
||||
void D3D12PipelineState::SetSampleQuality(uint32_t quality) {
|
||||
m_sampleQuality = quality;
|
||||
}
|
||||
|
||||
PipelineStateHash D3D12PipelineState::GetHash() const {
|
||||
PipelineStateHash hash = {};
|
||||
hash.blendStateHash = std::hash<uint64_t>{}(*reinterpret_cast<const uint64_t*>(&m_blendDesc));
|
||||
hash.depthStateHash = std::hash<uint64_t>{}(*reinterpret_cast<const uint64_t*>(&m_depthStencilDesc));
|
||||
hash.rasterizerStateHash = std::hash<uint64_t>{}(*reinterpret_cast<const uint64_t*>(&m_rasterizerDesc));
|
||||
hash.topologyHash = m_topologyType;
|
||||
hash.renderTargetHash = m_renderTargetCount | (m_depthStencilFormat << 8);
|
||||
hash.renderTargetHash =
|
||||
m_renderTargetCount |
|
||||
(static_cast<uint64_t>(m_depthStencilFormat) << 8) |
|
||||
(static_cast<uint64_t>(m_sampleCount) << 32) |
|
||||
(static_cast<uint64_t>(m_sampleQuality) << 48);
|
||||
return hash;
|
||||
}
|
||||
|
||||
@@ -190,6 +199,9 @@ void D3D12PipelineState::SetShaderBytecodes(const D3D12_SHADER_BYTECODE& vs, con
|
||||
}
|
||||
|
||||
void D3D12PipelineState::SetComputeShader(RHIShader* shader) {
|
||||
if (shader != m_ownedComputeShader.get()) {
|
||||
m_ownedComputeShader.reset();
|
||||
}
|
||||
m_computeShader = shader;
|
||||
m_csBytecode = {};
|
||||
m_computePipelineState.Reset();
|
||||
@@ -204,6 +216,11 @@ void D3D12PipelineState::SetComputeShader(RHIShader* shader) {
|
||||
}
|
||||
}
|
||||
|
||||
void D3D12PipelineState::SetOwnedComputeShader(std::unique_ptr<D3D12Shader> shader) {
|
||||
m_ownedComputeShader = std::move(shader);
|
||||
SetComputeShader(m_ownedComputeShader.get());
|
||||
}
|
||||
|
||||
void D3D12PipelineState::SetRootSignature(ID3D12RootSignature* rootSignature) {
|
||||
m_rootSignature = rootSignature;
|
||||
}
|
||||
@@ -275,7 +292,7 @@ bool D3D12PipelineState::CreateD3D12PSO() {
|
||||
}
|
||||
desc.DSVFormat = ToD3D12(static_cast<Format>(m_depthStencilFormat));
|
||||
desc.SampleDesc.Count = m_sampleCount;
|
||||
desc.SampleDesc.Quality = 0;
|
||||
desc.SampleDesc.Quality = m_sampleQuality;
|
||||
desc.SampleMask = 0xffffffff;
|
||||
desc.PrimitiveTopologyType = static_cast<D3D12_PRIMITIVE_TOPOLOGY_TYPE>(m_topologyType);
|
||||
|
||||
@@ -349,6 +366,7 @@ void D3D12PipelineState::Shutdown() {
|
||||
m_pipelineState.Reset();
|
||||
m_computePipelineState.Reset();
|
||||
m_rootSignature.Reset();
|
||||
m_ownedComputeShader.reset();
|
||||
m_vsBytecode = {};
|
||||
m_psBytecode = {};
|
||||
m_gsBytecode = {};
|
||||
|
||||
@@ -136,7 +136,7 @@ uint64_t QuerySharedSystemMemoryBytes() {
|
||||
}
|
||||
|
||||
bool HasShaderPayload(const ShaderCompileDesc& desc) {
|
||||
return !desc.source.empty() || !desc.fileName.empty();
|
||||
return !desc.source.empty() || !desc.fileName.empty() || !desc.compiledBinary.empty();
|
||||
}
|
||||
|
||||
std::string SourceToString(const ShaderCompileDesc& desc) {
|
||||
@@ -594,8 +594,20 @@ bool CompileOpenGLShaderHandleFromSpirv(const ShaderCompileDesc& desc,
|
||||
GLuint& outShaderHandle,
|
||||
ShaderType& outShaderType,
|
||||
std::string* errorMessage) {
|
||||
ShaderCompileDesc spirvDesc = desc;
|
||||
SpirvTargetEnvironment targetEnvironment = SpirvTargetEnvironment::Vulkan;
|
||||
if (desc.compiledBinaryBackend == ShaderBinaryBackend::OpenGL &&
|
||||
!desc.compiledBinary.empty()) {
|
||||
spirvDesc.source = desc.compiledBinary;
|
||||
spirvDesc.sourceLanguage = ShaderLanguage::SPIRV;
|
||||
spirvDesc.fileName.clear();
|
||||
spirvDesc.compiledBinary.clear();
|
||||
spirvDesc.compiledBinaryBackend = ShaderBinaryBackend::Unknown;
|
||||
targetEnvironment = SpirvTargetEnvironment::OpenGL;
|
||||
}
|
||||
|
||||
CompiledSpirvShader compiledShader = {};
|
||||
if (!CompileSpirvShader(desc, SpirvTargetEnvironment::Vulkan, compiledShader, errorMessage)) {
|
||||
if (!CompileSpirvShader(spirvDesc, targetEnvironment, compiledShader, errorMessage)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1326,7 +1338,7 @@ RHIShader* OpenGLDevice::CreateShader(const ShaderCompileDesc& desc) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (desc.source.empty() && desc.fileName.empty()) {
|
||||
if (!HasShaderPayload(desc)) {
|
||||
delete shader;
|
||||
return nullptr;
|
||||
}
|
||||
@@ -1378,6 +1390,7 @@ RHIPipelineState* OpenGLDevice::CreatePipelineState(const GraphicsPipelineDesc&
|
||||
pso->SetTopology(desc.topologyType);
|
||||
pso->SetRenderTargetFormats(desc.renderTargetCount, desc.renderTargetFormats, desc.depthStencilFormat);
|
||||
pso->SetSampleCount(desc.sampleCount);
|
||||
pso->SetSampleQuality(desc.sampleQuality);
|
||||
|
||||
const bool hasVertexShader = HasShaderPayload(desc.vertexShader);
|
||||
const bool hasFragmentShader = HasShaderPayload(desc.fragmentShader);
|
||||
@@ -1422,6 +1435,21 @@ RHIPipelineState* OpenGLDevice::CreatePipelineState(const GraphicsPipelineDesc&
|
||||
return pso;
|
||||
}
|
||||
|
||||
RHIPipelineState* OpenGLDevice::CreateComputePipelineState(const ComputePipelineDesc& desc) {
|
||||
if (!HasShaderPayload(desc.computeShader) || !MakeContextCurrent()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto computeShader = std::unique_ptr<OpenGLShader>(static_cast<OpenGLShader*>(CreateShader(desc.computeShader)));
|
||||
if (computeShader == nullptr || computeShader->GetType() != ShaderType::Compute) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* pso = new OpenGLPipelineState();
|
||||
pso->SetOwnedComputeShader(std::move(computeShader));
|
||||
return pso;
|
||||
}
|
||||
|
||||
RHIPipelineLayout* OpenGLDevice::CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) {
|
||||
auto* layout = new OpenGLPipelineLayout();
|
||||
if (!layout->Initialize(desc)) {
|
||||
|
||||
@@ -72,7 +72,13 @@ void OpenGLPipelineState::SetRenderTargetFormats(uint32_t count, const uint32_t*
|
||||
void OpenGLPipelineState::SetSampleCount(uint32_t count) {
|
||||
}
|
||||
|
||||
void OpenGLPipelineState::SetSampleQuality(uint32_t quality) {
|
||||
}
|
||||
|
||||
void OpenGLPipelineState::SetComputeShader(RHIShader* shader) {
|
||||
if (shader != m_ownedComputeShader.get()) {
|
||||
m_ownedComputeShader.reset();
|
||||
}
|
||||
m_computeShader = shader;
|
||||
if (shader) {
|
||||
OpenGLShader* glShader = static_cast<OpenGLShader*>(shader);
|
||||
@@ -82,6 +88,12 @@ void OpenGLPipelineState::SetComputeShader(RHIShader* shader) {
|
||||
}
|
||||
}
|
||||
|
||||
void OpenGLPipelineState::SetOwnedComputeShader(std::unique_ptr<OpenGLShader> shader) {
|
||||
m_ownedComputeShader = std::move(shader);
|
||||
m_computeShader = m_ownedComputeShader.get();
|
||||
m_computeProgram = m_ownedComputeShader != nullptr ? m_ownedComputeShader->GetID() : 0;
|
||||
}
|
||||
|
||||
PipelineStateHash OpenGLPipelineState::GetHash() const {
|
||||
PipelineStateHash hash = {};
|
||||
return hash;
|
||||
@@ -89,6 +101,7 @@ PipelineStateHash OpenGLPipelineState::GetHash() const {
|
||||
|
||||
void OpenGLPipelineState::Shutdown() {
|
||||
m_graphicsShader.reset();
|
||||
m_ownedComputeShader.reset();
|
||||
m_program = 0;
|
||||
m_computeProgram = 0;
|
||||
m_computeShader = nullptr;
|
||||
|
||||
@@ -677,6 +677,21 @@ RHIPipelineState* VulkanDevice::CreatePipelineState(const GraphicsPipelineDesc&
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
RHIPipelineState* VulkanDevice::CreateComputePipelineState(const ComputePipelineDesc& desc) {
|
||||
auto computeShader = std::unique_ptr<VulkanShader>(static_cast<VulkanShader*>(CreateShader(desc.computeShader)));
|
||||
if (computeShader == nullptr || computeShader->GetType() != ShaderType::Compute) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* pipelineState = new VulkanPipelineState();
|
||||
if (pipelineState->InitializeCompute(this, desc.pipelineLayout, std::move(computeShader))) {
|
||||
return pipelineState;
|
||||
}
|
||||
|
||||
delete pipelineState;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
RHIPipelineLayout* VulkanDevice::CreatePipelineLayout(const RHIPipelineLayoutDesc& desc) {
|
||||
auto* pipelineLayout = new VulkanPipelineLayout();
|
||||
if (pipelineLayout->Initialize(m_device, desc)) {
|
||||
|
||||
@@ -17,7 +17,7 @@ namespace RHI {
|
||||
namespace {
|
||||
|
||||
bool HasShaderPayload(const ShaderCompileDesc& desc) {
|
||||
return !desc.source.empty() || !desc.fileName.empty();
|
||||
return !desc.source.empty() || !desc.fileName.empty() || !desc.compiledBinary.empty();
|
||||
}
|
||||
|
||||
VkShaderModule CreateShaderModule(VkDevice device, const std::vector<uint32_t>& words) {
|
||||
@@ -91,6 +91,9 @@ bool VulkanPipelineState::Initialize(VulkanDevice* device, const GraphicsPipelin
|
||||
m_renderTargetCount = desc.renderTargetCount;
|
||||
m_depthStencilFormat = desc.depthStencilFormat;
|
||||
m_sampleCount = desc.sampleCount > 0 ? desc.sampleCount : 1;
|
||||
m_computeShader = nullptr;
|
||||
m_ownedComputeShader.reset();
|
||||
m_hasComputeShader = false;
|
||||
for (uint32_t i = 0; i < 8; ++i) {
|
||||
m_renderTargetFormats[i] = desc.renderTargetFormats[i];
|
||||
}
|
||||
@@ -125,8 +128,41 @@ bool VulkanPipelineState::Initialize(VulkanDevice* device, const GraphicsPipelin
|
||||
return true;
|
||||
}
|
||||
|
||||
bool VulkanPipelineState::InitializeCompute(
|
||||
VulkanDevice* device,
|
||||
RHIPipelineLayout* pipelineLayout,
|
||||
std::unique_ptr<VulkanShader> shader) {
|
||||
if (device == nullptr ||
|
||||
device->GetDevice() == VK_NULL_HANDLE ||
|
||||
shader == nullptr ||
|
||||
!shader->IsValid() ||
|
||||
shader->GetType() != ShaderType::Compute) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Shutdown();
|
||||
|
||||
m_deviceOwner = device;
|
||||
m_device = device->GetDevice();
|
||||
m_ownedComputeShader = std::move(shader);
|
||||
m_computeShader = m_ownedComputeShader.get();
|
||||
m_hasComputeShader = m_computeShader != nullptr;
|
||||
|
||||
if (!EnsurePipelineLayout(pipelineLayout) || !CreateComputePipeline()) {
|
||||
Shutdown();
|
||||
return false;
|
||||
}
|
||||
|
||||
m_isConfigured = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool VulkanPipelineState::EnsurePipelineLayout(const GraphicsPipelineDesc& desc) {
|
||||
auto* externalPipelineLayout = static_cast<VulkanPipelineLayout*>(desc.pipelineLayout);
|
||||
return EnsurePipelineLayout(desc.pipelineLayout);
|
||||
}
|
||||
|
||||
bool VulkanPipelineState::EnsurePipelineLayout(RHIPipelineLayout* pipelineLayout) {
|
||||
auto* externalPipelineLayout = static_cast<VulkanPipelineLayout*>(pipelineLayout);
|
||||
if (externalPipelineLayout != nullptr) {
|
||||
m_pipelineLayout = externalPipelineLayout->GetPipelineLayout();
|
||||
m_ownsPipelineLayout = false;
|
||||
@@ -425,15 +461,33 @@ void VulkanPipelineState::SetSampleCount(uint32_t count) {
|
||||
m_sampleCount = count > 0 ? count : 1;
|
||||
}
|
||||
|
||||
void VulkanPipelineState::SetSampleQuality(uint32_t quality) {
|
||||
(void)quality;
|
||||
}
|
||||
|
||||
void VulkanPipelineState::SetComputeShader(RHIShader* shader) {
|
||||
if (m_pipeline != VK_NULL_HANDLE && m_device != VK_NULL_HANDLE) {
|
||||
vkDestroyPipeline(m_device, m_pipeline, nullptr);
|
||||
m_pipeline = VK_NULL_HANDLE;
|
||||
}
|
||||
m_computeShader = shader;
|
||||
if (m_computeShader != nullptr && m_pipelineLayout != VK_NULL_HANDLE) {
|
||||
m_isConfigured = true;
|
||||
if (shader != m_ownedComputeShader.get()) {
|
||||
m_ownedComputeShader.reset();
|
||||
}
|
||||
m_computeShader = shader;
|
||||
m_hasComputeShader = m_computeShader != nullptr;
|
||||
m_isConfigured = m_hasComputeShader && m_pipelineLayout != VK_NULL_HANDLE;
|
||||
}
|
||||
|
||||
void VulkanPipelineState::SetOwnedComputeShader(std::unique_ptr<VulkanShader> shader) {
|
||||
if (m_pipeline != VK_NULL_HANDLE && m_device != VK_NULL_HANDLE) {
|
||||
vkDestroyPipeline(m_device, m_pipeline, nullptr);
|
||||
m_pipeline = VK_NULL_HANDLE;
|
||||
}
|
||||
|
||||
m_ownedComputeShader = std::move(shader);
|
||||
m_computeShader = m_ownedComputeShader.get();
|
||||
m_hasComputeShader = m_computeShader != nullptr;
|
||||
m_isConfigured = m_hasComputeShader && m_pipelineLayout != VK_NULL_HANDLE;
|
||||
}
|
||||
|
||||
PipelineStateHash VulkanPipelineState::GetHash() const {
|
||||
@@ -499,7 +553,9 @@ void VulkanPipelineState::Shutdown() {
|
||||
m_deviceOwner = nullptr;
|
||||
m_device = VK_NULL_HANDLE;
|
||||
m_ownsPipelineLayout = false;
|
||||
m_ownedComputeShader.reset();
|
||||
m_computeShader = nullptr;
|
||||
m_hasComputeShader = false;
|
||||
m_isConfigured = false;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user