spirv: Fix PointCoord input using vec4 instead of vec2 default tip
authorMartin Krošlák <kroslakma@gmail.com>
Wed, 11 Nov 2020 15:09:32 -0500
changeset 1318 ff4eb6d9c9c2
parent 1317 752092c8f284
spirv: Fix PointCoord input using vec4 instead of vec2
mojoshader_common.c
mojoshader_internal.h
profiles/mojoshader_profile_spirv.c
profiles/mojoshader_profile_spirv.h
--- a/mojoshader_common.c	Tue Nov 10 19:17:45 2020 -0500
+++ b/mojoshader_common.c	Wed Nov 11 15:09:32 2020 -0500
@@ -1135,13 +1135,17 @@
     {
         if (vTable->attrib_offsets[MOJOSHADER_USAGE_POINTSIZE][0] > 0)
         {
+            ((uint32 *) pixel->output)[pTable->pointcoord_var_offset + 1] = pTable->tid_pvec2i;
+            ((uint32 *) pixel->output)[pTable->pointcoord_load_offset + 1] = pTable->tid_vec2;
             ((uint32 *) pixel->output)[texcoord0Loc - 1] = SpvDecorationBuiltIn;
             ((uint32 *) pixel->output)[texcoord0Loc] = SpvBuiltInPointCoord;
         } // if
         else
         {
+            ((uint32 *) pixel->output)[pTable->pointcoord_var_offset + 1] = pTable->tid_pvec4i;
+            ((uint32 *) pixel->output)[pTable->pointcoord_load_offset + 1] = pTable->tid_vec4;
+            ((uint32 *) pixel->output)[texcoord0Loc - 1] = SpvDecorationLocation;
             // texcoord0Loc should already have attr_loc from the above work!
-            ((uint32 *) pixel->output)[texcoord0Loc - 1] = SpvDecorationLocation;
         } // else
     } // if
 } // MOJOSHADER_spirv_link_attributes
--- a/mojoshader_internal.h	Tue Nov 10 19:17:45 2020 -0500
+++ b/mojoshader_internal.h	Wed Nov 11 15:09:32 2020 -0500
@@ -753,6 +753,16 @@
     SpirvPatchEntry samplers[16];
     int32 location_count;
 
+    // TEXCOORD0 is patched to PointCoord if VS outputs PointSize.
+    // In `helpers`: [OpDecorate|id|Location|0xDEADBEEF] -> [OpDecorate|id|BuiltIn|PointCoord]
+    // Offset derived from attrib_offsets[TEXCOORD][0].
+    uint32 pointcoord_var_offset; // in `mainline_intro`, [OpVariable|tid|id|StorageClass], patch tid to pvec2i
+    uint32 pointcoord_load_offset; // in `mainline_top`, [OpLoad|tid|id|src_id], patch tid to vec2
+    uint32 tid_pvec2i;
+    uint32 tid_vec2;
+    uint32 tid_pvec4i;
+    uint32 tid_vec4;
+
     // Patches for linking vertex output/pixel input
     uint32 attrib_offsets[MOJOSHADER_USAGE_TOTAL][16];
     uint32 output_offsets[16];
--- a/profiles/mojoshader_profile_spirv.c	Tue Nov 10 19:17:45 2020 -0500
+++ b/profiles/mojoshader_profile_spirv.c	Wed Nov 11 15:09:32 2020 -0500
@@ -403,6 +403,12 @@
         tid = spv_bumpid(ctx);
         spv_emit(ctx, 4, SpvOpTypePointer, tid, SpvStorageClassUniformConstant, tid_image);
     } // else if
+    else if (tidx == STI_PTR_VEC2_I)
+    {
+        uint32 tid_base = spv_get_type(ctx, STI_VEC2);
+        tid = spv_bumpid(ctx);
+        spv_emit(ctx, 4, SpvOpTypePointer, tid, SpvStorageClassInput, tid_base);
+    } // else if
     else
         assert(!"Unexpected value of type index.");
     pop_output(ctx);
