mojoshader_d3d11.c
changeset 1255 0135d797e287
child 1262 8df5c62abd22
equal deleted inserted replaced
1254:422f68756c9f 1255:0135d797e287
       
     1 /**
       
     2  * MojoShader; generate shader programs from bytecode of compiled
       
     3  *  Direct3D shaders.
       
     4  *
       
     5  * Please see the file LICENSE.txt in the source's root directory.
       
     6  *
       
     7  *  This file written by Ryan C. Gordon.
       
     8  */
       
     9 
       
    10 #ifdef _WIN32
       
    11 #define WIN32_LEAN_AND_MEAN 1
       
    12 #include <windows.h> // Include this early to avoid SDL conflicts
       
    13 #endif
       
    14 
       
    15 #define __MOJOSHADER_INTERNAL__ 1
       
    16 #include "mojoshader_internal.h"
       
    17 
       
    18 #if SUPPORT_PROFILE_HLSL
       
    19 
       
    20 #define D3D11_NO_HELPERS
       
    21 #define CINTERFACE
       
    22 #define COBJMACROS
       
    23 #include <d3d11.h>
       
    24 
       
    25 #ifndef WINAPI_FAMILY_WINRT
       
    26 #define WINAPI_FAMILY_WINRT 0
       
    27 #endif
       
    28 #if WINAPI_FAMILY_WINRT
       
    29 #include <d3dcompiler.h>
       
    30 #endif
       
    31 
       
    32 /* Error state */
       
    33 
       
    34 static char error_buffer[1024] = { '\0' };
       
    35 
       
    36 static void set_error(const char *str)
       
    37 {
       
    38     snprintf(error_buffer, sizeof (error_buffer), "%s", str);
       
    39 } // set_error
       
    40 
       
    41 static inline void out_of_memory(void)
       
    42 {
       
    43     set_error("out of memory");
       
    44 } // out_of_memory
       
    45 
       
    46 /* D3DCompile signature */
       
    47 
       
    48 typedef HRESULT(WINAPI *PFN_D3DCOMPILE)(
       
    49     LPCVOID pSrcData,
       
    50     SIZE_T SrcDataSize,
       
    51     LPCSTR pSourceName,
       
    52     const D3D_SHADER_MACRO *pDefines,
       
    53     ID3DInclude *pInclude,
       
    54     LPCSTR pEntrypoint,
       
    55     LPCSTR pTarget,
       
    56     UINT Flags1,
       
    57     UINT Flags2,
       
    58     ID3DBlob **ppCode,
       
    59     ID3DBlob **ppErrorMsgs
       
    60 );
       
    61 
       
    62 /* Structs */
       
    63 
       
    64 typedef struct d3d11ShaderMap
       
    65 {
       
    66     void *val;
       
    67     union
       
    68     {
       
    69         struct
       
    70         {
       
    71             uint64 layoutHash;
       
    72             ID3D10Blob *blob;
       
    73         } vertex;
       
    74         struct
       
    75         {
       
    76             MOJOSHADER_d3d11Shader *vshader;
       
    77         } pixel;
       
    78     };
       
    79 } d3d11ShaderMap;
       
    80 
       
    81 typedef struct MOJOSHADER_d3d11Shader
       
    82 {
       
    83     const MOJOSHADER_parseData *parseData;
       
    84     uint32 refcount;
       
    85     ID3D11Buffer *ubo;
       
    86     size_t buflen;
       
    87     unsigned char *constantData;
       
    88     unsigned int mapCapacity;
       
    89     unsigned int numMaps;
       
    90     d3d11ShaderMap *shaderMaps;
       
    91 } MOJOSHADER_d3d11Shader;
       
    92 
       
    93 // Max entries for each register file type...
       
    94 #define MAX_REG_FILE_F 8192
       
    95 #define MAX_REG_FILE_I 2047
       
    96 #define MAX_REG_FILE_B 2047
       
    97 
       
    98 typedef struct MOJOSHADER_d3d11Context
       
    99 {
       
   100     // Allocators...
       
   101     MOJOSHADER_malloc malloc_fn;
       
   102     MOJOSHADER_free free_fn;
       
   103     void *malloc_data;
       
   104 
       
   105     // The constant register files...
       
   106     // !!! FIXME: Man, it kills me how much memory this takes...
       
   107     // !!! FIXME:  ... make this dynamically allocated on demand.
       
   108     float vs_reg_file_f[MAX_REG_FILE_F * 4];
       
   109     int vs_reg_file_i[MAX_REG_FILE_I * 4];
       
   110     uint8 vs_reg_file_b[MAX_REG_FILE_B];
       
   111     float ps_reg_file_f[MAX_REG_FILE_F * 4];
       
   112     int ps_reg_file_i[MAX_REG_FILE_I * 4];
       
   113     uint8 ps_reg_file_b[MAX_REG_FILE_B];
       
   114 
       
   115     // Pointer to the active ID3D11Device.
       
   116     ID3D11Device *device;
       
   117 
       
   118     // Pointer to the ID3D11DeviceContext.
       
   119     ID3D11DeviceContext *deviceContext;
       
   120 
       
   121     // Currently bound vertex and pixel shaders.
       
   122     MOJOSHADER_d3d11Shader *vertexShader;
       
   123     MOJOSHADER_d3d11Shader *pixelShader;
       
   124     int vertexNeedsBound;
       
   125     int pixelNeedsBound;
       
   126 
       
   127     // D3DCompile function pointer.
       
   128     PFN_D3DCOMPILE D3DCompileFunc;
       
   129 #if !WINAPI_FAMILY_WINRT
       
   130     HMODULE d3dcompilerDLL;
       
   131 #endif
       
   132 } MOJOSHADER_d3d11Context;
       
   133 
       
   134 static MOJOSHADER_d3d11Context *ctx = NULL;
       
   135 
       
   136 /* Uniform buffer utilities */
       
   137 
       
   138 static inline int next_highest_alignment(int n)
       
   139 {
       
   140     const int align = 16;
       
   141     return align * ((n + align - 1) / align);
       
   142 } // next_highest_alignment
       
   143 
       
   144 static inline void *get_uniform_buffer(MOJOSHADER_d3d11Shader *shader)
       
   145 {
       
   146     return (shader == NULL || shader->ubo == NULL) ? NULL : shader->ubo;
       
   147 } // get_uniform_buffer
       
   148 
       
   149 static void update_uniform_buffer(MOJOSHADER_d3d11Shader *shader)
       
   150 {
       
   151     if (shader == NULL || shader->ubo == NULL)
       
   152         return;
       
   153 
       
   154     float *regF; int *regI; uint8 *regB;
       
   155     if (shader->parseData->shader_type == MOJOSHADER_TYPE_VERTEX)
       
   156     {
       
   157         regF = ctx->vs_reg_file_f;
       
   158         regI = ctx->vs_reg_file_i;
       
   159         regB = ctx->vs_reg_file_b;
       
   160     } // if
       
   161     else
       
   162     {
       
   163         regF = ctx->ps_reg_file_f;
       
   164         regI = ctx->ps_reg_file_i;
       
   165         regB = ctx->ps_reg_file_b;
       
   166     } // else
       
   167 
       
   168     // Update the buffer contents
       
   169     int needsUpdate = 0;
       
   170     size_t offset = 0;
       
   171     for (int i = 0; i < shader->parseData->uniform_count; i++)
       
   172     {
       
   173         if (shader->parseData->uniforms[i].constant)
       
   174             continue;
       
   175 
       
   176         int idx = shader->parseData->uniforms[i].index;
       
   177         int arrayCount = shader->parseData->uniforms[i].array_count;
       
   178 
       
   179         void *src = NULL;
       
   180         void *dst = NULL;
       
   181         size_t size = arrayCount ? arrayCount : 1;
       
   182 
       
   183         switch (shader->parseData->uniforms[i].type)
       
   184         {
       
   185             case MOJOSHADER_UNIFORM_FLOAT:
       
   186                 src = &regF[4 * idx];
       
   187                 dst = shader->constantData + offset;
       
   188                 size *= 16;
       
   189                 break;
       
   190 
       
   191             case MOJOSHADER_UNIFORM_INT:
       
   192                 src = &regI[4 * idx];
       
   193                 dst = shader->constantData + offset;
       
   194                 size *= 16;
       
   195                 break;
       
   196 
       
   197             case MOJOSHADER_UNIFORM_BOOL:
       
   198                 src = &regB[idx];
       
   199                 dst = shader->constantData + offset;
       
   200                 break;
       
   201 
       
   202             default:
       
   203                 assert(0); // This should never happen.
       
   204                 break;
       
   205         } // switch
       
   206 
       
   207         if (memcmp(dst, src, size) != 0)
       
   208         {
       
   209             memcpy(dst, src, size);
       
   210             needsUpdate = 1;
       
   211         } // if
       
   212 
       
   213         offset += size;
       
   214     } // for
       
   215 
       
   216     if (needsUpdate)
       
   217     {
       
   218         // Map the buffer
       
   219         D3D11_MAPPED_SUBRESOURCE res;
       
   220         ID3D11DeviceContext_Map((ID3D11DeviceContext*) ctx->deviceContext,
       
   221                                 (ID3D11Resource*) shader->ubo, 0,
       
   222                                 D3D11_MAP_WRITE_DISCARD, 0, &res);
       
   223 
       
   224         // Copy the contents
       
   225         memcpy(res.pData, shader->constantData, shader->buflen);
       
   226 
       
   227         // Unmap the buffer
       
   228         ID3D11DeviceContext_Unmap(
       
   229             (ID3D11DeviceContext*) ctx->deviceContext,
       
   230             (ID3D11Resource*) shader->ubo,
       
   231             0
       
   232         );
       
   233     } // if
       
   234 } // update_uniform_buffer
       
   235 
       
   236 static inline void expand_map(MOJOSHADER_d3d11Shader *shader)
       
   237 {
       
   238     if (shader->numMaps == shader->mapCapacity)
       
   239     {
       
   240         d3d11ShaderMap *newMap = (d3d11ShaderMap *) ctx->malloc_fn(
       
   241             sizeof(d3d11ShaderMap) * shader->mapCapacity * 2,
       
   242             ctx->malloc_data
       
   243         );
       
   244         memcpy(newMap, shader->shaderMaps,
       
   245             sizeof(d3d11ShaderMap) * shader->mapCapacity);
       
   246         shader->mapCapacity *= 2;
       
   247         ctx->free_fn(shader->shaderMaps, ctx->malloc_data);
       
   248         shader->shaderMaps = newMap;
       
   249         newMap = NULL;
       
   250     } // if
       
   251 } // expand_map
       
   252 
       
   253 static inline int element_is_uint(DXGI_FORMAT format)
       
   254 {
       
   255     return  format == DXGI_FORMAT_R32G32B32A32_UINT ||
       
   256             format == DXGI_FORMAT_R32G32B32_UINT ||
       
   257             format == DXGI_FORMAT_R16G16B16A16_UINT ||
       
   258             format == DXGI_FORMAT_R32G32_UINT ||
       
   259             format == DXGI_FORMAT_R10G10B10A2_UINT ||
       
   260             format == DXGI_FORMAT_R8G8B8A8_UINT ||
       
   261             format == DXGI_FORMAT_R16G16_UINT ||
       
   262             format == DXGI_FORMAT_R32_UINT ||
       
   263             format == DXGI_FORMAT_R8G8_UINT ||
       
   264             format == DXGI_FORMAT_R16_UINT ||
       
   265             format == DXGI_FORMAT_R8_UINT;
       
   266 } // element_is_uint
       
   267 
       
   268 static inline int element_is_int(DXGI_FORMAT format)
       
   269 {
       
   270     return  format == DXGI_FORMAT_R32G32B32A32_SINT ||
       
   271             format == DXGI_FORMAT_R32G32B32_SINT ||
       
   272             format == DXGI_FORMAT_R16G16B16A16_SINT ||
       
   273             format == DXGI_FORMAT_R32G32_SINT ||
       
   274             format == DXGI_FORMAT_R8G8B8A8_SINT ||
       
   275             format == DXGI_FORMAT_R16G16_SINT ||
       
   276             format == DXGI_FORMAT_R32_SINT ||
       
   277             format == DXGI_FORMAT_R8G8_SINT ||
       
   278             format == DXGI_FORMAT_R16_SINT ||
       
   279             format == DXGI_FORMAT_R8_SINT;
       
   280 } // element_is_int
       
   281 
       
   282 /* Shader Compilation Utilities */
       
   283 
       
   284 static ID3D11VertexShader *compileVertexShader(MOJOSHADER_d3d11Shader *shader,
       
   285                                                const char *src, int src_len,
       
   286                                                ID3D10Blob **blob)
       
   287 {
       
   288     const MOJOSHADER_parseData *pd = shader->parseData;
       
   289     HRESULT result = ctx->D3DCompileFunc(src, src_len, pd->mainfn,
       
   290                                          NULL, NULL, pd->mainfn, "vs_4_0",
       
   291                                          0, 0, blob, blob);
       
   292 
       
   293     if (result < 0)
       
   294     {
       
   295         set_error((const char *) ID3D10Blob_GetBufferPointer(*blob));
       
   296         ID3D10Blob_Release(*blob);
       
   297         return NULL;
       
   298     } // if
       
   299 
       
   300     void *bytecode = ID3D10Blob_GetBufferPointer(*blob);
       
   301     int bytecodeLength = ID3D10Blob_GetBufferSize(*blob);
       
   302     ID3D11VertexShader *ret = NULL;
       
   303     ID3D11Device_CreateVertexShader(ctx->device, bytecode, bytecodeLength,
       
   304                                     NULL, &ret);
       
   305     return ret;
       
   306 } // compileVertexShader
       
   307 
       
   308 static void replaceVarname(const char *find, const char *replace,
       
   309                            const char **source)
       
   310 {
       
   311     const char *srcbuf = *source;
       
   312     size_t find_len = strlen(find);
       
   313     size_t replace_len = strlen(replace);
       
   314 
       
   315     #define IS_PARTIAL_TOKEN(token) \
       
   316         (isalnum(*(token + find_len)) || isalnum(*(token-1)))
       
   317 
       
   318     // How many times does `find` occur in the source buffer?
       
   319     int count = 0;
       
   320     char *ptr = (char *) strstr(srcbuf, find);
       
   321     while (ptr != NULL)
       
   322     {
       
   323         if (!IS_PARTIAL_TOKEN(ptr))
       
   324             count++;
       
   325         ptr = strstr(ptr + find_len, find);
       
   326     } // while
       
   327 
       
   328     // How big should we make the new text buffer?
       
   329     size_t oldlen = strlen(srcbuf) + 1;
       
   330     size_t newlen = oldlen + (count * (replace_len - find_len));
       
   331 
       
   332     // Easy case; just find/replace in the original buffer
       
   333     if (newlen == oldlen)
       
   334     {
       
   335         ptr = (char *) strstr(srcbuf, find);
       
   336         while (ptr != NULL)
       
   337         {
       
   338             if (!IS_PARTIAL_TOKEN(ptr))
       
   339                 memcpy(ptr, replace, replace_len);
       
   340             ptr = strstr(ptr + find_len, find);
       
   341         } // while
       
   342         return;
       
   343     } // if
       
   344 
       
   345     // Allocate a new buffer
       
   346     char *newbuf = (char *) ctx->malloc_fn(newlen, ctx->malloc_data);
       
   347     memset(newbuf, '\0', newlen);
       
   348 
       
   349     // Find + replace
       
   350     char *prev_ptr = (char *) srcbuf;
       
   351     char *curr_ptr = (char *) newbuf;
       
   352     ptr = (char*) strstr(srcbuf, find);
       
   353     while (ptr != NULL)
       
   354     {
       
   355         memcpy(curr_ptr, prev_ptr, ptr - prev_ptr);
       
   356         curr_ptr += ptr - prev_ptr;
       
   357 
       
   358         if (!IS_PARTIAL_TOKEN(ptr))
       
   359         {
       
   360             memcpy(curr_ptr, replace, replace_len);
       
   361             curr_ptr += replace_len;
       
   362         } // if
       
   363         else
       
   364         {
       
   365             // Don't accidentally eat partial tokens...
       
   366             memcpy(curr_ptr, find, find_len);
       
   367             curr_ptr += find_len;
       
   368         } // else
       
   369 
       
   370         prev_ptr = ptr + find_len;
       
   371         ptr = strstr(prev_ptr, find);
       
   372     } // while
       
   373 
       
   374     #undef IS_PARTIAL_TOKEN
       
   375 
       
   376     // Copy the remaining part of the source buffer
       
   377     memcpy(curr_ptr, prev_ptr, (srcbuf + oldlen) - prev_ptr);
       
   378 
       
   379     // Free the source buffer
       
   380     ctx->free_fn((void *) srcbuf, ctx->malloc_data);
       
   381 
       
   382     // Point the source parameter to the new buffer
       
   383     *source = newbuf;
       
   384 } // replaceVarname
       
   385 
       
   386 static char *rewritePixelShader(MOJOSHADER_d3d11Shader *vshader,
       
   387                                 MOJOSHADER_d3d11Shader *pshader)
       
   388 {
       
   389     const MOJOSHADER_parseData *vpd = vshader->parseData;
       
   390     const MOJOSHADER_parseData *ppd = pshader->parseData;
       
   391     const char *_Output = "_Output" ENDLINE_STR "{" ENDLINE_STR;
       
   392     const char *_Input = "_Input" ENDLINE_STR "{" ENDLINE_STR;
       
   393     const char *vsrc = vpd->output;
       
   394     const char *psrc = ppd->output;
       
   395     const char *a, *b, *vout, *pstart, *vface, *pend;
       
   396     size_t substr_len;
       
   397     char *pfinal;
       
   398 
       
   399     #define MAKE_STRBUF(buf) \
       
   400         substr_len = b - a; \
       
   401         buf = (const char *) ctx->malloc_fn(substr_len + 1, ctx->malloc_data); \
       
   402         memset((void *) buf, '\0', substr_len + 1); \
       
   403         memcpy((void *) buf, a, substr_len);
       
   404 
       
   405     // Copy the vertex function's output struct into a buffer
       
   406     a = strstr(vsrc, _Output) + strlen(_Output);
       
   407     b = a;
       
   408     while (*(b++) != '}');
       
   409     b--;
       
   410     MAKE_STRBUF(vout)
       
   411 
       
   412     // Split up the pixel shader text...
       
   413 
       
   414     // ...everything up to the input contents...
       
   415     a = psrc;
       
   416     b = strstr(psrc, _Input) + strlen(_Input);
       
   417     MAKE_STRBUF(pstart)
       
   418 
       
   419     // ...everything after the input contents.
       
   420     a = b;
       
   421     while (*(a++) != '}');
       
   422     a--;
       
   423     while (*(b++) != '\0');
       
   424     MAKE_STRBUF(pend)
       
   425 
       
   426     // Find matching semantics
       
   427     int i, j;
       
   428     int vfaceidx = -1;
       
   429     const char *pvarname, *vvarname;
       
   430     for (i = 0; i < ppd->attribute_count; i++)
       
   431     {
       
   432         for (j = 0; j < vpd->output_count; j++)
       
   433         {
       
   434             if (ppd->attributes[i].usage == vpd->outputs[j].usage &&
       
   435                 ppd->attributes[i].index == vpd->outputs[j].index)
       
   436             {
       
   437                 pvarname = ppd->attributes[i].name;
       
   438                 vvarname = vpd->outputs[j].name;
       
   439                 if (strcmp(pvarname, vvarname) != 0)
       
   440                     replaceVarname(pvarname, vvarname, &pend);
       
   441             } // if
       
   442             else if (strcmp(ppd->attributes[i].name, "vPos") == 0 &&
       
   443                      vpd->outputs[j].usage == MOJOSHADER_USAGE_POSITION &&
       
   444                      vpd->outputs[j].index == 0)
       
   445             {
       
   446                 pvarname = ppd->attributes[i].name;
       
   447                 vvarname = vpd->outputs[j].name;
       
   448                 if (strcmp(pvarname, vvarname) != 0)
       
   449                     replaceVarname(pvarname, vvarname, &pend);
       
   450             } // else if
       
   451         } // for
       
   452 
       
   453         if (strcmp(ppd->attributes[i].name, "vFace") == 0)
       
   454             vfaceidx = i;
       
   455     } // for
       
   456 
       
   457     // Special handling for VFACE
       
   458     vface = (vfaceidx != -1) ? "\tbool m_vFace : SV_IsFrontFace;\n" : "";
       
   459 
       
   460     // Concatenate the shader pieces together
       
   461     substr_len = strlen(pstart) + strlen(vout) + strlen(vface) + strlen(pend);
       
   462     pfinal = (char *) ctx->malloc_fn(substr_len + 1, ctx->malloc_data);
       
   463     memset((void *) pfinal, '\0', substr_len + 1);
       
   464     memcpy(pfinal, pstart, strlen(pstart));
       
   465     memcpy(pfinal + strlen(pstart), vout, strlen(vout));
       
   466     memcpy(pfinal + strlen(pstart) + strlen(vout), vface, strlen(vface));
       
   467     memcpy(pfinal + strlen(pstart) + strlen(vout) + strlen(vface), pend, strlen(pend));
       
   468 
       
   469     // Free the temporary buffers
       
   470     ctx->free_fn((void *) vout, ctx->malloc_data);
       
   471     ctx->free_fn((void *) pstart, ctx->malloc_data);
       
   472     ctx->free_fn((void *) pend, ctx->malloc_data);
       
   473 
       
   474     #undef MAKE_STRBUF
       
   475 
       
   476     return pfinal;
       
   477 } // spliceVertexShaderInput
       
   478 
       
   479 static ID3D11PixelShader *compilePixelShader(MOJOSHADER_d3d11Shader *vshader,
       
   480                                              MOJOSHADER_d3d11Shader *pshader)
       
   481 {
       
   482     ID3D11PixelShader *retval = NULL;
       
   483     const char *source;
       
   484     ID3DBlob *blob;
       
   485     HRESULT result;
       
   486     int needs_free;
       
   487 
       
   488     if (pshader->parseData->attribute_count > 0)
       
   489     {
       
   490         source = rewritePixelShader(vshader, pshader);
       
   491         needs_free = 1;
       
   492     } // if
       
   493     else
       
   494     {
       
   495         source = pshader->parseData->output;
       
   496         needs_free = 0;
       
   497     } // else
       
   498 
       
   499     result = ctx->D3DCompileFunc(source, strlen(source),
       
   500                                  pshader->parseData->mainfn, NULL, NULL,
       
   501                                  pshader->parseData->mainfn, "ps_4_0", 0, 0,
       
   502                                  &blob, &blob);
       
   503 
       
   504     if (result < 0)
       
   505     {
       
   506         set_error((const char *) ID3D10Blob_GetBufferPointer(blob));
       
   507         ctx->free_fn((void *) source, ctx->malloc_data);
       
   508         return NULL;
       
   509     } // if
       
   510 
       
   511     ID3D11Device_CreatePixelShader(ctx->device,
       
   512                                    ID3D10Blob_GetBufferPointer(blob),
       
   513                                    ID3D10Blob_GetBufferSize(blob),
       
   514                                    NULL, &retval);
       
   515 
       
   516     ID3D10Blob_Release(blob);
       
   517     if (needs_free)
       
   518         ctx->free_fn((void *) source, ctx->malloc_data);
       
   519     return retval;
       
   520 } // compilePixelShader
       
   521 
       
   522 /* Public API */
       
   523 
       
   524 int MOJOSHADER_d3d11CreateContext(void *device, void *deviceContext,
       
   525                                   MOJOSHADER_malloc m, MOJOSHADER_free f,
       
   526                                   void *malloc_d)
       
   527 {
       
   528     assert(ctx == NULL);
       
   529 
       
   530     if (m == NULL) m = MOJOSHADER_internal_malloc;
       
   531     if (f == NULL) f = MOJOSHADER_internal_free;
       
   532 
       
   533     ctx = (MOJOSHADER_d3d11Context *) m(sizeof(MOJOSHADER_d3d11Context), malloc_d);
       
   534     if (ctx == NULL)
       
   535     {
       
   536         out_of_memory();
       
   537         goto init_fail;
       
   538     } // if
       
   539 
       
   540     memset(ctx, '\0', sizeof (MOJOSHADER_d3d11Context));
       
   541     ctx->malloc_fn = m;
       
   542     ctx->free_fn = f;
       
   543     ctx->malloc_data = malloc_d;
       
   544 
       
   545     // Store references to the D3D device and immediate context
       
   546     ctx->device = (ID3D11Device*) device;
       
   547     ctx->deviceContext = (ID3D11DeviceContext*) deviceContext;
       
   548 
       
   549     // Grab the D3DCompile function pointer
       
   550 #if WINAPI_FAMILY_WINRT
       
   551     ctx->D3DCompileFunc = D3DCompile;
       
   552 #else
       
   553     ctx->d3dcompilerDLL = LoadLibrary("d3dcompiler_47.dll");
       
   554     assert(ctx->d3dcompilerDLL != NULL);
       
   555     ctx->D3DCompileFunc = (PFN_D3DCOMPILE) GetProcAddress(ctx->d3dcompilerDLL,
       
   556                                                           "D3DCompile");
       
   557 #endif /* WINAPI_FAMILY_WINRT */
       
   558 
       
   559     return 0;
       
   560 
       
   561 init_fail:
       
   562     if (ctx != NULL)
       
   563         f(ctx, malloc_d);
       
   564     return -1;
       
   565 } // MOJOSHADER_d3d11CreateContext
       
   566 
       
   567 void MOJOSHADER_d3d11DestroyContext(void)
       
   568 {
       
   569 #if !WINAPI_FAMILY_WINRT
       
   570     FreeLibrary(ctx->d3dcompilerDLL);
       
   571 #endif
       
   572     ctx->free_fn(ctx, ctx->malloc_data);
       
   573     ctx = NULL;
       
   574 } // MOJOSHADER_d3d11DestroyContext
       
   575 
       
   576 MOJOSHADER_d3d11Shader *MOJOSHADER_d3d11CompileShader(const char *mainfn,
       
   577                                                       const unsigned char *tokenbuf,
       
   578                                                       const unsigned int bufsize,
       
   579                                                       const MOJOSHADER_swizzle *swiz,
       
   580                                                       const unsigned int swizcount,
       
   581                                                       const MOJOSHADER_samplerMap *smap,
       
   582                                                       const unsigned int smapcount)
       
   583 {
       
   584     MOJOSHADER_malloc m = ctx->malloc_fn;
       
   585     MOJOSHADER_free f = ctx->free_fn;
       
   586     void *d = ctx->malloc_data;
       
   587     int i;
       
   588 
       
   589     const MOJOSHADER_parseData *pd = MOJOSHADER_parse("hlsl", mainfn, tokenbuf,
       
   590                                                      bufsize, swiz, swizcount,
       
   591                                                      smap, smapcount, m, f, d);
       
   592 
       
   593     if (pd->error_count > 0)
       
   594     {
       
   595         // !!! FIXME: put multiple errors in the buffer? Don't use
       
   596         // !!! FIXME:  MOJOSHADER_d3d11GetError() for this?
       
   597         set_error(pd->errors[0].error);
       
   598         goto compile_shader_fail;
       
   599     } // if
       
   600 
       
   601     MOJOSHADER_d3d11Shader *retval = (MOJOSHADER_d3d11Shader *) m(sizeof(MOJOSHADER_d3d11Shader), d);
       
   602     if (retval == NULL)
       
   603         goto compile_shader_fail;
       
   604 
       
   605     retval->parseData = pd;
       
   606     retval->refcount = 1;
       
   607     retval->ubo = NULL;
       
   608     retval->constantData = NULL;
       
   609     retval->buflen = 0;
       
   610     retval->numMaps = 0;
       
   611 
       
   612     // Allocate shader maps
       
   613     retval->mapCapacity = 4; // arbitrary!
       
   614     retval->shaderMaps = (d3d11ShaderMap *) m(retval->mapCapacity * sizeof(d3d11ShaderMap), d);
       
   615     if (retval->shaderMaps == NULL)
       
   616         goto compile_shader_fail;
       
   617 
       
   618     memset(retval->shaderMaps, '\0', retval->mapCapacity * sizeof(d3d11ShaderMap));
       
   619 
       
   620     // Create the uniform buffer, if needed
       
   621     if (pd->uniform_count > 0)
       
   622     {
       
   623         // Calculate how big we need to make the buffer
       
   624         for (i = 0; i < pd->uniform_count; i++)
       
   625         {
       
   626             int arrayCount = pd->uniforms[i].array_count;
       
   627             int uniformSize = 16;
       
   628             if (pd->uniforms[i].type == MOJOSHADER_UNIFORM_BOOL)
       
   629                 uniformSize = 1;
       
   630             retval->buflen += (arrayCount ? arrayCount : 1) * uniformSize;
       
   631         } // for
       
   632 
       
   633         D3D11_BUFFER_DESC bdesc;
       
   634         bdesc.ByteWidth = next_highest_alignment(retval->buflen);
       
   635         bdesc.Usage = D3D11_USAGE_DYNAMIC;
       
   636         bdesc.BindFlags = D3D11_BIND_CONSTANT_BUFFER;
       
   637         bdesc.CPUAccessFlags = D3D11_CPU_ACCESS_WRITE;
       
   638         bdesc.MiscFlags = 0;
       
   639         bdesc.StructureByteStride = 0;
       
   640         ID3D11Device_CreateBuffer((ID3D11Device*) ctx->device, &bdesc, NULL,
       
   641                                   (ID3D11Buffer**) &retval->ubo);
       
   642 
       
   643         // Additionally allocate a CPU-side staging buffer
       
   644         retval->constantData = (unsigned char *) m(retval->buflen, d);
       
   645         memset(retval->constantData, '\0', retval->buflen);
       
   646     } // if
       
   647 
       
   648     return retval;
       
   649 
       
   650 compile_shader_fail:
       
   651     MOJOSHADER_freeParseData(pd);
       
   652     return NULL;
       
   653 } // MOJOSHADER_d3d11CompileShader
       
   654 
       
   655 void MOJOSHADER_d3d11ShaderAddRef(MOJOSHADER_d3d11Shader *shader)
       
   656 {
       
   657     if (shader != NULL)
       
   658         shader->refcount++;
       
   659 } // MOJOSHADER_d3d11ShaderAddRef
       
   660 
       
   661 void MOJOSHADER_d3d11DeleteShader(MOJOSHADER_d3d11Shader *shader)
       
   662 {
       
   663     if (shader != NULL)
       
   664     {
       
   665         if (shader->refcount > 1)
       
   666             shader->refcount--;
       
   667         else
       
   668         {
       
   669             if (shader->ubo != NULL)
       
   670             {
       
   671                 ID3D11Buffer_Release((ID3D11Buffer*) shader->ubo);
       
   672                 ctx->free_fn(shader->constantData, ctx->malloc_data);
       
   673             } // if
       
   674 
       
   675             if (shader->parseData->shader_type == MOJOSHADER_TYPE_VERTEX)
       
   676             {
       
   677                 for (int i = 0; i < shader->numMaps; i++)
       
   678                 {
       
   679                     ID3D11VertexShader_Release((ID3D11VertexShader *) shader->shaderMaps[i].val);
       
   680                     ID3D10Blob_Release(shader->shaderMaps[i].vertex.blob);
       
   681                 } // for
       
   682             } // if
       
   683             else if (shader->parseData->shader_type == MOJOSHADER_TYPE_PIXEL)
       
   684             {
       
   685                 for (int i = 0; i < shader->numMaps; i++)
       
   686                     ID3D11PixelShader_Release((ID3D11PixelShader *) shader->shaderMaps[i].val);
       
   687             } // else if
       
   688 
       
   689             ctx->free_fn(shader->shaderMaps, ctx->malloc_data);
       
   690             shader->shaderMaps = NULL;
       
   691             MOJOSHADER_freeParseData(shader->parseData);
       
   692             ctx->free_fn(shader, ctx->malloc_data);
       
   693         } // else
       
   694     } // if
       
   695 } // MOJOSHADER_d3d11DeleteShader
       
   696 
       
   697 const MOJOSHADER_parseData *MOJOSHADER_d3d11GetShaderParseData(
       
   698                                                 MOJOSHADER_d3d11Shader *shader)
       
   699 {
       
   700     return (shader != NULL) ? shader->parseData : NULL;
       
   701 } // MOJOSHADER_d3d11GetParseData
       
   702 
       
   703 void MOJOSHADER_d3d11BindShaders(MOJOSHADER_d3d11Shader *vshader,
       
   704                                  MOJOSHADER_d3d11Shader *pshader)
       
   705 {
       
   706     // Use the last bound shaders in case of NULL
       
   707     if (vshader != NULL)
       
   708     {
       
   709         ctx->vertexShader = vshader;
       
   710         ctx->vertexNeedsBound = 1;
       
   711     } // if
       
   712 
       
   713     if (pshader != NULL)
       
   714     {
       
   715         ctx->pixelShader = pshader;
       
   716         ctx->pixelNeedsBound = 1;
       
   717     } // if
       
   718 } // MOJOSHADER_d3d11BindShaders
       
   719 
       
   720 void MOJOSHADER_d3d11GetBoundShaders(MOJOSHADER_d3d11Shader **vshader,
       
   721                                      MOJOSHADER_d3d11Shader **pshader)
       
   722 {
       
   723     *vshader = ctx->vertexShader;
       
   724     *pshader = ctx->pixelShader;
       
   725 } // MOJOSHADER_d3d11GetBoundShaders
       
   726 
       
   727 void MOJOSHADER_d3d11MapUniformBufferMemory(float **vsf, int **vsi, unsigned char **vsb,
       
   728                                             float **psf, int **psi, unsigned char **psb)
       
   729 {
       
   730     *vsf = ctx->vs_reg_file_f;
       
   731     *vsi = ctx->vs_reg_file_i;
       
   732     *vsb = ctx->vs_reg_file_b;
       
   733     *psf = ctx->ps_reg_file_f;
       
   734     *psi = ctx->ps_reg_file_i;
       
   735     *psb = ctx->ps_reg_file_b;
       
   736 } // MOJOSHADER_d3d11MapUniformBufferMemory
       
   737 
       
   738 void MOJOSHADER_d3d11UnmapUniformBufferMemory()
       
   739 {
       
   740     /* This has nothing to do with unmapping memory
       
   741      * and everything to do with updating uniform
       
   742      * buffers with the latest parameter contents.
       
   743      */
       
   744     MOJOSHADER_d3d11Shader *vs, *ps;
       
   745     MOJOSHADER_d3d11GetBoundShaders(&vs, &ps);
       
   746     update_uniform_buffer(vs);
       
   747     update_uniform_buffer(ps);
       
   748 } // MOJOSHADER_d3d11UnmapUniformBufferMemory
       
   749 
       
   750 int MOJOSHADER_d3d11GetVertexAttribLocation(MOJOSHADER_d3d11Shader *vert,
       
   751                                             MOJOSHADER_usage usage, int index)
       
   752 {
       
   753     if (vert == NULL)
       
   754         return -1;
       
   755 
       
   756     for (int i = 0; i < vert->parseData->attribute_count; i++)
       
   757     {
       
   758         if (vert->parseData->attributes[i].usage == usage &&
       
   759             vert->parseData->attributes[i].index == index)
       
   760         {
       
   761             return i;
       
   762         } // if
       
   763     } // for
       
   764 
       
   765     // failure, couldn't find requested attribute
       
   766     return -1;
       
   767 } // MOJOSHADER_d3d11GetVertexAttribLocation
       
   768 
       
   769 void MOJOSHADER_d3d11CompileVertexShader(unsigned long long inputLayoutHash,
       
   770                                          void* elements, int elementCount,
       
   771                                          void **bytecode, int *bytecodeLength)
       
   772 {
       
   773     MOJOSHADER_d3d11Shader *vshader = ctx->vertexShader;
       
   774     ID3D10Blob *blob;
       
   775 
       
   776     // Don't compile if there's already a mapping for this layout.
       
   777     for (int i = 0; i < vshader->numMaps; i++)
       
   778     {
       
   779         if (inputLayoutHash == vshader->shaderMaps[i].vertex.layoutHash)
       
   780         {
       
   781             blob = vshader->shaderMaps[i].vertex.blob;
       
   782             *bytecode = ID3D10Blob_GetBufferPointer(blob);
       
   783             *bytecodeLength = ID3D10Blob_GetBufferSize(blob);
       
   784             return;
       
   785         } // if
       
   786     } // for
       
   787 
       
   788     // Check for and replace non-float types
       
   789     D3D11_INPUT_ELEMENT_DESC *d3dElements = (D3D11_INPUT_ELEMENT_DESC*) elements;
       
   790     const char *origSource = vshader->parseData->output;
       
   791     int srcLength = vshader->parseData->output_len;
       
   792     char *newSource = (char*) origSource;
       
   793     for (int i = 0; i < elementCount; i += 1)
       
   794     {
       
   795         D3D11_INPUT_ELEMENT_DESC e = d3dElements[i];
       
   796 
       
   797         const char *replace;
       
   798         if (element_is_uint(e.Format))
       
   799             replace = " uint4";
       
   800         else if (element_is_int(e.Format))
       
   801             replace = "  int4";
       
   802         else
       
   803             replace = NULL;
       
   804 
       
   805         if (replace != NULL)
       
   806         {
       
   807             char sem[16];
       
   808             memset(sem, '\0', sizeof(sem));
       
   809             snprintf(sem, sizeof(sem), "%s%d", e.SemanticName, e.SemanticIndex);
       
   810             // !!! FIXME: POSITIONT has no index. What to do? -caleb
       
   811 
       
   812             if (newSource == origSource)
       
   813             {
       
   814                 newSource = (char *) ctx->malloc_fn(srcLength + 1,
       
   815                                                     ctx->malloc_data);
       
   816                 strcpy(newSource, origSource);
       
   817             } // if
       
   818 
       
   819             char *ptr = strstr(newSource, sem);
       
   820             assert(ptr != NULL && "Could not find semantic in shader source!");
       
   821 
       
   822             int spaces = 0;
       
   823             while (spaces < 3)
       
   824                 if (*(--ptr) == ' ') spaces++;
       
   825             memcpy(ptr - strlen("float4"), replace, strlen(replace));
       
   826         } // if
       
   827     } // for
       
   828 
       
   829     // Expand the map array, if needed
       
   830     expand_map(vshader);
       
   831 
       
   832     // Add the new mapping
       
   833     vshader->shaderMaps[vshader->numMaps].vertex.layoutHash = inputLayoutHash;
       
   834     ID3D11VertexShader *vs = compileVertexShader(vshader, newSource,
       
   835                                                  srcLength, &blob);
       
   836     vshader->shaderMaps[ctx->vertexShader->numMaps].val = vs;
       
   837     vshader->shaderMaps[ctx->vertexShader->numMaps].vertex.blob = blob;
       
   838     ctx->vertexShader->numMaps++;
       
   839     assert(vs != NULL);
       
   840 
       
   841     // Return the bytecode info
       
   842     *bytecode = ID3D10Blob_GetBufferPointer(blob);
       
   843     *bytecodeLength = ID3D10Blob_GetBufferSize(blob);
       
   844 } // MOJOSHADER_d3d11CompileVertexShader
       
   845 
       
   846 void MOJOSHADER_d3d11ProgramReady(unsigned long long inputLayoutHash)
       
   847 {
       
   848     MOJOSHADER_d3d11Shader *vshader = ctx->vertexShader;
       
   849     MOJOSHADER_d3d11Shader *pshader = ctx->pixelShader;
       
   850 
       
   851     // Vertex shader...
       
   852     if (ctx->vertexNeedsBound)
       
   853     {
       
   854         ID3D11VertexShader *realVS = NULL;
       
   855         for (int i = 0; i < vshader->numMaps; i++)
       
   856         {
       
   857             if (inputLayoutHash == vshader->shaderMaps[i].vertex.layoutHash)
       
   858             {
       
   859                 realVS = (ID3D11VertexShader *) vshader->shaderMaps[i].val;
       
   860                 break;
       
   861             } // if
       
   862         } // for
       
   863         assert(realVS != NULL);
       
   864         ID3D11DeviceContext_VSSetShader(ctx->deviceContext, realVS, NULL, 0);
       
   865         ID3D11DeviceContext_VSSetConstantBuffers(ctx->deviceContext, 0, 1,
       
   866                                                  &vshader->ubo);
       
   867         ctx->vertexNeedsBound = 0;
       
   868     } // if
       
   869 
       
   870     // Pixel shader...
       
   871     if (ctx->pixelNeedsBound)
       
   872     {
       
   873         // Is there already a mapping for the current vertex shader?
       
   874         ID3D11PixelShader *realPS = NULL;
       
   875         for (int i = 0; i < pshader->numMaps; i++)
       
   876         {
       
   877             if (pshader->shaderMaps[i].pixel.vshader == vshader)
       
   878             {
       
   879                 realPS = (ID3D11PixelShader *) pshader->shaderMaps[i].val;
       
   880                 break;
       
   881             } // if
       
   882         } // for
       
   883 
       
   884         // We have to create a new vertex/pixel shader mapping...
       
   885         if (realPS == NULL)
       
   886         {
       
   887             // Expand the map array, if needed
       
   888             expand_map(pshader);
       
   889 
       
   890             // Add the new mapping
       
   891             pshader->shaderMaps[pshader->numMaps].pixel.vshader = vshader;
       
   892             realPS = compilePixelShader(vshader, pshader);
       
   893             pshader->shaderMaps[pshader->numMaps].val = realPS;
       
   894             pshader->numMaps++;
       
   895             assert(realPS != NULL);
       
   896         } // if
       
   897 
       
   898         ID3D11DeviceContext_PSSetShader(ctx->deviceContext, realPS, NULL, 0);
       
   899         ID3D11DeviceContext_PSSetConstantBuffers(ctx->deviceContext, 0, 1,
       
   900                                                  &pshader->ubo);
       
   901         ctx->pixelNeedsBound = 0;
       
   902     } // if
       
   903 } // MOJOSHADER_d3d11ProgramReady
       
   904 
       
   905 const char *MOJOSHADER_d3d11GetError(void)
       
   906 {
       
   907     return error_buffer;
       
   908 } // MOJOSHADER_d3d11GetError
       
   909 
       
   910 #endif /* SUPPORT_PROFILE_HLSL */
       
   911 
       
   912 // end of mojoshader_d3d11.c ...