mojoshader_compiler.c
changeset 973 6d4cdbc21301
parent 972 993c4d6e21a8
child 974 e4c4963e8889
--- a/mojoshader_compiler.c	Tue Jan 11 20:46:24 2011 -0500
+++ b/mojoshader_compiler.c	Wed Jan 12 03:46:17 2011 -0500
@@ -100,6 +100,8 @@
     int is_func_scope; // non-zero if semantic analysis is in function scope.
     int var_index;  // next variable index for current function.
     int global_var_index;  // next variable index for global scope.
+    int user_func_index;  // next function index for user-defined functions.
+    int intrinsic_func_index;  // next function index for intrinsic functions.
 
     // Cache intrinsic types for fast lookup and consistent pointer values.
     MOJOSHADER_astDataType dt_bool;
@@ -224,37 +226,109 @@
     return (map->hash != NULL);
 } // create_symbolmap
 
+static int datatypes_match(const MOJOSHADER_astDataType *a,
+                           const MOJOSHADER_astDataType *b)
+{
+    int i;
+
+    if (a == b)
+        return 1;
+    else if (a->type != b->type)
+        return 0;
+
+    switch (a->type)
+    {
+        case MOJOSHADER_AST_DATATYPE_STRUCT:
+            if (a->structure.member_count != b->structure.member_count)
+                return 0;
+            for (i = 0; i < a->structure.member_count; i++)
+            {
+                if (!datatypes_match(a->structure.members[i].datatype,
+                                     b->structure.members[i].datatype))
+                    return 0;
+                // stringcache'd, pointer compare is safe.
+                else if (a->structure.members[i].identifier !=
+                         b->structure.members[i].identifier)
+                    return 0;
+            } // for
+            return 1;
+
+        case MOJOSHADER_AST_DATATYPE_ARRAY:
+            if (a->array.elements != b->array.elements)
+                return 0;
+            else if (!datatypes_match(a->array.base, b->array.base))
+                return 0;
+            return 1;
+
+        case MOJOSHADER_AST_DATATYPE_VECTOR:
+            if (a->vector.elements != b->vector.elements)
+                return 0;
+            else if (!datatypes_match(a->vector.base, b->vector.base))
+                return 0;
+            return 1;
+
+        case MOJOSHADER_AST_DATATYPE_MATRIX:
+            if (a->matrix.rows != b->matrix.rows)
+                return 0;
+            else if (a->matrix.columns != b->matrix.columns)
+                return 0;
+            else if (!datatypes_match(a->matrix.base, b->matrix.base))
+                return 0;
+            return 1;
+
+        case MOJOSHADER_AST_DATATYPE_BUFFER:
+            return datatypes_match(a->buffer.base, b->buffer.base);
+
+        case MOJOSHADER_AST_DATATYPE_FUNCTION:
+            if (a->function.num_params != b->function.num_params)
+                return 0;
+            else if (a->function.intrinsic != b->function.intrinsic)
+                return 0;
+            else if (!datatypes_match(a->function.retval, b->function.retval))
+                return 0;
+            for (i = 0; i < a->function.num_params; i++)
+            {
+                if (!datatypes_match(a->function.params[i], b->function.params[i]))
+                    return 0;
+            } // for
+            return 1;
+
+        case MOJOSHADER_AST_DATATYPE_USER:
+            return 0;  // pointers must match, this clearly didn't.
+
+        default:
+            assert(0 && "unexpected case");
+            return 0;
+    } // switch
+
+    return 0;
+} // datatypes_match
 
 static void push_symbol(Context *ctx, SymbolMap *map, const char *sym,
-                        const MOJOSHADER_astDataType *dt, const int index)
+                        const MOJOSHADER_astDataType *dt, const int index,
+                        const int check_dupes)
 {
+    if (ctx->out_of_memory)
+        return;
+
     // Decide if this symbol is defined, and if it's in the current scope.
     SymbolScope *item = NULL;
     const void *value = NULL;
-    void *iter = NULL;
-    if ((sym != NULL) && (hash_iter(map->hash, sym, &value, &iter)))
+    if ((check_dupes) && (sym != NULL) && (hash_find(map->hash, sym, &value)))
     {
-        item = (SymbolScope *) value;
-        // Functions are always global, so no need to search scopes.
-            // !!! FIXME: Functions overload, though, so we have to continue
-            // !!! FIXME: iterating to see if it matches anything.
-        //const MOJOSHADER_astDataType *dt = item->datatype;
-        //if (dt->type == MOJOSHADER_AST_DATATYPE_FUNCTION)
-        //{
-        //} // if
-        //else  // check the current scope for a dupe.
+        // check the current scope for a dupe.
+        // !!! FIXME: note current scope's starting index, see if found
+        // !!! FIXME:  item is < index (and thus, a previous scope).
+        item = map->scope;
+        while ((item) && (item->symbol))
         {
-            item = map->scope;
-            while ((item) && (item->symbol))
+            if ( ((const void *) item) == value )
             {
-                if (strcmp(item->symbol, sym) == 0)
-                {
-                    failf(ctx, "Symbol '%s' already defined", sym);
-                    return;
-                } // if
-                item = item->next;
-            } // while
-        } // else
+                failf(ctx, "Symbol '%s' already defined", sym);
+                return;
+            } // if
+            item = item->next;
+        } // while
     } // if
 
     // Add the symbol to our map and scope stack.
