More work on overloaded function matching.
authorRyan C. Gordon <icculus@icculus.org>
Wed, 09 Feb 2011 04:32:00 -0500
changeset 988 018e77a5ba67
parent 987 109aeb1b6409
child 989 b3dacb4c2804
More work on overloaded function matching. This works more like Microsoft's compiler appears to now, but I've still got failing cases to shake out.
mojoshader_compiler.c
--- a/mojoshader_compiler.c	Sun Feb 06 04:01:43 2011 -0500
+++ b/mojoshader_compiler.c	Wed Feb 09 04:32:00 2011 -0500
@@ -2221,6 +2221,24 @@
     return ((is_rgba + is_xyzw) == 1);  // can only be one or the other.
 } // is_swizzle_str
 
+static inline int is_scalar_datatype(const MOJOSHADER_astDataType *dt)
+{
+    switch (dt->type)
+    {
+        case MOJOSHADER_AST_DATATYPE_BOOL:
+        case MOJOSHADER_AST_DATATYPE_INT:
+        case MOJOSHADER_AST_DATATYPE_UINT:
+        case MOJOSHADER_AST_DATATYPE_FLOAT:
+        case MOJOSHADER_AST_DATATYPE_FLOAT_SNORM:
+        case MOJOSHADER_AST_DATATYPE_FLOAT_UNORM:
+        case MOJOSHADER_AST_DATATYPE_HALF:
+        case MOJOSHADER_AST_DATATYPE_DOUBLE:
+            return 1;
+        default:
+            return 0;
+    } // switch
+} // is_scalar_datatype
+
 static const MOJOSHADER_astDataType *type_check_ast(Context *ctx, void *_ast);
 
 // !!! FIXME: this function sucks.
@@ -2243,6 +2261,7 @@
     } // while;
 
     // we do some tapdancing to handle function overloading here.
+    int match = 0;
     while (hash_iter(ctx->variables.hash, sym, &value, &iter))
     {
         SymbolScope *item = (SymbolScope *) value;
@@ -2252,16 +2271,23 @@
         if (dt->type != MOJOSHADER_AST_DATATYPE_FUNCTION)
             return dt;
 
-        // !!! FIXME: this needs to find functions that implicit casts would catch:
-        //          void fn(int x);
-        //          needs to match:
-        //          short q = 2; fn(q);
+        // The matching rules for HLSL function overloading, as far as I can
+        //  tell from experimenting with Microsoft's compiler, seem to be this:
+        //
+        // - All parameters of a function must match what the caller specified.
+        // - Scalars can be promoted to vectors to make a parameter match.
+        // - Scalars can promote to other scalars (short to int, etc).
+        // - Vectors may NOT be promoted (a float2 can't extend to a float4).
+        // - If more than one function matches after this (all params that
+        //   would be different between two functions are passed scalars)
+        //   then fail().
+
         const MOJOSHADER_astDataTypeFunction *dtfn = (MOJOSHADER_astDataTypeFunction *) dt;
-        int match = 1;
+        int this_match = 1;
         int i;
 
         if (argcount != dtfn->num_params)  // !!! FIXME: default args.
-            match = 0;
+            this_match = 0;
         else
         {
             args = ast->args;
@@ -2271,21 +2297,32 @@
                 dt = args->argument->datatype;
                 args = args->next;
 
-                if (!datatypes_match(dt, dtfn->params[i]))
+                if (datatypes_match(dt, dtfn->params[i]))
+                    continue;  // so far, so good!
+
+                // we let this go for now if we passed a scalar.
+                //  !!! FIXME: should warn when downcasting.
+                if (!is_scalar_datatype(reduce_datatype(ctx, dt)))
                 {
-                    match = 0;  // can't be perfect match.
+                    this_match = 0;  // can't be perfect match.
                     break;
                 } // if
             } // for
 
             if (args != NULL)
-                match = 0;  // too many arguments supplied. No match.
+                this_match = 0;  // too many arguments supplied. No match.
         } // else
 
-        if (match)
+        if (this_match)
         {
-            best = item;
-            break;
+            if (match++ == 0)
+                best = item;
+            else
+            {
+                if (match++ == 1)
+                    failf(ctx, "Ambiguous function call to '%s'", sym);
+                // !!! FIXME: list each possible function in a fail() here.
+            } // else
         } // if
     } // while
 
@@ -2302,14 +2339,17 @@
 
 
 static const MOJOSHADER_astDataType *vectype_from_base(Context *ctx,
-                                            const MOJOSHADER_astDataTypeType base,
+                                            const MOJOSHADER_astDataType *base,
                                             const int len)
 {
     assert(len > 0);
     assert(len <= 4);
 
+    if (len == 1)  // return "float" and not "float1"
+        return base;
+
     const char *typestr = NULL;
-    switch (base)
+    switch (base->type)
     {
         case MOJOSHADER_AST_DATATYPE_BOOL: typestr = "bool"; break;
         case MOJOSHADER_AST_DATATYPE_INT: typestr = "int"; break;
@@ -2378,7 +2418,7 @@
             if (datatype->type == MOJOSHADER_AST_DATATYPE_VECTOR)
                 ast->binary.datatype = datatype->vector.base;
             else if (datatype->type == MOJOSHADER_AST_DATATYPE_MATRIX)
-                ast->binary.datatype = vectype_from_base(ctx, datatype->matrix.base->type, datatype->matrix.columns);  // !!! FIXME: rows?
+                ast->binary.datatype = vectype_from_base(ctx, datatype->matrix.base, datatype->matrix.columns);  // !!! FIXME: rows?
             else
             {
                 require_array_datatype(ctx, datatype);
@@ -2408,7 +2448,7 @@
 
                 const int swizlen = (int) strlen(member);
                 if (swizlen != veclen)
-                    datatype = vectype_from_base(ctx, reduced->vector.base->type, swizlen);
+                    datatype = vectype_from_base(ctx, reduced->vector.base, swizlen);
 
                 ast->derefstruct.datatype = datatype;
                 return ast->derefstruct.datatype;
@@ -2651,7 +2691,7 @@
                 {
                     // make sure things like float4(half3(1,2,3),1) convert that half3 to float3.
                     const int count = reduced->vector.elements;
-                    datatype3 = vectype_from_base(ctx, base_dt->type, count);
+                    datatype3 = vectype_from_base(ctx, base_dt, count);
                     add_type_coercion(ctx, NULL, datatype3, &arg->argument, datatype2);
                     i += count - 1;
                 } // else