Skip to content

Commit ce66a2b

Browse files
authored
Merge branch 'master' into fix/vm-jump-offset
2 parents b875d6c + f423c10 commit ce66a2b

File tree

19 files changed

+507
-31
lines changed

19 files changed

+507
-31
lines changed

builtin/builtin.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package builtin
33
import (
44
"encoding/base64"
55
"encoding/json"
6+
"errors"
67
"fmt"
78
"reflect"
89
"sort"
@@ -16,6 +17,10 @@ import (
1617
var (
1718
Index map[string]int
1819
Names []string
20+
21+
// MaxDepth limits the recursion depth for nested structures.
22+
MaxDepth = 10000
23+
ErrorMaxDepth = errors.New("recursion depth exceeded")
1924
)
2025

2126
func init() {
@@ -377,7 +382,7 @@ var Builtins = []*Function{
377382
{
378383
Name: "max",
379384
Func: func(args ...any) (any, error) {
380-
return minMax("max", runtime.Less, args...)
385+
return minMax("max", runtime.Less, 0, args...)
381386
},
382387
Validate: func(args []reflect.Type) (reflect.Type, error) {
383388
return validateAggregateFunc("max", args)
@@ -386,7 +391,7 @@ var Builtins = []*Function{
386391
{
387392
Name: "min",
388393
Func: func(args ...any) (any, error) {
389-
return minMax("min", runtime.More, args...)
394+
return minMax("min", runtime.More, 0, args...)
390395
},
391396
Validate: func(args []reflect.Type) (reflect.Type, error) {
392397
return validateAggregateFunc("min", args)
@@ -395,7 +400,7 @@ var Builtins = []*Function{
395400
{
396401
Name: "mean",
397402
Func: func(args ...any) (any, error) {
398-
count, sum, err := mean(args...)
403+
count, sum, err := mean(0, args...)
399404
if err != nil {
400405
return nil, err
401406
}
@@ -411,7 +416,7 @@ var Builtins = []*Function{
411416
{
412417
Name: "median",
413418
Func: func(args ...any) (any, error) {
414-
values, err := median(args...)
419+
values, err := median(0, args...)
415420
if err != nil {
416421
return nil, err
417422
}
@@ -940,7 +945,10 @@ var Builtins = []*Function{
940945
if v.Kind() != reflect.Array && v.Kind() != reflect.Slice {
941946
return nil, size, fmt.Errorf("cannot flatten %s", v.Kind())
942947
}
943-
ret := flatten(v)
948+
ret, err := flatten(v, 0)
949+
if err != nil {
950+
return nil, 0, err
951+
}
944952
size = uint(len(ret))
945953
return ret, size, nil
946954
},

builtin/builtin_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,3 +722,100 @@ func TestBuiltin_with_deref(t *testing.T) {
722722
})
723723
}
724724
}
725+
726+
func TestBuiltin_flatten_recursion(t *testing.T) {
727+
var s []any
728+
s = append(s, &s) // s contains a pointer to itself
729+
730+
env := map[string]any{
731+
"arr": s,
732+
}
733+
734+
program, err := expr.Compile("flatten(arr)", expr.Env(env))
735+
require.NoError(t, err)
736+
737+
_, err = expr.Run(program, env)
738+
require.Error(t, err)
739+
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
740+
}
741+
742+
func TestBuiltin_flatten_recursion_slice(t *testing.T) {
743+
s := make([]any, 1)
744+
s[0] = s
745+
746+
env := map[string]any{
747+
"arr": s,
748+
}
749+
750+
program, err := expr.Compile("flatten(arr)", expr.Env(env))
751+
require.NoError(t, err)
752+
753+
_, err = expr.Run(program, env)
754+
require.Error(t, err)
755+
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
756+
}
757+
758+
func TestBuiltin_numerical_recursion(t *testing.T) {
759+
s := make([]any, 1)
760+
s[0] = s
761+
762+
env := map[string]any{
763+
"arr": s,
764+
}
765+
766+
tests := []string{
767+
"max(arr)",
768+
"min(arr)",
769+
"mean(arr)",
770+
"median(arr)",
771+
}
772+
773+
for _, input := range tests {
774+
t.Run(input, func(t *testing.T) {
775+
program, err := expr.Compile(input, expr.Env(env))
776+
require.NoError(t, err)
777+
778+
_, err = expr.Run(program, env)
779+
require.Error(t, err)
780+
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
781+
})
782+
}
783+
}
784+
785+
func TestBuiltin_recursion_custom_max_depth(t *testing.T) {
786+
originalMaxDepth := builtin.MaxDepth
787+
defer func() {
788+
builtin.MaxDepth = originalMaxDepth
789+
}()
790+
791+
// Set a small depth limit
792+
builtin.MaxDepth = 2
793+
794+
// Create a deeply nested array (depth 5)
795+
// [1, [2, [3, [4, [5]]]]]
796+
arr := []any{1, []any{2, []any{3, []any{4, []any{5}}}}}
797+
798+
env := map[string]any{
799+
"arr": arr,
800+
}
801+
802+
t.Run("flatten exceeds max depth", func(t *testing.T) {
803+
program, err := expr.Compile("flatten(arr)", expr.Env(env))
804+
require.NoError(t, err)
805+
806+
_, err = expr.Run(program, env)
807+
require.Error(t, err)
808+
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
809+
})
810+
811+
t.Run("flatten within max depth", func(t *testing.T) {
812+
// Depth 2: [1, [2]]
813+
shallowArr := []any{1, []any{2}}
814+
envShallow := map[string]any{"arr": shallowArr}
815+
program, err := expr.Compile("flatten(arr)", expr.Env(envShallow))
816+
require.NoError(t, err)
817+
818+
_, err = expr.Run(program, envShallow)
819+
require.NoError(t, err)
820+
})
821+
}

builtin/lib.go

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -253,15 +253,18 @@ func String(arg any) any {
253253
return fmt.Sprintf("%v", arg)
254254
}
255255

256-
func minMax(name string, fn func(any, any) bool, args ...any) (any, error) {
256+
func minMax(name string, fn func(any, any) bool, depth int, args ...any) (any, error) {
257+
if depth > MaxDepth {
258+
return nil, ErrorMaxDepth
259+
}
257260
var val any
258261
for _, arg := range args {
259262
rv := reflect.ValueOf(arg)
260263
switch rv.Kind() {
261264
case reflect.Array, reflect.Slice:
262265
size := rv.Len()
263266
for i := 0; i < size; i++ {
264-
elemVal, err := minMax(name, fn, rv.Index(i).Interface())
267+
elemVal, err := minMax(name, fn, depth+1, rv.Index(i).Interface())
265268
if err != nil {
266269
return nil, err
267270
}
@@ -294,7 +297,10 @@ func minMax(name string, fn func(any, any) bool, args ...any) (any, error) {
294297
return val, nil
295298
}
296299

297-
func mean(args ...any) (int, float64, error) {
300+
func mean(depth int, args ...any) (int, float64, error) {
301+
if depth > MaxDepth {
302+
return 0, 0, ErrorMaxDepth
303+
}
298304
var total float64
299305
var count int
300306

@@ -304,7 +310,7 @@ func mean(args ...any) (int, float64, error) {
304310
case reflect.Array, reflect.Slice:
305311
size := rv.Len()
306312
for i := 0; i < size; i++ {
307-
elemCount, elemSum, err := mean(rv.Index(i).Interface())
313+
elemCount, elemSum, err := mean(depth+1, rv.Index(i).Interface())
308314
if err != nil {
309315
return 0, 0, err
310316
}
@@ -327,7 +333,10 @@ func mean(args ...any) (int, float64, error) {
327333
return count, total, nil
328334
}
329335

330-
func median(args ...any) ([]float64, error) {
336+
func median(depth int, args ...any) ([]float64, error) {
337+
if depth > MaxDepth {
338+
return nil, ErrorMaxDepth
339+
}
331340
var values []float64
332341

333342
for _, arg := range args {
@@ -336,7 +345,7 @@ func median(args ...any) ([]float64, error) {
336345
case reflect.Array, reflect.Slice:
337346
size := rv.Len()
338347
for i := 0; i < size; i++ {
339-
elems, err := median(rv.Index(i).Interface())
348+
elems, err := median(depth+1, rv.Index(i).Interface())
340349
if err != nil {
341350
return nil, err
342351
}
@@ -355,18 +364,24 @@ func median(args ...any) ([]float64, error) {
355364
return values, nil
356365
}
357366

358-
func flatten(arg reflect.Value) []any {
367+
func flatten(arg reflect.Value, depth int) ([]any, error) {
368+
if depth > MaxDepth {
369+
return nil, ErrorMaxDepth
370+
}
359371
ret := []any{}
360372
for i := 0; i < arg.Len(); i++ {
361373
v := deref.Value(arg.Index(i))
362374
if v.Kind() == reflect.Array || v.Kind() == reflect.Slice {
363-
x := flatten(v)
375+
x, err := flatten(v, depth+1)
376+
if err != nil {
377+
return nil, err
378+
}
364379
ret = append(ret, x...)
365380
} else {
366381
ret = append(ret, v.Interface())
367382
}
368383
}
369-
return ret
384+
return ret, nil
370385
}
371386

372387
func get(params ...any) (out any, err error) {

checker/checker.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,13 @@ func (v *Checker) conditionalNode(node *ast.ConditionalNode) Nature {
12771277
return v.config.NtCache.NatureOf(nil)
12781278
}
12791279
if t1.AssignableTo(t2) {
1280+
if t1.IsArray() && t2.IsArray() {
1281+
e1 := t1.Elem(&v.config.NtCache)
1282+
e2 := t2.Elem(&v.config.NtCache)
1283+
if !e1.AssignableTo(e2) || !e2.AssignableTo(e1) {
1284+
return v.config.NtCache.FromType(arrayType)
1285+
}
1286+
}
12801287
return t1
12811288
}
12821289
return Nature{}

checker/checker_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ func TestCheck(t *testing.T) {
134134
{"Bool ?? Bool"},
135135
{"let foo = 1; foo == 1"},
136136
{"(Embed).EmbedPointerEmbedInt > 0"},
137+
{"(true ? [1] : [[1]])[0][0] == 1"},
137138
}
138139

139140
c := new(checker.Checker)

compiler/compiler.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro
5353
c.emit(OpCast, 1)
5454
case reflect.Float64:
5555
c.emit(OpCast, 2)
56+
case reflect.Bool:
57+
c.emit(OpCast, 3)
5658
}
5759
if c.config.Optimize {
5860
c.optimize()
@@ -1103,7 +1105,9 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
11031105
if f.Fast != nil {
11041106
c.emit(OpCallBuiltin1, id)
11051107
} else if f.Safe != nil {
1106-
c.emit(OpPush, c.addConstant(f.Safe))
1108+
id := c.addConstant(f.Safe)
1109+
c.emit(OpPush, id)
1110+
c.debugInfo[fmt.Sprintf("const_%d", id)] = node.Name
11071111
c.emit(OpCallSafe, len(node.Arguments))
11081112
} else if f.Func != nil {
11091113
c.emitFunction(f, len(node.Arguments))

compiler/compiler_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package compiler_test
22

33
import (
44
"math"
5+
"reflect"
56
"testing"
67

78
"github.com/expr-lang/expr/internal/testify/assert"
@@ -675,3 +676,50 @@ func TestCompile_call_on_nil(t *testing.T) {
675676
require.Error(t, err)
676677
require.Contains(t, err.Error(), "foo is nil; cannot call nil as function")
677678
}
679+
680+
func TestCompile_Expect(t *testing.T) {
681+
tests := []struct {
682+
input string
683+
option expr.Option
684+
op vm.Opcode
685+
arg int
686+
}{
687+
{
688+
input: "1",
689+
option: expr.AsKind(reflect.Int),
690+
op: vm.OpCast,
691+
arg: 0,
692+
},
693+
{
694+
input: "1",
695+
option: expr.AsInt64(),
696+
op: vm.OpCast,
697+
arg: 1,
698+
},
699+
{
700+
input: "1",
701+
option: expr.AsFloat64(),
702+
op: vm.OpCast,
703+
arg: 2,
704+
},
705+
{
706+
input: "true",
707+
option: expr.AsBool(),
708+
op: vm.OpCast,
709+
arg: 3,
710+
},
711+
}
712+
713+
for _, tt := range tests {
714+
t.Run(tt.input, func(t *testing.T) {
715+
program, err := expr.Compile(tt.input, tt.option)
716+
require.NoError(t, err)
717+
718+
lastOp := program.Bytecode[len(program.Bytecode)-1]
719+
lastArg := program.Arguments[len(program.Arguments)-1]
720+
721+
assert.Equal(t, tt.op, lastOp)
722+
assert.Equal(t, tt.arg, lastArg)
723+
})
724+
}
725+
}

optimizer/fold.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,12 @@ func (fold *fold) Visit(node *Node) {
296296
Name: "filter",
297297
Arguments: []Node{
298298
base.Arguments[0],
299-
&BinaryNode{
300-
Operator: "&&",
301-
Left: base.Arguments[1].(*PredicateNode).Node,
302-
Right: n.Arguments[1].(*PredicateNode).Node,
299+
&PredicateNode{
300+
Node: &BinaryNode{
301+
Operator: "&&",
302+
Left: base.Arguments[1].(*PredicateNode).Node,
303+
Right: n.Arguments[1].(*PredicateNode).Node,
304+
},
303305
},
304306
},
305307
})

optimizer/fold_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,10 @@ func TestOptimize_constant_folding_filter_filter(t *testing.T) {
7070
Value: 2,
7171
},
7272
},
73-
&ast.BoolNode{
74-
Value: true,
73+
&ast.PredicateNode{
74+
Node: &ast.BoolNode{
75+
Value: true,
76+
},
7577
},
7678
},
7779
Throws: false,

0 commit comments

Comments
 (0)