@@ -301,7 +375,7 @@
         } // if
     } // if
 
-    push_symbol(ctx, &ctx->usertypes, sym, dt, 0);
+    push_symbol(ctx, &ctx->usertypes, sym, dt, 0, 1);
 } // push_usertype
 
 static inline void push_variable(Context *ctx, const char *sym, const MOJOSHADER_astDataType *dt)
@@ -315,9 +389,47 @@
             idx = --ctx->global_var_index;  // these are negative.
     } // if
 
-    push_symbol(ctx, &ctx->variables, sym, dt, idx);
+    push_symbol(ctx, &ctx->variables, sym, dt, idx, 1);
 } // push_variable
 
+static void push_function(Context *ctx, const char *sym,
+                          const MOJOSHADER_astDataType *dt,
+                          const int just_declare)
+{
+    // we don't have any reason to support nested functions at the moment,
+    //  so this would be a bug.
+    assert(!ctx->is_func_scope);
+    assert(dt->type == MOJOSHADER_AST_DATATYPE_FUNCTION);
+
+    int idx = 0;
+    if ((sym != NULL) && (dt != NULL))
+    {
+        if (!dt->function.intrinsic)
+            idx = ++ctx->user_func_index;  // these are positive.
+        else
+            idx = --ctx->intrinsic_func_index;  // these are negative.
+    } // if
+
+    // Functions are always global, so no need to search scopes.
+    //  Functions overload, though, so we have to continue iterating to
+    //  see if it matches anything.
+    const void *value = NULL;
+    void *iter = NULL;
+    while (hash_iter(ctx->variables.hash, sym, &value, &iter))
+    {
+        // there's already something called this.
+        if (datatypes_match(dt, ((SymbolScope *) value)->datatype))
+        {
+            if (!just_declare)
+                failf(ctx, "Function '%s' already defined.", sym);
+            return;
+        } // if
+    } // while
+
+    // push_symbol() doesn't check dupes, because we just did.
+    push_symbol(ctx, &ctx->variables, sym, dt, idx, 0);
+} // push_function
+
 static inline void push_scope(Context *ctx)
 {
     push_usertype(ctx, NULL, NULL);
@@ -461,13 +573,23 @@
 static void delete_compilation_unit(Context*, MOJOSHADER_astCompilationUnit*);
 static void delete_statement(Context *ctx, MOJOSHADER_astStatement *stmt);
 
+static MOJOSHADER_astExpression *new_identifier_expr(Context *ctx,
+                                                     const char *string)
+{
+    NEW_AST_NODE(retval, MOJOSHADER_astExpressionIdentifier,
+                 MOJOSHADER_AST_OP_IDENTIFIER);
+    retval->identifier = string;  // cached; don't copy string.
+    return (MOJOSHADER_astExpression *) retval;
+} // new_identifier_expr
+
 static MOJOSHADER_astExpression *new_callfunc_expr(Context *ctx,
-                                        MOJOSHADER_astExpression *identifier,
+                                        const char *identifier,
                                         MOJOSHADER_astArguments *args)
 {
     NEW_AST_NODE(retval, MOJOSHADER_astExpressionCallFunction,
                  MOJOSHADER_AST_OP_CALLFUNC);
-    retval->identifier = identifier;
+    MOJOSHADER_astExpression *expr = new_identifier_expr(ctx, identifier);
+    retval->identifier = (MOJOSHADER_astExpressionIdentifier *) expr;
     retval->args = args;
     return (MOJOSHADER_astExpression *) retval;
 } // new_callfunc_expr
@@ -543,15 +665,6 @@
     return (MOJOSHADER_astExpression *) retval;
 } // new_deref_struct_expr
 
