Yet-another attempt at getting function overloading selection correct.
authorRyan C. Gordon <icculus@icculus.org>
Mon, 21 Feb 2011 18:25:42 -0500
changeset 996 164238a438e1
parent 995 7e7826e191f5
child 997 a5f4e546b24e
Yet-another attempt at getting function overloading selection correct.
mojoshader_compiler.c
--- a/mojoshader_compiler.c	Mon Feb 21 03:06:00 2011 -0500
+++ b/mojoshader_compiler.c	Mon Feb 21 18:25:42 2011 -0500
@@ -2258,7 +2258,39 @@
     } // switch
 } // is_scalar_datatype
 
-static int compatible_arg_datatype(Context *ctx,
+static const MOJOSHADER_astDataType *datatype_base(Context *ctx, const MOJOSHADER_astDataType *dt)
+{
+    dt = reduce_datatype(ctx, dt);
+    switch (dt->type)
+    {
+        case MOJOSHADER_AST_DATATYPE_VECTOR:
+            dt = dt->vector.base;
+            break;
+        case MOJOSHADER_AST_DATATYPE_MATRIX:
+            dt = dt->matrix.base;
+            break;
+        case MOJOSHADER_AST_DATATYPE_BUFFER:
+            dt = dt->buffer.base;
+            break;
+        case MOJOSHADER_AST_DATATYPE_ARRAY:
+            dt = dt->array.base;
+            break;
+        default: break;
+    } // switch
+
+    return dt;
+} // datatype_base
+
+typedef enum
+{
+    DT_MATCH_INCOMPATIBLE,
+    DT_MATCH_COMPATIBLE_DOWNCAST,
+    DT_MATCH_COMPATIBLE_UPCAST,
+    DT_MATCH_COMPATIBLE,
+    DT_MATCH_PERFECT
+} DatatypeMatch;
+
+static DatatypeMatch compatible_arg_datatype(Context *ctx,
                                    const MOJOSHADER_astDataType *arg,
                                    const MOJOSHADER_astDataType *param)
 {
@@ -2269,44 +2301,162 @@
     //   after possible type promotion via the following rules.
     // - Scalars can be promoted to vectors to make a parameter match.
     // - Scalars can promote to other scalars (short to int, etc).
-    // - Vectors may NOT be promoted (a float2 can't extend to a float4).
-    // - Vectors with the same elements can promote (a half2 can become a float2...I _think_ it can't downcast here.).
+    // - Datatypes can downcast, but should generate a warning.
+    //   (calling void fn(float x); as fn((double)1.0) should warn).
+    // - Vectors may NOT be extend (a float2 can't implicity extend to a
+    //   float4).
+    // - Vectors with the same elements can promote (a half2 can become
+    //   a float2...I _think_ it can't downcast here.).
     // - A perfect match of all params will be favored over any functions
     //   that only match if type promotion is applied.
+    // - An imperfect match that doesn't require downcasting will be
+    //   favored over one that does.
     // - If more than one function matches after this (all params that
-    //   would be different between two functions are passed scalars)
+    //   would be different between two functions can be legally type-promoted)
     //   then fail().
 
     if (datatypes_match(arg, param))
-        return 1;  // that was easy.
+        return DT_MATCH_PERFECT;  // that was easy.
 
     arg = reduce_datatype(ctx, arg);
     param = reduce_datatype(ctx, param);
 
-    // we let this go for now if we passed a scalar.
-    //  !!! FIXME: should warn when downcasting.
-    //  !!! FIXME: also, being a bit more picky would be good.
+    int do_size_test = 0;
+
     if (is_scalar_datatype(arg))
-        return 1;
+        do_size_test = 1; // we let these all go through for now.
 
     else if (arg->type == param->type)
     {
         if (arg->type == MOJOSHADER_AST_DATATYPE_VECTOR)
-        {
-            if (arg->vector.elements == param->vector.elements)
-                return datatype_size(arg->vector.base) <= datatype_size(param->vector.base);
-        } // if
+            do_size_test = (arg->vector.elements == param->vector.elements);
         else if (arg->type == MOJOSHADER_AST_DATATYPE_MATRIX)
         {
-            if ((arg->matrix.rows == param->matrix.rows) &&
-                (arg->matrix.columns == param->matrix.columns))
-                return datatype_size(arg->matrix.base) <= datatype_size(param->matrix.base);
+            do_size_test =
+                ((arg->matrix.rows == param->matrix.rows) &&
+                 (arg->matrix.columns == param->matrix.columns));
         } // if
     } // if
 
-    return 0;
+    if (do_size_test)
+    {
+        const int argsize = datatype_size(datatype_base(ctx, arg));
+        const int paramsize = datatype_size(datatype_base(ctx, param));
+        if (argsize == paramsize)
+            return DT_MATCH_COMPATIBLE;
+        else if (argsize < paramsize)
+            return DT_MATCH_COMPATIBLE_UPCAST;
+        else /* if (argsize > paramsize) */
+            return DT_MATCH_COMPATIBLE_DOWNCAST;
+    } // if
+
+    return DT_MATCH_INCOMPATIBLE;
 } // compatible_arg_datatype
 
+static void print_ast_datatype(FILE *io, const MOJOSHADER_astDataType *dt)
+{
+    int i;
+
+    if (dt == NULL)
+        return;
+
+    switch (dt->type)
+    {
+        case MOJOSHADER_AST_DATATYPE_BOOL:
+            fprintf(io, "bool");
+            return;
+        case MOJOSHADER_AST_DATATYPE_INT:
+            fprintf(io, "int");
+            return;
+        case MOJOSHADER_AST_DATATYPE_UINT:
+            fprintf(io, "uint");
+            return;
+        case MOJOSHADER_AST_DATATYPE_FLOAT:
+            fprintf(io, "float");
+            return;
+        case MOJOSHADER_AST_DATATYPE_FLOAT_SNORM:
+            fprintf(io, "snorm float");
+            return;
+        case MOJOSHADER_AST_DATATYPE_FLOAT_UNORM:
+            fprintf(io, "unorm float");
+            return;
+        case MOJOSHADER_AST_DATATYPE_HALF:
+            fprintf(io, "half");
+            return;
+        case MOJOSHADER_AST_DATATYPE_DOUBLE:
+            fprintf(io, "double");
+            return;
+        case MOJOSHADER_AST_DATATYPE_STRING:
+            fprintf(io, "string");
+            return;
+        case MOJOSHADER_AST_DATATYPE_SAMPLER_1D:
+            fprintf(io, "sampler1D");
+            return;
+        case MOJOSHADER_AST_DATATYPE_SAMPLER_2D:
+            fprintf(io, "sampler2D");
+            return;
+        case MOJOSHADER_AST_DATATYPE_SAMPLER_3D:
+            fprintf(io, "sampler3D");
+            return;
+        case MOJOSHADER_AST_DATATYPE_SAMPLER_CUBE:
+            fprintf(io, "samplerCUBE");
+            return;
+        case MOJOSHADER_AST_DATATYPE_SAMPLER_STATE:
+            fprintf(io, "sampler_state");
+            return;
+        case MOJOSHADER_AST_DATATYPE_SAMPLER_COMPARISON_STATE:
+            fprintf(io, "SamplerComparisonState");
+            return;
+
+        case MOJOSHADER_AST_DATATYPE_STRUCT:
+            fprintf(io, "struct { ");
+            for (i = 0; i < dt->structure.member_count; i++)
+            {
+                print_ast_datatype(io, dt->structure.members[i].datatype);
+                fprintf(io, " %s; ", dt->structure.members[i].identifier);
+            } // for
+            fprintf(io, "}");
+            return;
+
+        case MOJOSHADER_AST_DATATYPE_ARRAY:
+            print_ast_datatype(io, dt->array.base);
+            if (dt->array.elements < 0)
+                fprintf(io, "[]");
+            else
+                fprintf(io, "[%d]", dt->array.elements);
+            return;
+
+        case MOJOSHADER_AST_DATATYPE_VECTOR:
+            fprintf(io, "vector<");
+            print_ast_datatype(io, dt->vector.base);
+            fprintf(io, ",%d>", dt->vector.elements);
+            return;
+
+        case MOJOSHADER_AST_DATATYPE_MATRIX:
+            fprintf(io, "matrix<");
+            print_ast_datatype(io, dt->matrix.base);
+            fprintf(io, ",%d,%d>", dt->matrix.rows, dt->matrix.columns);
+            return;
+
+        case MOJOSHADER_AST_DATATYPE_BUFFER:
+            fprintf(io, "buffer<");
+            print_ast_datatype(io, dt->buffer.base);
+            fprintf(io, ">");
+            return;
+
+        case MOJOSHADER_AST_DATATYPE_USER:
+            fprintf(io, "%s", dt->user.name);
+            return;
+
+        //case MOJOSHADER_AST_DATATYPE_NONE:
+        //case MOJOSHADER_AST_DATATYPE_FUNCTION:
+
+        default:
+            assert(0 && "Unexpected datatype.");
+            return;
+    } // switch
+} // print_ast_datatype
+
 
 static const MOJOSHADER_astDataType *type_check_ast(Context *ctx, void *_ast);
 
@@ -2315,6 +2465,7 @@
                                     MOJOSHADER_astExpressionCallFunction *ast)
 {
     SymbolScope *best = NULL;  // best choice we find.
+    DatatypeMatch best_match = DT_MATCH_INCOMPATIBLE;
     MOJOSHADER_astExpressionIdentifier *ident = ast->identifier;
     const char *sym = ident->identifier;
     const void *value = NULL;
@@ -2329,8 +2480,21 @@
         args = args->next;
     } // while;
 