@@ -2281,11 +2287,56 @@
                 // ps_1_1 is dealt with in emit_SPIRV_global().
                 if (usage != MOJOSHADER_USAGE_TEXCOORD || shader_version_atleast(ctx, 1, 4))
                 {
-                    spv_link_ps_attributes(ctx, r->spirv.iddecl, regtype, usage, index);
-                    push_output(ctx, &ctx->mainline_intro);
-                    tid = spv_get_type(ctx, STI_PTR_VEC4_I);
-                    spv_emit(ctx, 4, SpvOpVariable, tid, r->spirv.iddecl, SpvStorageClassInput);
-                    pop_output(ctx);
+                    if (usage == MOJOSHADER_USAGE_TEXCOORD && index == 0)
+                    {
+                        // This can be either BuiltInPointCoord (vec2) or normal TEXCOORD0 input (vec4).
+                        // To determine correct type, we need to wait until link-time when we can see
+                        // vertex shader outputs and then patch in correct types. To avoid having to
+                        // fix all loads from the input variable, we never access it directly, but
+                        // instead go through private variable that is always vec4.
+                        // Here we generate input and private variables and helper code that gets
+                        // patched at link-time. See SpirvPatchTable for details on patching.
+                        SpirvPatchTable* table = &ctx->spirv.patch_table;
+
+                        uint32 tid_pvec2i = spv_get_type(ctx, STI_PTR_VEC2_I);
+                        uint32 tid_pvec4i = spv_get_type(ctx, STI_PTR_VEC4_I);
+                        uint32 tid_pvec4p = spv_get_type(ctx, STI_PTR_VEC4_P);
+                        uint32 tid_vec2 = spv_get_type(ctx, STI_VEC2);
+                        uint32 tid_vec4 = spv_get_type(ctx, STI_VEC4);
+
+                        table->tid_pvec2i = tid_pvec2i;
+                        table->tid_vec2 = tid_vec2;
+                        table->tid_pvec4i = tid_pvec4i;
+                        table->tid_vec4 = tid_vec4;
+
+                        push_output(ctx, &ctx->mainline_intro);
+                        ctx->spirv.id_var_texcoord0_private = r->spirv.iddecl;
+                        ctx->spirv.id_var_texcoord0_input = spv_bumpid(ctx);
+                        table->pointcoord_var_offset = buffer_size(ctx->mainline_intro) >> 2;
+                        spv_emit(ctx, 4, SpvOpVariable, tid_pvec4i, ctx->spirv.id_var_texcoord0_input, SpvStorageClassInput);
+                        spv_emit(ctx, 4, SpvOpVariable, tid_pvec4p, ctx->spirv.id_var_texcoord0_private, SpvStorageClassPrivate);
+                        pop_output(ctx);
+
+                        spv_link_ps_attributes(ctx, ctx->spirv.id_var_texcoord0_input, regtype, usage, index);
+                        spv_output_name(ctx, ctx->spirv.id_var_texcoord0_input, "ps_PointCoordOrTexCoord0");
+
+                        push_output(ctx, &ctx->mainline_top);
+                        uint32 id_loaded = spv_bumpid(ctx);
+                        uint32 id_shuffled = spv_bumpid(ctx);
+                        table->pointcoord_load_offset = buffer_size(ctx->mainline_top) >> 2;
+                        spv_emit(ctx, 4, SpvOpLoad, tid_vec4, id_loaded, ctx->spirv.id_var_texcoord0_input);
+                        spv_emit(ctx, 9, SpvOpVectorShuffle, tid_vec4, id_shuffled, id_loaded, id_loaded, 0, 1, 2, 3);
+                        spv_emit(ctx, 3, SpvOpStore, ctx->spirv.id_var_texcoord0_private, id_shuffled);
+                        pop_output(ctx);
+                    } // if
+                    else
+                    {
+                        spv_link_ps_attributes(ctx, r->spirv.iddecl, regtype, usage, index);
+                        push_output(ctx, &ctx->mainline_intro);
+                        tid = spv_get_type(ctx, STI_PTR_VEC4_I);
+                        spv_emit(ctx, 4, SpvOpVariable, tid, r->spirv.iddecl, SpvStorageClassInput);
+                        pop_output(ctx);
+                    } // else
                 } // if
                 break;
             default:
@@ -2517,6 +2568,8 @@
                 spv_emit_word(ctx, ctx->spirv.id_var_fragcoord);
             else if (r->spirv.iddecl == ctx->spirv.id_var_vface)
                 spv_emit_word(ctx, ctx->spirv.id_var_frontfacing);
+            else if (r->spirv.iddecl == ctx->spirv.id_var_texcoord0_private)
+                spv_emit_word(ctx, ctx->spirv.id_var_texcoord0_input);
             else
                 spv_emit_word(ctx, r->spirv.iddecl);
         } // if
@@ -2630,6 +2683,22 @@
         if (table->output_offsets[i])
             table->output_offsets[i] += base_offset;
 
+    base_offset <<= 2;
+    if (ctx->helpers)     base_offset += buffer_size(ctx->helpers);
+    if (ctx->subroutines) base_offset += buffer_size(ctx->subroutines);
+    base_offset >>= 2;
+
+    if (table->pointcoord_var_offset)
+        table->pointcoord_var_offset += base_offset;
+
+    base_offset <<= 2;
+    if (ctx->mainline_intro)     base_offset += buffer_size(ctx->mainline_intro);
+    if (ctx->mainline_arguments) base_offset += buffer_size(ctx->mainline_arguments);
+    base_offset >>= 2;
+
+    if (table->pointcoord_load_offset)
+        table->pointcoord_load_offset += base_offset;
+
     push_output(ctx, &ctx->postflight);
     buffer_append(ctx->output, &ctx->spirv.patch_table, sizeof(ctx->spirv.patch_table));
     pop_output(ctx);
--- a/profiles/mojoshader_profile_spirv.h	Tue Nov 10 19:17:45 2020 -0500
+++ b/profiles/mojoshader_profile_spirv.h	Wed Nov 11 15:09:32 2020 -0500
@@ -75,8 +75,9 @@
     STI_PTR_IMAGE2D   = 6,
     STI_PTR_IMAGE3D   = 7,
     STI_PTR_IMAGECUBE = 8,
+    STI_PTR_VEC2_I    = 9, // special case, needed only for point coord input.
 
-    // 7 unused entries
+    // 6 unused entries
 
     // 4 base types * 4 vector sizes = 16 entries
     STI_FLOAT = (0 << 5) | (1 << 4) | (ST_FLOAT << 2) | 0,
@@ -175,6 +176,8 @@
     uint32 id_var_vpos;
     uint32 id_var_frontfacing;
     uint32 id_var_vface;
+    uint32 id_var_texcoord0_input;
+    uint32 id_var_texcoord0_private;
     // ids for types so we can reuse them after they're declared
     uint32 tid[STI_LENGTH_];
     uint32 idtrue;