diff --git a/src/video_core/shader/generator/spv_shader_gen.cpp b/src/video_core/shader/generator/spv_shader_gen.cpp index edf518e7a..03383d7b1 100644 --- a/src/video_core/shader/generator/spv_shader_gen.cpp +++ b/src/video_core/shader/generator/spv_shader_gen.cpp @@ -1,3 +1,4 @@ +#include "spv_shader_gen.h" // Copyright 2024 Citra Emulator Project // Licensed under GPLv2 or any later version // Refer to the license.txt file included. @@ -16,6 +17,9 @@ constexpr u32 SPIRV_VERSION_1_3 = 0x00010300; VertexModule::VertexModule() : Sirit::Module{SPIRV_VERSION_1_3} { DefineArithmeticTypes(); DefineInterface(); + + ids.sanitize_vertex = WriteFuncSanitizeVertex(); + DefineEntryPoint(); } @@ -97,6 +101,66 @@ void VertexModule::DefineInterface() { Decorate(ids.gl_position, spv::Decoration::BuiltIn, spv::BuiltIn::Position); } +Id VertexModule::WriteFuncSanitizeVertex() { + const Id func_type = TypeFunction(ids.vec_ids.Get(4), ids.vec_ids.Get(4)); + const Id func = Name(OpFunction(ids.vec_ids.Get(4), spv::FunctionControlMask::MaskNone, func_type), "SanitizeVertex"); + const Id arg_pos = OpFunctionParameter(ids.vec_ids.Get(4)); + + AddLabel(OpLabel()); + + const Id result = AddLocalVariable(TypePointer(spv::StorageClass::Function, ids.vec_ids.Get(4)), spv::StorageClass::Function); + OpStore(result, arg_pos); + + const Id pos_z = OpCompositeExtract(ids.f32_id, arg_pos, 2); + const Id pos_w = OpCompositeExtract(ids.f32_id, arg_pos, 3); + + const Id ndc_z = OpFDiv(ids.f32_id, pos_z, pos_w); + + // if (ndc_z > 0.f && ndc_z < 0.000001f) + const Id test_1 = + OpLogicalAnd(ids.bool_id, OpFOrdGreaterThan(ids.bool_id, ndc_z, Constant(ids.f32_id, 0.0f)), + OpFOrdLessThan(ids.bool_id, ndc_z, Constant(ids.f32_id, 0.000001f))); + + { + const Id true_label = OpLabel(); + const Id end_label = OpLabel(); + + OpSelectionMerge(end_label, spv::SelectionControlMask::MaskNone); + OpBranchConditional(test_1, true_label, end_label); + AddLabel(true_label); + + // .z = 0.0f; + OpStore(result, OpCompositeInsert(ids.vec_ids.Get(4), ConstantNull(ids.f32_id), arg_pos, 2)); + + OpBranch(end_label); + AddLabel(end_label); + } + + // if (ndc_z < -1.f && ndc_z > -1.00001f) + const Id test_2 = + OpLogicalAnd(ids.bool_id, OpFOrdLessThan(ids.bool_id, ndc_z, Constant(ids.f32_id, -1.0f)), + OpFOrdGreaterThan(ids.bool_id, ndc_z, Constant(ids.f32_id, -1.00001f))); + { + const Id true_label = OpLabel(); + const Id end_label = OpLabel(); + + OpSelectionMerge(end_label, spv::SelectionControlMask::MaskNone); + OpBranchConditional(test_2, true_label, end_label); + AddLabel(true_label); + + // .z = -.w; + const Id neg_w = OpFNegate(ids.f32_id, OpCompositeExtract(ids.f32_id, arg_pos, 3)); + OpStore(result, OpCompositeInsert(ids.vec_ids.Get(4), neg_w, arg_pos, 2)); + + OpBranch(end_label); + AddLabel(end_label); + } + + OpReturnValue(OpLoad(ids.vec_ids.Get(4), result)); + OpFunctionEnd(); + return func; +} + void VertexModule::Generate(Common::UniqueFunction proc) { AddLabel(OpLabel()); proc(*this, ids); @@ -113,7 +177,10 @@ void VertexModule::Generate(const PicaVSConfig& config, const Profile& profile) std::vector GenerateTrivialVertexShader(bool use_clip_planes) { VertexModule module; module.Generate([](Sirit::Module& code, const VertexModule::EmitterIDs& ids) -> void { - code.OpStore(ids.gl_position, code.OpLoad(ids.vec_ids.Get(4), ids.vert_in_position_id)); + const Id pos_sanitized = + code.OpFunctionCall(ids.vec_ids.Get(4), ids.sanitize_vertex, + code.OpLoad(ids.vec_ids.Get(4), ids.vert_in_position_id)); + code.OpStore(ids.gl_position, pos_sanitized); // Negate Z const Id pos_z = code.OpAccessChain(code.TypePointer(spv::StorageClass::Output, ids.f32_id), diff --git a/src/video_core/shader/generator/spv_shader_gen.h b/src/video_core/shader/generator/spv_shader_gen.h index 419051283..a5f9c77aa 100644 --- a/src/video_core/shader/generator/spv_shader_gen.h +++ b/src/video_core/shader/generator/spv_shader_gen.h @@ -66,6 +66,8 @@ private: void DefineEntryPoint(); void DefineInterface(); + Id WriteFuncSanitizeVertex(); + public: struct EmitterIDs { Id void_id{}; @@ -100,6 +102,9 @@ public: // Built-ins Id gl_position; + + // Functions + Id sanitize_vertex; } ids; /// Generate code using the provided SPIRV emitter context