-static MOJOSHADER_astExpression *new_identifier_expr(Context *ctx,
-                                                     const char *string)
-{
-    NEW_AST_NODE(retval, MOJOSHADER_astExpressionIdentifier,
-                 MOJOSHADER_AST_OP_IDENTIFIER);
-    retval->identifier = string;  // cached; don't copy string.
-    return (MOJOSHADER_astExpression *) retval;
-} // new_identifier_expr
-
 static MOJOSHADER_astExpression *new_literal_int_expr(Context *ctx,
                                                        const int value)
 {
@@ -627,7 +740,7 @@
 
     else if (expr->ast.type == MOJOSHADER_AST_OP_CALLFUNC)
     {
-        delete_expr(ctx, expr->callfunc.identifier);
+        delete_expr(ctx, (MOJOSHADER_astExpression*)expr->callfunc.identifier);
         delete_arguments(ctx, expr->callfunc.args);
     } // else if
 
@@ -1700,6 +1813,47 @@
 } // reduce_datatype
 
 
+static const MOJOSHADER_astDataType *build_function_datatype(Context *ctx,
+                                        const MOJOSHADER_astDataType *rettype,
+                                        const int paramcount,
+                                        const MOJOSHADER_astDataType **params,
+                                        const int intrinsic)
+{
+    assert( ((paramcount == 0) && (params == NULL)) ||
+            ((paramcount > 0) && (params != NULL)) );
+
+    // !!! FIXME: this is hacky.
+    const MOJOSHADER_astDataType **dtparams;
+    void *ptr = Malloc(ctx, sizeof (*params) * paramcount);
+    if (ptr == NULL)
+        return NULL;
+    if (!buffer_append(ctx->garbage, &ptr, sizeof (ptr)))
+    {
+        Free(ctx, ptr);
+        return NULL;
+    } // if
+    dtparams = (const MOJOSHADER_astDataType **) ptr;
+    memcpy(dtparams, params, sizeof (*params) * paramcount);
+
+    ptr = Malloc(ctx, sizeof (MOJOSHADER_astDataType));
+    if (ptr == NULL)
+        return NULL;
+    if (!buffer_append(ctx->garbage, &ptr, sizeof (ptr)))
+    {
+        Free(ctx, ptr);
+        return NULL;
+    } // if
+
+    MOJOSHADER_astDataType *dt = (MOJOSHADER_astDataType *) ptr;
+    dt->type = MOJOSHADER_AST_DATATYPE_FUNCTION;
+    dt->function.retval = rettype;
+    dt->function.params = dtparams;
+    dt->function.num_params = paramcount;
+    dt->function.intrinsic = intrinsic;
+    return dt;
+} // build_function_datatype
+
+
 static const MOJOSHADER_astDataType *build_datatype(Context *ctx,
                                             const int isconst,
                                             const MOJOSHADER_astDataType *dt,
@@ -1851,18 +2005,17 @@
 } // require_struct_datatype
 
 
-static const MOJOSHADER_astDataType *require_function_datatype(Context *ctx,
-                                        const MOJOSHADER_astDataType *datatype)
+static int require_function_datatype(Context *ctx,
+                                     const MOJOSHADER_astDataType *datatype)
 {
     datatype = reduce_datatype(datatype);
-    if (datatype->type != MOJOSHADER_AST_DATATYPE_FUNCTION)
+    if ((!datatype) || (datatype->type != MOJOSHADER_AST_DATATYPE_FUNCTION))
     {
         fail(ctx, "expected function");
-        // !!! FIXME: delete function call for further processing.
-        return &ctx->dt_int;
+        return 0;
     } // if
 
-    return datatype->function.retval;
+    return 1;
 } // require_function_datatype
 
 
@@ -1994,6 +2147,85 @@
     return ((is_rgba + is_xyzw) == 1);  // can only be one or the other.
 } // is_swizzle_str
 
