Add dynamic linking support for SPIR-V modules
authorEthan Lee <flibitijibibo@flibitijibibo.com>
Tue, 07 Jul 2020 17:19:45 -0400
changeset 1277 da61410edbc9
parent 1276 89c389e4112f
child 1278 be0f548f321f
Add dynamic linking support for SPIR-V modules
mojoshader.h
mojoshader_common.c
mojoshader_internal.h
mojoshader_opengl.c
mojoshader_vulkan.c
profiles/mojoshader_profile_spirv.c
--- a/mojoshader.h	Mon Jul 06 16:23:06 2020 -0400
+++ b/mojoshader.h	Tue Jul 07 17:19:45 2020 -0400
@@ -3558,6 +3558,7 @@
 VK_DEFINE_HANDLE(VkDevice)
 VK_DEFINE_HANDLE(VkPhysicalDevice)
 VK_DEFINE_NON_DISPATCHABLE_HANDLE(VkBuffer)
+VK_DEFINE_NON_DISPATCHABLE_HANDLE(VkShaderModule)
 
 #endif /* !NO_MOJOSHADER_VULKAN_TYPEDEFS */
 
@@ -3573,6 +3574,7 @@
 
 typedef struct MOJOSHADER_vkContext MOJOSHADER_vkContext;
 typedef struct MOJOSHADER_vkShader MOJOSHADER_vkShader;
+typedef struct MOJOSHADER_vkProgram MOJOSHADER_vkProgram;
 
 /*
  * Prepares a context to manage Vulkan shaders.
@@ -3671,7 +3673,7 @@
  * This function destroys the MOJOSHADER_vkContext you pass it. If it's the
  *  current context, then no context will be current upon return.
  */
-DECLSPEC void MOJOSHADER_vkDestroyContext();
+DECLSPEC void MOJOSHADER_vkDestroyContext(MOJOSHADER_vkContext *ctx);
 
 /*
  * Compile a buffer of Direct3D shader bytecode into a Vulkan shader module.
@@ -3728,6 +3730,53 @@
                                                 MOJOSHADER_vkShader *shader);
 
 /*
+ * Link a vertex and pixel shader into a working Vulkan shader program.
+ *  (vshader) or (pshader) can NOT be NULL, unlike OpenGL.
+ *
+ * You can reuse shaders in various combinations across
+ *  multiple programs, by relinking different pairs.
+ *
+ * It is illegal to give a vertex shader for (pshader) or a pixel shader
+ *  for (vshader).
+ *
+ * Once you have successfully linked a program, you may render with it.
+ *
+ * Returns NULL on error, or a program handle on success.
+ *
+ * This call requires a valid MOJOSHADER_vkContext to have been made current,
+ *  or it will crash your program. See MOJOSHADER_vkMakeContextCurrent().
+ */
+DECLSPEC MOJOSHADER_vkProgram *MOJOSHADER_vkLinkProgram(MOJOSHADER_vkShader *vshader,
+                                                        MOJOSHADER_vkShader *pshader);
+
+/*
+ * This binds the program to the active context, and does nothing particularly
+ * special until you start working with uniform buffers or shader modules.
+ *
+ * After binding a program, you should update any uniforms you care about
+ *  with MOJOSHADER_vkMapUniformBufferMemory() (etc), set any vertex arrays
+ *  using MOJOSHADER_vkGetVertexAttribLocation(), and finally call
+ *  MOJOSHADER_vkGetShaderModules() to get the final modules. Then you may
+ *  begin building your pipeline state objects.
+ *
+ * This call requires a valid MOJOSHADER_vkContext to have been made current,
+ *  or it will crash your program. See MOJOSHADER_vkMakeContextCurrent().
+ */
+DECLSPEC void MOJOSHADER_vkBindProgram(MOJOSHADER_vkProgram *program);
+
+/*
+ * Free the resources of a linked program. This will delete the shader modules
+ *  and free memory.
+ *
+ * If the program is currently bound by MOJOSHADER_vkBindProgram(), it will
+ *  be deleted as soon as it becomes unbound.
+ *
+ * This call requires a valid MOJOSHADER_vkContext to have been made current,
+ *  or it will crash your program. See MOJOSHADER_vkMakeContextCurrent().
+ */
+DECLSPEC void MOJOSHADER_vkDeleteProgram(MOJOSHADER_vkProgram *program);
+
+/*
  * This "binds" individual shaders, which effectively means the context
  *  will store these shaders for later retrieval. No actual binding or
  *  pipeline creation is performed.
@@ -3816,10 +3865,10 @@
                                                   int index);
 
 /*
- * Get the VkShaderModule from the given MOJOSHADER_vkShader.
+ * Get the VkShaderModules from the currently bound shader program.
  */
-DECLSPEC unsigned long long MOJOSHADER_vkGetShaderModule(
-                                                MOJOSHADER_vkShader *shader);
+DECLSPEC void MOJOSHADER_vkGetShaderModules(VkShaderModule *vmodule,
+                                            VkShaderModule *pmodule);
 
 /* D3D11 interface... */
 
