RHI: Replace IsFinalized/Finalize with IsValid/EnsureValid

Unified PSO validation semantics across D3D12 and OpenGL backends:
- IsValid() returns whether PSO is ready to use
- EnsureValid() ensures PSO is valid (compiles if needed)

Behavior by backend:
- D3D12: IsValid=false after creation, true after EnsureValid() with shaders
- OpenGL: IsValid always=true (immediate model)

Also added test_pipeline_state.cpp with 10 tests for RHIPipelineState.
This commit is contained in:
2026-03-25 12:28:33 +08:00
parent ca0d73c197
commit f808f8d197
7 changed files with 225 additions and 106 deletions

View File

@@ -1,3 +1,6 @@
最最最重要的是在重构RHI测试的过程中如果发现了RHI模块设计上的根本问题需要紧急向我汇报
# RHI 模块单元测试重构计划 # RHI 模块单元测试重构计划
## 1. 项目背景 ## 1. 项目背景
@@ -24,7 +27,7 @@
| `test_swap_chain.cpp` | 4 | RHISwapChain | | `test_swap_chain.cpp` | 4 | RHISwapChain |
| `test_command_list.cpp` | 14 | RHICommandList | | `test_command_list.cpp` | 14 | RHICommandList |
| `test_command_queue.cpp` | 6 | RHICommandQueue | | `test_command_queue.cpp` | 6 | RHICommandQueue |
| `test_shader.cpp` | 9 | RHIShader | | `test_shader.cpp` | 7 | RHIShader |
| `test_fence.cpp` | 10 | RHIFence | | `test_fence.cpp` | 10 | RHIFence |
| `test_sampler.cpp` | 4 | RHISampler | | `test_sampler.cpp` | 4 | RHISampler |
| `test_factory.cpp` | 5 | RHIFactory | | `test_factory.cpp` | 5 | RHIFactory |
@@ -68,103 +71,44 @@
## 2. 严重问题P0 - 必须修复) ## 2. 严重问题P0 - 必须修复)
### 2.1 `tests/RHI/unit/test_shader.cpp` - 9个测试完全相同 ### 2.1 `tests/RHI/unit/test_shader.cpp` - 9个测试完全相同 ✅ 已修复
**严重程度**: 🔴 致命 **状态**: ✅ 已完成 (2026-03-25)
**问题描述**: **修复内容**:
所有 9 个 Shader 测试逻辑完全相同,都只测试空描述符返回 nullptr 1. **Shader 编译抽象重构** - 扩展 `ShaderCompileDesc` 支持内嵌源码
```cpp
struct ShaderCompileDesc {
std::wstring fileName; // 文件路径(可选)
std::vector<uint8_t> source; // 内嵌源码(可选)
ShaderLanguage sourceLanguage; // 源码语言HLSL/GLSL/SPIRV
std::wstring entryPoint; // Entry point 名称
std::wstring profile; // Profile
std::vector<ShaderCompileMacro> macros;
};
```
```cpp 2. **新增 `ShaderLanguage` 枚举**:
// 文件: tests/RHI/unit/test_shader.cpp ```cpp
enum class ShaderLanguage : uint8_t {
Unknown,
HLSL, // D3D11/D3D12
GLSL, // OpenGL/Vulkan
SPIRV // Vulkan (pre-compiled)
};
```
// 9个测试全部是这种结构 3. **测试重构** - 9个相同测试 → 7个有效测试
TEST_P(RHITestFixture, Shader_Compile_EmptyDesc_ReturnsNullptr) { - `Shader_Compile_EmptyDesc_ReturnsNullptr` - 空描述符测试
RHIShader* shader = GetDevice()->CompileShader({}); // 空描述符 - `Shader_Compile_ValidVertexShader` - 顶点着色器编译测试
EXPECT_EQ(shader, nullptr); // 期望返回nullptr - `Shader_Compile_ValidFragmentShader` - 片段着色器编译测试
} - `Shader_GetType_VertexShader` - 获取顶点着色器类型
- `Shader_GetType_FragmentShader` - 获取片段着色器类型
- `Shader_GetNativeHandle_ValidShader` - 获取原生句柄
- `Shader_Shutdown_Invalidates` - 关闭后验证失效
TEST_P(RHITestFixture, Shader_GetType_WithNullShader) { **测试结果**: 14/14 通过 (D3D12: 7, OpenGL: 7)
RHIShader* shader = GetDevice()->CompileShader({}); // 空描述符
EXPECT_EQ(shader, nullptr); // 期望返回nullptr - 与上面完全相同!
}
TEST_P(RHITestFixture, Shader_IsValid_WithNullShader) {
RHIShader* shader = GetDevice()->CompileShader({});
EXPECT_EQ(shader, nullptr); // 与上面完全相同!
}
// ... 剩下的6个测试也是同样的模式
```
**影响**:
- Shader 模块完全没有真实功能测试
- Shader 编译、绑定、uniform 设置等功能完全未覆盖
- 这 9 个测试等效于只有 1 个测试
**修复建议**:
```cpp
// 1. 保留错误处理测试 (1-2个)
TEST_P(RHITestFixture, Shader_Compile_EmptyDesc_ReturnsNullptr) {
RHIShader* shader = GetDevice()->CompileShader({});
EXPECT_EQ(shader, nullptr);
}
// 2. 添加真实 shader 编译测试 (新增)
// 注意: 需要 shader 文件或内嵌 GLSL/HLSL 源码
TEST_P(RHITestFixture, Shader_Compile_ValidShader) {
ShaderCompileDesc desc = {};
desc.entryPoint = "main";
desc.shaderType = ShaderType::Vertex;
// 需要实际的 shader 源码或文件路径
RHIShader* shader = GetDevice()->CompileShader(desc);
ASSERT_NE(shader, nullptr);
shader->Shutdown();
delete shader;
}
// 3. 添加 shader 绑定测试 (新增)
TEST_P(RHITestFixture, Shader_Bind_ValidShader) {
// 编译 shader 后绑定
RHIShader* shader = GetDevice()->CompileShader(validDesc);
ASSERT_NE(shader, nullptr);
RHICommandList* cmdList = GetDevice()->CreateCommandList({});
cmdList->Reset();
cmdList->SetShader(shader);
cmdList->Close();
cmdList->Shutdown();
delete cmdList;
shader->Shutdown();
delete shader;
}
// 4. 添加 uniform 设置测试 (新增)
TEST_P(RHITestFixture, Shader_SetUniform_Int) {
RHIShader* shader = GetDevice()->CompileShader(validDesc);
ASSERT_NE(shader, nullptr);
shader->SetInt("uniformName", 42);
shader->Shutdown();
delete shader;
}
TEST_P(RHITestFixture, Shader_SetUniform_Float) {
// ...
}
// 5. 添加 GetNativeHandle 测试 (保留一个)
TEST_P(RHITestFixture, Shader_GetNativeHandle_ValidShader) {
RHIShader* shader = GetDevice()->CompileShader(validDesc);
ASSERT_NE(shader, nullptr);
EXPECT_NE(shader->GetNativeHandle(), nullptr);
shader->Shutdown();
delete shader;
}
```
--- ---