+static const MOJOSHADER_astDataType *type_check_ast(Context *ctx, void *_ast);
+
+// !!! FIXME: this function sucks.
+static const MOJOSHADER_astDataType *match_func_to_call(Context *ctx,
+                                    MOJOSHADER_astExpressionCallFunction *ast)
+{
+    SymbolScope *best = NULL;  // best choice we find.
+    MOJOSHADER_astExpressionIdentifier *ident = ast->identifier;
+    const char *sym = ident->identifier;
+    const void *value = NULL;
+    void *iter = NULL;
+
+    int argcount = 0;
+    MOJOSHADER_astArguments *args = ast->args;
+    while (args != NULL)
+    {
+        argcount++;
+        type_check_ast(ctx, args->argument);
+        args = args->next;
+    } // while;
+
+    // we do some tapdancing to handle function overloading here.
+    while (hash_iter(ctx->variables.hash, sym, &value, &iter))
+    {
+        SymbolScope *item = (SymbolScope *) value;
+        const MOJOSHADER_astDataType *dt = item->datatype;
+        dt = reduce_datatype(dt);
+        // there's a locally-scoped symbol with this name? It takes precedence.
+        if (dt->type != MOJOSHADER_AST_DATATYPE_FUNCTION)
+            return dt;
+
+        // !!! FIXME: this needs to find functions that implicit casts would catch:
+        //          void fn(int x);
+        //          needs to match:
+        //          short q = 2; fn(q);
+        const MOJOSHADER_astDataTypeFunction *dtfn = (MOJOSHADER_astDataTypeFunction *) dt;
+        args = ast->args;
+        int match = 1;
+        int i;
+
+        if (argcount != dtfn->num_params)  // !!! FIXME: default args.
+            match = 0;
+        else
+        {
+            for (i = 0; i < argcount; i++)
+            {
+                assert(args != NULL);
+                dt = args->argument->datatype;
+                args = args->next;
+
+                if (!datatypes_match(dt, dtfn->params[i]))
+                {
+                    match = 0;  // can't be perfect match.
+                    break;
+                } // if
+            } // for
+
+            if (args != NULL)
+                match = 0;  // too many arguments supplied. No match.
+        } // else
+
+        if (match)
+        {
+            best = item;
+            break;
+        } // if
+    } // while
+
+    if (best == NULL)
+        failf(ctx, "No matching function named '%s'", sym);
+    else
+    {
+        ident->datatype = reduce_datatype(best->datatype);
+        ident->index = best->index;
+    } // else
+
+    return ident->datatype;
+} // match_func_to_call
+
 
 // Go through the AST and make sure all datatypes check out okay. For datatypes
 //  that are compatible but are relying on an implicit cast, we add explicit
