diff --git a/checker/checker.go b/checker/checker.go index 3dc4e95a..1d62aef3 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -964,6 +964,10 @@ func (v *checker) checkArguments( continue } + if kind(in) != reflect.Ptr && kind(t) == reflect.Ptr { + t = deref(t) + } + if !t.AssignableTo(in) && kind(t) != reflect.Interface { return anyType, &file.Error{ Location: arg.Location(), diff --git a/compiler/compiler.go b/compiler/compiler.go index ac11805e..867ce261 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -363,8 +363,8 @@ func (c *compiler) UnaryNode(node *ast.UnaryNode) { } func (c *compiler) BinaryNode(node *ast.BinaryNode) { - l := kind(node.Left) - r := kind(node.Right) + l := kind(node.Left.Type()) + r := kind(node.Right.Type()) leftIsSimple := isSimpleType(node.Left) rightIsSimple := isSimpleType(node.Right) @@ -650,8 +650,23 @@ func (c *compiler) SliceNode(node *ast.SliceNode) { } func (c *compiler) CallNode(node *ast.CallNode) { - for _, arg := range node.Arguments { + fn := node.Callee.Type() + + for i, arg := range node.Arguments { c.compile(arg) + + // Check func in argument type and deref arg in needed. + if kind(fn) == reflect.Func { + var in reflect.Type + if fn.IsVariadic() && i >= fn.NumIn()-1 { + in = fn.In(fn.NumIn() - 1).Elem() + } else { + in = fn.In(i) + } + if kind(in) != reflect.Ptr && kind(arg.Type()) == reflect.Ptr { + c.emit(OpDeref) + } + } } if ident, ok := node.Callee.(*ast.IdentifierNode); ok { if c.config != nil { @@ -1044,14 +1059,13 @@ func (c *compiler) PairNode(node *ast.PairNode) { } func (c *compiler) derefInNeeded(node ast.Node) { - switch kind(node) { + switch kind(node.Type()) { case reflect.Ptr, reflect.Interface: c.emit(OpDeref) } } -func kind(node ast.Node) reflect.Kind { - t := node.Type() +func kind(t reflect.Type) reflect.Kind { if t == nil { return reflect.Invalid } diff --git a/test/deref/deref_test.go b/test/deref/deref_test.go index 684794a0..fe6fca86 100644 --- a/test/deref/deref_test.go +++ b/test/deref/deref_test.go @@ -237,3 +237,37 @@ func TestDeref_сommutative(t *testing.T) { }) } } + +func TestDeref_func_args(t *testing.T) { + i := 20 + env := map[string]any{ + "var": &i, + "fn": func(p int) int { + return p + 1 + }, + } + + program, err := expr.Compile(`fn(var) + fn(var + 0)`, expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, 42, out) +} + +func TestDeref_func_args_not_needed(t *testing.T) { + f := foo(1) + env := map[string]any{ + "foo": &f, + "fn": func(f *foo) int { + return f.Bar() + }, + } + + program, err := expr.Compile(`fn(foo)`, expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, 42, out) +}