mojoshader_vulkan.c
changeset 1277 da61410edbc9
parent 1276 89c389e4112f
child 1280 d2a0d76469f9
--- a/mojoshader_vulkan.c	Mon Jul 06 16:23:06 2020 -0400
+++ b/mojoshader_vulkan.c	Tue Jul 07 17:19:45 2020 -0400
@@ -26,11 +26,19 @@
 
 typedef struct MOJOSHADER_vkShader
 {
-    VkShaderModule shaderModule;
     const MOJOSHADER_parseData *parseData;
+    uint16_t tag;
     uint32_t refcount;
 } MOJOSHADER_vkShader;
 
+typedef struct MOJOSHADER_vkProgram
+{
+    VkShaderModule vertexModule;
+    VkShaderModule pixelModule;
+    MOJOSHADER_vkShader *vertexShader;
+    MOJOSHADER_vkShader *pixelShader;
+} MOJOSHADER_vkProgram;
+
 typedef struct MOJOSHADER_vkUniformBuffer
 {
     VkBuffer buffer;
@@ -89,8 +97,8 @@
     MOJOSHADER_vkUniformBuffer *vertUboBuffer;
     MOJOSHADER_vkUniformBuffer *fragUboBuffer;
 
-    MOJOSHADER_vkShader *vertexShader;
-    MOJOSHADER_vkShader *pixelShader;
+    MOJOSHADER_vkProgram *bound_program;
+    HashTable *linker_cache;
 
     #define VULKAN_INSTANCE_FUNCTION(ret, func, params) \
         vkfntype_MOJOSHADER_##func func;
@@ -100,6 +108,7 @@
 } MOJOSHADER_vkContext;
 
 static MOJOSHADER_vkContext *ctx = NULL;
+static uint16_t tagCounter = 1;
 
 static uint8_t find_memory_type(
     MOJOSHADER_vkContext *ctx,
@@ -352,15 +361,81 @@
     return shader->parseData->output_len - sizeof(SpirvPatchTable);
 } // shader_bytecode_len
 
-static void delete_shader(
-    VkShaderModule shaderModule
-) {
-    ctx->vkDestroyShaderModule(
+static VkShaderModule compile_shader(MOJOSHADER_vkShader *shader)
+{
+    VkResult result;
+    VkShaderModule module;
+    VkShaderModuleCreateInfo shaderModuleCreateInfo =
+    {
+        VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO
+    };
+
+    shaderModuleCreateInfo.flags = 0;
+    shaderModuleCreateInfo.codeSize = shader_bytecode_len(shader);
+    shaderModuleCreateInfo.pCode = (uint32_t*) shader->parseData->output;
+
+    result = ctx->vkCreateShaderModule(
         *ctx->logical_device,
-        shaderModule,
-        NULL
+        &shaderModuleCreateInfo,
+        NULL,
+        &module
     );
-} // delete_shader
+
+    if (result != VK_SUCCESS)
+    {
+        // FIXME: should display VK error code
+        set_error("Error when creating VkShaderModule");
+        ctx->vkDestroyShaderModule(
+            *ctx->logical_device,
+            module,
+            NULL
+        );
+        return VK_NULL_HANDLE;
+    } // if
+
+    return module;
+} // compile_shader
+
+typedef struct
+{
+    MOJOSHADER_vkShader *vertex;
+    MOJOSHADER_vkShader *fragment;
+} BoundShaders;
+
+static uint32_t hash_shaders(const void *sym, void *data)
+{
+    (void) data;
+    const BoundShaders *s = (const BoundShaders *) sym;
+    const uint16_t v = (s->vertex) ? s->vertex->tag : 0;
+    const uint16_t f = (s->fragment) ? s->fragment->tag : 0;
+    return ((uint32_t) v << 16) | (uint32_t) f;
+} // hash_shaders
+
+static int match_shaders(const void *_a, const void *_b, void *data)
+{
+    (void) data;
+    const BoundShaders *a = (const BoundShaders *) _a;
+    const BoundShaders *b = (const BoundShaders *) _b;
+
+    const uint16_t av = (a->vertex) ? a->vertex->tag : 0;
+    const uint16_t bv = (b->vertex) ? b->vertex->tag : 0;
+    if (av != bv)
+        return 0;
+
+    const uint16_t af = (a->fragment) ? a->fragment->tag : 0;
+    const uint16_t bf = (b->fragment) ? b->fragment->tag : 0;
+    if (af != bf)
+        return 0;
+
+    return 1;
+} // match_shaders
+
+static void nuke_shaders(const void *key, const void *value, void *data)
+{
+    (void) data;
+    ctx->free_fn((void *) key, ctx->malloc_data); // this was a BoundShaders struct.
+    MOJOSHADER_vkDeleteProgram((MOJOSHADER_vkProgram *) value);
+} // nuke_shaders
 
 // Public API
 
@@ -422,8 +497,15 @@
     ctx = _ctx;
 } // MOJOSHADER_vkMakeContextCurrent
 