@@ -2204,39 +2436,34 @@
 
         case MOJOSHADER_AST_OP_CALLFUNC:
         {
-            datatype = type_check_ast(ctx, ast->callfunc.identifier);
+            datatype = match_func_to_call(ctx, &ast->callfunc);
             const MOJOSHADER_astDataType *reduced = reduce_datatype(datatype);
-            require_function_datatype(ctx, reduced);
-            // !!! FIXME: replace with an int literal if this isn't a function.
+            // !!! FIXME: replace AST node with an int if this isn't a func.
+            if (!require_function_datatype(ctx, reduced))
+            {
+                ast->callfunc.datatype = &ctx->dt_int;
+                return ast->callfunc.datatype;
+            } // if
+
             MOJOSHADER_astArguments *arg = ast->callfunc.args;
             MOJOSHADER_astArguments *prev = NULL;
             int i;
             for (i = 0; i < reduced->function.num_params; i++)
             {
-                if (arg == NULL)  // !!! FIXME: check for default parameters.
+                if (arg == NULL)  // !!! FIXME: check for default parameters, fill them in.
                 {
                     fail(ctx, "Too few arguments");
                     // !!! FIXME: replace AST here.
                     break;
                 } // if
-                datatype2 = type_check_ast(ctx, arg->argument);
-                add_type_coercion(ctx, NULL, &reduced->function.params[i],
+                datatype2 = arg->argument->datatype;  // already type-checked.
+                add_type_coercion(ctx, NULL, reduced->function.params[i],
                                   &arg->argument, datatype2);
                 prev = arg;
                 arg = arg->next;
             } // for
 
-            if (arg != NULL)
-            {
-                // Process extra arguments then chop them out.
-                MOJOSHADER_astArguments *argi;
-                for (argi = arg; argi != NULL; argi = argi->next)
-                    type_check_ast(ctx, argi->argument);
-                if (prev != NULL)
-                    prev->next = NULL;
-                delete_arguments(ctx, arg);
-                fail(ctx, "Too many arguments");
-            } // if
+            assert(arg == NULL);  // shouldn't have chosen func if too many args.
 
             ast->callfunc.datatype = reduced->function.retval;
             return ast->callfunc.datatype;
@@ -2422,36 +2649,30 @@
             return NULL;
 
         case MOJOSHADER_AST_COMPUNIT_FUNCTION:
-            // !!! FIXME: this is totally broken for function overloading.
-//fsdfsdf
-            datatype = get_usertype(ctx, ast->funcunit.declaration->identifier);
-            if (datatype == NULL)
-            {
-                // add function declaration if we've not seen it.
-                datatype = ast->funcunit.declaration->datatype;
-                push_usertype(ctx, ast->funcunit.declaration->identifier, datatype);
-            } // if
-
-            // declarations can be done multiple times if they match.
-            else if (datatype != ast->funcunit.declaration->datatype)
+            assert(!ctx->is_func_scope);
+
+            // We have to tapdance here to make sure the function is in
+            //  the global scope, but it's parameters are pushed as variables
+            //  in the function's scope.
+
+            datatype = type_check_ast(ctx, ast->funcunit.declaration);
+            push_function(ctx, ast->funcunit.declaration->identifier,
+                          datatype, ast->funcunit.definition == NULL);
+
+            // not just a declaration, but a full function definition?
+            if (ast->funcunit.definition != NULL)
             {
-                // !!! FIXME: function overloading is legal.
-                fail(ctx, "function sigs don't match");
-            } // else
-
-            ctx->is_func_scope = 1;
-            ctx->var_index = 0;  // reset this every function.
-            push_scope(ctx);  // so function params are in function scope.
-            type_check_ast(ctx, ast->funcunit.declaration);
-            if (ast->funcunit.definition == NULL)
-                pop_scope(ctx);
-            else
-            {
+                ctx->is_func_scope = 1;
+                ctx->var_index = 0;  // reset this every function.
+                push_scope(ctx);  // so function params are in function scope.
+                // repush the parameters before checking the actual function.
+                MOJOSHADER_astFunctionParameters *param;
+                for (param = ast->funcunit.declaration->params; param; param = param->next)
+                    push_variable(ctx, param->identifier, param->datatype);
                 type_check_ast(ctx, ast->funcunit.definition);
                 pop_scope(ctx);
-                push_variable(ctx, ast->funcunit.declaration->identifier, datatype);
+                ctx->is_func_scope = 0;
             } // else
-            ctx->is_func_scope = 0;
 
             type_check_ast(ctx, ast->funcunit.next);
             return NULL;
@@ -2502,52 +2723,25 @@
         case MOJOSHADER_AST_FUNCTION_SIGNATURE:
         {
             MOJOSHADER_astFunctionParameters *param;
-            int count = 0;
-
-            // !!! FIXME: pre-count this?
-            for (param = ast->funcsig.params; param; param = param->next)
-                count++;
-
-            // !!! FIXME: this is hacky.
-            MOJOSHADER_astDataType *dtparams;
-            void *ptr = Malloc(ctx, sizeof (*dtparams) * count);
-            if (ptr == NULL)
-                return NULL;
-            if (!buffer_append(ctx->garbage, &ptr, sizeof (ptr)))
-            {
-                Free(ctx, ptr);
-                return NULL;
-            } // if
-            dtparams = (MOJOSHADER_astDataType *) ptr;
-
-            ptr = Malloc(ctx, sizeof (MOJOSHADER_astDataType));
-            if (ptr == NULL)
-                return NULL;
-            if (!buffer_append(ctx->garbage, &ptr, sizeof (ptr)))
-            {
-                Free(ctx, ptr);
-                return NULL;
-            } // if
-            MOJOSHADER_astDataType *dt = (MOJOSHADER_astDataType *) ptr;
+            const MOJOSHADER_astDataType *dtparams[64];
 
             int i = 0;
             for (param = ast->funcsig.params; param; param = param->next)
             {
-                assert(i < count);
-                push_variable(ctx, param->identifier, param->datatype);
-                datatype2 = type_check_ast(ctx, param->initializer);
-                add_type_coercion(ctx, NULL, param->datatype,
-                                  &param->initializer, datatype2);
-                memcpy(&dtparams[i], param->datatype, sizeof (*param->datatype));
+                assert(i <= STATICARRAYLEN(dtparams));  // laziness.
+                if (param->initializer != NULL)
+                {
+                    datatype2 = type_check_ast(ctx, param->initializer);
+                    add_type_coercion(ctx, NULL, param->datatype,
+                                      &param->initializer, datatype2);
+                } // if
+                dtparams[i] = param->datatype;
                 i++;
             } // for
 
-            dt->type = MOJOSHADER_AST_DATATYPE_FUNCTION;
-            dt->function.retval = ast->funcsig.datatype;
-            dt->function.params = dtparams;
-            dt->function.num_params = count;
-
-            ast->funcsig.datatype = dt;
+            ast->funcsig.datatype = build_function_datatype(ctx,
+                                                        ast->funcsig.datatype,
+                                                        i, dtparams, 0);
             return ast->funcsig.datatype;
         } // case