+printf("Attempt to call function %s(", sym);
+args = ast->args;
+int qq = 0;
+for (qq = 0; qq < argcount; qq++)
+{
+assert(args != NULL);
+const MOJOSHADER_astDataType *x = args->argument->datatype;
+args = args->next;
+print_ast_datatype(stdout, x);
+if (args) printf(", ");
+}
+printf("); ...\n");
+
     // we do some tapdancing to handle function overloading here.
-    int match = 0;
+    DatatypeMatch match = 0;
     while (hash_iter(ctx->variables.hash, sym, &value, &iter))
     {
         SymbolScope *item = (SymbolScope *) value;
@@ -2341,7 +2505,7 @@
             return dt;
 
         const MOJOSHADER_astDataTypeFunction *dtfn = (MOJOSHADER_astDataTypeFunction *) dt;
-        int this_match = 2;  // 2 == perfect, 1 == compatible, 0 == not.
+        int this_match = DT_MATCH_PERFECT;
         int i;
 
         if (argcount != dtfn->num_params)  // !!! FIXME: default args.
@@ -2354,38 +2518,76 @@
                 assert(args != NULL);
                 dt = args->argument->datatype;
                 args = args->next;
-                if (datatypes_match(dt, dtfn->params[i]))
-                    continue;
-
-                // not perfect, but maybe compatible?
-                this_match = compatible_arg_datatype(ctx, dt, dtfn->params[i]);
+                const int compatible = compatible_arg_datatype(ctx, dt, dtfn->params[i]);
+                if (this_match > compatible)
+                    this_match = compatible;
                 if (!this_match)
                     break;
             } // for
 
             if (args != NULL)