-void MOJOSHADER_vkDestroyContext()
+void MOJOSHADER_vkDestroyContext(MOJOSHADER_vkContext *_ctx)
 {
+    MOJOSHADER_vkContext *current_ctx = ctx;
+    ctx = _ctx;
+
+    MOJOSHADER_vkBindProgram(NULL);
+    if (ctx->linker_cache)
+        hash_destroy(ctx->linker_cache);
+
     ctx->vkDestroyBuffer(*ctx->logical_device,
                          ctx->vertUboBuffer->buffer,
                          NULL);
@@ -444,6 +526,8 @@
     ctx->free_fn(ctx->fragUboBuffer, ctx->malloc_data);
 
     ctx->free_fn(ctx, ctx->malloc_data);
+
+    ctx = ((current_ctx == _ctx) ? NULL : current_ctx);
 } // MOJOSHADER_vkDestroyContext
 
 MOJOSHADER_vkShader *MOJOSHADER_vkCompileShader(
@@ -455,11 +539,6 @@
     const MOJOSHADER_samplerMap *smap,
     const unsigned int smapcount
 ) {
-    VkResult result;
-    VkShaderModuleCreateInfo shaderModuleCreateInfo =
-    {
-        VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO
-    };
     MOJOSHADER_vkShader *shader;
 
     const MOJOSHADER_parseData *pd = MOJOSHADER_parse(
@@ -475,49 +554,27 @@
     if (pd->error_count > 0)
     {
         set_error(pd->errors[0].error);
-        goto compile_shader_fail;
+        goto parse_shader_fail;
     } // if
 
     shader = (MOJOSHADER_vkShader *) ctx->malloc_fn(sizeof(MOJOSHADER_vkShader), ctx->malloc_data);
     if (shader == NULL)
     {
         out_of_memory();
-        goto compile_shader_fail;
+        goto parse_shader_fail;
     } // if
 
     shader->parseData = pd;
     shader->refcount = 1;
-
-    shaderModuleCreateInfo.flags = 0;
-    shaderModuleCreateInfo.codeSize = shader_bytecode_len(shader);
-    shaderModuleCreateInfo.pCode = (uint32_t*) pd->output;
-
-    result = ctx->vkCreateShaderModule(
-        *ctx->logical_device,
-        &shaderModuleCreateInfo,
-        NULL,
-        &shader->shaderModule
-    );
-
-    if (result != VK_SUCCESS)
-    {
-        // FIXME: should display VK error code
-        set_error("Error when creating VkShaderModule");
-        goto compile_shader_fail;
-    } // if
-
+    shader->tag = tagCounter++;
     return shader;
 
-compile_shader_fail:
+parse_shader_fail:
     MOJOSHADER_freeParseData(pd);
     if (shader != NULL)
-    {
-        delete_shader(shader->shaderModule);
         ctx->free_fn(shader, ctx->malloc_data);
-    } // if
     return NULL;
-
-} // MOJOSHADER_vkMakeContextCurrent
+} // MOJOSHADER_vkCompileShader
 
 void MOJOSHADER_vkShaderAddRef(MOJOSHADER_vkShader *shader)
 {
@@ -533,7 +590,25 @@
             shader->refcount--;
         else
         {
-            delete_shader(shader->shaderModule);
+            // See if this was bound as an unlinked program anywhere...
+            if (ctx->linker_cache)
+            {
+                const void *key = NULL;
+                void *iter = NULL;
+                int morekeys = hash_iter_keys(ctx->linker_cache, &key, &iter);
+                while (morekeys)
+                {
+                    const BoundShaders *shaders = (const BoundShaders *) key;
+                    // Do this here so we don't confuse the iteration by removing...
+                    morekeys = hash_iter_keys(ctx->linker_cache, &key, &iter);
+                    if ((shaders->vertex == shader) || (shaders->fragment == shader))
+                    {
+                        // Deletes the linked program
+                        hash_remove(ctx->linker_cache, shaders);
+                    } // if
+                } // while
+            } // if
+
             MOJOSHADER_freeParseData(shader->parseData);
             ctx->free_fn(shader, ctx->malloc_data);
         } // else
@@ -546,22 +621,119 @@
     return (shader != NULL) ? shader->parseData : NULL;
 } // MOJOSHADER_vkGetShaderParseData
 
