spv_shader_gen: Add SanitizeVertex

This commit is contained in:
Wunkolo 2024-03-03 09:43:48 -08:00
parent 3fb681bcc6
commit 319a278bf9
2 changed files with 73 additions and 1 deletions

View File

@ -1,3 +1,4 @@
#include "spv_shader_gen.h"
// Copyright 2024 Citra Emulator Project // Copyright 2024 Citra Emulator Project
// Licensed under GPLv2 or any later version // Licensed under GPLv2 or any later version
// Refer to the license.txt file included. // Refer to the license.txt file included.
@ -16,6 +17,9 @@ constexpr u32 SPIRV_VERSION_1_3 = 0x00010300;
VertexModule::VertexModule() : Sirit::Module{SPIRV_VERSION_1_3} { VertexModule::VertexModule() : Sirit::Module{SPIRV_VERSION_1_3} {
DefineArithmeticTypes(); DefineArithmeticTypes();
DefineInterface(); DefineInterface();
ids.sanitize_vertex = WriteFuncSanitizeVertex();
DefineEntryPoint(); DefineEntryPoint();
} }
@ -97,6 +101,66 @@ void VertexModule::DefineInterface() {
Decorate(ids.gl_position, spv::Decoration::BuiltIn, spv::BuiltIn::Position); Decorate(ids.gl_position, spv::Decoration::BuiltIn, spv::BuiltIn::Position);
} }
Id VertexModule::WriteFuncSanitizeVertex() {
const Id func_type = TypeFunction(ids.vec_ids.Get(4), ids.vec_ids.Get(4));
const Id func = Name(OpFunction(ids.vec_ids.Get(4), spv::FunctionControlMask::MaskNone, func_type), "SanitizeVertex");
const Id arg_pos = OpFunctionParameter(ids.vec_ids.Get(4));
AddLabel(OpLabel());
const Id result = AddLocalVariable(TypePointer(spv::StorageClass::Function, ids.vec_ids.Get(4)), spv::StorageClass::Function);
OpStore(result, arg_pos);
const Id pos_z = OpCompositeExtract(ids.f32_id, arg_pos, 2);
const Id pos_w = OpCompositeExtract(ids.f32_id, arg_pos, 3);
const Id ndc_z = OpFDiv(ids.f32_id, pos_z, pos_w);
// if (ndc_z > 0.f && ndc_z < 0.000001f)
const Id test_1 =
OpLogicalAnd(ids.bool_id, OpFOrdGreaterThan(ids.bool_id, ndc_z, Constant(ids.f32_id, 0.0f)),
OpFOrdLessThan(ids.bool_id, ndc_z, Constant(ids.f32_id, 0.000001f)));
{
const Id true_label = OpLabel();
const Id end_label = OpLabel();
OpSelectionMerge(end_label, spv::SelectionControlMask::MaskNone);
OpBranchConditional(test_1, true_label, end_label);
AddLabel(true_label);
// .z = 0.0f;
OpStore(result, OpCompositeInsert(ids.vec_ids.Get(4), ConstantNull(ids.f32_id), arg_pos, 2));
OpBranch(end_label);
AddLabel(end_label);
}
// if (ndc_z < -1.f && ndc_z > -1.00001f)
const Id test_2 =
OpLogicalAnd(ids.bool_id, OpFOrdLessThan(ids.bool_id, ndc_z, Constant(ids.f32_id, -1.0f)),
OpFOrdGreaterThan(ids.bool_id, ndc_z, Constant(ids.f32_id, -1.00001f)));
{
const Id true_label = OpLabel();
const Id end_label = OpLabel();
OpSelectionMerge(end_label, spv::SelectionControlMask::MaskNone);
OpBranchConditional(test_2, true_label, end_label);
AddLabel(true_label);
// .z = -.w;
const Id neg_w = OpFNegate(ids.f32_id, OpCompositeExtract(ids.f32_id, arg_pos, 3));
OpStore(result, OpCompositeInsert(ids.vec_ids.Get(4), neg_w, arg_pos, 2));
OpBranch(end_label);
AddLabel(end_label);
}
OpReturnValue(OpLoad(ids.vec_ids.Get(4), result));
OpFunctionEnd();
return func;
}
void VertexModule::Generate(Common::UniqueFunction<void, Sirit::Module&, const EmitterIDs&> proc) { void VertexModule::Generate(Common::UniqueFunction<void, Sirit::Module&, const EmitterIDs&> proc) {
AddLabel(OpLabel()); AddLabel(OpLabel());
proc(*this, ids); proc(*this, ids);
@ -113,7 +177,10 @@ void VertexModule::Generate(const PicaVSConfig& config, const Profile& profile)
std::vector<u32> GenerateTrivialVertexShader(bool use_clip_planes) { std::vector<u32> GenerateTrivialVertexShader(bool use_clip_planes) {
VertexModule module; VertexModule module;
module.Generate([](Sirit::Module& code, const VertexModule::EmitterIDs& ids) -> void { module.Generate([](Sirit::Module& code, const VertexModule::EmitterIDs& ids) -> void {
code.OpStore(ids.gl_position, code.OpLoad(ids.vec_ids.Get(4), ids.vert_in_position_id)); const Id pos_sanitized =
code.OpFunctionCall(ids.vec_ids.Get(4), ids.sanitize_vertex,
code.OpLoad(ids.vec_ids.Get(4), ids.vert_in_position_id));
code.OpStore(ids.gl_position, pos_sanitized);
// Negate Z // Negate Z
const Id pos_z = code.OpAccessChain(code.TypePointer(spv::StorageClass::Output, ids.f32_id), const Id pos_z = code.OpAccessChain(code.TypePointer(spv::StorageClass::Output, ids.f32_id),

View File

@ -66,6 +66,8 @@ private:
void DefineEntryPoint(); void DefineEntryPoint();
void DefineInterface(); void DefineInterface();
Id WriteFuncSanitizeVertex();
public: public:
struct EmitterIDs { struct EmitterIDs {
Id void_id{}; Id void_id{};
@ -100,6 +102,9 @@ public:
// Built-ins // Built-ins
Id gl_position; Id gl_position;
// Functions
Id sanitize_vertex;
} ids; } ids;
/// Generate code using the provided SPIRV emitter context /// Generate code using the provided SPIRV emitter context