-                this_match = 0;  // too many arguments supplied. No match.
+                this_match = DT_MATCH_INCOMPATIBLE;  // too many arguments supplied. No match.
         } // else
 
-        if (this_match == 2)  // perfect match.
+#if 0
+        if (this_match == DT_MATCH_PERFECT)  // perfect match.
         {
+FILE *io = stdout;
+printf("%d PERFECT MATCH: ", ctx->sourceline);
+if (dtfn->intrinsic)
+    printf("/* intrinsic */ ");
+print_ast_datatype(io, dt->function.retval);
+printf(" %s(", sym);
+int i;
+for (i = 0; i < dtfn->num_params; i++) {
+    print_ast_datatype(io, dtfn->params[i]);
+    if (i < dtfn->num_params-1)
+        printf(", ");
+}
+printf(");\n");
             match = 1;  // ignore all other compatible matches.
             best = item;
             break;
         } // if
 
-        else if (this_match == 1)  // compatible, but not perfect, match.
+        else
+#endif
+
+        if (this_match > DT_MATCH_INCOMPATIBLE)  // compatible, but not perfect, match.
         {
-            match++;
-            if (match == 1)
-                best = item;
-            else
+#if 1
+FILE *io = stdout;
+printf("%d COMPATIBLE MATCH (%d): ", ctx->sourceline, (int) this_match);
+if (dtfn->intrinsic)
+    printf("/* intrinsic */ ");
+if (dt->function.retval)
+    print_ast_datatype(io, dt->function.retval);
+else
+    printf("void");
+printf(" %s(", sym);
+int i;
+for (i = 0; i < dtfn->num_params; i++) {
+    print_ast_datatype(io, dtfn->params[i]);
+    if (i < dtfn->num_params-1)
+        printf(", ");
+}
+printf(");\n");
+#endif
+
+            if (this_match == best_match)
             {
+                match++;
                 // !!! FIXME: list each possible function in a fail(),
                 // !!! FIXME:  but you can't actually fail() here, since
                 // !!! FIXME:  this will cease to be ambiguous if we get
                 // !!! FIXME:  a perfect match on a later overload.
-            } // else
+            } // if
+            else if (this_match > best_match)
+            {
+                match = 1;  // reset the ambiguousness count.
+                best = item;
+                best_match = this_match;
+            } // if
         } // else if
     } // while