--- a/mojoshader_common.c	Mon Jul 06 16:23:06 2020 -0400
+++ b/mojoshader_common.c	Tue Jul 07 17:19:45 2020 -0400
@@ -1055,5 +1055,68 @@
     return (text - textstart);
 } // MOJOSHADER_printFloat
 
+#if SUPPORT_PROFILE_SPIRV
+#include "spirv/spirv.h"
+#include "spirv/GLSL.std.450.h"
+void MOJOSHADER_spirv_link_attributes(const MOJOSHADER_parseData *vertex,
+                                      const MOJOSHADER_parseData *pixel)
+{
+    int i;
+    uint32 attr_loc = 1; // 0 is reserved for COLOR0
+    uint32 vOffset, pOffset;
+    int vDataLen = vertex->output_len - sizeof(SpirvPatchTable);
+    int pDataLen = pixel->output_len - sizeof(SpirvPatchTable);
+    SpirvPatchTable *vTable = (SpirvPatchTable *) &vertex->output[vDataLen];
+    SpirvPatchTable *pTable = (SpirvPatchTable *) &pixel->output[pDataLen];
+    const uint32 texcoord0Loc = pTable->attrib_offsets[MOJOSHADER_USAGE_TEXCOORD][0];
+
+    for (i = 0; i < pixel->attribute_count; i++)
+    {
+        const MOJOSHADER_attribute *pAttr = &pixel->attributes[i];
+        if (pAttr->usage == MOJOSHADER_USAGE_COLOR && pAttr->index == 0)
+            continue;
+
+        // The input may not exist in the output list!
+        pOffset = pTable->attrib_offsets[pAttr->usage][pAttr->index];
+        vOffset = vTable->attrib_offsets[pAttr->usage][pAttr->index];
+        ((uint32 *) pixel->output)[pOffset] = attr_loc;
+        if (vOffset)
+            ((uint32 *) vertex->output)[vOffset] = attr_loc;
+        attr_loc++;
+    } // for
+
+    // There may be outputs not present in the input list!
+    for (i = 0; i < vertex->output_count; i++)
+    {
+        const MOJOSHADER_attribute *vAttr = &vertex->outputs[i];
+        if (vAttr->usage == MOJOSHADER_USAGE_POSITION && vAttr->index == 0)
+            continue;
+        if (vAttr->usage == MOJOSHADER_USAGE_COLOR && vAttr->index == 0)
+            continue;
+
+        if (!pTable->attrib_offsets[vAttr->usage][vAttr->index])
+        {
+            vOffset = vTable->attrib_offsets[vAttr->usage][vAttr->index];
+            ((uint32 *) vertex->output)[vOffset] = attr_loc++;
+        } // if
+    } // while
+
+    // gl_PointCoord support
+    if (texcoord0Loc)
+    {
+        if (vTable->attrib_offsets[MOJOSHADER_USAGE_POINTSIZE][0] > 0)
+        {
+            ((uint32 *) pixel->output)[texcoord0Loc - 1] = SpvDecorationBuiltIn;
+            ((uint32 *) pixel->output)[texcoord0Loc] = SpvBuiltInPointCoord;
+        } // if
+        else
+        {
+            // texcoord0Loc should already have attr_loc from the above work!
+            ((uint32 *) pixel->output)[texcoord0Loc - 1] = SpvDecorationLocation;
+        } // else
+    } // if
+} // MOJOSHADER_spirv_link_attributes
+#endif
+
 // end of mojoshader_common.c ...
 
--- a/mojoshader_internal.h	Mon Jul 06 16:23:06 2020 -0400
+++ b/mojoshader_internal.h	Tue Jul 07 17:19:45 2020 -0400
@@ -743,23 +743,20 @@
 
 typedef struct SpirvPatchTable
 {
+    // Patches for uniforms
     SpirvPatchEntry vpflip;
     SpirvPatchEntry array_vec4;
     SpirvPatchEntry array_ivec4;
     SpirvPatchEntry array_bool;
     SpirvPatchEntry samplers[16];
     int32 location_count;
-    union
-    {
-        // VS only; non-0 when there is PSIZE output
-        uint32 vs_has_psize;
 
-        // PS only; offset to TEXCOORD0 location part of OpDecorate.
-        // Used to find OpDecorate and patch it to BuiltIn PointCoord when
-        // VS outputs PSIZE.
-        uint32 ps_texcoord0_offset;
-    };
+    // Patches for linking vertex output/pixel input
+    uint32 attrib_offsets[MOJOSHADER_USAGE_TOTAL][16];
 } SpirvPatchTable;
+
+void MOJOSHADER_spirv_link_attributes(const MOJOSHADER_parseData *vertex,
+                                      const MOJOSHADER_parseData *pixel);
 #endif
 
 #endif  // _INCLUDE_MOJOSHADER_INTERNAL_H_
--- a/mojoshader_opengl.c	Mon Jul 06 16:23:06 2020 -0400
+++ b/mojoshader_opengl.c	Tue Jul 07 17:19:45 2020 -0400
@@ -448,14 +448,14 @@
     return (const SpirvPatchTable *) (pd->output + table_offset);
 } // spv_getPatchTable
 
