diff --git a/ast/term.go b/ast/term.go index 1dc5bd1bbf..ce254685b1 100644 --- a/ast/term.go +++ b/ast/term.go @@ -1076,26 +1076,30 @@ type QueryIterator func(map[Var]Value, Value) error // ArrayTerm creates a new Term with an Array value. func ArrayTerm(a ...*Term) *Term { - return &Term{Value: &Array{a, 0}} + return &Term{Value: &Array{elems: a, hash: 0, ground: termSliceIsGround(a)}} } // NewArray creates an Array with the terms provided. The array will // use the provided term slice. func NewArray(a ...*Term) *Array { - return &Array{a, 0} + return &Array{elems: a, hash: 0, ground: termSliceIsGround(a)} } // Array represents an array as defined by the language. Arrays are similar to the // same types as defined by JSON with the exception that they can contain Vars // and References. type Array struct { - elems []*Term - hash int + elems []*Term + hash int + ground bool } // Copy returns a deep copy of arr. func (arr *Array) Copy() *Array { - return &Array{termSliceCopy(arr.elems), arr.hash} + return &Array{ + elems: termSliceCopy(arr.elems), + hash: arr.hash, + ground: arr.IsGround()} } // Equal returns true if arr is equal to other. @@ -1170,13 +1174,13 @@ func (arr *Array) Hash() int { // IsGround returns true if all of the Array elements are ground. func (arr *Array) IsGround() bool { - return termSliceIsGround(arr.elems) + return arr.ground } // MarshalJSON returns JSON encoded bytes representing arr. func (arr *Array) MarshalJSON() ([]byte, error) { if len(arr.elems) == 0 { - return json.Marshal([]interface{}{}) + return []byte(`[]`), nil } return json.Marshal(arr.elems) } @@ -1206,6 +1210,7 @@ func (arr *Array) Elem(i int) *Term { // set sets the element i of arr. func (arr *Array) set(i int, v *Term) { + arr.ground = arr.ground && v.IsGround() arr.elems[i] = v arr.hash = 0 } @@ -1215,11 +1220,16 @@ func (arr *Array) set(i int, v *Term) { // copy and any modifications to either of arrays may be reflected to // the other. func (arr *Array) Slice(i, j int) *Array { + var elems []*Term if j == -1 { - return &Array{elems: arr.elems[i:]} + elems = arr.elems[i:] + } else { + elems = arr.elems[i:j] } - - return &Array{elems: arr.elems[i:j]} + // If arr is ground, the slice is, too. + // If it's not, the slice could still be. + gr := arr.ground || termSliceIsGround(elems) + return &Array{elems: elems, ground: gr} } // Iter calls f on each element in arr. If f returns an error, @@ -1257,6 +1267,7 @@ func (arr *Array) Append(v *Term) *Array { cpy := *arr cpy.elems = append(arr.elems, v) cpy.hash = 0 + cpy.ground = arr.ground && v.IsGround() return &cpy } @@ -1509,7 +1520,7 @@ func (s *set) Len() int { // MarshalJSON returns JSON encoded bytes representing s. func (s *set) MarshalJSON() ([]byte, error) { if s.keys == nil { - return json.Marshal([]interface{}{}) + return []byte(`[]`), nil } return json.Marshal(s.keys) } diff --git a/test/cases/testdata/walkbuiltin/test-walkbuiltin-0971.yaml b/test/cases/testdata/walkbuiltin/test-walkbuiltin-0971.yaml index adede97f41..95d0b69cd3 100644 --- a/test/cases/testdata/walkbuiltin/test-walkbuiltin-0971.yaml +++ b/test/cases/testdata/walkbuiltin/test-walkbuiltin-0971.yaml @@ -5,83 +5,6 @@ cases: - 2 - 3 - 4 - b: - v1: hello - v2: goodbye - c: - - x: - - true - - false - - foo - "y": - - null - - 3.14159 - z: - p: true - q: false - d: - e: - - bar - - baz - f: - - xs: - - 1 - ys: - - 2 - - xs: - - 2 - ys: - - 3 - g: - a: - - 1 - - 0 - - 0 - - 0 - b: - - 0 - - 2 - - 0 - - 0 - c: - - 0 - - 0 - - 0 - - 4 - h: - - - 1 - - 2 - - 3 - - - 2 - - 3 - - 4 - l: - - a: bob - b: -1 - c: - - 1 - - 2 - - 3 - - 4 - - a: alice - b: 1 - c: - - 2 - - 3 - - 4 - - 5 - d: null - m: [] - numbers: - - "1" - - "2" - - "3" - - "4" - strings: - bar: 2 - baz: 3 - foo: 1 - three: 3 modules: - | package generated diff --git a/topdown/bindings.go b/topdown/bindings.go index 52321bb139..7621dac529 100644 --- a/topdown/bindings.go +++ b/topdown/bindings.go @@ -87,6 +87,9 @@ func (u *bindings) plugNamespaced(a *ast.Term, caller *bindings) *ast.Term { } return u.namespaceVar(b, caller) case *ast.Array: + if a.IsGround() { + return a + } cpy := *a arr := make([]*ast.Term, v.Len()) for i := 0; i < len(arr); i++ { @@ -104,6 +107,9 @@ func (u *bindings) plugNamespaced(a *ast.Term, caller *bindings) *ast.Term { }) return &cpy case ast.Set: + if a.IsGround() { + return a + } cpy := *a cpy.Value, _ = v.Map(func(x *ast.Term) (*ast.Term, error) { return u.plugNamespaced(x, caller), nil @@ -242,6 +248,9 @@ func (vis namespacingVisitor) namespaceTerm(a *ast.Term) *ast.Term { case ast.Var: return vis.b.namespaceVar(a, vis.caller) case *ast.Array: + if a.IsGround() { + return a + } cpy := *a arr := make([]*ast.Term, v.Len()) for i := 0; i < len(arr); i++ { @@ -259,6 +268,9 @@ func (vis namespacingVisitor) namespaceTerm(a *ast.Term) *ast.Term { }) return &cpy case ast.Set: + if a.IsGround() { + return a + } cpy := *a cpy.Value, _ = v.Map(func(x *ast.Term) (*ast.Term, error) { return vis.namespaceTerm(x), nil diff --git a/topdown/topdown_bench_test.go b/topdown/topdown_bench_test.go index d18e465a8d..1d128ac386 100644 --- a/topdown/topdown_bench_test.go +++ b/topdown/topdown_bench_test.go @@ -29,6 +29,54 @@ func BenchmarkArrayIteration(b *testing.B) { } } +func BenchmarkArrayPlugging(b *testing.B) { + ctx := context.Background() + + sizes := []int{10, 100, 1000, 10000} + + for _, n := range sizes { + b.Run(fmt.Sprint(n), func(b *testing.B) { + data := make([]interface{}, n) + for i := 0; i < n; i++ { + data[i] = fmt.Sprintf("whatever%d", i) + } + store := inmem.NewFromObject(map[string]interface{}{"fixture": data}) + module := `package test + fixture := data.fixture + main { x := fixture }` + + query := ast.MustParseBody("data.test.main") + compiler := ast.MustCompileModules(map[string]string{ + "test.rego": module, + }) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + + err := storage.Txn(ctx, store, storage.TransactionParams{}, func(txn storage.Transaction) error { + + q := NewQuery(query). + WithCompiler(compiler). + WithStore(store). + WithTransaction(txn) + + _, err := q.Run(ctx) + if err != nil { + return err + } + + return nil + }) + + if err != nil { + b.Fatal(err) + } + } + }) + } +} + func BenchmarkSetIteration(b *testing.B) { sizes := []int{10, 100, 1000, 10000} for _, n := range sizes { diff --git a/topdown/walk.go b/topdown/walk.go index 520098f855..092014f12e 100644 --- a/topdown/walk.go +++ b/topdown/walk.go @@ -8,7 +8,7 @@ import ( "github.com/open-policy-agent/opa/ast" ) -func evalWalk(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error { +func evalWalk(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error { input := args[0] filter := getOutputPath(args) return walk(filter, nil, input, iter) @@ -21,7 +21,7 @@ func walk(filter, path *ast.Array, input *ast.Term, iter func(*ast.Term) error) path = ast.NewArray() } - if err := iter(ast.ArrayTerm(ast.NewTerm(path), input)); err != nil { + if err := iter(ast.ArrayTerm(ast.NewTerm(path.Copy()), input)); err != nil { return err } }