diff --git a/ast/node.go b/ast/node.go index 03e8cf622..191ab1886 100644 --- a/ast/node.go +++ b/ast/node.go @@ -3,13 +3,21 @@ package ast import ( "reflect" + "github.com/expr-lang/expr/checker/nature" "github.com/expr-lang/expr/file" ) +var ( + anyType = reflect.TypeOf(new(any)).Elem() +) + // Node represents items of abstract syntax tree. type Node interface { Location() file.Location SetLocation(file.Location) + Nature() nature.Nature + SetNature(nature.Nature) + Kind() reflect.Kind Type() reflect.Type SetType(reflect.Type) String() string @@ -25,8 +33,8 @@ func Patch(node *Node, newNode Node) { // base is a base struct for all nodes. type base struct { - loc file.Location - nodeType reflect.Type + loc file.Location + nature nature.Nature } // Location returns the location of the node in the source code. @@ -39,14 +47,36 @@ func (n *base) SetLocation(loc file.Location) { n.loc = loc } +// Nature returns the nature of the node. +func (n *base) Nature() nature.Nature { + return n.nature +} + +// SetNature sets the nature of the node. +func (n *base) SetNature(nature nature.Nature) { + n.nature = nature +} + +// Kind returns the kind of the node. +// If the type is nil (meaning unknown) then it returns reflect.Interface. +func (n *base) Kind() reflect.Kind { + if n.nature.Type == nil { + return reflect.Interface + } + return n.nature.Type.Kind() +} + // Type returns the type of the node. func (n *base) Type() reflect.Type { - return n.nodeType + if n.nature.Type == nil { + return anyType + } + return n.nature.Type } // SetType sets the type of the node. func (n *base) SetType(t reflect.Type) { - n.nodeType = t + n.nature.Type = t } // NilNode represents nil. diff --git a/builtin/builtin_test.go b/builtin/builtin_test.go index b99ed42ce..97f249896 100644 --- a/builtin/builtin_test.go +++ b/builtin/builtin_test.go @@ -235,7 +235,7 @@ func TestBuiltin_errors(t *testing.T) { {`bitushr(-5, -2)`, "invalid operation: negative shift count -2 (type int) (1:1)"}, {`now(nil)`, "invalid number of arguments (expected 0, got 1)"}, {`date(nil)`, "interface {} is nil, not string (1:1)"}, - {`timezone(nil)`, "interface {} is nil, not string (1:1)"}, + {`timezone(nil)`, "cannot use nil as argument (type string) to call timezone (1:10)"}, } for _, test := range errorTests { t.Run(test.input, func(t *testing.T) { diff --git a/checker/checker.go b/checker/checker.go index c71a98f07..fae8f5a16 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -7,9 +7,9 @@ import ( "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/builtin" + . "github.com/expr-lang/expr/checker/nature" "github.com/expr-lang/expr/conf" "github.com/expr-lang/expr/file" - "github.com/expr-lang/expr/internal/deref" "github.com/expr-lang/expr/parser" ) @@ -52,14 +52,20 @@ func ParseCheck(input string, config *conf.Config) (*parser.Tree, error) { // Check checks types of the expression tree. It returns type of the expression // and error if any. If config is nil, then default configuration will be used. -func Check(tree *parser.Tree, config *conf.Config) (t reflect.Type, err error) { +func Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) { if config == nil { config = conf.New(nil) } v := &checker{config: config} - t, _ = v.visit(tree.Node) + nt := v.visit(tree.Node) + + // To keep compatibility with previous versions, we should return any, if nature is unknown. + t := nt.Type + if t == nil { + t = anyType + } if v.err != nil { return t, v.err.Bind(tree.Source) @@ -67,23 +73,20 @@ func Check(tree *parser.Tree, config *conf.Config) (t reflect.Type, err error) { if v.config.Expect != reflect.Invalid { if v.config.ExpectAny { - if isAny(t) { + if isUnknown(nt) { return t, nil } } switch v.config.Expect { case reflect.Int, reflect.Int64, reflect.Float64: - if !isNumber(t) { - return nil, fmt.Errorf("expected %v, but got %v", v.config.Expect, t) + if !isNumber(nt) { + return nil, fmt.Errorf("expected %v, but got %v", v.config.Expect, nt) } default: - if t != nil { - if t.Kind() == v.config.Expect { - return t, nil - } + if nt.Kind() != v.config.Expect { + return nil, fmt.Errorf("expected %v, but got %s", v.config.Expect, nt) } - return nil, fmt.Errorf("expected %v, but got %v", v.config.Expect, t) } } @@ -98,14 +101,13 @@ type checker struct { } type predicateScope struct { - vtype reflect.Type - vars map[string]reflect.Type + collection Nature + vars map[string]Nature } type varScope struct { - name string - vtype reflect.Type - info info + name string + nature Nature } type info struct { @@ -119,285 +121,278 @@ type info struct { elem reflect.Type } -func (v *checker) visit(node ast.Node) (reflect.Type, info) { - var t reflect.Type - var i info +func (v *checker) visit(node ast.Node) Nature { + var nt Nature switch n := node.(type) { case *ast.NilNode: - t, i = v.NilNode(n) + nt = v.NilNode(n) case *ast.IdentifierNode: - t, i = v.IdentifierNode(n) + nt = v.IdentifierNode(n) case *ast.IntegerNode: - t, i = v.IntegerNode(n) + nt = v.IntegerNode(n) case *ast.FloatNode: - t, i = v.FloatNode(n) + nt = v.FloatNode(n) case *ast.BoolNode: - t, i = v.BoolNode(n) + nt = v.BoolNode(n) case *ast.StringNode: - t, i = v.StringNode(n) + nt = v.StringNode(n) case *ast.ConstantNode: - t, i = v.ConstantNode(n) + nt = v.ConstantNode(n) case *ast.UnaryNode: - t, i = v.UnaryNode(n) + nt = v.UnaryNode(n) case *ast.BinaryNode: - t, i = v.BinaryNode(n) + nt = v.BinaryNode(n) case *ast.ChainNode: - t, i = v.ChainNode(n) + nt = v.ChainNode(n) case *ast.MemberNode: - t, i = v.MemberNode(n) + nt = v.MemberNode(n) case *ast.SliceNode: - t, i = v.SliceNode(n) + nt = v.SliceNode(n) case *ast.CallNode: - t, i = v.CallNode(n) + nt = v.CallNode(n) case *ast.BuiltinNode: - t, i = v.BuiltinNode(n) + nt = v.BuiltinNode(n) case *ast.ClosureNode: - t, i = v.ClosureNode(n) + nt = v.ClosureNode(n) case *ast.PointerNode: - t, i = v.PointerNode(n) + nt = v.PointerNode(n) case *ast.VariableDeclaratorNode: - t, i = v.VariableDeclaratorNode(n) + nt = v.VariableDeclaratorNode(n) case *ast.ConditionalNode: - t, i = v.ConditionalNode(n) + nt = v.ConditionalNode(n) case *ast.ArrayNode: - t, i = v.ArrayNode(n) + nt = v.ArrayNode(n) case *ast.MapNode: - t, i = v.MapNode(n) + nt = v.MapNode(n) case *ast.PairNode: - t, i = v.PairNode(n) + nt = v.PairNode(n) default: panic(fmt.Sprintf("undefined node type (%T)", node)) } - node.SetType(t) - return t, i + node.SetNature(nt) + return nt } -func (v *checker) error(node ast.Node, format string, args ...any) (reflect.Type, info) { +func (v *checker) error(node ast.Node, format string, args ...any) Nature { if v.err == nil { // show first error v.err = &file.Error{ Location: node.Location(), Message: fmt.Sprintf(format, args...), } } - return anyType, info{} // interface represent undefined type + return unknown } -func (v *checker) NilNode(*ast.NilNode) (reflect.Type, info) { - return nilType, info{} +func (v *checker) NilNode(*ast.NilNode) Nature { + return nilNature } -func (v *checker) IdentifierNode(node *ast.IdentifierNode) (reflect.Type, info) { - if s, ok := v.lookupVariable(node.Value); ok { - return s.vtype, s.info +func (v *checker) IdentifierNode(node *ast.IdentifierNode) Nature { + if variable, ok := v.lookupVariable(node.Value); ok { + return variable.nature } if node.Value == "$env" { - return mapType, info{} + return unknown } return v.ident(node, node.Value, true, true) } // ident method returns type of environment variable, builtin or function. -func (v *checker) ident(node ast.Node, name string, strict, builtins bool) (reflect.Type, info) { +func (v *checker) ident(node ast.Node, name string, strict, builtins bool) Nature { if t, ok := v.config.Types[name]; ok { if t.Ambiguous { return v.error(node, "ambiguous identifier %v", name) } - return t.Type, info{method: t.Method} + if t.Type == nil { + return nilNature + } + return Nature{Type: t.Type, Method: t.Method} } if builtins { if fn, ok := v.config.Functions[name]; ok { - return fn.Type(), info{fn: fn} + return Nature{Type: fn.Type(), Func: fn} } if fn, ok := v.config.Builtins[name]; ok { - return fn.Type(), info{fn: fn} + return Nature{Type: fn.Type(), Func: fn} } } if v.config.Strict && strict { return v.error(node, "unknown name %v", name) } if v.config.DefaultType != nil { - return v.config.DefaultType, info{} + return Nature{Type: v.config.DefaultType} } - return anyType, info{} + return unknown } -func (v *checker) IntegerNode(*ast.IntegerNode) (reflect.Type, info) { - return integerType, info{} +func (v *checker) IntegerNode(*ast.IntegerNode) Nature { + return integerNature } -func (v *checker) FloatNode(*ast.FloatNode) (reflect.Type, info) { - return floatType, info{} +func (v *checker) FloatNode(*ast.FloatNode) Nature { + return floatNature } -func (v *checker) BoolNode(*ast.BoolNode) (reflect.Type, info) { - return boolType, info{} +func (v *checker) BoolNode(*ast.BoolNode) Nature { + return boolNature } -func (v *checker) StringNode(*ast.StringNode) (reflect.Type, info) { - return stringType, info{} +func (v *checker) StringNode(*ast.StringNode) Nature { + return stringNature } -func (v *checker) ConstantNode(node *ast.ConstantNode) (reflect.Type, info) { - return reflect.TypeOf(node.Value), info{} +func (v *checker) ConstantNode(node *ast.ConstantNode) Nature { + return Nature{Type: reflect.TypeOf(node.Value)} } -func (v *checker) UnaryNode(node *ast.UnaryNode) (reflect.Type, info) { - t, _ := v.visit(node.Node) - t = deref.Type(t) +func (v *checker) UnaryNode(node *ast.UnaryNode) Nature { + nt := v.visit(node.Node) + nt = nt.Deref() switch node.Operator { case "!", "not": - if isBool(t) { - return boolType, info{} + if isBool(nt) { + return boolNature } - if isAny(t) { - return boolType, info{} + if isUnknown(nt) { + return boolNature } case "+", "-": - if isNumber(t) { - return t, info{} + if isNumber(nt) { + return nt } - if isAny(t) { - return anyType, info{} + if isUnknown(nt) { + return unknown } default: return v.error(node, "unknown operator (%v)", node.Operator) } - return v.error(node, `invalid operation: %v (mismatched type %v)`, node.Operator, t) + return v.error(node, `invalid operation: %v (mismatched type %s)`, node.Operator, nt) } -func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) { - l, _ := v.visit(node.Left) - r, ri := v.visit(node.Right) +func (v *checker) BinaryNode(node *ast.BinaryNode) Nature { + l := v.visit(node.Left) + r := v.visit(node.Right) - l = deref.Type(l) - r = deref.Type(r) + l = l.Deref() + r = r.Deref() switch node.Operator { case "==", "!=": if isComparable(l, r) { - return boolType, info{} + return boolNature } case "or", "||", "and", "&&": if isBool(l) && isBool(r) { - return boolType, info{} + return boolNature } if or(l, r, isBool) { - return boolType, info{} + return boolNature } case "<", ">", ">=", "<=": if isNumber(l) && isNumber(r) { - return boolType, info{} + return boolNature } if isString(l) && isString(r) { - return boolType, info{} + return boolNature } if isTime(l) && isTime(r) { - return boolType, info{} + return boolNature } if or(l, r, isNumber, isString, isTime) { - return boolType, info{} + return boolNature } case "-": if isNumber(l) && isNumber(r) { - return combined(l, r), info{} + return combined(l, r) } if isTime(l) && isTime(r) { - return durationType, info{} + return durationNature } if isTime(l) && isDuration(r) { - return timeType, info{} + return timeNature } - if or(l, r, isNumber, isTime) { - return anyType, info{} + if or(l, r, isNumber, isTime, isDuration) { + return unknown } case "*": if isNumber(l) && isNumber(r) { - return combined(l, r), info{} + return combined(l, r) } if or(l, r, isNumber) { - return anyType, info{} + return unknown } case "/": if isNumber(l) && isNumber(r) { - return floatType, info{} + return floatNature } if or(l, r, isNumber) { - return floatType, info{} + return floatNature } case "**", "^": if isNumber(l) && isNumber(r) { - return floatType, info{} + return floatNature } if or(l, r, isNumber) { - return floatType, info{} + return floatNature } case "%": if isInteger(l) && isInteger(r) { - return combined(l, r), info{} + return integerNature } if or(l, r, isInteger) { - return anyType, info{} + return integerNature } case "+": if isNumber(l) && isNumber(r) { - return combined(l, r), info{} + return combined(l, r) } if isString(l) && isString(r) { - return stringType, info{} + return stringNature } if isTime(l) && isDuration(r) { - return timeType, info{} + return timeNature } if isDuration(l) && isTime(r) { - return timeType, info{} + return timeNature } if or(l, r, isNumber, isString, isTime, isDuration) { - return anyType, info{} + return unknown } case "in": - if (isString(l) || isAny(l)) && isStruct(r) { - return boolType, info{} + if (isString(l) || isUnknown(l)) && isStruct(r) { + return boolNature } if isMap(r) { - if l == nil { // It is possible to compare with nil. - return boolType, info{} - } - if !isAny(l) && !l.AssignableTo(r.Key()) { + if !isUnknown(l) && !l.AssignableTo(r.Key()) { return v.error(node, "cannot use %v as type %v in map key", l, r.Key()) } - return boolType, info{} + return boolNature } if isArray(r) { - if l == nil { // It is possible to compare with nil. - return boolType, info{} - } if !isComparable(l, r.Elem()) { return v.error(node, "cannot use %v as type %v in array", l, r.Elem()) } - if !isComparable(l, ri.elem) { - return v.error(node, "cannot use %v as type %v in array", l, ri.elem) - } - return boolType, info{} + return boolNature } - if isAny(l) && anyOf(r, isString, isArray, isMap) { - return boolType, info{} + if isUnknown(l) && anyOf(r, isString, isArray, isMap) { + return boolNature } - if isAny(r) { - return boolType, info{} + if isUnknown(r) { + return boolNature } case "matches": @@ -408,43 +403,48 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) { } } if isString(l) && isString(r) { - return boolType, info{} + return boolNature } if or(l, r, isString) { - return boolType, info{} + return boolNature } case "contains", "startsWith", "endsWith": if isString(l) && isString(r) { - return boolType, info{} + return boolNature } if or(l, r, isString) { - return boolType, info{} + return boolNature } case "..": - ret := reflect.SliceOf(integerType) if isInteger(l) && isInteger(r) { - return ret, info{} + return Nature{ + Type: arrayType, + SubType: Array{Of: integerNature}, + } } if or(l, r, isInteger) { - return ret, info{} + return Nature{ + Type: arrayType, + SubType: Array{Of: integerNature}, + } } case "??": - if l == nil && r != nil { - return r, info{} + if isNil(l) && !isNil(r) { + return r } - if l != nil && r == nil { - return l, info{} + if !isNil(l) && isNil(r) { + return l } - if l == nil && r == nil { - return nilType, info{} + if isNil(l) && isNil(r) { + return nilNature } if r.AssignableTo(l) { - return l, info{} + return l } - return anyType, info{} + return unknown default: return v.error(node, "unknown operator (%v)", node.Operator) @@ -454,11 +454,11 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) { return v.error(node, `invalid operation: %v (mismatched types %v and %v)`, node.Operator, l, r) } -func (v *checker) ChainNode(node *ast.ChainNode) (reflect.Type, info) { +func (v *checker) ChainNode(node *ast.ChainNode) Nature { return v.visit(node.Node) } -func (v *checker) MemberNode(node *ast.MemberNode) (reflect.Type, info) { +func (v *checker) MemberNode(node *ast.MemberNode) Nature { // $env variable if an, ok := node.Node.(*ast.IdentifierNode); ok && an.Value == "$env" { if name, ok := node.Property.(*ast.StringNode); ok { @@ -472,59 +472,48 @@ func (v *checker) MemberNode(node *ast.MemberNode) (reflect.Type, info) { } return v.ident(node, name.Value, strict, false /* no builtins and no functions */) } - return anyType, info{} + return unknown } - base, _ := v.visit(node.Node) - prop, _ := v.visit(node.Property) + base := v.visit(node.Node) + prop := v.visit(node.Property) + + if isUnknown(base) { + return unknown + } if name, ok := node.Property.(*ast.StringNode); ok { - if base == nil { - return v.error(node, "type %v has no field %v", base, name.Value) + if isNil(base) { + return v.error(node, "type nil has no field %v", name.Value) } + // First, check methods defined on base type itself, // independent of which type it is. Without dereferencing. if m, ok := base.MethodByName(name.Value); ok { - if kind(base) == reflect.Interface { - // In case of interface type method will not have a receiver, - // and to prevent checker decreasing numbers of in arguments - // return method type as not method (second argument is false). - - // Also, we can not use m.Index here, because it will be - // different indexes for different types which implement - // the same interface. - return m.Type, info{} - } else { - return m.Type, info{method: true} - } + return m } } - if kind(base) == reflect.Ptr { - base = base.Elem() - } - - switch kind(base) { - case reflect.Interface: - return anyType, info{} + base = base.Deref() + switch base.Kind() { case reflect.Map: - if prop != nil && !prop.AssignableTo(base.Key()) && !isAny(prop) { + if !prop.AssignableTo(base.Key()) && !isUnknown(prop) { return v.error(node.Property, "cannot use %v to get an element from %v", prop, base) } - return base.Elem(), info{} + return base.Elem() case reflect.Array, reflect.Slice: - if !isInteger(prop) && !isAny(prop) { + if !isInteger(prop) && !isUnknown(prop) { return v.error(node.Property, "array elements can only be selected using an integer (got %v)", prop) } - return base.Elem(), info{} + return base.Elem() case reflect.Struct: if name, ok := node.Property.(*ast.StringNode); ok { propertyName := name.Value if field, ok := fetchField(base, propertyName); ok { - return field.Type, info{} + return Nature{Type: field.Type} } if node.Method { return v.error(node, "type %v has no method %v", base, propertyName) @@ -536,35 +525,39 @@ func (v *checker) MemberNode(node *ast.MemberNode) (reflect.Type, info) { return v.error(node, "type %v[%v] is undefined", base, prop) } -func (v *checker) SliceNode(node *ast.SliceNode) (reflect.Type, info) { - t, _ := v.visit(node.Node) +func (v *checker) SliceNode(node *ast.SliceNode) Nature { + nt := v.visit(node.Node) - switch kind(t) { - case reflect.Interface: - // ok + if isUnknown(nt) { + return unknown + } + + switch nt.Kind() { case reflect.String, reflect.Array, reflect.Slice: // ok default: - return v.error(node, "cannot slice %v", t) + return v.error(node, "cannot slice %s", nt) } if node.From != nil { - from, _ := v.visit(node.From) - if !isInteger(from) && !isAny(from) { + from := v.visit(node.From) + if !isInteger(from) && !isUnknown(from) { return v.error(node.From, "non-integer slice index %v", from) } } + if node.To != nil { - to, _ := v.visit(node.To) - if !isInteger(to) && !isAny(to) { + to := v.visit(node.To) + if !isInteger(to) && !isUnknown(to) { return v.error(node.To, "non-integer slice index %v", to) } } - return t, info{} + + return nt } -func (v *checker) CallNode(node *ast.CallNode) (reflect.Type, info) { - t, i := v.functionReturnType(node) +func (v *checker) CallNode(node *ast.CallNode) Nature { + nt := v.functionReturnType(node) // Check if type was set on node (for example, by patcher) // and use node type instead of function return type. @@ -578,17 +571,17 @@ func (v *checker) CallNode(node *ast.CallNode) (reflect.Type, info) { // checker pass we should replace anyType on method node // with new correct function return type. if node.Type() != nil && node.Type() != anyType { - return node.Type(), i + return node.Nature() } - return t, i + return nt } -func (v *checker) functionReturnType(node *ast.CallNode) (reflect.Type, info) { - fn, fnInfo := v.visit(node.Callee) +func (v *checker) functionReturnType(node *ast.CallNode) Nature { + nt := v.visit(node.Callee) - if fnInfo.fn != nil { - return v.checkFunction(fnInfo.fn, node, node.Arguments) + if nt.Func != nil { + return v.checkFunction(nt.Func, node, node.Arguments) } fnName := "function" @@ -601,240 +594,248 @@ func (v *checker) functionReturnType(node *ast.CallNode) (reflect.Type, info) { } } - if fn == nil { + if isUnknown(nt) { + return unknown + } + + if isNil(nt) { return v.error(node, "%v is nil; cannot call nil as function", fnName) } - switch fn.Kind() { - case reflect.Interface: - return anyType, info{} + switch nt.Kind() { case reflect.Func: - outType, err := v.checkArguments(fnName, fn, fnInfo.method, node.Arguments, node) + outType, err := v.checkArguments(fnName, nt, node.Arguments, node) if err != nil { if v.err == nil { v.err = err } - return anyType, info{} + return unknown } - return outType, info{} + return outType } - return v.error(node, "%v is not callable", fn) + return v.error(node, "%s is not callable", nt) } -func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) { +func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature { switch node.Name { case "all", "none", "any", "one": - collection, _ := v.visit(node.Arguments[0]) - if !isArray(collection) && !isAny(collection) { + collection := v.visit(node.Arguments[0]) + if !isArray(collection) && !isUnknown(collection) { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } v.begin(collection) - closure, _ := v.visit(node.Arguments[1]) + closure := v.visit(node.Arguments[1]) v.end() if isFunc(closure) && closure.NumOut() == 1 && - closure.NumIn() == 1 && isAny(closure.In(0)) { + closure.NumIn() == 1 && isUnknown(closure.In(0)) { - if !isBool(closure.Out(0)) && !isAny(closure.Out(0)) { + if !isBool(closure.Out(0)) && !isUnknown(closure.Out(0)) { return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String()) } - return boolType, info{} + return boolNature } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "filter": - collection, _ := v.visit(node.Arguments[0]) - if !isArray(collection) && !isAny(collection) { + collection := v.visit(node.Arguments[0]) + if !isArray(collection) && !isUnknown(collection) { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } v.begin(collection) - closure, _ := v.visit(node.Arguments[1]) + closure := v.visit(node.Arguments[1]) v.end() if isFunc(closure) && closure.NumOut() == 1 && - closure.NumIn() == 1 && isAny(closure.In(0)) { + closure.NumIn() == 1 && isUnknown(closure.In(0)) { - if !isBool(closure.Out(0)) && !isAny(closure.Out(0)) { + if !isBool(closure.Out(0)) && !isUnknown(closure.Out(0)) { return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String()) } - if isAny(collection) { - return arrayType, info{} + if isUnknown(collection) { + return arrayNature + } + return Nature{ + Type: arrayType, + SubType: Array{Of: collection.Elem()}, } - return arrayType, info{} } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "map": - collection, _ := v.visit(node.Arguments[0]) - if !isArray(collection) && !isAny(collection) { + collection := v.visit(node.Arguments[0]) + if !isArray(collection) && !isUnknown(collection) { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } - v.begin(collection, scopeVar{"index", integerType}) - closure, _ := v.visit(node.Arguments[1]) + v.begin(collection, scopeVar{"index", integerNature}) + closure := v.visit(node.Arguments[1]) v.end() if isFunc(closure) && closure.NumOut() == 1 && - closure.NumIn() == 1 && isAny(closure.In(0)) { + closure.NumIn() == 1 && isUnknown(closure.In(0)) { - return arrayType, info{} + return Nature{ + Type: arrayType, + SubType: Array{Of: closure.Out(0)}, + } } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "count": - collection, _ := v.visit(node.Arguments[0]) - if !isArray(collection) && !isAny(collection) { + collection := v.visit(node.Arguments[0]) + if !isArray(collection) && !isUnknown(collection) { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } if len(node.Arguments) == 1 { - return integerType, info{} + return integerNature } v.begin(collection) - closure, _ := v.visit(node.Arguments[1]) + closure := v.visit(node.Arguments[1]) v.end() if isFunc(closure) && closure.NumOut() == 1 && - closure.NumIn() == 1 && isAny(closure.In(0)) { - if !isBool(closure.Out(0)) && !isAny(closure.Out(0)) { + closure.NumIn() == 1 && isUnknown(closure.In(0)) { + if !isBool(closure.Out(0)) && !isUnknown(closure.Out(0)) { return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String()) } - return integerType, info{} + return integerNature } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "sum": - collection, _ := v.visit(node.Arguments[0]) - if !isArray(collection) && !isAny(collection) { + collection := v.visit(node.Arguments[0]) + if !isArray(collection) && !isUnknown(collection) { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } if len(node.Arguments) == 2 { v.begin(collection) - closure, _ := v.visit(node.Arguments[1]) + closure := v.visit(node.Arguments[1]) v.end() if isFunc(closure) && closure.NumOut() == 1 && - closure.NumIn() == 1 && isAny(closure.In(0)) { - return closure.Out(0), info{} + closure.NumIn() == 1 && isUnknown(closure.In(0)) { + return closure.Out(0) } } else { - if isAny(collection) { - return anyType, info{} + if isUnknown(collection) { + return unknown } - return collection.Elem(), info{} + return collection.Elem() } case "find", "findLast": - collection, _ := v.visit(node.Arguments[0]) - if !isArray(collection) && !isAny(collection) { + collection := v.visit(node.Arguments[0]) + if !isArray(collection) && !isUnknown(collection) { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } v.begin(collection) - closure, _ := v.visit(node.Arguments[1]) + closure := v.visit(node.Arguments[1]) v.end() if isFunc(closure) && closure.NumOut() == 1 && - closure.NumIn() == 1 && isAny(closure.In(0)) { + closure.NumIn() == 1 && isUnknown(closure.In(0)) { - if !isBool(closure.Out(0)) && !isAny(closure.Out(0)) { + if !isBool(closure.Out(0)) && !isUnknown(closure.Out(0)) { return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String()) } - if isAny(collection) { - return anyType, info{} + if isUnknown(collection) { + return unknown } - return collection.Elem(), info{} + return collection.Elem() } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "findIndex", "findLastIndex": - collection, _ := v.visit(node.Arguments[0]) - if !isArray(collection) && !isAny(collection) { + collection := v.visit(node.Arguments[0]) + if !isArray(collection) && !isUnknown(collection) { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } v.begin(collection) - closure, _ := v.visit(node.Arguments[1]) + closure := v.visit(node.Arguments[1]) v.end() if isFunc(closure) && closure.NumOut() == 1 && - closure.NumIn() == 1 && isAny(closure.In(0)) { + closure.NumIn() == 1 && isUnknown(closure.In(0)) { - if !isBool(closure.Out(0)) && !isAny(closure.Out(0)) { + if !isBool(closure.Out(0)) && !isUnknown(closure.Out(0)) { return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String()) } - return integerType, info{} + return integerNature } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "groupBy": - collection, _ := v.visit(node.Arguments[0]) - if !isArray(collection) && !isAny(collection) { + collection := v.visit(node.Arguments[0]) + if !isArray(collection) && !isUnknown(collection) { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } v.begin(collection) - closure, _ := v.visit(node.Arguments[1]) + closure := v.visit(node.Arguments[1]) v.end() if isFunc(closure) && closure.NumOut() == 1 && - closure.NumIn() == 1 && isAny(closure.In(0)) { + closure.NumIn() == 1 && isUnknown(closure.In(0)) { - return reflect.TypeOf(map[any][]any{}), info{} + return Nature{Type: reflect.TypeOf(map[any][]any{})} } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "sortBy": - collection, _ := v.visit(node.Arguments[0]) - if !isArray(collection) && !isAny(collection) { + collection := v.visit(node.Arguments[0]) + if !isArray(collection) && !isUnknown(collection) { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } v.begin(collection) - closure, _ := v.visit(node.Arguments[1]) + closure := v.visit(node.Arguments[1]) v.end() if len(node.Arguments) == 3 { - _, _ = v.visit(node.Arguments[2]) + _ = v.visit(node.Arguments[2]) } if isFunc(closure) && closure.NumOut() == 1 && - closure.NumIn() == 1 && isAny(closure.In(0)) { + closure.NumIn() == 1 && isUnknown(closure.In(0)) { - return reflect.TypeOf([]any{}), info{} + return Nature{Type: reflect.TypeOf([]any{})} } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "reduce": - collection, _ := v.visit(node.Arguments[0]) - if !isArray(collection) && !isAny(collection) { + collection := v.visit(node.Arguments[0]) + if !isArray(collection) && !isUnknown(collection) { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } - v.begin(collection, scopeVar{"index", integerType}, scopeVar{"acc", anyType}) - closure, _ := v.visit(node.Arguments[1]) + v.begin(collection, scopeVar{"index", integerNature}, scopeVar{"acc", unknown}) + closure := v.visit(node.Arguments[1]) v.end() if len(node.Arguments) == 3 { - _, _ = v.visit(node.Arguments[2]) + _ = v.visit(node.Arguments[2]) } if isFunc(closure) && closure.NumOut() == 1 { - return closure.Out(0), info{} + return closure.Out(0) } return v.error(node.Arguments[1], "predicate should has two input and one output param") @@ -852,14 +853,14 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) { } type scopeVar struct { - name string - vtype reflect.Type + varName string + varNature Nature } -func (v *checker) begin(vtype reflect.Type, vars ...scopeVar) { - scope := predicateScope{vtype: vtype, vars: make(map[string]reflect.Type)} +func (v *checker) begin(collectionNature Nature, vars ...scopeVar) { + scope := predicateScope{collection: collectionNature, vars: make(map[string]Nature)} for _, v := range vars { - scope.vars[v.name] = v.vtype + scope.vars[v.varName] = v.varNature } v.predicateScopes = append(v.predicateScopes, scope) } @@ -868,83 +869,81 @@ func (v *checker) end() { v.predicateScopes = v.predicateScopes[:len(v.predicateScopes)-1] } -func (v *checker) checkBuiltinGet(node *ast.BuiltinNode) (reflect.Type, info) { +func (v *checker) checkBuiltinGet(node *ast.BuiltinNode) Nature { if len(node.Arguments) != 2 { return v.error(node, "invalid number of arguments (expected 2, got %d)", len(node.Arguments)) } - val := node.Arguments[0] - prop := node.Arguments[1] - if id, ok := val.(*ast.IdentifierNode); ok && id.Value == "$env" { - if s, ok := prop.(*ast.StringNode); ok { - return v.config.Types[s.Value].Type, info{} + base := v.visit(node.Arguments[0]) + prop := v.visit(node.Arguments[1]) + + if id, ok := node.Arguments[0].(*ast.IdentifierNode); ok && id.Value == "$env" { + if s, ok := node.Arguments[1].(*ast.StringNode); ok { + return Nature{Type: v.config.Types[s.Value].Type} } - return anyType, info{} + return unknown } - t, _ := v.visit(val) + if isUnknown(base) { + return unknown + } - switch kind(t) { - case reflect.Interface: - return anyType, info{} + switch base.Kind() { case reflect.Slice, reflect.Array: - p, _ := v.visit(prop) - if p == nil { - return v.error(prop, "cannot use nil as slice index") + if !isInteger(prop) && !isUnknown(prop) { + return v.error(node.Arguments[1], "non-integer slice index %s", prop) } - if !isInteger(p) && !isAny(p) { - return v.error(prop, "non-integer slice index %v", p) - } - return t.Elem(), info{} + return base.Elem() case reflect.Map: - p, _ := v.visit(prop) - if p == nil { - return v.error(prop, "cannot use nil as map index") - } - if !p.AssignableTo(t.Key()) && !isAny(p) { - return v.error(prop, "cannot use %v to get an element from %v", p, t) + if !prop.AssignableTo(base.Key()) && !isUnknown(prop) { + return v.error(node.Arguments[1], "cannot use %s to get an element from %s", prop, base) } - return t.Elem(), info{} + return base.Elem() } - return v.error(val, "type %v does not support indexing", t) + return v.error(node.Arguments[0], "type %v does not support indexing", base) } -func (v *checker) checkFunction(f *builtin.Function, node ast.Node, arguments []ast.Node) (reflect.Type, info) { +func (v *checker) checkFunction(f *builtin.Function, node ast.Node, arguments []ast.Node) Nature { if f.Validate != nil { args := make([]reflect.Type, len(arguments)) for i, arg := range arguments { - args[i], _ = v.visit(arg) + argNature := v.visit(arg) + if isUnknown(argNature) { + args[i] = anyType + } else { + args[i] = argNature.Type + } } t, err := f.Validate(args) if err != nil { return v.error(node, "%v", err) } - return t, info{} + return Nature{Type: t} } else if len(f.Types) == 0 { - t, err := v.checkArguments(f.Name, f.Type(), false, arguments, node) + nt, err := v.checkArguments(f.Name, Nature{Type: f.Type()}, arguments, node) if err != nil { if v.err == nil { v.err = err } - return anyType, info{} + return unknown } // No type was specified, so we assume the function returns any. - return t, info{} + return nt } var lastErr *file.Error for _, t := range f.Types { - outType, err := v.checkArguments(f.Name, t, false, arguments, node) + outNature, err := v.checkArguments(f.Name, Nature{Type: t}, arguments, node) if err != nil { lastErr = err continue } - return outType, info{} + return outNature } if lastErr != nil { if v.err == nil { v.err = lastErr } - return anyType, info{} + return unknown } return v.error(node, "no matching overload for %v", f.Name) @@ -952,23 +951,22 @@ func (v *checker) checkFunction(f *builtin.Function, node ast.Node, arguments [] func (v *checker) checkArguments( name string, - fn reflect.Type, - method bool, + fn Nature, arguments []ast.Node, node ast.Node, -) (reflect.Type, *file.Error) { - if isAny(fn) { - return anyType, nil +) (Nature, *file.Error) { + if isUnknown(fn) { + return unknown, nil } if fn.NumOut() == 0 { - return anyType, &file.Error{ + return unknown, &file.Error{ Location: node.Location(), Message: fmt.Sprintf("func %v doesn't return value", name), } } if numOut := fn.NumOut(); numOut > 2 { - return anyType, &file.Error{ + return unknown, &file.Error{ Location: node.Location(), Message: fmt.Sprintf("func %v returns more then two values", name), } @@ -977,12 +975,12 @@ func (v *checker) checkArguments( // If func is method on an env, first argument should be a receiver, // and actual arguments less than fnNumIn by one. fnNumIn := fn.NumIn() - if method { + if fn.Method { // TODO: Move subtraction to the Nature.NumIn() and Nature.In() methods. fnNumIn-- } // Skip first argument in case of the receiver. fnInOffset := 0 - if method { + if fn.Method { fnInOffset = 1 } @@ -1013,15 +1011,15 @@ func (v *checker) checkArguments( // If we have an error, we should still visit all arguments to // type check them, as a patch can fix the error later. for _, arg := range arguments { - _, _ = v.visit(arg) + _ = v.visit(arg) } return fn.Out(0), err } for i, arg := range arguments { - t, _ := v.visit(arg) + argNature := v.visit(arg) - var in reflect.Type + var in Nature if fn.IsVariadic() && i >= fnNumIn-1 { // For variadic arguments fn(xs ...int), go replaces type of xs (int) with ([]int). // As we compare arguments one by one, we need underling type. @@ -1030,24 +1028,40 @@ func (v *checker) checkArguments( in = fn.In(i + fnInOffset) } - if isFloat(in) && isInteger(t) { + if isFloat(in) && isInteger(argNature) { traverseAndReplaceIntegerNodesWithFloatNodes(&arguments[i], in) continue } - if isInteger(in) && isInteger(t) && kind(t) != kind(in) { + if isInteger(in) && isInteger(argNature) && argNature.Kind() != in.Kind() { traverseAndReplaceIntegerNodesWithIntegerNodes(&arguments[i], in) continue } - if t == nil { - continue + if isNil(argNature) { + if in.Kind() == reflect.Ptr || in.Kind() == reflect.Interface { + continue + } + return unknown, &file.Error{ + Location: arg.Location(), + Message: fmt.Sprintf("cannot use nil as argument (type %s) to call %v", in, name), + } } - if !(t.AssignableTo(in) || deref.Type(t).AssignableTo(in)) && kind(t) != reflect.Interface { - return anyType, &file.Error{ + // Check if argument is assignable to the function input type. + // We check original type (like *time.Time), not dereferenced type, + // as function input type can be pointer to a struct. + assignable := argNature.AssignableTo(in) + + // We also need to check if dereference arg type is assignable to the function input type. + // For example, func(int) and argument *int. In this case we will add OpDeref to the argument, + // so we can call the function with *int argument. + assignable = assignable || argNature.Deref().AssignableTo(in) + + if !assignable && !isUnknown(argNature) { + return unknown, &file.Error{ Location: arg.Location(), - Message: fmt.Sprintf("cannot use %v as argument (type %v) to call %v ", t, in, name), + Message: fmt.Sprintf("cannot use %s as argument (type %s) to call %v ", argNature, in, name), } } } @@ -1055,74 +1069,82 @@ func (v *checker) checkArguments( return fn.Out(0), nil } -func traverseAndReplaceIntegerNodesWithFloatNodes(node *ast.Node, newType reflect.Type) { +func traverseAndReplaceIntegerNodesWithFloatNodes(node *ast.Node, newNature Nature) { switch (*node).(type) { case *ast.IntegerNode: *node = &ast.FloatNode{Value: float64((*node).(*ast.IntegerNode).Value)} - (*node).SetType(newType) + (*node).SetType(newNature.Type) case *ast.UnaryNode: unaryNode := (*node).(*ast.UnaryNode) - traverseAndReplaceIntegerNodesWithFloatNodes(&unaryNode.Node, newType) + traverseAndReplaceIntegerNodesWithFloatNodes(&unaryNode.Node, newNature) case *ast.BinaryNode: binaryNode := (*node).(*ast.BinaryNode) switch binaryNode.Operator { case "+", "-", "*": - traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Left, newType) - traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Right, newType) + traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Left, newNature) + traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Right, newNature) } } } -func traverseAndReplaceIntegerNodesWithIntegerNodes(node *ast.Node, newType reflect.Type) { +func traverseAndReplaceIntegerNodesWithIntegerNodes(node *ast.Node, newNature Nature) { switch (*node).(type) { case *ast.IntegerNode: - (*node).SetType(newType) + (*node).SetType(newNature.Type) case *ast.UnaryNode: - (*node).SetType(newType) + (*node).SetType(newNature.Type) unaryNode := (*node).(*ast.UnaryNode) - traverseAndReplaceIntegerNodesWithIntegerNodes(&unaryNode.Node, newType) + traverseAndReplaceIntegerNodesWithIntegerNodes(&unaryNode.Node, newNature) case *ast.BinaryNode: // TODO: Binary node return type is dependent on the type of the operands. We can't just change the type of the node. binaryNode := (*node).(*ast.BinaryNode) switch binaryNode.Operator { case "+", "-", "*": - traverseAndReplaceIntegerNodesWithIntegerNodes(&binaryNode.Left, newType) - traverseAndReplaceIntegerNodesWithIntegerNodes(&binaryNode.Right, newType) + traverseAndReplaceIntegerNodesWithIntegerNodes(&binaryNode.Left, newNature) + traverseAndReplaceIntegerNodesWithIntegerNodes(&binaryNode.Right, newNature) } } } -func (v *checker) ClosureNode(node *ast.ClosureNode) (reflect.Type, info) { - t, _ := v.visit(node.Node) - if t == nil { - return v.error(node.Node, "closure cannot be nil") +func (v *checker) ClosureNode(node *ast.ClosureNode) Nature { + nt := v.visit(node.Node) + var out reflect.Type + if isUnknown(nt) { + out = anyType + } else { + out = nt.Type } - return reflect.FuncOf([]reflect.Type{anyType}, []reflect.Type{t}, false), info{} + return Nature{Type: reflect.FuncOf( + []reflect.Type{anyType}, + []reflect.Type{out}, + false, + )} } -func (v *checker) PointerNode(node *ast.PointerNode) (reflect.Type, info) { +func (v *checker) PointerNode(node *ast.PointerNode) Nature { if len(v.predicateScopes) == 0 { return v.error(node, "cannot use pointer accessor outside closure") } scope := v.predicateScopes[len(v.predicateScopes)-1] if node.Name == "" { - switch scope.vtype.Kind() { - case reflect.Interface: - return anyType, info{} + if isUnknown(scope.collection) { + return unknown + } + switch scope.collection.Kind() { case reflect.Array, reflect.Slice: - return scope.vtype.Elem(), info{} + return scope.collection.Elem() } return v.error(node, "cannot use %v as array", scope) } if scope.vars != nil { if t, ok := scope.vars[node.Name]; ok { - return t, info{} + return t } } return v.error(node, "unknown pointer #%v", node.Name) } -func (v *checker) VariableDeclaratorNode(node *ast.VariableDeclaratorNode) (reflect.Type, info) { +func (v *checker) VariableDeclaratorNode(node *ast.VariableDeclaratorNode) Nature { if _, ok := v.config.Types[node.Name]; ok { return v.error(node, "cannot redeclare %v", node.Name) } @@ -1135,11 +1157,11 @@ func (v *checker) VariableDeclaratorNode(node *ast.VariableDeclaratorNode) (refl if _, ok := v.lookupVariable(node.Name); ok { return v.error(node, "cannot redeclare variable %v", node.Name) } - vtype, vinfo := v.visit(node.Value) - v.varScopes = append(v.varScopes, varScope{node.Name, vtype, vinfo}) - t, i := v.visit(node.Expr) + varNature := v.visit(node.Value) + v.varScopes = append(v.varScopes, varScope{node.Name, varNature}) + exprNature := v.visit(node.Expr) v.varScopes = v.varScopes[:len(v.varScopes)-1] - return t, i + return exprNature } func (v *checker) lookupVariable(name string) (varScope, bool) { @@ -1151,59 +1173,60 @@ func (v *checker) lookupVariable(name string) (varScope, bool) { return varScope{}, false } -func (v *checker) ConditionalNode(node *ast.ConditionalNode) (reflect.Type, info) { - c, _ := v.visit(node.Cond) - if !isBool(c) && !isAny(c) { +func (v *checker) ConditionalNode(node *ast.ConditionalNode) Nature { + c := v.visit(node.Cond) + if !isBool(c) && !isUnknown(c) { return v.error(node.Cond, "non-bool expression (type %v) used as condition", c) } - t1, _ := v.visit(node.Exp1) - t2, _ := v.visit(node.Exp2) + t1 := v.visit(node.Exp1) + t2 := v.visit(node.Exp2) - if t1 == nil && t2 != nil { - return t2, info{} + if isNil(t1) && !isNil(t2) { + return t2 } - if t1 != nil && t2 == nil { - return t1, info{} + if !isNil(t1) && isNil(t2) { + return t1 } - if t1 == nil && t2 == nil { - return nilType, info{} + if isNil(t1) && isNil(t2) { + return nilNature } if t1.AssignableTo(t2) { - return t1, info{} + return t1 } - return anyType, info{} + return unknown } -func (v *checker) ArrayNode(node *ast.ArrayNode) (reflect.Type, info) { - var prev reflect.Type +func (v *checker) ArrayNode(node *ast.ArrayNode) Nature { + var prev Nature allElementsAreSameType := true for i, node := range node.Nodes { - curr, _ := v.visit(node) + curr := v.visit(node) if i > 0 { - if curr == nil || prev == nil { - allElementsAreSameType = false - } else if curr.Kind() != prev.Kind() { + if curr.Kind() != prev.Kind() { allElementsAreSameType = false } } prev = curr } - if allElementsAreSameType && prev != nil { - return arrayType, info{elem: prev} + if allElementsAreSameType { + return Nature{ + Type: arrayNature.Type, + SubType: Array{Of: prev}, + } } - return arrayType, info{} + return arrayNature } -func (v *checker) MapNode(node *ast.MapNode) (reflect.Type, info) { +func (v *checker) MapNode(node *ast.MapNode) Nature { for _, pair := range node.Pairs { v.visit(pair) } - return mapType, info{} + return mapNature } -func (v *checker) PairNode(node *ast.PairNode) (reflect.Type, info) { +func (v *checker) PairNode(node *ast.PairNode) Nature { v.visit(node.Key) v.visit(node.Value) - return nilType, info{} + return nilNature } diff --git a/checker/checker_test.go b/checker/checker_test.go index 1509045e3..ae42392c0 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -149,427 +149,541 @@ func TestCheck(t *testing.T) { } } -const errorTests = ` -Foo.Bar.Not +func TestCheck_error(t *testing.T) { + errorTests := []struct{ code, err string }{ + { + `Foo.Bar.Not`, + ` type mock.Bar has no field Not (1:9) | Foo.Bar.Not | ........^ - -Noo +`, + }, + { + `Noo`, + ` unknown name Noo (1:1) | Noo | ^ - -Foo() +`, + }, + { + `Foo()`, ` mock.Foo is not callable (1:1) | Foo() | ^ - -Foo['bar'] +`, + }, + { + `Foo['bar']`, ` type mock.Foo has no field bar (1:4) | Foo['bar'] | ...^ - -Foo.Method(42) +`, + }, + { + `Foo.Method(42)`, + ` too many arguments to call Method (1:5) | Foo.Method(42) | ....^ - -Foo.Bar() +`, + }, + {`Foo.Bar()`, ` mock.Bar is not callable (1:5) | Foo.Bar() | ....^ - -Foo.Bar.Not() +`, + }, + {`Foo.Bar.Not()`, ` type mock.Bar has no method Not (1:9) | Foo.Bar.Not() | ........^ - -ArrayOfFoo[0].Not +`, + }, + {`ArrayOfFoo[0].Not`, ` type mock.Foo has no field Not (1:15) | ArrayOfFoo[0].Not | ..............^ - -ArrayOfFoo[Not] +`, + }, + {`ArrayOfFoo[Not]`, ` unknown name Not (1:12) | ArrayOfFoo[Not] | ...........^ - -Not[0] +`, + }, + {`Not[0]`, ` unknown name Not (1:1) | Not[0] | ^ - -Not.Bar +`, + }, + {`Not.Bar`, ` unknown name Not (1:1) | Not.Bar | ^ - -ArrayOfFoo.Not +`, + }, + {`ArrayOfFoo.Not`, ` array elements can only be selected using an integer (got string) (1:12) | ArrayOfFoo.Not | ...........^ - -FuncParam(true) +`, + }, + {`FuncParam(true)`, ` not enough arguments to call FuncParam (1:1) | FuncParam(true) | ^ - -MapOfFoo['str'].Not +`, + }, + {`MapOfFoo['str'].Not`, ` type mock.Foo has no field Not (1:17) | MapOfFoo['str'].Not | ................^ - -Bool && IntPtr +`, + }, + {`Bool && IntPtr`, ` invalid operation: && (mismatched types bool and int) (1:6) | Bool && IntPtr | .....^ - -No ? Any.Bool : Any.Not +`, + }, + {`No ? Any.Bool : Any.Not`, ` unknown name No (1:1) | No ? Any.Bool : Any.Not | ^ - -Any.Cond ? No : Any.Not +`, + }, + {`Any.Cond ? No : Any.Not`, ` unknown name No (1:12) | Any.Cond ? No : Any.Not | ...........^ - -Any.Cond ? Any.Bool : No +`, + }, + {`Any.Cond ? Any.Bool : No`, ` unknown name No (1:23) | Any.Cond ? Any.Bool : No | ......................^ - -MapOfAny ? Any : Any +`, + }, + {`MapOfAny ? Any : Any`, ` non-bool expression (type map[string]interface {}) used as condition (1:1) | MapOfAny ? Any : Any | ^ - -String matches Int +`, + }, + {`String matches Int`, ` invalid operation: matches (mismatched types string and int) (1:8) | String matches Int | .......^ - -Int matches String +`, + }, + {`Int matches String`, ` invalid operation: matches (mismatched types int and string) (1:5) | Int matches String | ....^ - -String contains Int +`, + }, + {`String contains Int`, ` invalid operation: contains (mismatched types string and int) (1:8) | String contains Int | .......^ - -Int contains String +`, + }, + {`Int contains String`, ` invalid operation: contains (mismatched types int and string) (1:5) | Int contains String | ....^ - -!Not +`, + }, + {`!Not`, ` unknown name Not (1:2) | !Not | .^ - -Not == Any +`, + }, + {`Not == Any`, ` unknown name Not (1:1) | Not == Any | ^ - -[Not] +`, + }, + {`[Not]`, ` unknown name Not (1:2) | [Not] | .^ - -{id: Not} +`, + }, + {`{id: Not}`, ` unknown name Not (1:6) | {id: Not} | .....^ - -(nil).Foo -type has no field Foo (1:7) +`, + }, + {`(nil).Foo`, ` +type nil has no field Foo (1:7) | (nil).Foo | ......^ - -(nil)['Foo'] -type has no field Foo (1:6) +`, + }, + {`(nil)['Foo']`, ` +type nil has no field Foo (1:6) | (nil)['Foo'] | .....^ - -1 and false +`, + }, + {`1 and false`, ` invalid operation: and (mismatched types int and bool) (1:3) | 1 and false | ..^ - -true or 0 +`, + }, + {`true or 0`, ` invalid operation: or (mismatched types bool and int) (1:6) | true or 0 | .....^ - -not IntPtr +`, + }, + {`not IntPtr`, ` invalid operation: not (mismatched type int) (1:1) | not IntPtr | ^ - -len(Not) +`, + }, + {`len(Not)`, ` unknown name Not (1:5) | len(Not) | ....^ - -Int < Bool +`, + }, + {`Int < Bool`, ` invalid operation: < (mismatched types int and bool) (1:5) | Int < Bool | ....^ - -Int > Bool +`, + }, + {`Int > Bool`, ` invalid operation: > (mismatched types int and bool) (1:5) | Int > Bool | ....^ - -Int >= Bool +`, + }, + {`Int >= Bool`, ` invalid operation: >= (mismatched types int and bool) (1:5) | Int >= Bool | ....^ - -Int <= Bool +`, + }, + {`Int <= Bool`, ` invalid operation: <= (mismatched types int and bool) (1:5) | Int <= Bool | ....^ - -Int + Bool +`, + }, + {`Int + Bool`, ` invalid operation: + (mismatched types int and bool) (1:5) | Int + Bool | ....^ - -Int - Bool +`, + }, + {`Int - Bool`, ` invalid operation: - (mismatched types int and bool) (1:5) | Int - Bool | ....^ - -Int * Bool +`, + }, + {`Int * Bool`, ` invalid operation: * (mismatched types int and bool) (1:5) | Int * Bool | ....^ - -Int / Bool +`, + }, + {`Int / Bool`, ` invalid operation: / (mismatched types int and bool) (1:5) | Int / Bool | ....^ - -Int % Bool +`, + }, + {`Int % Bool`, ` invalid operation: % (mismatched types int and bool) (1:5) | Int % Bool | ....^ - -Int ** Bool +`, + }, + {`Int ** Bool`, ` invalid operation: ** (mismatched types int and bool) (1:5) | Int ** Bool | ....^ - -Int .. Bool +`, + }, + {`Int .. Bool`, ` invalid operation: .. (mismatched types int and bool) (1:5) | Int .. Bool | ....^ - -Any > Foo +`, + }, + {`Any > Foo`, ` invalid operation: > (mismatched types interface {} and mock.Foo) (1:5) | Any > Foo | ....^ - -NilFn() and BoolFn() +`, + }, + {`NilFn() and BoolFn()`, ` func NilFn doesn't return value (1:1) | NilFn() and BoolFn() | ^ - -'str' in String +`, + }, + {`'str' in String`, ` invalid operation: in (mismatched types string and string) (1:7) | 'str' in String | ......^ - -1 in Foo +`, + }, + {`1 in Foo`, ` invalid operation: in (mismatched types int and mock.Foo) (1:3) | 1 in Foo | ..^ - -1 + '' +`, + }, + {`1 + ''`, ` invalid operation: + (mismatched types int and string) (1:3) | 1 + '' | ..^ - -all(ArrayOfFoo, {#.Method() < 0}) +`, + }, + {`all(ArrayOfFoo, {#.Method() < 0})`, ` invalid operation: < (mismatched types mock.Bar and int) (1:29) | all(ArrayOfFoo, {#.Method() < 0}) | ............................^ - -Variadic() +`, + }, + {`Variadic()`, ` not enough arguments to call Variadic (1:1) | Variadic() | ^ - -Variadic(0, '') +`, + }, + {`Variadic(0, '')`, ` cannot use string as argument (type int) to call Variadic (1:13) | Variadic(0, '') | ............^ - -count(1, {#}) +`, + }, + {`count(1, {#})`, ` builtin count takes only array (got int) (1:7) | count(1, {#}) | ......^ - -count(ArrayOfInt, {#}) +`, + }, + {`count(ArrayOfInt, {#})`, ` predicate should return boolean (got int) (1:19) | count(ArrayOfInt, {#}) | ..................^ - -all(ArrayOfInt, {# + 1}) +`, + }, + {`all(ArrayOfInt, {# + 1})`, ` predicate should return boolean (got int) (1:17) | all(ArrayOfInt, {# + 1}) | ................^ - -filter(ArrayOfFoo, {.Bar.Baz}) +`, + }, + {`filter(ArrayOfFoo, {.Bar.Baz})`, ` predicate should return boolean (got string) (1:20) | filter(ArrayOfFoo, {.Bar.Baz}) | ...................^ - -find(ArrayOfFoo, {.Bar.Baz}) +`, + }, + {`find(ArrayOfFoo, {.Bar.Baz})`, ` predicate should return boolean (got string) (1:18) | find(ArrayOfFoo, {.Bar.Baz}) | .................^ - -map(1, {2}) +`, + }, + {`map(1, {2})`, ` builtin map takes only array (got int) (1:5) | map(1, {2}) | ....^ - -ArrayOfFoo[Foo] +`, + }, + {`ArrayOfFoo[Foo]`, ` array elements can only be selected using an integer (got mock.Foo) (1:12) | ArrayOfFoo[Foo] | ...........^ - -ArrayOfFoo[Bool:] +`, + }, + {`ArrayOfFoo[Bool:]`, ` non-integer slice index bool (1:12) | ArrayOfFoo[Bool:] | ...........^ - -ArrayOfFoo[1:Bool] +`, + }, + {`ArrayOfFoo[1:Bool]`, ` non-integer slice index bool (1:14) | ArrayOfFoo[1:Bool] | .............^ - -Bool[:] +`, + }, + {`Bool[:]`, ` cannot slice bool (1:5) | Bool[:] | ....^ - -FuncTooManyReturns() +`, + }, + {`FuncTooManyReturns()`, ` func FuncTooManyReturns returns more then two values (1:1) | FuncTooManyReturns() | ^ - -len(42) +`, + }, + {`len(42)`, ` invalid argument for len (type int) (1:1) | len(42) | ^ - -any(42, {#}) +`, + }, + {`any(42, {#})`, ` builtin any takes only array (got int) (1:5) | any(42, {#}) | ....^ - -filter(42, {#}) +`, + }, + {`filter(42, {#})`, ` builtin filter takes only array (got int) (1:8) | filter(42, {#}) | .......^ - -MapOfAny[0] +`, + }, + {`MapOfAny[0]`, ` cannot use int to get an element from map[string]interface {} (1:10) | MapOfAny[0] | .........^ - -1 /* one */ + "2" +`, + }, + {`1 /* one */ + "2"`, ` invalid operation: + (mismatched types int and string) (1:13) | 1 /* one */ + "2" | ............^ - -FuncTyped(42) +`, + }, + {`FuncTyped(42)`, ` cannot use int as argument (type string) to call FuncTyped (1:11) | FuncTyped(42) | ..........^ - -.0 in MapOfFoo +`, + }, + {`.0 in MapOfFoo`, ` cannot use float64 as type string in map key (1:4) | .0 in MapOfFoo | ...^ - -1/2 in MapIntAny +`, + }, + {`1/2 in MapIntAny`, ` cannot use float64 as type int in map key (1:5) | 1/2 in MapIntAny | ....^ - -0.5 in ArrayOfFoo +`, + }, + {`0.5 in ArrayOfFoo`, ` cannot use float64 as type *mock.Foo in array (1:5) | 0.5 in ArrayOfFoo | ....^ - -repeat("0", 1/0) +`, + }, + {`repeat("0", 1/0)`, ` cannot use float64 as argument (type int) to call repeat (1:14) | repeat("0", 1/0) | .............^ - -let map = 42; map +`, + }, + {`let map = 42; map`, ` cannot redeclare builtin map (1:5) | let map = 42; map | ....^ - -let len = 42; len +`, + }, + {`let len = 42; len`, ` cannot redeclare builtin len (1:5) | let len = 42; len | ....^ - -let Float = 42; Float +`, + }, + {`let Float = 42; Float`, ` cannot redeclare Float (1:5) | let Float = 42; Float | ....^ - -let foo = 1; let foo = 2; foo +`, + }, + {`let foo = 1; let foo = 2; foo`, ` cannot redeclare variable foo (1:18) | let foo = 1; let foo = 2; foo | .................^ - -map(1..9, #unknown) +`, + }, + {`map(1..9, #unknown)`, ` unknown pointer #unknown (1:11) | map(1..9, #unknown) | ..........^ - -42 in ["a", "b", "c"] +`, + }, + {`42 in ["a", "b", "c"]`, ` cannot use int as type string in array (1:4) | 42 in ["a", "b", "c"] | ...^ - -"foo" matches "[+" +`, + }, + {`"foo" matches "[+"`, ` error parsing regexp: missing closing ]: ` + "`[+`" + ` (1:7) | "foo" matches "[+" | ......^ -` - -func TestCheck_error(t *testing.T) { - tests := strings.Split(strings.Trim(errorTests, "\n"), "\n\n") - - for _, test := range tests { - input := strings.SplitN(test, "\n", 2) - if len(input) != 2 { - t.Errorf("syntax error in test: %q", test) - break - } +`, + }, + {`get(false, 2)`, ` +type bool does not support indexing (1:5) + | get(false, 2) + | ....^ +`, + }, + {`get(1..2, 0.5)`, ` +non-integer slice index float64 (1:11) + | get(1..2, 0.5) + | ..........^`, + }, + {`trimPrefix(nil)`, ` +cannot use nil as argument (type string) to call trimPrefix (1:12) + | trimPrefix(nil) + | ...........^ +`, + }, + {`1..3 | filter(# > 1) | filter(# == "str")`, + ` +invalid operation: == (mismatched types int and string) (1:33) + | 1..3 | filter(# > 1) | filter(# == "str") + | ................................^ +`, + }, + {`1..3 | map("str") | filter(# > 1)`, + ` +invalid operation: > (mismatched types string and int) (1:30) + | 1..3 | map("str") | filter(# > 1) + | .............................^ +`, + }, + } - tree, err := parser.Parse(input[0]) - assert.NoError(t, err) + for _, tt := range errorTests { + t.Run(tt.code, func(t *testing.T) { + tree, err := parser.Parse(tt.code) + require.NoError(t, err) - _, err = checker.Check(tree, conf.New(mock.Env{})) - if err == nil { - err = fmt.Errorf("") - } + _, err = checker.Check(tree, conf.New(mock.Env{})) + if err == nil { + err = fmt.Errorf("") + } - assert.Equal(t, input[1], err.Error(), input[0]) + assert.Equal(t, strings.Trim(tt.err, "\n"), err.Error()) + }) } } diff --git a/checker/info.go b/checker/info.go index 112bfab31..3c9396fd5 100644 --- a/checker/info.go +++ b/checker/info.go @@ -15,11 +15,9 @@ func FieldIndex(types conf.TypesTable, node ast.Node) (bool, []int, string) { return true, t.FieldIndex, n.Value } case *ast.MemberNode: - base := n.Node.Type() - if kind(base) == reflect.Ptr { - base = base.Elem() - } - if kind(base) == reflect.Struct { + base := n.Node.Nature() + base = base.Deref() + if base.Kind() == reflect.Struct { if prop, ok := n.Property.(*ast.StringNode); ok { name := prop.Value if field, ok := fetchField(base, name); ok { @@ -114,8 +112,7 @@ func IsFastFunc(fn reflect.Type, method bool) bool { if method { numIn = 2 } - if !isAny(fn) && - fn.IsVariadic() && + if fn.IsVariadic() && fn.NumIn() == numIn && fn.NumOut() == 1 && fn.Out(0).Kind() == reflect.Interface { diff --git a/checker/nature/nature.go b/checker/nature/nature.go new file mode 100644 index 000000000..a7365998c --- /dev/null +++ b/checker/nature/nature.go @@ -0,0 +1,142 @@ +package nature + +import ( + "reflect" + + "github.com/expr-lang/expr/builtin" + "github.com/expr-lang/expr/internal/deref" +) + +var ( + unknown = Nature{} +) + +type Nature struct { + Type reflect.Type + SubType SubType + Func *builtin.Function + Method bool +} + +func (n Nature) String() string { + if n.SubType != nil { + return n.SubType.String() + } + if n.Type != nil { + return n.Type.String() + } + return "unknown" +} + +func (n Nature) Deref() Nature { + if n.Type != nil { + n.Type = deref.Type(n.Type) + } + return n +} + +func (n Nature) Kind() reflect.Kind { + if n.Type != nil { + return n.Type.Kind() + } + return reflect.Invalid +} + +func (n Nature) Key() Nature { + if n.Kind() == reflect.Map { + return Nature{Type: n.Type.Key()} + } + return unknown +} + +func (n Nature) Elem() Nature { + switch n.Kind() { + case reflect.Map, reflect.Ptr: + return Nature{Type: n.Type.Elem()} + case reflect.Array, reflect.Slice: + if array, ok := n.SubType.(Array); ok { + return array.Of + } + return Nature{Type: n.Type.Elem()} + } + return unknown +} + +func (n Nature) AssignableTo(nt Nature) bool { + if n.Type == nil || nt.Type == nil { + return false + } + return n.Type.AssignableTo(nt.Type) +} + +func (n Nature) MethodByName(name string) (Nature, bool) { + if n.Type == nil { + return unknown, false + } + method, ok := n.Type.MethodByName(name) + if !ok { + return unknown, false + } + + if n.Type.Kind() == reflect.Interface { + // In case of interface type method will not have a receiver, + // and to prevent checker decreasing numbers of in arguments + // return method type as not method (second argument is false). + + // Also, we can not use m.Index here, because it will be + // different indexes for different types which implement + // the same interface. + return Nature{Type: method.Type}, true + } else { + return Nature{Type: method.Type, Method: true}, true + } +} + +func (n Nature) NumField() int { + if n.Type == nil { + return 0 + } + return n.Type.NumField() +} + +func (n Nature) Field(i int) reflect.StructField { + if n.Type == nil { + return reflect.StructField{} + } + return n.Type.Field(i) +} + +func (n Nature) NumIn() int { + if n.Type == nil { + return 0 + } + return n.Type.NumIn() +} + +func (n Nature) In(i int) Nature { + if n.Type == nil { + return unknown + } + return Nature{Type: n.Type.In(i)} +} + +func (n Nature) NumOut() int { + if n.Type == nil { + return 0 + } + return n.Type.NumOut() +} + +func (n Nature) Out(i int) Nature { + if n.Type == nil { + return unknown + } + return Nature{Type: n.Type.Out(i)} +} + +func (n Nature) IsVariadic() bool { + if n.Type == nil { + return false + } + return n.Type.IsVariadic() +} diff --git a/checker/nature/types.go b/checker/nature/types.go new file mode 100644 index 000000000..1f9955e92 --- /dev/null +++ b/checker/nature/types.go @@ -0,0 +1,13 @@ +package nature + +type SubType interface { + String() string +} + +type Array struct { + Of Nature +} + +func (a Array) String() string { + return "[]" + a.Of.String() +} diff --git a/checker/types.go b/checker/types.go index d10736a77..2eb5392e0 100644 --- a/checker/types.go +++ b/checker/types.go @@ -4,204 +4,193 @@ import ( "reflect" "time" + . "github.com/expr-lang/expr/checker/nature" "github.com/expr-lang/expr/conf" ) var ( - nilType = reflect.TypeOf(nil) - boolType = reflect.TypeOf(true) - integerType = reflect.TypeOf(0) - floatType = reflect.TypeOf(float64(0)) - stringType = reflect.TypeOf("") - arrayType = reflect.TypeOf([]any{}) - mapType = reflect.TypeOf(map[string]any{}) + unknown = Nature{} + nilNature = Nature{Type: reflect.TypeOf(Nil{})} + boolNature = Nature{Type: reflect.TypeOf(true)} + integerNature = Nature{Type: reflect.TypeOf(0)} + floatNature = Nature{Type: reflect.TypeOf(float64(0))} + stringNature = Nature{Type: reflect.TypeOf("")} + arrayNature = Nature{Type: reflect.TypeOf([]any{})} + mapNature = Nature{Type: reflect.TypeOf(map[string]any{})} + timeNature = Nature{Type: reflect.TypeOf(time.Time{})} + durationNature = Nature{Type: reflect.TypeOf(time.Duration(0))} +) + +var ( anyType = reflect.TypeOf(new(any)).Elem() timeType = reflect.TypeOf(time.Time{}) durationType = reflect.TypeOf(time.Duration(0)) + arrayType = reflect.TypeOf([]any{}) ) -func combined(a, b reflect.Type) reflect.Type { - if a.Kind() == b.Kind() { - return a +// Nil is a special type to represent nil. +type Nil struct{} + +func isNil(nt Nature) bool { + if nt.Type == nil { + return false + } + return nt.Type == nilNature.Type +} + +func combined(l, r Nature) Nature { + if isUnknown(l) || isUnknown(r) { + return unknown } - if isFloat(a) || isFloat(b) { - return floatType + if isFloat(l) || isFloat(r) { + return floatNature } - return integerType + return integerNature } -func anyOf(t reflect.Type, fns ...func(reflect.Type) bool) bool { +func anyOf(nt Nature, fns ...func(Nature) bool) bool { for _, fn := range fns { - if fn(t) { + if fn(nt) { return true } } return false } -func or(l, r reflect.Type, fns ...func(reflect.Type) bool) bool { - if isAny(l) && isAny(r) { +func or(l, r Nature, fns ...func(Nature) bool) bool { + if isUnknown(l) && isUnknown(r) { return true } - if isAny(l) && anyOf(r, fns...) { + if isUnknown(l) && anyOf(r, fns...) { return true } - if isAny(r) && anyOf(l, fns...) { + if isUnknown(r) && anyOf(l, fns...) { return true } return false } -func isAny(t reflect.Type) bool { - if t != nil { - switch t.Kind() { - case reflect.Interface: - return true - } +func isUnknown(nt Nature) bool { + switch { + case nt.Type == nil: + return true + case nt.Kind() == reflect.Interface: + return true } return false } -func isInteger(t reflect.Type) bool { - if t != nil { - switch t.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - fallthrough - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return true - } +func isInteger(nt Nature) bool { + switch nt.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + fallthrough + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return true } return false } -func isFloat(t reflect.Type) bool { - if t != nil { - switch t.Kind() { - case reflect.Float32, reflect.Float64: - return true - } +func isFloat(nt Nature) bool { + switch nt.Kind() { + case reflect.Float32, reflect.Float64: + return true } return false } -func isNumber(t reflect.Type) bool { - return isInteger(t) || isFloat(t) +func isNumber(nt Nature) bool { + return isInteger(nt) || isFloat(nt) } -func isTime(t reflect.Type) bool { - if t != nil { - switch t { - case timeType: - return true - } +func isTime(nt Nature) bool { + switch nt.Type { + case timeType: + return true } return false } -func isDuration(t reflect.Type) bool { - if t != nil { - switch t { - case durationType: - return true - } +func isDuration(nt Nature) bool { + switch nt.Type { + case durationType: + return true } return false } -func isBool(t reflect.Type) bool { - if t != nil { - switch t.Kind() { - case reflect.Bool: - return true - } +func isBool(nt Nature) bool { + switch nt.Kind() { + case reflect.Bool: + return true } return false } -func isString(t reflect.Type) bool { - if t != nil { - switch t.Kind() { - case reflect.String: - return true - } +func isString(nt Nature) bool { + switch nt.Kind() { + case reflect.String: + return true } return false } -func isArray(t reflect.Type) bool { - if t != nil { - switch t.Kind() { - case reflect.Ptr: - return isArray(t.Elem()) - case reflect.Slice, reflect.Array: - return true - } +func isArray(nt Nature) bool { + switch nt.Kind() { + case reflect.Slice, reflect.Array: + return true } return false } -func isMap(t reflect.Type) bool { - if t != nil { - switch t.Kind() { - case reflect.Ptr: - return isMap(t.Elem()) - case reflect.Map: - return true - } +func isMap(nt Nature) bool { + switch nt.Kind() { + case reflect.Map: + return true } return false } -func isStruct(t reflect.Type) bool { - if t != nil { - switch t.Kind() { - case reflect.Ptr: - return isStruct(t.Elem()) - case reflect.Struct: - return true - } +func isStruct(nt Nature) bool { + switch nt.Kind() { + case reflect.Struct: + return true } return false } -func isFunc(t reflect.Type) bool { - if t != nil { - switch t.Kind() { - case reflect.Ptr: - return isFunc(t.Elem()) - case reflect.Func: - return true - } +func isFunc(nt Nature) bool { + switch nt.Kind() { + case reflect.Func: + return true } return false } -func fetchField(t reflect.Type, name string) (reflect.StructField, bool) { - if t != nil { - // First check all structs fields. - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - // Search all fields, even embedded structs. - if conf.FieldName(field) == name { - return field, true - } +func fetchField(nt Nature, name string) (reflect.StructField, bool) { + // First check all structs fields. + for i := 0; i < nt.NumField(); i++ { + field := nt.Field(i) + // Search all fields, even embedded structs. + if conf.FieldName(field) == name { + return field, true } + } - // Second check fields of embedded structs. - for i := 0; i < t.NumField(); i++ { - anon := t.Field(i) - if anon.Anonymous { - anonType := anon.Type - if anonType.Kind() == reflect.Pointer { - anonType = anonType.Elem() - } - if field, ok := fetchField(anonType, name); ok { - field.Index = append(anon.Index, field.Index...) - return field, true - } + // Second check fields of embedded structs. + for i := 0; i < nt.NumField(); i++ { + anon := nt.Field(i) + if anon.Anonymous { + anonType := anon.Type + if anonType.Kind() == reflect.Pointer { + anonType = anonType.Elem() + } + if field, ok := fetchField(Nature{Type: anonType}, name); ok { + field.Index = append(anon.Index, field.Index...) + return field, true } } } + return reflect.StructField{}, false } @@ -212,16 +201,15 @@ func kind(t reflect.Type) reflect.Kind { return t.Kind() } -func isComparable(l, r reflect.Type) bool { - if l == nil || r == nil { - return true - } +func isComparable(l, r Nature) bool { switch { case l.Kind() == r.Kind(): return true case isNumber(l) && isNumber(r): return true - case isAny(l) || isAny(r): + case isNil(l) || isNil(r): + return true + case isUnknown(l) || isUnknown(r): return true } return false diff --git a/compiler/compiler.go b/compiler/compiler.go index 29a6d4cdc..b4ec6dbd4 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -377,16 +377,13 @@ func (c *compiler) IntegerNode(node *ast.IntegerNode) { } func (c *compiler) FloatNode(node *ast.FloatNode) { - t := node.Type() - if t == nil { - c.emitPush(node.Value) - return - } - switch t.Kind() { + switch node.Kind() { case reflect.Float32: c.emitPush(float32(node.Value)) case reflect.Float64: c.emitPush(node.Value) + default: + c.emitPush(node.Value) } } @@ -1202,7 +1199,7 @@ func (c *compiler) PairNode(node *ast.PairNode) { } func (c *compiler) derefInNeeded(node ast.Node) { - switch kind(node.Type()) { + switch node.Kind() { case reflect.Ptr, reflect.Interface: c.emit(OpDeref) } diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index ba5d6dc54..526c5d67d 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -442,10 +442,11 @@ func TestCompile_OpCallFast(t *testing.T) { func TestCompile_optimizes_jumps(t *testing.T) { env := map[string]any{ - "a": true, - "b": true, - "c": true, - "d": true, + "a": true, + "b": true, + "c": true, + "d": true, + "i64": int64(1), } tests := []struct { code string @@ -497,36 +498,33 @@ func TestCompile_optimizes_jumps(t *testing.T) { `filter([1, 2, 3, 4, 5], # > 3 && # != 4 && # != 5)`, `0 OpPush <0> [1 2 3 4 5] 1 OpBegin -2 OpJumpIfEnd <26> (29) +2 OpJumpIfEnd <23> (26) 3 OpPointer -4 OpDeref -5 OpPush <1> 3 -6 OpMore -7 OpJumpIfFalse <18> (26) -8 OpPop -9 OpPointer -10 OpDeref -11 OpPush <2> 4 -12 OpEqual -13 OpNot -14 OpJumpIfFalse <11> (26) -15 OpPop -16 OpPointer -17 OpDeref -18 OpPush <3> 5 -19 OpEqual -20 OpNot -21 OpJumpIfFalse <4> (26) -22 OpPop -23 OpIncrementCount -24 OpPointer -25 OpJump <1> (27) -26 OpPop -27 OpIncrementIndex -28 OpJumpBackward <27> (2) -29 OpGetCount -30 OpEnd -31 OpArray +4 OpPush <1> 3 +5 OpMore +6 OpJumpIfFalse <16> (23) +7 OpPop +8 OpPointer +9 OpPush <2> 4 +10 OpEqualInt +11 OpNot +12 OpJumpIfFalse <10> (23) +13 OpPop +14 OpPointer +15 OpPush <3> 5 +16 OpEqualInt +17 OpNot +18 OpJumpIfFalse <4> (23) +19 OpPop +20 OpIncrementCount +21 OpPointer +22 OpJump <1> (24) +23 OpPop +24 OpIncrementIndex +25 OpJumpBackward <24> (2) +26 OpGetCount +27 OpEnd +28 OpArray `, }, { @@ -650,3 +648,12 @@ func TestCompile_IntegerArgsFunc(t *testing.T) { }) } } + +func TestCompile_call_on_nil(t *testing.T) { + env := map[string]any{ + "foo": nil, + } + _, err := expr.Compile(`foo()`, expr.Env(env)) + require.Error(t, err) + require.Contains(t, err.Error(), "foo is nil; cannot call nil as function") +} diff --git a/expr_test.go b/expr_test.go index 8b7856a43..3724467fc 100644 --- a/expr_test.go +++ b/expr_test.go @@ -2296,14 +2296,6 @@ func TestIssue432(t *testing.T) { assert.Equal(t, float64(10), out) } -func TestIssue453(t *testing.T) { - env := map[string]any{ - "foo": nil, - } - _, err := expr.Compile(`foo()`, expr.Env(env)) - require.Error(t, err) -} - func TestIssue461(t *testing.T) { type EnvStr string type EnvField struct { diff --git a/optimizer/const_expr.go b/optimizer/const_expr.go index 501ea3c58..1b45385f6 100644 --- a/optimizer/const_expr.go +++ b/optimizer/const_expr.go @@ -30,11 +30,6 @@ func (c *constExpr) Visit(node *Node) { } }() - patch := func(newNode Node) { - c.applied = true - Patch(node, newNode) - } - if call, ok := (*node).(*CallNode); ok { if name, ok := call.Callee.(*IdentifierNode); ok { fn, ok := c.fns[name.Value] @@ -78,7 +73,8 @@ func (c *constExpr) Visit(node *Node) { return } constNode := &ConstantNode{Value: value} - patch(constNode) + patchWithType(node, constNode) + c.applied = true } } } diff --git a/optimizer/filter_first.go b/optimizer/filter_first.go index 7ea8f6fa4..b04a5cb34 100644 --- a/optimizer/filter_first.go +++ b/optimizer/filter_first.go @@ -12,7 +12,7 @@ func (*filterFirst) Visit(node *Node) { if filter, ok := member.Node.(*BuiltinNode); ok && filter.Name == "filter" && len(filter.Arguments) == 2 { - Patch(node, &BuiltinNode{ + patchCopyType(node, &BuiltinNode{ Name: "find", Arguments: filter.Arguments, Throws: true, // to match the behavior of filter()[0] @@ -27,7 +27,7 @@ func (*filterFirst) Visit(node *Node) { if filter, ok := first.Arguments[0].(*BuiltinNode); ok && filter.Name == "filter" && len(filter.Arguments) == 2 { - Patch(node, &BuiltinNode{ + patchCopyType(node, &BuiltinNode{ Name: "find", Arguments: filter.Arguments, Throws: false, // as first() will return nil if not found diff --git a/optimizer/filter_last.go b/optimizer/filter_last.go index 9a1cc5e29..8c046bf88 100644 --- a/optimizer/filter_last.go +++ b/optimizer/filter_last.go @@ -12,7 +12,7 @@ func (*filterLast) Visit(node *Node) { if filter, ok := member.Node.(*BuiltinNode); ok && filter.Name == "filter" && len(filter.Arguments) == 2 { - Patch(node, &BuiltinNode{ + patchCopyType(node, &BuiltinNode{ Name: "findLast", Arguments: filter.Arguments, Throws: true, // to match the behavior of filter()[-1] @@ -27,7 +27,7 @@ func (*filterLast) Visit(node *Node) { if filter, ok := first.Arguments[0].(*BuiltinNode); ok && filter.Name == "filter" && len(filter.Arguments) == 2 { - Patch(node, &BuiltinNode{ + patchCopyType(node, &BuiltinNode{ Name: "findLast", Arguments: filter.Arguments, Throws: false, // as last() will return nil if not found diff --git a/optimizer/filter_len.go b/optimizer/filter_len.go index 6577163ec..c66fde961 100644 --- a/optimizer/filter_len.go +++ b/optimizer/filter_len.go @@ -13,7 +13,7 @@ func (*filterLen) Visit(node *Node) { if filter, ok := ln.Arguments[0].(*BuiltinNode); ok && filter.Name == "filter" && len(filter.Arguments) == 2 { - Patch(node, &BuiltinNode{ + patchCopyType(node, &BuiltinNode{ Name: "count", Arguments: filter.Arguments, }) diff --git a/optimizer/filter_map.go b/optimizer/filter_map.go index d988dc692..e916dd75b 100644 --- a/optimizer/filter_map.go +++ b/optimizer/filter_map.go @@ -14,7 +14,7 @@ func (*filterMap) Visit(node *Node) { if filter, ok := mapBuiltin.Arguments[0].(*BuiltinNode); ok && filter.Name == "filter" && filter.Map == nil /* not already optimized */ { - Patch(node, &BuiltinNode{ + patchCopyType(node, &BuiltinNode{ Name: "filter", Arguments: filter.Arguments, Map: closure.Node, diff --git a/optimizer/fold.go b/optimizer/fold.go index 910c92402..2f4562c22 100644 --- a/optimizer/fold.go +++ b/optimizer/fold.go @@ -1,20 +1,12 @@ package optimizer import ( - "fmt" "math" - "reflect" . "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/file" ) -var ( - integerType = reflect.TypeOf(0) - floatType = reflect.TypeOf(float64(0)) - stringType = reflect.TypeOf("") -) - type fold struct { applied bool err *file.Error @@ -23,20 +15,11 @@ type fold struct { func (fold *fold) Visit(node *Node) { patch := func(newNode Node) { fold.applied = true - Patch(node, newNode) + patchWithType(node, newNode) } - patchWithType := func(newNode Node) { - patch(newNode) - switch newNode.(type) { - case *IntegerNode: - newNode.SetType(integerType) - case *FloatNode: - newNode.SetType(floatType) - case *StringNode: - newNode.SetType(stringType) - default: - panic(fmt.Sprintf("unknown type %T", newNode)) - } + patchCopy := func(newNode Node) { + fold.applied = true + patchCopyType(node, newNode) } switch n := (*node).(type) { @@ -44,17 +27,17 @@ func (fold *fold) Visit(node *Node) { switch n.Operator { case "-": if i, ok := n.Node.(*IntegerNode); ok { - patchWithType(&IntegerNode{Value: -i.Value}) + patch(&IntegerNode{Value: -i.Value}) } if i, ok := n.Node.(*FloatNode); ok { - patchWithType(&FloatNode{Value: -i.Value}) + patch(&FloatNode{Value: -i.Value}) } case "+": if i, ok := n.Node.(*IntegerNode); ok { - patchWithType(&IntegerNode{Value: i.Value}) + patch(&IntegerNode{Value: i.Value}) } if i, ok := n.Node.(*FloatNode); ok { - patchWithType(&FloatNode{Value: i.Value}) + patch(&FloatNode{Value: i.Value}) } case "!", "not": if a := toBool(n.Node); a != nil { @@ -69,28 +52,28 @@ func (fold *fold) Visit(node *Node) { a := toInteger(n.Left) b := toInteger(n.Right) if a != nil && b != nil { - patchWithType(&IntegerNode{Value: a.Value + b.Value}) + patch(&IntegerNode{Value: a.Value + b.Value}) } } { a := toInteger(n.Left) b := toFloat(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: float64(a.Value) + b.Value}) + patch(&FloatNode{Value: float64(a.Value) + b.Value}) } } { a := toFloat(n.Left) b := toInteger(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: a.Value + float64(b.Value)}) + patch(&FloatNode{Value: a.Value + float64(b.Value)}) } } { a := toFloat(n.Left) b := toFloat(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: a.Value + b.Value}) + patch(&FloatNode{Value: a.Value + b.Value}) } } { @@ -105,28 +88,28 @@ func (fold *fold) Visit(node *Node) { a := toInteger(n.Left) b := toInteger(n.Right) if a != nil && b != nil { - patchWithType(&IntegerNode{Value: a.Value - b.Value}) + patch(&IntegerNode{Value: a.Value - b.Value}) } } { a := toInteger(n.Left) b := toFloat(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: float64(a.Value) - b.Value}) + patch(&FloatNode{Value: float64(a.Value) - b.Value}) } } { a := toFloat(n.Left) b := toInteger(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: a.Value - float64(b.Value)}) + patch(&FloatNode{Value: a.Value - float64(b.Value)}) } } { a := toFloat(n.Left) b := toFloat(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: a.Value - b.Value}) + patch(&FloatNode{Value: a.Value - b.Value}) } } case "*": @@ -134,28 +117,28 @@ func (fold *fold) Visit(node *Node) { a := toInteger(n.Left) b := toInteger(n.Right) if a != nil && b != nil { - patchWithType(&IntegerNode{Value: a.Value * b.Value}) + patch(&IntegerNode{Value: a.Value * b.Value}) } } { a := toInteger(n.Left) b := toFloat(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: float64(a.Value) * b.Value}) + patch(&FloatNode{Value: float64(a.Value) * b.Value}) } } { a := toFloat(n.Left) b := toInteger(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: a.Value * float64(b.Value)}) + patch(&FloatNode{Value: a.Value * float64(b.Value)}) } } { a := toFloat(n.Left) b := toFloat(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: a.Value * b.Value}) + patch(&FloatNode{Value: a.Value * b.Value}) } } case "/": @@ -163,28 +146,28 @@ func (fold *fold) Visit(node *Node) { a := toInteger(n.Left) b := toInteger(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: float64(a.Value) / float64(b.Value)}) + patch(&FloatNode{Value: float64(a.Value) / float64(b.Value)}) } } { a := toInteger(n.Left) b := toFloat(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: float64(a.Value) / b.Value}) + patch(&FloatNode{Value: float64(a.Value) / b.Value}) } } { a := toFloat(n.Left) b := toInteger(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: a.Value / float64(b.Value)}) + patch(&FloatNode{Value: a.Value / float64(b.Value)}) } } { a := toFloat(n.Left) b := toFloat(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: a.Value / b.Value}) + patch(&FloatNode{Value: a.Value / b.Value}) } } case "%": @@ -205,28 +188,28 @@ func (fold *fold) Visit(node *Node) { a := toInteger(n.Left) b := toInteger(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: math.Pow(float64(a.Value), float64(b.Value))}) + patch(&FloatNode{Value: math.Pow(float64(a.Value), float64(b.Value))}) } } { a := toInteger(n.Left) b := toFloat(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: math.Pow(float64(a.Value), b.Value)}) + patch(&FloatNode{Value: math.Pow(float64(a.Value), b.Value)}) } } { a := toFloat(n.Left) b := toInteger(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: math.Pow(a.Value, float64(b.Value))}) + patch(&FloatNode{Value: math.Pow(a.Value, float64(b.Value))}) } } { a := toFloat(n.Left) b := toFloat(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: math.Pow(a.Value, b.Value)}) + patch(&FloatNode{Value: math.Pow(a.Value, b.Value)}) } } case "and", "&&": @@ -234,9 +217,9 @@ func (fold *fold) Visit(node *Node) { b := toBool(n.Right) if a != nil && a.Value { // true and x - patch(n.Right) + patchCopy(n.Right) } else if b != nil && b.Value { // x and true - patch(n.Left) + patchCopy(n.Left) } else if (a != nil && !a.Value) || (b != nil && !b.Value) { // "x and false" or "false and x" patch(&BoolNode{Value: false}) } @@ -245,9 +228,9 @@ func (fold *fold) Visit(node *Node) { b := toBool(n.Right) if a != nil && !a.Value { // false or x - patch(n.Right) + patchCopy(n.Right) } else if b != nil && !b.Value { // x or false - patch(n.Left) + patchCopy(n.Left) } else if (a != nil && a.Value) || (b != nil && b.Value) { // "x or true" or "true or x" patch(&BoolNode{Value: true}) } @@ -302,20 +285,21 @@ func (fold *fold) Visit(node *Node) { } case *BuiltinNode: + // TODO: Move this to a separate visitor filter_filter.go switch n.Name { case "filter": if len(n.Arguments) != 2 { return } if base, ok := n.Arguments[0].(*BuiltinNode); ok && base.Name == "filter" { - patch(&BuiltinNode{ + patchCopy(&BuiltinNode{ Name: "filter", Arguments: []Node{ base.Arguments[0], &BinaryNode{ Operator: "&&", - Left: base.Arguments[1], - Right: n.Arguments[1], + Left: base.Arguments[1].(*ClosureNode).Node, + Right: n.Arguments[1].(*ClosureNode).Node, }, }, }) diff --git a/optimizer/fold_test.go b/optimizer/fold_test.go new file mode 100644 index 000000000..d3f44fcf4 --- /dev/null +++ b/optimizer/fold_test.go @@ -0,0 +1,82 @@ +package optimizer_test + +import ( + "reflect" + "testing" + + "github.com/expr-lang/expr/ast" + "github.com/expr-lang/expr/internal/testify/assert" + "github.com/expr-lang/expr/internal/testify/require" + "github.com/expr-lang/expr/optimizer" + "github.com/expr-lang/expr/parser" +) + +func TestOptimize_constant_folding(t *testing.T) { + tree, err := parser.Parse(`[1,2,3][5*5-25]`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + expected := &ast.MemberNode{ + Node: &ast.ConstantNode{Value: []any{1, 2, 3}}, + Property: &ast.IntegerNode{Value: 0}, + } + + assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) +} + +func TestOptimize_constant_folding_with_floats(t *testing.T) { + tree, err := parser.Parse(`1 + 2.0 * ((1.0 * 2) / 2) - 0`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + expected := &ast.FloatNode{Value: 3.0} + + assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) + assert.Equal(t, reflect.Float64, tree.Node.Type().Kind()) +} + +func TestOptimize_constant_folding_with_bools(t *testing.T) { + tree, err := parser.Parse(`(true and false) or (true or false) or (false and false) or (true and (true == false))`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + expected := &ast.BoolNode{Value: true} + + assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) +} + +func TestOptimize_constant_folding_filter_filter(t *testing.T) { + tree, err := parser.Parse(`filter(filter(1..2, true), true)`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + expected := &ast.BuiltinNode{ + Name: "filter", + Arguments: []ast.Node{ + &ast.BinaryNode{ + Operator: "..", + Left: &ast.IntegerNode{ + Value: 1, + }, + Right: &ast.IntegerNode{ + Value: 2, + }, + }, + &ast.BoolNode{ + Value: true, + }, + }, + Throws: false, + Map: nil, + } + + assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) +} diff --git a/optimizer/in_array.go b/optimizer/in_array.go index 8933d9b91..e91320c0f 100644 --- a/optimizer/in_array.go +++ b/optimizer/in_array.go @@ -32,10 +32,12 @@ func (*inArray) Visit(node *Node) { for _, a := range array.Nodes { value[a.(*IntegerNode).Value] = struct{}{} } - Patch(node, &BinaryNode{ + m := &ConstantNode{Value: value} + m.SetType(reflect.TypeOf(value)) + patchCopyType(node, &BinaryNode{ Operator: n.Operator, Left: n.Left, - Right: &ConstantNode{Value: value}, + Right: m, }) } @@ -50,10 +52,12 @@ func (*inArray) Visit(node *Node) { for _, a := range array.Nodes { value[a.(*StringNode).Value] = struct{}{} } - Patch(node, &BinaryNode{ + m := &ConstantNode{Value: value} + m.SetType(reflect.TypeOf(value)) + patchCopyType(node, &BinaryNode{ Operator: n.Operator, Left: n.Left, - Right: &ConstantNode{Value: value}, + Right: m, }) } diff --git a/optimizer/in_range.go b/optimizer/in_range.go index 01faabbdf..ed2f557ea 100644 --- a/optimizer/in_range.go +++ b/optimizer/in_range.go @@ -22,7 +22,7 @@ func (*inRange) Visit(node *Node) { if rangeOp, ok := n.Right.(*BinaryNode); ok && rangeOp.Operator == ".." { if from, ok := rangeOp.Left.(*IntegerNode); ok { if to, ok := rangeOp.Right.(*IntegerNode); ok { - Patch(node, &BinaryNode{ + patchCopyType(node, &BinaryNode{ Operator: "and", Left: &BinaryNode{ Operator: ">=", diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index 4ceb3fa43..9a9677c1b 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -1,6 +1,9 @@ package optimizer import ( + "fmt" + "reflect" + . "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/conf" ) @@ -41,3 +44,36 @@ func Optimize(node *Node, config *conf.Config) error { Walk(node, &sumMap{}) return nil } + +var ( + boolType = reflect.TypeOf(true) + integerType = reflect.TypeOf(0) + floatType = reflect.TypeOf(float64(0)) + stringType = reflect.TypeOf("") +) + +func patchWithType(node *Node, newNode Node) { + switch n := newNode.(type) { + case *BoolNode: + newNode.SetType(boolType) + case *IntegerNode: + newNode.SetType(integerType) + case *FloatNode: + newNode.SetType(floatType) + case *StringNode: + newNode.SetType(stringType) + case *ConstantNode: + newNode.SetType(reflect.TypeOf(n.Value)) + case *BinaryNode: + newNode.SetType(n.Type()) + default: + panic(fmt.Sprintf("unknown type %T", newNode)) + } + Patch(node, newNode) +} + +func patchCopyType(node *Node, newNode Node) { + t := (*node).Type() + newNode.SetType(t) + Patch(node, newNode) +} diff --git a/optimizer/optimizer_test.go b/optimizer/optimizer_test.go index 0ff4b4901..56a890492 100644 --- a/optimizer/optimizer_test.go +++ b/optimizer/optimizer_test.go @@ -2,7 +2,6 @@ package optimizer_test import ( "fmt" - "reflect" "strings" "testing" @@ -92,46 +91,6 @@ func TestOptimize(t *testing.T) { } } -func TestOptimize_constant_folding(t *testing.T) { - tree, err := parser.Parse(`[1,2,3][5*5-25]`) - require.NoError(t, err) - - err = optimizer.Optimize(&tree.Node, nil) - require.NoError(t, err) - - expected := &ast.MemberNode{ - Node: &ast.ConstantNode{Value: []any{1, 2, 3}}, - Property: &ast.IntegerNode{Value: 0}, - } - - assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) -} - -func TestOptimize_constant_folding_with_floats(t *testing.T) { - tree, err := parser.Parse(`1 + 2.0 * ((1.0 * 2) / 2) - 0`) - require.NoError(t, err) - - err = optimizer.Optimize(&tree.Node, nil) - require.NoError(t, err) - - expected := &ast.FloatNode{Value: 3.0} - - assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) - assert.Equal(t, reflect.Float64, tree.Node.Type().Kind()) -} - -func TestOptimize_constant_folding_with_bools(t *testing.T) { - tree, err := parser.Parse(`(true and false) or (true or false) or (false and false) or (true and (true == false))`) - require.NoError(t, err) - - err = optimizer.Optimize(&tree.Node, nil) - require.NoError(t, err) - - expected := &ast.BoolNode{Value: true} - - assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) -} - func TestOptimize_in_array(t *testing.T) { config := conf.New(map[string]int{"v": 0}) diff --git a/optimizer/predicate_combination.go b/optimizer/predicate_combination.go index 6e8a7f7cf..62e296d1f 100644 --- a/optimizer/predicate_combination.go +++ b/optimizer/predicate_combination.go @@ -29,7 +29,7 @@ func (v *predicateCombination) Visit(node *Node) { }, } v.Visit(&closure.Node) - Patch(node, &BuiltinNode{ + patchCopyType(node, &BuiltinNode{ Name: left.Name, Arguments: []Node{ left.Arguments[0], diff --git a/optimizer/sum_array.go b/optimizer/sum_array.go index 0a05d1f2e..3c96795ef 100644 --- a/optimizer/sum_array.go +++ b/optimizer/sum_array.go @@ -14,7 +14,7 @@ func (*sumArray) Visit(node *Node) { len(sumBuiltin.Arguments) == 1 { if array, ok := sumBuiltin.Arguments[0].(*ArrayNode); ok && len(array.Nodes) >= 2 { - Patch(node, sumArrayFold(array)) + patchCopyType(node, sumArrayFold(array)) } } } diff --git a/optimizer/sum_map.go b/optimizer/sum_map.go index a41a53732..6de97d373 100644 --- a/optimizer/sum_map.go +++ b/optimizer/sum_map.go @@ -13,7 +13,7 @@ func (*sumMap) Visit(node *Node) { if mapBuiltin, ok := sumBuiltin.Arguments[0].(*BuiltinNode); ok && mapBuiltin.Name == "map" && len(mapBuiltin.Arguments) == 2 { - Patch(node, &BuiltinNode{ + patchCopyType(node, &BuiltinNode{ Name: "sum", Arguments: []Node{ mapBuiltin.Arguments[0], diff --git a/test/fuzz/fuzz_corpus.txt b/test/fuzz/fuzz_corpus.txt index 7b0174923..59349d3d5 100644 --- a/test/fuzz/fuzz_corpus.txt +++ b/test/fuzz/fuzz_corpus.txt @@ -6262,9 +6262,7 @@ get(array, true ? i : 1) get(false ? "bar" : f64, list) get(false ? "bar" : i32, greet) get(false ? "foo" : i, ok) -get(false ? 0.5 : i64, true not in String) get(false ? 1 : greet, ok) -get(false ? div : true, min("foo")) get(false ? greet : score, score) get(false ? half : list, i32) get(false ? list : 1, i32) @@ -6350,11 +6348,7 @@ get(map(list, true), i64) get(ok ? 0.5 : add, f32) get(ok ? add : 0.5, ok) get(ok ? add : false, div) -get(ok ? array : 0.5, Bar) get(ok ? array : i, add) -get(ok ? div : foo, filter("bar", ok)) -get(ok ? greet : 0.5, foo in String) -get(ok ? greet : f32, Bar) get(ok ? i : f64, greet)?.f64 get(ok ? i32 : "foo", score) get(ok ? i32 : add, div) @@ -6365,14 +6359,11 @@ get(sort(array), i32) get(true ? 0.5 : score, foo) get(true ? 1 : f64, foo) get(true ? 1 : greet, greet)?.foo -get(true ? 1 : score, String?.Qux()) get(true ? add : score, i) get(true ? array : 1, half) get(true ? array : i64, ok) get(true ? div : foo, list) get(true ? f32 : 1, greet) -get(true ? f64 : greet, String) -get(true ? f64 : ok, Bar?.ok) get(true ? i : f64, add) get(true ? i64 : score, i) get(true ? list : ok, ok) @@ -7595,7 +7586,6 @@ i32 != max(f64, 0.5) i32 != min(1, 0.5, f64) i32 != min(i) i32 != nil ? half : 1 -i32 != nil ? list : half(nil) i32 % -i i32 % -i32 i32 % 1 != i64 @@ -9647,7 +9637,6 @@ map([i64], #) map([i], max(#)) map([list, nil], i64) map([list], #) -map([nil], #?.div(#)) map([score, "bar"], #) map([score], #) map([score], array) diff --git a/test/fuzz/fuzz_expr_seed_corpus.zip b/test/fuzz/fuzz_expr_seed_corpus.zip index 0fd46ab3a..d812de92e 100644 Binary files a/test/fuzz/fuzz_expr_seed_corpus.zip and b/test/fuzz/fuzz_expr_seed_corpus.zip differ diff --git a/testdata/examples.txt b/testdata/examples.txt index 9c18442e4..e4f6d36c8 100644 --- a/testdata/examples.txt +++ b/testdata/examples.txt @@ -2041,7 +2041,6 @@ 0.5 < 1 || ok 0.5 < f32 && ok 0.5 < f32 || ok -0.5 <= 0.5 ? half : trimPrefix(nil) 0.5 <= f64 ? i : i64 0.5 <= f64 || ok 0.5 == 1 and ok @@ -3144,7 +3143,6 @@ all(reduce(array, array), # > #) all(reduce(array, list), ok) all(true ? list : false, not true) any(["bar"], # not endsWith #) -any([foo], # != i64) any([greet], ok) any(array, !(i32 <= #)) any(array, "foo" >= "foo") @@ -6261,7 +6259,6 @@ findLast(map(list, #), greet == nil) findLast(map(list, #), ok) findLast(map(list, ok), # ? # : #) findLastIndex(1 .. 1, i64 < #) -findLastIndex([false], 0.5 == #) findLastIndex([i64, half], ok) findLastIndex(array, "bar" not matches "foo") findLastIndex(array, "bar" startsWith "foo") @@ -7231,7 +7228,6 @@ get(false ? f64 : 1, ok) get(false ? f64 : score, add) get(false ? false : f32, i) get(false ? i32 : list, i64) -get(false ? i64 : foo, Bar) get(false ? i64 : true, f64) get(false ? score : ok, trimSuffix("bar", "bar")) get(filter(list, true), i) @@ -7370,27 +7366,19 @@ get(map(list, i32), i64) get(map(list, i64), i64) get(ok ? "bar" : false, array) get(ok ? 1 : "bar", f32) -get(ok ? 1 : add, String) get(ok ? add : f32, array) -get(ok ? array : i, Bar) -get(ok ? f32 : list, get(half, nil)) get(ok ? f64 : div, i64) -get(ok ? f64 : list, half in Bar)?.array() get(ok ? false : list, f32 > i) get(ok ? foo : 0.5, div) -get(ok ? greet : "bar", get(add, String)) get(ok ? half : i32, i) get(ok ? half : ok, f64) get(ok ? i : 0.5, half) -get(ok ? i : half, String) get(ok ? i32 : half, f64) get(ok ? i64 : foo, f32) get(ok ? list : "foo", add) get(ok ? list : i32, f32) get(ok ? ok : div, greet) -get(ok ? score : 1, Qux?.i) get(ok ? score : f64, i) -get(ok ? score : foo, String?.foo()) get(ok ? score : i64, foo) get(reduce(list, array), i32) get(sort(array), i32) @@ -7398,19 +7386,14 @@ get(take(list, i), i64) get(true ? "bar" : ok, score(i)) get(true ? "foo" : half, list) get(true ? 0.5 : i32, array) -get(true ? 1 : array, Bar)?.f64 get(true ? f32 : 0.5, ok) get(true ? false : foo, i64 > 0.5) -get(true ? greet : false, Bar) get(true ? greet : i32, score) -get(true ? half : 0.5, Qux) get(true ? half : f32, greet) get(true ? half : list, add) get(true ? i64 : greet, i32) get(true ? score : true, half)?.half -get(true ? true : add, Qux) get(true ? true : i, f64) -get({"bar": true, "foo": greet}.i, String) get({"foo": foo, "bar": false}, type(i)) greet greet != add @@ -12465,7 +12448,6 @@ map(1 .. i64, # ^ #) map(1 .. i64, #) map(1 .. i64, half) map(1 .. i64, i32) -map([1], get(#, 1)) map([f64], half) map([false], ok) map([half], #) @@ -12475,7 +12457,6 @@ map([i32], foo) map([i32], greet) map([i32], half) map([list, 1, foo], i32) -map([nil], #?.f32(ok, #, false)) map([nil], foo) map([score, "bar"], f32) map([true, i32, 1], #) @@ -14870,7 +14851,6 @@ ok ? array : foo ok ? array : foo.String ok ? array : foo?.String ok ? array : greet -ok ? array : greet(nil) ok ? array : i ok ? array : i32 ok ? array : list @@ -14890,7 +14870,6 @@ ok ? div : ok ok ? div : score ok ? f32 != i32 : f32 ok ? f32 ** i64 : score -ok ? f32 : "foo" endsWith lower(nil) ok ? f32 : -0.5 ok ? f32 : add ok ? f32 : array @@ -15104,7 +15083,6 @@ ok or f64 <= 1 ok or false ? ok : 1 ok or false and false ok or foo != foo -ok or fromJSON(nil)?.ok ok or half != add ok or half == div ok or i < 1 @@ -15340,29 +15318,15 @@ reduce(1 .. i64, greet) reduce(1 .. i64, half) reduce(["foo"], f32) reduce([0.5], f64) -reduce([1], .Qux) reduce([div, "foo"], i32) -reduce([div], # >= 0.5) -reduce([div], list == #) reduce([f32, f64], f32) -reduce([f32], .f64) -reduce([f64], # or #) -reduce([f64], .greet) -reduce([false], # ** #) reduce([foo], ok) -reduce([greet], # % 1) reduce([i32], list) reduce([i64, half], "bar" not startsWith "foo") reduce([i], # < #) -reduce([i], #[array]) reduce([i], ok) -reduce([list], # - #) -reduce([list], # or #) reduce([list], f64) reduce([nil], # in array) -reduce([ok], #?.half) -reduce([true], #?.score()) -reduce([true], 0.5 > #) reduce(array, !false) reduce(array, "foo" endsWith "foo") reduce(array, "foo") not in foo