+void MOJOSHADER_vkDeleteProgram(MOJOSHADER_vkProgram *p)
+{
+    if (p->vertexModule != VK_NULL_HANDLE)
+        ctx->vkDestroyShaderModule(*ctx->logical_device, p->vertexModule, NULL);
+    if (p->pixelModule != VK_NULL_HANDLE)
+        ctx->vkDestroyShaderModule(*ctx->logical_device, p->pixelModule, NULL);
+    ctx->free_fn(p, ctx->malloc_data);
+} // MOJOSHADER_vkDeleteProgram
+
+MOJOSHADER_vkProgram *MOJOSHADER_vkLinkProgram(MOJOSHADER_vkShader *vshader,
+                                               MOJOSHADER_vkShader *pshader)
+{
+    MOJOSHADER_vkProgram *result;
+
+    if ((vshader == NULL) && (pshader == NULL))
+        return NULL;
+
+    result = ctx->malloc_fn(sizeof (MOJOSHADER_vkProgram), ctx->malloc_data);
+    if (result == NULL)
+    {
+        out_of_memory();
+        return NULL;
+    } // if
+
+    MOJOSHADER_spirv_link_attributes(vshader->parseData, pshader->parseData);
+    result->vertexModule = compile_shader(vshader);
+    result->pixelModule = compile_shader(pshader);
+    result->vertexShader = vshader;
+    result->pixelShader = pshader;
+
+    if (result->vertexModule == VK_NULL_HANDLE
+     || result->pixelModule == VK_NULL_HANDLE)
+    {
+        MOJOSHADER_vkDeleteProgram(result);
+        return NULL;
+    }
+    return result;
+} // MOJOSHADER_vkLinkProgram
+
+void MOJOSHADER_vkBindProgram(MOJOSHADER_vkProgram *p)
+{
+    ctx->bound_program = p;
+} // MOJOSHADER_vkBindProgram
+
 void MOJOSHADER_vkBindShaders(MOJOSHADER_vkShader *vshader,
                               MOJOSHADER_vkShader *pshader)
 {
-    // NOOP if shader is null
+    if (ctx->linker_cache == NULL)
+    {
+        ctx->linker_cache = hash_create(NULL, hash_shaders, match_shaders,
+                                        nuke_shaders, 0, ctx->malloc_fn,
+                                        ctx->free_fn, ctx->malloc_data);
+
+        if (ctx->linker_cache == NULL)
+        {
+            out_of_memory();
+            return;
+        } // if
+    } // if
+
+    MOJOSHADER_vkProgram *program = NULL;
+    BoundShaders shaders;
+    shaders.vertex = vshader;
+    shaders.fragment = pshader;
 
-    if (vshader != NULL)
-        ctx->vertexShader = vshader;
-    if (pshader != NULL)
-        ctx->pixelShader = pshader;
+    const void *val = NULL;
+    if (hash_find(ctx->linker_cache, &shaders, &val))
+        program = (MOJOSHADER_vkProgram *) val;
+    else
+    {
+        program = MOJOSHADER_vkLinkProgram(vshader, pshader);
+        if (program == NULL)
+            return;
+
+        BoundShaders *item = (BoundShaders *) ctx->malloc_fn(sizeof (BoundShaders),
+                                                             ctx->malloc_data);
+        if (item == NULL)
+        {
+            MOJOSHADER_vkDeleteProgram(program);
+            return;
+        } // if
+
+        memcpy(item, &shaders, sizeof (BoundShaders));
+        if (hash_insert(ctx->linker_cache, item, program) != 1)
+        {
+            ctx->free_fn(item, ctx->malloc_data);
+            MOJOSHADER_vkDeleteProgram(program);
+            out_of_memory();
+            return;
+        } // if
+    } // else
+
+    assert(program != NULL);
+    ctx->bound_program = program;
 } // MOJOSHADER_vkBindShaders
 
 void MOJOSHADER_vkGetBoundShaders(MOJOSHADER_vkShader **vshader,
                                   MOJOSHADER_vkShader **pshader)
 {
-    *vshader = ctx->vertexShader;
-    *pshader = ctx->pixelShader;
+    if (vshader != NULL)
+    {
+        if (ctx->bound_program != NULL)
+            *vshader = ctx->bound_program->vertexShader;
+        else
+            *vshader = NULL;
+    } // if
+    if (pshader != NULL)
+    {
+        if (ctx->bound_program != NULL)
+            *pshader = ctx->bound_program->pixelShader;
+        else
+            *pshader = NULL;
+    } // if
 } // MOJOSHADER_vkGetBoundShaders
 
 void MOJOSHADER_vkMapUniformBufferMemory(float **vsf, int **vsi, unsigned char **vsb,
@@ -580,20 +752,21 @@
     /* Why is this function named unmap instead of update?
      * the world may never know...
      */
-
-    update_uniform_buffer(ctx->vertexShader);
-    update_uniform_buffer(ctx->pixelShader);
+    assert(ctx->bound_program != NULL);
+    update_uniform_buffer(ctx->bound_program->vertexShader);
+    update_uniform_buffer(ctx->bound_program->pixelShader);
 } // MOJOSHADER_vkUnmapUniformBufferMemory
 
 void MOJOSHADER_vkGetUniformBuffers(VkBuffer *vbuf, unsigned long long *voff, unsigned long long *vsize,
                                     VkBuffer *pbuf, unsigned long long *poff, unsigned long long *psize)
 {
-    *vbuf = get_uniform_buffer(ctx->vertexShader);
-    *voff = get_uniform_offset(ctx->vertexShader);
-    *vsize = get_uniform_size(ctx->vertexShader);
-    *pbuf = get_uniform_buffer(ctx->pixelShader);
-    *poff = get_uniform_offset(ctx->pixelShader);
-    *psize = get_uniform_size(ctx->pixelShader);
+    assert(ctx->bound_program != NULL);
+    *vbuf = get_uniform_buffer(ctx->bound_program->vertexShader);
+    *voff = get_uniform_offset(ctx->bound_program->vertexShader);
+    *vsize = get_uniform_size(ctx->bound_program->vertexShader);
+    *pbuf = get_uniform_buffer(ctx->bound_program->pixelShader);
+    *poff = get_uniform_offset(ctx->bound_program->pixelShader);
+    *psize = get_uniform_size(ctx->bound_program->pixelShader);
 } // MOJOSHADER_vkGetUniformBuffers
 
 void MOJOSHADER_vkEndFrame()
@@ -624,13 +797,15 @@
     return -1;
 } //MOJOSHADER_vkGetVertexAttribLocation
 
-unsigned long long MOJOSHADER_vkGetShaderModule(MOJOSHADER_vkShader *shader)
+void MOJOSHADER_vkGetShaderModules(VkShaderModule *vmodule,
+                                   VkShaderModule *pmodule)
 {
-    if (shader == NULL)
-        return 0;
-
-    return (unsigned long long) shader->shaderModule;
-} //MOJOSHADER_vkGetShaderModule
+    assert(ctx->bound_program != NULL);
+    if (vmodule != NULL)
+        *vmodule = ctx->bound_program->vertexModule;
+    if (pmodule != NULL)
+        *pmodule = ctx->bound_program->pixelModule;
+} //MOJOSHADER_vkGetShaderModules
 
 const char *MOJOSHADER_vkGetError(void)
 {