mojoshader_compiler.c
changeset 988 018e77a5ba67
parent 987 109aeb1b6409
child 989 b3dacb4c2804
--- a/mojoshader_compiler.c	Sun Feb 06 04:01:43 2011 -0500
+++ b/mojoshader_compiler.c	Wed Feb 09 04:32:00 2011 -0500
@@ -2221,6 +2221,24 @@
     return ((is_rgba + is_xyzw) == 1);  // can only be one or the other.
 } // is_swizzle_str
 
+static inline int is_scalar_datatype(const MOJOSHADER_astDataType *dt)
+{
+    switch (dt->type)
+    {
+        case MOJOSHADER_AST_DATATYPE_BOOL:
+        case MOJOSHADER_AST_DATATYPE_INT:
+        case MOJOSHADER_AST_DATATYPE_UINT:
+        case MOJOSHADER_AST_DATATYPE_FLOAT:
+        case MOJOSHADER_AST_DATATYPE_FLOAT_SNORM:
+        case MOJOSHADER_AST_DATATYPE_FLOAT_UNORM:
+        case MOJOSHADER_AST_DATATYPE_HALF:
+        case MOJOSHADER_AST_DATATYPE_DOUBLE:
+            return 1;
+        default:
+            return 0;
+    } // switch
+} // is_scalar_datatype
+
 static const MOJOSHADER_astDataType *type_check_ast(Context *ctx, void *_ast);
 
 // !!! FIXME: this function sucks.
@@ -2243,6 +2261,7 @@
     } // while;
 
     // we do some tapdancing to handle function overloading here.
+    int match = 0;
     while (hash_iter(ctx->variables.hash, sym, &value, &iter))
     {
         SymbolScope *item = (SymbolScope *) value;
@@ -2252,16 +2271,23 @@
         if (dt->type != MOJOSHADER_AST_DATATYPE_FUNCTION)
             return dt;
 
-        // !!! FIXME: this needs to find functions that implicit casts would catch:
-        //          void fn(int x);
-        //          needs to match:
-        //          short q = 2; fn(q);
+        // The matching rules for HLSL function overloading, as far as I can
+        //  tell from experimenting with Microsoft's compiler, seem to be this:
+        //
+        // - All parameters of a function must match what the caller specified.
+        // - Scalars can be promoted to vectors to make a parameter match.
+        // - Scalars can promote to other scalars (short to int, etc).
+        // - Vectors may NOT be promoted (a float2 can't extend to a float4).
+        // - If more than one function matches after this (all params that
+        //   would be different between two functions are passed scalars)
+        //   then fail().
+
         const MOJOSHADER_astDataTypeFunction *dtfn = (MOJOSHADER_astDataTypeFunction *) dt;
-        int match = 1;
+        int this_match = 1;
         int i;
 
         if (argcount != dtfn->num_params)  // !!! FIXME: default args.
-            match = 0;
+            this_match = 0;
         else
         {
             args = ast->args;
@@ -2271,21 +2297,32 @@
                 dt = args->argument->datatype;
                 args = args->next;
 
-                if (!datatypes_match(dt, dtfn->params[i]))
+                if (datatypes_match(dt, dtfn->params[i]))
+                    continue;  // so far, so good!
+
+                // we let this go for now if we passed a scalar.
+                //  !!! FIXME: should warn when downcasting.
+                if (!is_scalar_datatype(reduce_datatype(ctx, dt)))
                 {
-                    match = 0;  // can't be perfect match.
+                    this_match = 0;  // can't be perfect match.
                     break;
                 } // if
             } // for
 
             if (args != NULL)
-                match = 0;  // too many arguments supplied. No match.
+                this_match = 0;  // too many arguments supplied. No match.
         } // else
 
-        if (match)
+        if (this_match)
         {
-            best = item;
-            break;
+            if (match++ == 0)
+                best = item;
+            else
+            {
+                if (match++ == 1)
+                    failf(ctx, "Ambiguous function call to '%s'", sym);
+                // !!! FIXME: list each possible function in a fail() here.
+            } // else
         } // if
     } // while
 
@@ -2302,14 +2339,17 @@
 
 
 static const MOJOSHADER_astDataType *vectype_from_base(Context *ctx,
-                                            const MOJOSHADER_astDataTypeType base,
+                                            const MOJOSHADER_astDataType *base,
                                             const int len)
 {
     assert(len > 0);
     assert(len <= 4);
 
+    if (len == 1)  // return "float" and not "float1"
+        return base;
+
     const char *typestr = NULL;
-    switch (base)
+    switch (base->type)
     {
         case MOJOSHADER_AST_DATATYPE_BOOL: typestr = "bool"; break;
         case MOJOSHADER_AST_DATATYPE_INT: typestr = "int"; break;
@@ -2378,7 +2418,7 @@
             if (datatype->type == MOJOSHADER_AST_DATATYPE_VECTOR)
                 ast->binary.datatype = datatype->vector.base;
             else if (datatype->type == MOJOSHADER_AST_DATATYPE_MATRIX)
-                ast->binary.datatype = vectype_from_base(ctx, datatype->matrix.base->type, datatype->matrix.columns);  // !!! FIXME: rows?
+                ast->binary.datatype = vectype_from_base(ctx, datatype->matrix.base, datatype->matrix.columns);  // !!! FIXME: rows?
             else
             {
                 require_array_datatype(ctx, datatype);
@@ -2408,7 +2448,7 @@
 
                 const int swizlen = (int) strlen(member);
                 if (swizlen != veclen)
-                    datatype = vectype_from_base(ctx, reduced->vector.base->type, swizlen);
+                    datatype = vectype_from_base(ctx, reduced->vector.base, swizlen);
 
                 ast->derefstruct.datatype = datatype;
                 return ast->derefstruct.datatype;
@@ -2651,7 +2691,7 @@
                 {
                     // make sure things like float4(half3(1,2,3),1) convert that half3 to float3.
                     const int count = reduced->vector.elements;
-                    datatype3 = vectype_from_base(ctx, base_dt->type, count);
+                    datatype3 = vectype_from_base(ctx, base_dt, count);
                     add_type_coercion(ctx, NULL, datatype3, &arg->argument, datatype2);
                     i += count - 1;
                 } // else