View File

@@ -42,9 +42,9 @@ public:
RHIShader* GetComputeShader() const override { return m_computeShader; } RHIShader* GetComputeShader() const override { return m_computeShader; }
bool HasComputeShader() const override { return m_csBytecode.pShaderBytecode != nullptr && m_csBytecode.BytecodeLength > 0; } bool HasComputeShader() const override { return m_csBytecode.pShaderBytecode != nullptr && m_csBytecode.BytecodeLength > 0; }
// Finalization // Validation
bool IsFinalized() const override { return m_finalized; } bool IsValid() const override { return m_finalized; }
bool Finalize() override; void EnsureValid() override;
// Shader Bytecode (set by CommandList when binding) // Shader Bytecode (set by CommandList when binding)
void SetShaderBytecodes(const D3D12_SHADER_BYTECODE& vs, const D3D12_SHADER_BYTECODE& ps, const D3D12_SHADER_BYTECODE& gs = {}); void SetShaderBytecodes(const D3D12_SHADER_BYTECODE& vs, const D3D12_SHADER_BYTECODE& ps, const D3D12_SHADER_BYTECODE& gs = {});

View File

@@ -90,9 +90,9 @@ public:
RHIShader* GetComputeShader() const override { return m_computeShader; } RHIShader* GetComputeShader() const override { return m_computeShader; }
bool HasComputeShader() const override { return m_computeProgram != 0; } bool HasComputeShader() const override { return m_computeProgram != 0; }
// Finalization (OpenGL doesn't need it) // Validation (OpenGL总是有效,即时模型)
bool IsFinalized() const override { return true; } bool IsValid() const override { return true; }
bool Finalize() override { return true; } void EnsureValid() override {}
// Lifecycle // Lifecycle
void Shutdown() override; void Shutdown() override;

