Skip to content

Commit

Permalink
Invoke the Deref function as needed for the function arguments. (#651)
Browse files Browse the repository at this point in the history
  • Loading branch information
ckganesan authored May 21, 2024
1 parent c6c7227 commit 596f54f
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 8 deletions.
2 changes: 1 addition & 1 deletion checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ func (v *checker) checkArguments(
continue
}

if !t.AssignableTo(in) && kind(t) != reflect.Interface {
if !(t.AssignableTo(in) || deref.Type(t).AssignableTo(in)) && kind(t) != reflect.Interface {
return anyType, &file.Error{
Location: arg.Location(),
Message: fmt.Sprintf("cannot use %v as argument (type %v) to call %v ", t, in, name),
Expand Down
48 changes: 41 additions & 7 deletions compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,8 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) {
}

func (c *compiler) equalBinaryNode(node *ast.BinaryNode) {
l := kind(node.Left)
r := kind(node.Right)
l := kind(node.Left.Type())
r := kind(node.Right.Type())

leftIsSimple := isSimpleType(node.Left)
rightIsSimple := isSimpleType(node.Right)
Expand Down Expand Up @@ -727,9 +727,44 @@ func (c *compiler) SliceNode(node *ast.SliceNode) {
}

func (c *compiler) CallNode(node *ast.CallNode) {
for _, arg := range node.Arguments {
c.compile(arg)
fn := node.Callee.Type()
if kind(fn) == reflect.Func {
fnInOffset := 0
fnNumIn := fn.NumIn()
switch callee := node.Callee.(type) {
case *ast.MemberNode:
if prop, ok := callee.Property.(*ast.StringNode); ok {
if _, ok = callee.Node.Type().MethodByName(prop.Value); ok && callee.Node.Type().Kind() != reflect.Interface {
fnInOffset = 1
fnNumIn--
}
}
case *ast.IdentifierNode:
if t, ok := c.config.Types[callee.Value]; ok && t.Method {
fnInOffset = 1
fnNumIn--
}
}
for i, arg := range node.Arguments {
c.compile(arg)
if k := kind(arg.Type()); k == reflect.Ptr || k == reflect.Interface {
var in reflect.Type
if fn.IsVariadic() && i >= fnNumIn-1 {
in = fn.In(fn.NumIn() - 1).Elem()
} else {
in = fn.In(i + fnInOffset)
}
if k = kind(in); k != reflect.Ptr && k != reflect.Interface {
c.emit(OpDeref)
}
}
}
} else {
for _, arg := range node.Arguments {
c.compile(arg)
}
}

if ident, ok := node.Callee.(*ast.IdentifierNode); ok {
if c.config != nil {
if fn, ok := c.config.Functions[ident.Value]; ok {
Expand Down Expand Up @@ -1162,7 +1197,7 @@ func (c *compiler) PairNode(node *ast.PairNode) {
}

func (c *compiler) derefInNeeded(node ast.Node) {
switch kind(node) {
switch kind(node.Type()) {
case reflect.Ptr, reflect.Interface:
c.emit(OpDeref)
}
Expand All @@ -1181,8 +1216,7 @@ func (c *compiler) optimize() {
}
}

func kind(node ast.Node) reflect.Kind {
t := node.Type()
func kind(t reflect.Type) reflect.Kind {
if t == nil {
return reflect.Invalid
}
Expand Down
67 changes: 67 additions & 0 deletions test/deref/deref_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package deref_test
import (
"context"
"testing"
"time"

"github.com/expr-lang/expr/internal/testify/assert"
"github.com/expr-lang/expr/internal/testify/require"
Expand Down Expand Up @@ -253,3 +254,69 @@ func TestDeref_fetch_from_interface_mix_pointer(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "waldo", res)
}

func TestDeref_func_args(t *testing.T) {
i := 20
env := map[string]any{
"var": &i,
"fn": func(p int) int {
return p + 1
},
}

program, err := expr.Compile(`fn(var) + fn(var + 0)`, expr.Env(env))
require.NoError(t, err)

out, err := expr.Run(program, env)
require.NoError(t, err)
require.Equal(t, 42, out)
}

func TestDeref_struct_func_args(t *testing.T) {
n, _ := time.Parse(time.RFC3339, "2024-05-12T18:30:00+00:00")
duration := 30 * time.Minute
env := map[string]any{
"time": n,
"duration": &duration,
}

program, err := expr.Compile(`time.Add(duration).Format('2006-01-02T15:04:05Z07:00')`, expr.Env(env))
require.NoError(t, err)

out, err := expr.Run(program, env)
require.NoError(t, err)
require.Equal(t, "2024-05-12T19:00:00Z", out)
}

func TestDeref_ignore_func_args(t *testing.T) {
f := foo(1)
env := map[string]any{
"foo": &f,
"fn": func(f *foo) int {
return f.Bar()
},
}

program, err := expr.Compile(`fn(foo)`, expr.Env(env))
require.NoError(t, err)

out, err := expr.Run(program, env)
require.NoError(t, err)
require.Equal(t, 42, out)
}

func TestDeref_ignore_struct_func_args(t *testing.T) {
n := time.Now()
location, _ := time.LoadLocation("UTC")
env := map[string]any{
"time": n,
"location": location,
}

program, err := expr.Compile(`time.In(location).Location().String()`, expr.Env(env))
require.NoError(t, err)

out, err := expr.Run(program, env)
require.NoError(t, err)
require.Equal(t, "UTC", out)
}

0 comments on commit 596f54f

Please sign in to comment.