-static int spv_CompileShader(const MOJOSHADER_parseData *pd, int32 base_location, GLuint *s, int32 patch_pcoord)
+static int spv_CompileShader(const MOJOSHADER_parseData *pd, int32 base_location, GLuint *s)
 {
     GLint ok = 0;
 
     GLsizei data_len = pd->output_len - sizeof(SpirvPatchTable);
     const GLvoid* data = pd->output;
     uint32 *patched_data = NULL;
-    if (base_location || patch_pcoord)
+    if (base_location)
     {
         size_t i, max;
 
@@ -474,16 +474,6 @@
                 patched_data[entry.offset] += base_location;
         } // for
 
-        if (patch_pcoord && table->ps_texcoord0_offset)
-        {
-            // Subtract 3 to get from Location value offset to start of op.
-            uint32 op_base = table->ps_texcoord0_offset - 3;
-            assert(patched_data[op_base+0] == (SpvOpDecorate | (4 << 16)));
-            assert(patched_data[op_base+2] == SpvDecorationLocation);
-            patched_data[op_base+2] = SpvDecorationBuiltIn;
-            patched_data[op_base+3] = SpvBuiltInPointCoord;
-        } // if
-
         data = patched_data;
     } // if
 
@@ -521,27 +511,29 @@
                                      MOJOSHADER_glShader *pshader)
 {
     GLint ok = 0;
-
-    // Shader compilation postponed until linking due to uniform locations being global in program.
-    // To avoid overlap between VS and PS, we need to know about other shader stages to assign final
-    // uniform locations before compilation.
     GLuint vs_handle = 0;
     int32 base_location = 0;
-    int32 patch_pcoord = 0;
+
+    // Shader compilation postponed until linking due to locations being global
+    // in program. To avoid overlap between VS and PS, we need to know about
+    // other shader stages to assign final uniform/attrib locations before
+    // compilation.
+
+    MOJOSHADER_spirv_link_attributes(vshader->parseData, pshader->parseData);
+
     if (vshader)
     {
-        if (!spv_CompileShader(vshader->parseData, base_location, &vs_handle, patch_pcoord))
+        if (!spv_CompileShader(vshader->parseData, base_location, &vs_handle))
             return 0;
 
         const SpirvPatchTable* patch_table = spv_getPatchTable(vshader);
         base_location += patch_table->location_count;
-        patch_pcoord = patch_table->vs_has_psize;
     } // if
 
     GLuint ps_handle = 0;
     if (pshader)
     {
-        if (!spv_CompileShader(pshader->parseData, base_location, &ps_handle, patch_pcoord))
+        if (!spv_CompileShader(pshader->parseData, base_location, &ps_handle))
             return 0;
     } // if
 
--- a/mojoshader_vulkan.c	Mon Jul 06 16:23:06 2020 -0400
+++ b/mojoshader_vulkan.c	Tue Jul 07 17:19:45 2020 -0400
@@ -26,11 +26,19 @@
 
 typedef struct MOJOSHADER_vkShader
 {
-    VkShaderModule shaderModule;
     const MOJOSHADER_parseData *parseData;
+    uint16_t tag;
     uint32_t refcount;
 } MOJOSHADER_vkShader;
 
+typedef struct MOJOSHADER_vkProgram
+{
+    VkShaderModule vertexModule;
+    VkShaderModule pixelModule;
+    MOJOSHADER_vkShader *vertexShader;
+    MOJOSHADER_vkShader *pixelShader;
+} MOJOSHADER_vkProgram;
+
 typedef struct MOJOSHADER_vkUniformBuffer
 {
     VkBuffer buffer;
@@ -89,8 +97,8 @@
     MOJOSHADER_vkUniformBuffer *vertUboBuffer;
     MOJOSHADER_vkUniformBuffer *fragUboBuffer;
 
-    MOJOSHADER_vkShader *vertexShader;
-    MOJOSHADER_vkShader *pixelShader;
+    MOJOSHADER_vkProgram *bound_program;
+    HashTable *linker_cache;
 
     #define VULKAN_INSTANCE_FUNCTION(ret, func, params) \
         vkfntype_MOJOSHADER_##func func;
@@ -100,6 +108,7 @@
 } MOJOSHADER_vkContext;
 
 static MOJOSHADER_vkContext *ctx = NULL;
+static uint16_t tagCounter = 1;
 
 static uint8_t find_memory_type(
     MOJOSHADER_vkContext *ctx,
@@ -352,15 +361,81 @@
     return shader->parseData->output_len - sizeof(SpirvPatchTable);
 } // shader_bytecode_len
 
-static void delete_shader(
-    VkShaderModule shaderModule
-) {
-    ctx->vkDestroyShaderModule(
+static VkShaderModule compile_shader(MOJOSHADER_vkShader *shader)
+{
+    VkResult result;
+    VkShaderModule module;
+    VkShaderModuleCreateInfo shaderModuleCreateInfo =
+    {
+        VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO
+    };
+
+    shaderModuleCreateInfo.flags = 0;
+    shaderModuleCreateInfo.codeSize = shader_bytecode_len(shader);
+    shaderModuleCreateInfo.pCode = (uint32_t*) shader->parseData->output;
+
+    result = ctx->vkCreateShaderModule(
         *ctx->logical_device,
-        shaderModule,
-        NULL
+        &shaderModuleCreateInfo,
+        NULL,
+        &module
     );
-} // delete_shader
+
+    if (result != VK_SUCCESS)
+    {
+        // FIXME: should display VK error code
+        set_error("Error when creating VkShaderModule");
+        ctx->vkDestroyShaderModule(
+            *ctx->logical_device,
+            module,
+            NULL
+        );
+        return VK_NULL_HANDLE;
+    } // if
+
+    return module;
+} // compile_shader
+
+typedef struct
+{
+    MOJOSHADER_vkShader *vertex;
+    MOJOSHADER_vkShader *fragment;
+} BoundShaders;
+
+static uint32_t hash_shaders(const void *sym, void *data)
+{
+    (void) data;
+    const BoundShaders *s = (const BoundShaders *) sym;
+    const uint16_t v = (s->vertex) ? s->vertex->tag : 0;
+    const uint16_t f = (s->fragment) ? s->fragment->tag : 0;
+    return ((uint32_t) v << 16) | (uint32_t) f;
+} // hash_shaders
+
+static int match_shaders(const void *_a, const void *_b, void *data)
+{
+    (void) data;
+    const BoundShaders *a = (const BoundShaders *) _a;
+    const BoundShaders *b = (const BoundShaders *) _b;
+
+    const uint16_t av = (a->vertex) ? a->vertex->tag : 0;
+    const uint16_t bv = (b->vertex) ? b->vertex->tag : 0;
+    if (av != bv)
+        return 0;
+
+    const uint16_t af = (a->fragment) ? a->fragment->tag : 0;
+    const uint16_t bf = (b->fragment) ? b->fragment->tag : 0;
+    if (af != bf)
+        return 0;
+
+    return 1;
+} // match_shaders
+
+static void nuke_shaders(const void *key, const void *value, void *data)
+{
+    (void) data;
+    ctx->free_fn((void *) key, ctx->malloc_data); // this was a BoundShaders struct.
+    MOJOSHADER_vkDeleteProgram((MOJOSHADER_vkProgram *) value);
+} // nuke_shaders
 
 // Public API
 
@@ -422,8 +497,15 @@
     ctx = _ctx;
 } // MOJOSHADER_vkMakeContextCurrent
 
-void MOJOSHADER_vkDestroyContext()
+void MOJOSHADER_vkDestroyContext(MOJOSHADER_vkContext *_ctx)
 {
+    MOJOSHADER_vkContext *current_ctx = ctx;
+    ctx = _ctx;
+
+    MOJOSHADER_vkBindProgram(NULL);
+    if (ctx->linker_cache)
+        hash_destroy(ctx->linker_cache);
+
     ctx->vkDestroyBuffer(*ctx->logical_device,
                          ctx->vertUboBuffer->buffer,
                          NULL);
@@ -444,6 +526,8 @@
     ctx->free_fn(ctx->fragUboBuffer, ctx->malloc_data);
 
     ctx->free_fn(ctx, ctx->malloc_data);
+
+    ctx = ((current_ctx == _ctx) ? NULL : current_ctx);
 } // MOJOSHADER_vkDestroyContext
 
 MOJOSHADER_vkShader *MOJOSHADER_vkCompileShader(
@@ -455,11 +539,6 @@
     const MOJOSHADER_samplerMap *smap,
     const unsigned int smapcount
 ) {
-    VkResult result;
-    VkShaderModuleCreateInfo shaderModuleCreateInfo =
-    {
-        VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO
-    };
     MOJOSHADER_vkShader *shader;
 
     const MOJOSHADER_parseData *pd = MOJOSHADER_parse(
@@ -475,49 +554,27 @@
     if (pd->error_count > 0)
     {
         set_error(pd->errors[0].error);
-        goto compile_shader_fail;
+        goto parse_shader_fail;
     } // if
 
     shader = (MOJOSHADER_vkShader *) ctx->malloc_fn(sizeof(MOJOSHADER_vkShader), ctx->malloc_data);
     if (shader == NULL)
     {
         out_of_memory();
-        goto compile_shader_fail;
+        goto parse_shader_fail;
     } // if
 
     shader->parseData = pd;
     shader->refcount = 1;
-
-    shaderModuleCreateInfo.flags = 0;
-    shaderModuleCreateInfo.codeSize = shader_bytecode_len(shader);
-    shaderModuleCreateInfo.pCode = (uint32_t*) pd->output;
-
-    result = ctx->vkCreateShaderModule(
-        *ctx->logical_device,
-        &shaderModuleCreateInfo,
-        NULL,
-        &shader->shaderModule
-    );
-
-    if (result != VK_SUCCESS)
-    {
-        // FIXME: should display VK error code
-        set_error("Error when creating VkShaderModule");
-        goto compile_shader_fail;
-    } // if
-
+    shader->tag = tagCounter++;
     return shader;
 
-compile_shader_fail:
+parse_shader_fail:
     MOJOSHADER_freeParseData(pd);
     if (shader != NULL)
-    {
-        delete_shader(shader->shaderModule);
         ctx->free_fn(shader, ctx->malloc_data);
-    } // if
     return NULL;
-
-} // MOJOSHADER_vkMakeContextCurrent
+} // MOJOSHADER_vkCompileShader
 
 void MOJOSHADER_vkShaderAddRef(MOJOSHADER_vkShader *shader)
 {
@@ -533,7 +590,25 @@
             shader->refcount--;
         else
         {
-            delete_shader(shader->shaderModule);
+            // See if this was bound as an unlinked program anywhere...
+            if (ctx->linker_cache)
+            {
+                const void *key = NULL;
+                void *iter = NULL;
+                int morekeys = hash_iter_keys(ctx->linker_cache, &key, &iter);
+                while (morekeys)
+                {
+                    const BoundShaders *shaders = (const BoundShaders *) key;
+                    // Do this here so we don't confuse the iteration by removing...
+                    morekeys = hash_iter_keys(ctx->linker_cache, &key, &iter);
+                    if ((shaders->vertex == shader) || (shaders->fragment == shader))
+                    {
+                        // Deletes the linked program
+                        hash_remove(ctx->linker_cache, shaders);
+                    } // if
+                } // while
+            } // if
+
             MOJOSHADER_freeParseData(shader->parseData);
             ctx->free_fn(shader, ctx->malloc_data);
         } // else
@@ -546,22 +621,119 @@
     return (shader != NULL) ? shader->parseData : NULL;
 } // MOJOSHADER_vkGetShaderParseData
 
+void MOJOSHADER_vkDeleteProgram(MOJOSHADER_vkProgram *p)
+{
+    if (p->vertexModule != VK_NULL_HANDLE)
+        ctx->vkDestroyShaderModule(*ctx->logical_device, p->vertexModule, NULL);
+    if (p->pixelModule != VK_NULL_HANDLE)
+        ctx->vkDestroyShaderModule(*ctx->logical_device, p->pixelModule, NULL);
+    ctx->free_fn(p, ctx->malloc_data);
+} // MOJOSHADER_vkDeleteProgram
+
+MOJOSHADER_vkProgram *MOJOSHADER_vkLinkProgram(MOJOSHADER_vkShader *vshader,
+                                               MOJOSHADER_vkShader *pshader)
+{
+    MOJOSHADER_vkProgram *result;
+
+    if ((vshader == NULL) && (pshader == NULL))
+        return NULL;
+
+    result = ctx->malloc_fn(sizeof (MOJOSHADER_vkProgram), ctx->malloc_data);
+    if (result == NULL)
+    {
+        out_of_memory();
+        return NULL;
+    } // if
+
+    MOJOSHADER_spirv_link_attributes(vshader->parseData, pshader->parseData);
+    result->vertexModule = compile_shader(vshader);
+    result->pixelModule = compile_shader(pshader);
+    result->vertexShader = vshader;
+    result->pixelShader = pshader;
+
+    if (result->vertexModule == VK_NULL_HANDLE
+     || result->pixelModule == VK_NULL_HANDLE)
+    {
+        MOJOSHADER_vkDeleteProgram(result);
+        return NULL;
+    }
+    return result;
+} // MOJOSHADER_vkLinkProgram
+
+void MOJOSHADER_vkBindProgram(MOJOSHADER_vkProgram *p)
+{
+    ctx->bound_program = p;
+} // MOJOSHADER_vkBindProgram
+
 void MOJOSHADER_vkBindShaders(MOJOSHADER_vkShader *vshader,
                               MOJOSHADER_vkShader *pshader)
 {
-    // NOOP if shader is null
+    if (ctx->linker_cache == NULL)
+    {
+        ctx->linker_cache = hash_create(NULL, hash_shaders, match_shaders,
+                                        nuke_shaders, 0, ctx->malloc_fn,
+                                        ctx->free_fn, ctx->malloc_data);
+
+        if (ctx->linker_cache == NULL)
+        {
+            out_of_memory();
+            return;
+        } // if
+    } // if
+
+    MOJOSHADER_vkProgram *program = NULL;
+    BoundShaders shaders;
+    shaders.vertex = vshader;
+    shaders.fragment = pshader;
 
-    if (vshader != NULL)
-        ctx->vertexShader = vshader;
-    if (pshader != NULL)
-        ctx->pixelShader = pshader;
+    const void *val = NULL;
+    if (hash_find(ctx->linker_cache, &shaders, &val))
+        program = (MOJOSHADER_vkProgram *) val;
+    else
+    {
+        program = MOJOSHADER_vkLinkProgram(vshader, pshader);
+        if (program == NULL)
+            return;
+
+        BoundShaders *item = (BoundShaders *) ctx->malloc_fn(sizeof (BoundShaders),
+                                                             ctx->malloc_data);
+        if (item == NULL)
+        {
+            MOJOSHADER_vkDeleteProgram(program);
+            return;
+        } // if
+
+        memcpy(item, &shaders, sizeof (BoundShaders));
+        if (hash_insert(ctx->linker_cache, item, program) != 1)
+        {
+            ctx->free_fn(item, ctx->malloc_data);
+            MOJOSHADER_vkDeleteProgram(program);
+            out_of_memory();
+            return;
+        } // if
+    } // else
+
+    assert(program != NULL);
+    ctx->bound_program = program;
 } // MOJOSHADER_vkBindShaders
 
 void MOJOSHADER_vkGetBoundShaders(MOJOSHADER_vkShader **vshader,
                                   MOJOSHADER_vkShader **pshader)
 {
-    *vshader = ctx->vertexShader;
-    *pshader = ctx->pixelShader;
+    if (vshader != NULL)
+    {
+        if (ctx->bound_program != NULL)
+            *vshader = ctx->bound_program->vertexShader;
+        else
+            *vshader = NULL;
+    } // if
+    if (pshader != NULL)
+    {
+        if (ctx->bound_program != NULL)
+            *pshader = ctx->bound_program->pixelShader;
+        else
+            *pshader = NULL;
+    } // if
 } // MOJOSHADER_vkGetBoundShaders
 
 void MOJOSHADER_vkMapUniformBufferMemory(float **vsf, int **vsi, unsigned char **vsb,
@@ -580,20 +752,21 @@
     /* Why is this function named unmap instead of update?
      * the world may never know...
      */
-
-    update_uniform_buffer(ctx->vertexShader);
-    update_uniform_buffer(ctx->pixelShader);
+    assert(ctx->bound_program != NULL);
+    update_uniform_buffer(ctx->bound_program->vertexShader);
+    update_uniform_buffer(ctx->bound_program->pixelShader);
 } // MOJOSHADER_vkUnmapUniformBufferMemory
 
 void MOJOSHADER_vkGetUniformBuffers(VkBuffer *vbuf, unsigned long long *voff, unsigned long long *vsize,
                                     VkBuffer *pbuf, unsigned long long *poff, unsigned long long *psize)
 {
-    *vbuf = get_uniform_buffer(ctx->vertexShader);
-    *voff = get_uniform_offset(ctx->vertexShader);
-    *vsize = get_uniform_size(ctx->vertexShader);
-    *pbuf = get_uniform_buffer(ctx->pixelShader);
-    *poff = get_uniform_offset(ctx->pixelShader);
-    *psize = get_uniform_size(ctx->pixelShader);
+    assert(ctx->bound_program != NULL);
+    *vbuf = get_uniform_buffer(ctx->bound_program->vertexShader);
+    *voff = get_uniform_offset(ctx->bound_program->vertexShader);
+    *vsize = get_uniform_size(ctx->bound_program->vertexShader);
+    *pbuf = get_uniform_buffer(ctx->bound_program->pixelShader);
+    *poff = get_uniform_offset(ctx->bound_program->pixelShader);
+    *psize = get_uniform_size(ctx->bound_program->pixelShader);
 } // MOJOSHADER_vkGetUniformBuffers
 
 void MOJOSHADER_vkEndFrame()
@@ -624,13 +797,15 @@
     return -1;
 } //MOJOSHADER_vkGetVertexAttribLocation
 