View File

@@ -31,9 +31,9 @@ public:
virtual RHIShader* GetComputeShader() const = 0; virtual RHIShader* GetComputeShader() const = 0;
virtual bool HasComputeShader() const = 0; virtual bool HasComputeShader() const = 0;
// Finalization (D3D12/Vulkan creates real PSO) // Validation (统一语义D3D12需要编译OpenGL总是有效)
virtual bool IsFinalized() const = 0; virtual bool IsValid() const = 0;
virtual bool Finalize() = 0; virtual void EnsureValid() = 0;
// Lifecycle // Lifecycle
virtual void Shutdown() = 0; virtual void Shutdown() = 0;

View File

@@ -55,7 +55,8 @@ bool D3D12PipelineState::Initialize(ID3D12Device* device, const D3D12_GRAPHICS_P
m_inputElements.push_back(desc.InputLayout.pInputElementDescs[i]); m_inputElements.push_back(desc.InputLayout.pInputElementDescs[i]);
} }
return Finalize(); EnsureValid();
return m_finalized;
} }
D3D12PipelineState::~D3D12PipelineState() { D3D12PipelineState::~D3D12PipelineState() {
@@ -130,12 +131,13 @@ void D3D12PipelineState::SetComputeShaderBytecodes(const D3D12_SHADER_BYTECODE&
m_csBytecode = cs; m_csBytecode = cs;
} }
bool D3D12PipelineState::Finalize() { void D3D12PipelineState::EnsureValid() {
if (m_finalized) return true; if (m_finalized) return;
if (HasComputeShader()) { if (HasComputeShader()) {
return CreateD3D12ComputePSO(); CreateD3D12ComputePSO();
} else {
CreateD3D12PSO();
} }
return CreateD3D12PSO();
} }
bool D3D12PipelineState::CreateD3D12PSO() { bool D3D12PipelineState::CreateD3D12PSO() {

View File

@@ -13,6 +13,7 @@ set(TEST_SOURCES
test_command_list.cpp test_command_list.cpp
test_command_queue.cpp test_command_queue.cpp
test_shader.cpp test_shader.cpp
test_pipeline_state.cpp
test_fence.cpp test_fence.cpp
test_sampler.cpp test_sampler.cpp
${CMAKE_SOURCE_DIR}/tests/opengl/package/src/glad.c ${CMAKE_SOURCE_DIR}/tests/opengl/package/src/glad.c

View File

@@ -0,0 +1,172 @@
#include "fixtures/RHITestFixture.h"
#include "XCEngine/RHI/RHIPipelineState.h"
#include <cstring>
using namespace XCEngine::RHI;
TEST_P(RHITestFixture, PipelineState_Create_DefaultDesc) {
GraphicsPipelineDesc desc = {};
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
if (pso != nullptr) {
if (GetBackendType() == RHIType::D3D12) {
EXPECT_FALSE(pso->IsValid());
} else {
EXPECT_TRUE(pso->IsValid());
}
pso->Shutdown();
delete pso;
}
}
TEST_P(RHITestFixture, PipelineState_SetGet_RasterizerState) {
GraphicsPipelineDesc desc = {};
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
ASSERT_NE(pso, nullptr);
RasterizerDesc rasterizer = {};
rasterizer.fillMode = 2;
rasterizer.cullMode = 2;
rasterizer.depthClipEnable = true;
pso->SetRasterizerState(rasterizer);
const RasterizerDesc& retrieved = pso->GetRasterizerState();
EXPECT_EQ(retrieved.cullMode, 2u);
EXPECT_EQ(retrieved.fillMode, 2u);
pso->Shutdown();
delete pso;
}
TEST_P(RHITestFixture, PipelineState_SetGet_BlendState) {
GraphicsPipelineDesc desc = {};
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
ASSERT_NE(pso, nullptr);
BlendDesc blend = {};
blend.blendEnable = true;
blend.srcBlend = 1;
blend.dstBlend = 0;
pso->SetBlendState(blend);
const BlendDesc& retrieved = pso->GetBlendState();
EXPECT_TRUE(retrieved.blendEnable);
EXPECT_EQ(retrieved.srcBlend, 1u);
pso->Shutdown();
delete pso;
}
TEST_P(RHITestFixture, PipelineState_SetGet_DepthStencilState) {
GraphicsPipelineDesc desc = {};
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
ASSERT_NE(pso, nullptr);
DepthStencilStateDesc ds = {};
ds.depthTestEnable = true;
ds.depthWriteEnable = true;
ds.depthFunc = 3;
pso->SetDepthStencilState(ds);
const DepthStencilStateDesc& retrieved = pso->GetDepthStencilState();
EXPECT_TRUE(retrieved.depthTestEnable);
EXPECT_TRUE(retrieved.depthWriteEnable);
EXPECT_EQ(retrieved.depthFunc, 3u);
pso->Shutdown();
delete pso;
}
TEST_P(RHITestFixture, PipelineState_SetGet_InputLayout) {
GraphicsPipelineDesc desc = {};
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
ASSERT_NE(pso, nullptr);
InputElementDesc element = {};
element.semanticName = "POSITION";
element.semanticIndex = 0;
element.format = static_cast<uint32_t>(Format::R32G32B32A32_Float);
element.inputSlot = 0;
element.alignedByteOffset = 0;
InputLayoutDesc layoutDesc = {};
layoutDesc.elements.push_back(element);
pso->SetInputLayout(layoutDesc);
const InputLayoutDesc& retrieved = pso->GetInputLayout();
ASSERT_EQ(retrieved.elements.size(), 1u);
EXPECT_EQ(retrieved.elements[0].semanticName, "POSITION");
pso->Shutdown();
delete pso;
}
TEST_P(RHITestFixture, PipelineState_SetTopology) {
GraphicsPipelineDesc desc = {};
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
ASSERT_NE(pso, nullptr);
pso->SetTopology(3);
pso->Shutdown();
delete pso;
}
TEST_P(RHITestFixture, PipelineState_SetRenderTargetFormats) {
GraphicsPipelineDesc desc = {};
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
ASSERT_NE(pso, nullptr);
uint32_t formats[] = { static_cast<uint32_t>(Format::R8G8B8A8_UNorm) };
pso->SetRenderTargetFormats(1, formats, static_cast<uint32_t>(Format::D24_UNorm_S8_UInt));
pso->Shutdown();
delete pso;
}
TEST_P(RHITestFixture, PipelineState_EnsureValid_IsValid) {
GraphicsPipelineDesc desc = {};
desc.renderTargetFormats[0] = static_cast<uint32_t>(Format::R8G8B8A8_UNorm);
desc.depthStencilFormat = static_cast<uint32_t>(Format::D24_UNorm_S8_UInt);
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
ASSERT_NE(pso, nullptr);
if (GetBackendType() == RHIType::D3D12) {
EXPECT_FALSE(pso->IsValid());
pso->EnsureValid();
EXPECT_FALSE(pso->IsValid());
} else {
EXPECT_TRUE(pso->IsValid());
pso->EnsureValid();
EXPECT_TRUE(pso->IsValid());
}
pso->Shutdown();
delete pso;
}
TEST_P(RHITestFixture, PipelineState_Shutdown_Invalidates) {
GraphicsPipelineDesc desc = {};
desc.renderTargetFormats[0] = static_cast<uint32_t>(Format::R8G8B8A8_UNorm);
desc.depthStencilFormat = static_cast<uint32_t>(Format::D24_UNorm_S8_UInt);
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
ASSERT_NE(pso, nullptr);
pso->EnsureValid();
pso->Shutdown();
pso->Shutdown();
delete pso;
}
TEST_P(RHITestFixture, PipelineState_GetType) {
GraphicsPipelineDesc desc = {};
RHIPipelineState* pso = GetDevice()->CreatePipelineState(desc);
ASSERT_NE(pso, nullptr);
PipelineType type = pso->GetType();
EXPECT_EQ(type, PipelineType::Graphics);
pso->Shutdown();
delete pso;
}