-unsigned long long MOJOSHADER_vkGetShaderModule(MOJOSHADER_vkShader *shader)
+void MOJOSHADER_vkGetShaderModules(VkShaderModule *vmodule,
+                                   VkShaderModule *pmodule)
 {
-    if (shader == NULL)
-        return 0;
-
-    return (unsigned long long) shader->shaderModule;
-} //MOJOSHADER_vkGetShaderModule
+    assert(ctx->bound_program != NULL);
+    if (vmodule != NULL)
+        *vmodule = ctx->bound_program->vertexModule;
+    if (pmodule != NULL)
+        *pmodule = ctx->bound_program->pixelModule;
+} //MOJOSHADER_vkGetShaderModules
 
 const char *MOJOSHADER_vkGetError(void)
 {
--- a/profiles/mojoshader_profile_spirv.c	Mon Jul 06 16:23:06 2020 -0400
+++ b/profiles/mojoshader_profile_spirv.c	Tue Jul 07 17:19:45 2020 -0400
@@ -271,6 +271,19 @@
     return (buffer_size(ctx->helpers) >> 2) - 1;
 } // spv_output_location
 
+static uint32 spv_output_attrib_location(Context *ctx, uint32 id,
+                                         MOJOSHADER_usage usage, uint32 index)
+{
+    uint32 result;
+    SpirvPatchTable* table = &ctx->spirv.patch_table;
+    push_output(ctx, &ctx->helpers);
+    spv_emit(ctx, 4, SpvOpDecorate, id, SpvDecorationLocation, 0xDEADBEEF);
+    pop_output(ctx);
+    result = (buffer_size(ctx->helpers) >> 2) - 1;
+    table->attrib_offsets[usage][index] = result;
+    return result;
+} // spv_output_attrib_location
+
 static void spv_output_sampler_binding(Context *ctx, uint32 id, uint32 binding)
 {
     if (isfail(ctx))
@@ -1510,65 +1523,6 @@
     pop_output(ctx);
 } // spv_emit_func_end
 
-// These are prioritized by most common to least common...
-#define SPV_OFFSET_COLOR        0
-#define SPV_LENGTH_COLOR        4 // Based on max render target count for XNA
-#define SPV_OFFSET_TEXCOORD     (SPV_OFFSET_COLOR + SPV_LENGTH_COLOR)
-#define SPV_LENGTH_TEXCOORD     16
-#define SPV_OFFSET_NORMAL       (SPV_OFFSET_TEXCOORD + SPV_LENGTH_TEXCOORD)
-#define SPV_LENGTH_NORMAL       8 // Arbitrary!
-#define SPV_OFFSET_FOG          (SPV_OFFSET_NORMAL + SPV_LENGTH_NORMAL)
-#define SPV_LENGTH_FOG          2 // Arbitrary!
-#define SPV_OFFSET_TANGENT      (SPV_OFFSET_FOG + SPV_LENGTH_FOG)
-#define SPV_LENGTH_TANGENT      1 // Arbitrary!
-#define SPV_OFFSET_BLENDINDICES (SPV_OFFSET_TANGENT + SPV_LENGTH_TANGENT)
-#define SPV_LENGTH_BLENDINDICES 4 // Arbitrary!
-#define SPV_OFFSET_POSITION     (SPV_OFFSET_BLENDINDICES + SPV_LENGTH_BLENDINDICES)
-#define SPV_LENGTH_POSITION     (4 - 1) // Arbitrary!
-#define SPV_OFFSET_PSIZE        (SPV_OFFSET_POSITION + SPV_LENGTH_POSITION)
-#define SPV_LENGTH_PSIZE        (4 - 1) // Arbitrary!
-
-static void spv_link_vs_attributes(Context *ctx, uint32 id, MOJOSHADER_usage usage, int index)
-{
-    // Some usages map to specific ranges. Keep those in sync with spv_link_ps_attributes().
-    switch (usage)
-    {
-        case MOJOSHADER_USAGE_POSITION:
-            if (index == 0)
-                spv_output_builtin(ctx, id, SpvBuiltInPosition);
-            else
-            {
-                assert(index <= SPV_LENGTH_POSITION);
-                spv_output_location(ctx, id, SPV_OFFSET_POSITION + (index - 1));
-            } // else
-            break;
-        case MOJOSHADER_USAGE_POINTSIZE:
-            if (index == 0)
-                spv_output_builtin(ctx, id, SpvBuiltInPointSize);
-            else
-            {
-                assert(index <= SPV_LENGTH_PSIZE);
-                spv_output_location(ctx, id, SPV_OFFSET_PSIZE + (index - 1));
-            }
-            break;
-        #define SPV_DECORATION_USAGE(usage) \
-            case MOJOSHADER_USAGE_##usage: \
-                assert(index < SPV_LENGTH_##usage); \
-                spv_output_location(ctx, id, SPV_OFFSET_##usage + index); \
-                break;
-        SPV_DECORATION_USAGE(COLOR)
-        SPV_DECORATION_USAGE(TEXCOORD)
-        SPV_DECORATION_USAGE(NORMAL)
-        SPV_DECORATION_USAGE(FOG)
-        SPV_DECORATION_USAGE(TANGENT)
-        SPV_DECORATION_USAGE(BLENDINDICES)
-        #undef SPV_DECORATION_USAGE
-        default:
-            failf(ctx, "unexpected attribute usage %d in vertex shader", usage);
-            break;
-    } // switch
-} // spv_link_vs_attributes
-
 static void spv_emit_vpos_glmode(Context *ctx, uint32 id)
 {
     // In SM3.0 vPos only has x and y defined, but we should be
@@ -1671,7 +1625,21 @@
     ctx->spirv.id_var_vpos = id_var_vpos;
 } // spv_emit_vpos_vkmode
 
-static void spv_link_ps_attributes(Context *ctx, uint32 id, RegisterType regtype, MOJOSHADER_usage usage, int index)
+static void spv_link_vs_attributes(Context *ctx, uint32 id,
+                                   MOJOSHADER_usage usage, int index)
+{
+    if (usage == MOJOSHADER_USAGE_POSITION && index == 0)
+        spv_output_builtin(ctx, id, SpvBuiltInPosition);
+    else if (usage == MOJOSHADER_USAGE_POINTSIZE && index == 0)
+        spv_output_builtin(ctx, id, SpvBuiltInPointSize);
+    else if (usage == MOJOSHADER_USAGE_COLOR && index == 0)
+        spv_output_location(ctx, id, 0);
+    else
+        spv_output_attrib_location(ctx, id, usage, index);
+} // spv_link_vs_attributes
+
+static void spv_link_ps_attributes(Context *ctx, uint32 id, RegisterType regtype,
+                                   MOJOSHADER_usage usage, int index)
 {
     switch (regtype)
     {
@@ -1684,48 +1652,19 @@
             // - decorated with location 0
             // - not decorated as a built-in variable.
             // There is no implicit broadcast.
-            assert(index < SPV_LENGTH_COLOR);
-            spv_output_location(ctx, id, SPV_OFFSET_COLOR + index);
+            if (index == 0)
+                spv_output_location(ctx, id, 0);
+            else
+                spv_output_attrib_location(ctx, id, MOJOSHADER_USAGE_COLOR, index);
             break;
         case REG_TYPE_INPUT: // v# (MOJOSHADER_USAGE_COLOR aka `oC#` in vertex shader)
-            switch (usage)
-            {
-                #define SPV_DECORATION_USAGE(usage) \
-                    case MOJOSHADER_USAGE_##usage: \
-                        assert(index < SPV_LENGTH_##usage); \
-                        spv_output_location(ctx, id, SPV_OFFSET_##usage + index); \
-                        break;
-                SPV_DECORATION_USAGE(COLOR)
-                SPV_DECORATION_USAGE(NORMAL)
-                SPV_DECORATION_USAGE(FOG)
-                SPV_DECORATION_USAGE(TANGENT)
-                SPV_DECORATION_USAGE(BLENDINDICES)
-                #undef SPV_DECORATION_USAGE
-                case MOJOSHADER_USAGE_TEXCOORD:
-                {
-                    uint32 location_offset = spv_output_location(ctx, id, SPV_OFFSET_TEXCOORD + index);
-                    if (index == 0)
-                        ctx->spirv.patch_table.ps_texcoord0_offset = location_offset;
-                    break;
-                } // case
-                case MOJOSHADER_USAGE_POSITION:
-                {
-                    assert(index <= SPV_LENGTH_POSITION);
-                    spv_output_location(ctx, id, SPV_OFFSET_POSITION + (index - 1));
-                } // case
-                case MOJOSHADER_USAGE_POINTSIZE:
-                {
-                    assert(index <= SPV_LENGTH_PSIZE);
-                    spv_output_location(ctx, id, SPV_OFFSET_PSIZE + (index - 1));
-                } // case
-                default:
-                    failf(ctx, "unexpected attribute usage %d in pixel shader", usage);
-                    break;
-            } // switch
+            if (usage == MOJOSHADER_USAGE_COLOR && index == 0)
+                spv_output_location(ctx, id, 0);
+            else
+                spv_output_attrib_location(ctx, id, usage, index);
             break;
         case REG_TYPE_TEXTURE: // t# (MOJOSHADER_USAGE_TEXCOORD aka `oT#` in vertex shader)
-            assert(index < SPV_LENGTH_TEXCOORD);
-            spv_output_location(ctx, id, SPV_OFFSET_TEXCOORD + index);
+            spv_output_attrib_location(ctx, id, MOJOSHADER_USAGE_TEXCOORD, index);
             break;
         case REG_TYPE_DEPTHOUT:
             spv_output_builtin(ctx, id, SpvBuiltInFragDepth);
@@ -2277,13 +2216,11 @@
             {
                 push_output(ctx, &ctx->mainline_intro);
                 SpirvTypeIdx sti = STI_PTR_VEC4_O;
-                if (usage == MOJOSHADER_USAGE_POINTSIZE)
+                if (usage == MOJOSHADER_USAGE_POINTSIZE
+                 || usage == MOJOSHADER_USAGE_FOG)
                 {
                     sti = STI_PTR_FLOAT_O;
-                    ctx->spirv.patch_table.vs_has_psize = 1;
                 } // if
-                else if (usage == MOJOSHADER_USAGE_FOG)
-                    sti = STI_PTR_FLOAT_O;
 
                 tid = spv_get_type(ctx, sti);
                 spv_emit(ctx, 4, SpvOpVariable, tid, r->spirv.iddecl, SpvStorageClassOutput);
@@ -2385,7 +2322,7 @@
 
 void emit_SPIRV_finalize(Context *ctx)
 {
-    size_t i, max;
+    size_t i, j, max;
 
     /* The generator's magic number, this could be registered with Khronos
      * if we wanted to. 0 is fine though, so use that for now. */
@@ -2657,11 +2594,13 @@
             entry->location = -1;
     } // for
 
-    if (shader_is_pixel(ctx) && table->ps_texcoord0_offset)
-        table->ps_texcoord0_offset += base_offset;
-
     table->location_count = location_count;
 
+    for (i = 0; i < MOJOSHADER_USAGE_TOTAL; i++)
+        for (j = 0; j < 16; j++)
+            if (table->attrib_offsets[i][j])
+                table->attrib_offsets[i][j] += base_offset;
+
     push_output(ctx, &ctx->postflight);
     buffer_append(ctx->output, &ctx->spirv.patch_table, sizeof(ctx->spirv.patch_table));
     pop_output(ctx);