From 965301f90e1c10900c2c134ee21e486993796a20 Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Fri, 14 Oct 2022 10:15:54 +0200 Subject: [PATCH] ast: support dotted heads (#4660) This change allows rules to have string prefixes in their heads -- we've come to call them "ref heads". String prefixes means that where before, you had package a.b.c allow = true you can now have package a b.c.allow = true This allows for more concise policies, and different ways to structure larger rule corpuses. Backwards-compatibility: - There are code paths that accept ast.Module structs that don't necessarily come from the parser -- so we're backfilling the rule's Head.Reference field from the Name when it's not present. This is exposed through (Head).Ref() which always returns a Ref. This also affects the `opa parse` "pretty" output: With x.rego as package x import future.keywords a.b.c.d if true e[x] if true we get $ opa parse x rego module package ref data "x" import ref future "keywords" rule head ref a "b" "c" "d" true body expr index=0 true rule head ref e x true body expr index=0 true Note that Name: e Key: x becomes Reference: e[x] in the output above (since that's how we're parsing it, back-compat edge cases aside) - One special case for backcompat is `p[x] { ... }`: rule | ref | key | value | name ------------------------+-------+-----+-------+----- p[x] { ... } | p | x | nil | "p" p contains x if { ... } | p | x | nil | "p" p[x] if { ... } | p[x] | nil | true | "" For interpreting a rule, we now have the following procedure: 1. if it has a Key, it's a multi-value rule; and its Ref defines the set: Head{Key: x, Ref: p} ~> p is a set ^-- we'd get this from `p contains x if true` or `p[x] { true }` (back compat) 2. if it has a Value, it's a single-value rule; its Ref may contain vars: Head{Ref: p.q.r[s], Value: 12} ~> body determines s, `p.q.r.[s]` is 12 ^-- we'd get this from `p.q.r[s] = 12 { s := "whatever" }` Head{Key: x, Ref: p[x], Value: 3} ~> `p[x]` has value 3, `x` is determined by the rule body ^-- we'd get this from `p[x] = 3 if x := 2` or `p[x] = 3 { x := 2 }` (back compat) Here, the Key isn't used, it's present for backwards compatibility: for ref- less rule heads, `p[x] = 3` used to be a partial object: key x, value 3, name "p" - The destinction between complete rules and partial object rules disappears. They're both single-value rules now. - We're now outputting the refs of the rules completely in error messages, as it's hard to make sense of "rule r" when there's rule r in package a.b.c and rule b.c.r in package a. Restrictions/next steps: - Support for ref head rules in the REPL is pretty poor so far. Anything that works does so rather accidentally. You should be able to work with policies that contain ref heads, but you cannot interactively define them. This is because before, we'd looked at REPL input like p.foo.bar = true and noticed that it cannot be a rule, so it's got to be a query. This is no longer the case with ref heads. - Currently vars in Refs are only allowed in the last position. This is expected to change in the future. - Also, for multi-value rules, we can not have a var at all -- so the following isn't supported yet: p.q.r[s] contains t if { ... } ----- Most of the work happens when the RuleTree is derived from the ModuleTree -- in the RuleTree, it doesn't matter if a rule was `p` in `package a.b.c` or `b.c.p` in `package a`. As such, the planner and wasm compiler hasn't seen that many adaptations: - We're putting rules into the ruletree _including_ the var parts, so p.q.a = 1 p.q.[x] = 2 { x := "b" } end up in two different leaves: p `-> q `-> a = 1 `-> [x] = 2` - When planing a ref, we're checking if a rule tree node's children have var keys, and plan "one level higher" accordingly: Both sets of rules, p.q.a and p.q[x] will be planned into one function (same as before); and accordingly return an object {"a": 1, "b": 2} - When we don't have vars in the last ref part, we'll end up planning the rules separately. This will have an effect on the IR. p.q = 1 p.r = 2 Before, these would have been one function; now, it's two. As a result, in Wasm, some "object insertion" conflicts can become "var assignment conflicts", but that's in line with the now-new view of "multi-value" and "single-value" rules, not partial {set/obj} vs complete. * planner: only check ref.GroundPrefix() for optimizations In a previous commit, we've only mapped p.q.r[7] as p.q.r; and as such, also need to lookup the ref p.q.r[__local0__] via p.q.r (I think. Full disclosure: there might be edge cases here that are unaccounted for, but right now, I'm aiming for making the existing tests green...) New compiler stage: In the compiler, we're having a new early rewriting step to ensure that the RuleTree's keys are comparible. They're ast.Value, but some of them cause us grief: - ast.Object cannot be compared structurally; so _, ok := map[ast.Value]bool{ast.NewObject([2]*ast.Term{ast.StringTerm("foo"), ast.StringTerm("bar")}): true}[ast.NewObject([2]*ast.Term{ast.StringTerm("foo"), ast.StringTerm("bar")})] `ok` will never be true here. - ast.Ref is a slice type, not hashable, so adding that to the RuleTree would cause a runtime panic: p[y.z] { y := input } is now rewritten to p[__local0__] { y := input; __local0__ := y.z } This required moving the InitLocalVarGen stage up the chain, but as it's still below ResolveRefs, we should be OK. As a consequence, we've had to adapt `oracle` to cope with that rewriting: 1. The compiler rewrites rule head refs early because the rule tree expects only simple vars, no refs, in rule head refs. So `p[x.y]` becomes `p[local] { local = x.y }` 2. The oracle circles in on the node it's finding the definition for based on source location, and the logic for doing that depends on unaltered modules. So here, (2.) is relaxed: the logic for building the lookup node stack can now cope with generated statements that have been appended to the rule bodies. There is a peculiarity about ref rules and extents: See the added tests: having a ref rule implies that we get an empty object in the full extent: package p foo.bar if false makes the extent of data.p: {"foo": {}} This is somewhat odd, but also follows from the behaviour we have right now with empty modules: package p.foo bar if false this also gives data.p the extent {"foo": {}}. This could be worked around by recording, in the rule tree, when a node was added because it's an intermediary with no values, but only children. Signed-off-by: Stephan Renatus --- CHANGELOG.md | 112 +- ast/annotations_test.go | 106 +- ast/check.go | 128 +- ast/check_test.go | 160 +- ast/compile.go | 484 +++++-- ast/compile_test.go | 1288 +++++++++++++++-- ast/conflicts.go | 7 +- ast/env.go | 16 +- ast/index.go | 8 +- ast/index_test.go | 64 +- ast/parser.go | 139 +- ast/parser_ext.go | 161 ++- ast/parser_test.go | 785 +++++++--- ast/policy.go | 134 +- ast/policy_test.go | 211 ++- ast/pretty_test.go | 8 +- ast/term.go | 4 +- ast/transform.go | 19 +- ast/transform_test.go | 22 + ast/visit.go | 25 +- ast/visit_test.go | 8 +- cmd/build_test.go | 11 +- compile/compile.go | 48 +- compile/compile_test.go | 33 +- docs/content/policy-language.md | 18 + docs/content/policy-reference.md | 39 +- format/format.go | 12 +- format/testfiles/test.rego.formatted | 8 +- format/testfiles/test_ref_heads.rego | 13 + .../testfiles/test_ref_heads.rego.formatted | 19 + internal/oracle/oracle.go | 22 +- internal/oracle/oracle_test.go | 33 +- internal/planner/planner.go | 226 ++- internal/planner/planner_test.go | 11 +- internal/planner/rules.go | 3 + refactor/refactor_test.go | 2 +- rego/rego_test.go | 1 - repl/repl.go | 31 +- server/server.go | 8 +- .../testdata/jsonpatch/json-patch-tests.yaml | 2 +- .../partialobjectdoc/test-wasm-cases.yaml | 106 ++ .../refheads/test-refs-as-rule-heads.yaml | 313 ++++ .../testdata/refheads/test-regressions.yaml | 116 ++ test/wasm/assets/012_partialobjects.yaml | 18 +- tester/reporter_test.go | 51 +- tester/runner.go | 49 +- tester/runner_test.go | 98 +- topdown/eval.go | 138 +- wasm/src/value.c | 9 +- 49 files changed, 4423 insertions(+), 904 deletions(-) create mode 100644 format/testfiles/test_ref_heads.rego create mode 100644 format/testfiles/test_ref_heads.rego.formatted create mode 100644 test/cases/testdata/partialobjectdoc/test-wasm-cases.yaml create mode 100644 test/cases/testdata/refheads/test-refs-as-rule-heads.yaml create mode 100644 test/cases/testdata/refheads/test-regressions.yaml diff --git a/CHANGELOG.md b/CHANGELOG.md index a67ee8f226..0633393050 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,117 @@ project adheres to [Semantic Versioning](http://semver.org/). ## Unreleased +### Refs in Rule Heads + +With this version of OPA, we can use a shorthand for defining deeply-nested structures +in Rego: + +Before, we had to use multiple packages, and hence multiple files to define a structure +like this: +```json +{ + "method": { + "get": { + "allowed": true + } + "post": { + "allowed": true + } + } +} +``` + +```rego +package method.get +default allowed := false +allowed { ... } +``` + + +```rego +package method.post +default allowed := false +allowed { ... } +``` + +Now, we can define those rules in single package (and file): + +```rego +package method +import future.keywords.if +default get.allowed := false +get.allowed if { ... } + +default post.allowed := false +post.allowed if { ... } +``` + +Note that in this example, the use of the future keyword `if` is mandatory +for backwards-compatibility: without it, `get.allowed` would be interpreted +as `get["allowed"]`, a definition of a partial set rule. + +Currently, variables may only appear in the last part of the rule head: + +```rego +package method +import future.keywords.if + +endpoints[ep].allowed if ep := "/v1/data" # invalid +repos.get.endpoint[x] if x := "/v1/data" # valid +``` + +The valid rule defines this structure: +```json +{ + "method": { + "repos": { + "get": { + "endpoint": { + "/v1/data": true + } + } + } + } +} +``` + +To define a nested key-value pair, we would use + +```rego +package method +import future.keywords.if + +repos.get.endpoint[x] = y if { + x := "/v1/data" + y := "example" +} +``` + +Multi-value rules (previously referred to as "partial set rules") that are +nested like this need to use `contains` future keyword, to differentiate them +from the "last part is a variable" case mentioned just above: + +```rego +package method +import future.keywords.contains + +repos.get.endpoint contains x if x := "/v1/data" +``` + +This rule defines the same structure, but with multiple values instead of a key: +```json +{ + "method": { + "repos": { + "get": { + "endpoint": ["/v1/data"] + } + } + } + } +} +``` + ## 0.45.0 This release contains a mix of bugfixes, optimizations, and new features. @@ -319,7 +430,6 @@ This is a security release fixing the following vulnerabilities: Note that CVE-2022-32190 is most likely not relevant for OPA's usage of net/url. But since these CVEs tend to come up in security assessment tooling regardless, it's better to get it out of the way. - ## 0.43.0 This release contains a number of fixes, enhancements, and performance improvements. diff --git a/ast/annotations_test.go b/ast/annotations_test.go index 322cfc00fd..249d0daf8a 100644 --- a/ast/annotations_test.go +++ b/ast/annotations_test.go @@ -421,6 +421,37 @@ p[v] {v = 2}`, }, }, }, + { + note: "overlapping rule paths (different modules, rule head refs)", + modules: map[string]string{ + "mod1": `package test.a +# METADATA +# title: P1 +b.c.p[v] {v = 1}`, + "mod2": `package test +# METADATA +# title: P2 +a.b.c.p[v] {v = 2}`, + }, + expected: []AnnotationsRef{ + { + Path: MustParseRef("data.test.a.b.c.p"), + Location: &Location{File: "mod1", Row: 4}, + Annotations: &Annotations{ + Scope: "rule", + Title: "P1", + }, + }, + { + Path: MustParseRef("data.test.a.b.c.p"), + Location: &Location{File: "mod2", Row: 4}, + Annotations: &Annotations{ + Scope: "rule", + Title: "P2", + }, + }, + }, + }, } for _, tc := range tests { @@ -737,6 +768,78 @@ p = 1`, }, }, }, + { + note: "multiple subpackages, refs in rule heads", // NOTE(sr): same as above, but last module's rule is `foo.bar.p` in package `root` + modules: map[string]string{ + "root": `# METADATA +# scope: subpackages +# title: ROOT +package root`, + "root.foo": `# METADATA +# title: FOO +# scope: subpackages +package root.foo`, + "root.foo.bar": `# METADATA +# scope: subpackages +# description: subpackages scope applied to rule in other module +# title: BAR-sub + +# METADATA +# title: BAR-other +# description: This metadata is on the path of the queried rule, but shouldn't show up in the result as it's in a different module. +package root.foo.bar + +# METADATA +# scope: document +# description: document scope applied to rule in other module +# title: P-doc +p = 1`, + "rule": `# METADATA +# title: BAR +package root + +# METADATA +# title: P +foo.bar.p = 1`, + }, + moduleToAnalyze: "rule", + ruleOnLineToAnalyze: 7, + expected: []AnnotationsRef{ + { + Path: MustParseRef("data.root.foo.bar.p"), + Location: &Location{File: "rule", Row: 7}, + Annotations: &Annotations{ + Scope: "rule", + Title: "P", + }, + }, + { + Path: MustParseRef("data.root.foo.bar.p"), + Location: &Location{File: "root.foo.bar", Row: 15}, + Annotations: &Annotations{ + Scope: "document", + Title: "P-doc", + Description: "document scope applied to rule in other module", + }, + }, + { + Path: MustParseRef("data.root"), + Location: &Location{File: "rule", Row: 3}, + Annotations: &Annotations{ + Scope: "package", + Title: "BAR", + }, + }, + { + Path: MustParseRef("data.root"), + Location: &Location{File: "root", Row: 4}, + Annotations: &Annotations{ + Scope: "subpackages", + Title: "ROOT", + }, + }, + }, + }, { note: "multiple metadata blocks for single rule (order)", modules: map[string]string{ @@ -824,6 +927,7 @@ p = true`, chain := as.Chain(rule) if len(chain) != len(tc.expected) { + t.Errorf("expected %d elements, got %d:", len(tc.expected), len(chain)) t.Fatalf("chained AnnotationSet\n%v\n\ndoesn't match expected\n\n%v", toJSON(chain), toJSON(tc.expected)) } @@ -1022,7 +1126,7 @@ func TestAnnotations_toObject(t *testing.T) { } func toJSON(v interface{}) string { - b, _ := json.Marshal(v) + b, _ := json.MarshalIndent(v, "", " ") return string(b) } diff --git a/ast/check.go b/ast/check.go index fd35d017aa..b671e82999 100644 --- a/ast/check.go +++ b/ast/check.go @@ -200,7 +200,7 @@ func (tc *typeChecker) checkRule(env *TypeEnv, as *AnnotationSet, rule *Rule) { cpy, err := tc.CheckBody(env, rule.Body) env = env.next - path := rule.Path() + path := rule.Ref() if len(err) > 0 { // if the rule/function contains an error, add it to the type env so @@ -235,23 +235,28 @@ func (tc *typeChecker) checkRule(env *TypeEnv, as *AnnotationSet, rule *Rule) { tpe = types.Or(exist, f) } else { - switch rule.Head.DocKind() { - case CompleteDoc: + switch rule.Head.RuleKind() { + case SingleValue: typeV := cpy.Get(rule.Head.Value) - if typeV != nil { - exist := env.tree.Get(path) - tpe = types.Or(typeV, exist) - } - case PartialObjectDoc: - typeK := cpy.Get(rule.Head.Key) - typeV := cpy.Get(rule.Head.Value) - if typeK != nil && typeV != nil { - exist := env.tree.Get(path) - typeV = types.Or(types.Values(exist), typeV) - typeK = types.Or(types.Keys(exist), typeK) - tpe = types.NewObject(nil, types.NewDynamicProperty(typeK, typeV)) + if last := path[len(path)-1]; !last.IsGround() { + + // e.g. store object[string: whatever] at data.p.q.r, not data.p.q.r[x] + path = path.GroundPrefix() + + typeK := cpy.Get(last) + if typeK != nil && typeV != nil { + exist := env.tree.Get(path) + typeV = types.Or(types.Values(exist), typeV) + typeK = types.Or(types.Keys(exist), typeK) + tpe = types.NewObject(nil, types.NewDynamicProperty(typeK, typeV)) + } + } else { + if typeV != nil { + exist := env.tree.Get(path) + tpe = types.Or(typeV, exist) + } } - case PartialSetDoc: + case MultiValue: typeK := cpy.Get(rule.Head.Key) if typeK != nil { exist := env.tree.Get(path) @@ -652,72 +657,57 @@ func (rc *refChecker) checkRef(curr *TypeEnv, node *typeTreeNode, ref Ref, idx i head := ref[idx] - // Handle constant ref operands, i.e., strings or the ref head. - if _, ok := head.Value.(String); ok || idx == 0 { - - child := node.Child(head.Value) - if child == nil { - - if curr.next != nil { - next := curr.next - return rc.checkRef(next, next.tree, ref, 0) - } - - if RootDocumentNames.Contains(ref[0]) { - return rc.checkRefLeaf(types.A, ref, 1) - } - - return rc.checkRefLeaf(types.A, ref, 0) - } - - if child.Leaf() { - return rc.checkRefLeaf(child.Value(), ref, idx+1) + // NOTE(sr): as long as package statements are required, this isn't possible: + // the shortest possible rule ref is data.a.b (b is idx 2), idx 1 and 2 need to + // be strings or vars. + if idx == 1 || idx == 2 { + switch head.Value.(type) { + case Var, String: // OK + default: + have := rc.env.Get(head.Value) + return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, have, types.S, getOneOfForNode(node)) } - - return rc.checkRef(curr, child, ref, idx+1) } - // Handle dynamic ref operands. - switch value := head.Value.(type) { - - case Var: - - if exist := rc.env.Get(value); exist != nil { - if !unifies(types.S, exist) { - return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, types.S, getOneOfForNode(node)) + if v, ok := head.Value.(Var); ok && idx != 0 { + tpe := types.Keys(rc.env.getRefRecExtent(node)) + if exist := rc.env.Get(v); exist != nil { + if !unifies(tpe, exist) { + return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, tpe, getOneOfForNode(node)) } } else { - rc.env.tree.PutOne(value, types.S) + rc.env.tree.PutOne(v, tpe) } + } - case Ref: + child := node.Child(head.Value) + if child == nil { + // NOTE(sr): idx is reset on purpose: we start over + switch { + case curr.next != nil: + next := curr.next + return rc.checkRef(next, next.tree, ref, 0) - exist := rc.env.Get(value) - if exist == nil { - // If ref type is unknown, an error will already be reported so - // stop here. - return nil - } + case RootDocumentNames.Contains(ref[0]): + if idx != 0 { + node.Children().Iter(func(_, child util.T) bool { + _ = rc.checkRef(curr, child.(*typeTreeNode), ref, idx+1) // ignore error + return false + }) + return nil + } + return rc.checkRefLeaf(types.A, ref, 1) - if !unifies(types.S, exist) { - return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, types.S, getOneOfForNode(node)) + default: + return rc.checkRefLeaf(types.A, ref, 0) } - - // Catch other ref operand types here. Non-leaf nodes must be referred to - // with string values. - default: - return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, nil, types.S, getOneOfForNode(node)) } - // Run checking on remaining portion of the ref. Note, since the ref - // potentially refers to data for which no type information exists, - // checking should never fail. - node.Children().Iter(func(_, child util.T) bool { - _ = rc.checkRef(curr, child.(*typeTreeNode), ref, idx+1) // ignore error - return false - }) + if child.Leaf() { + return rc.checkRefLeaf(child.Value(), ref, idx+1) + } - return nil + return rc.checkRef(curr, child, ref, idx+1) } func (rc *refChecker) checkRefLeaf(tpe types.Type, ref Ref, idx int) *Error { diff --git a/ast/check_test.go b/ast/check_test.go index 33a51ca063..56162c5a97 100644 --- a/ast/check_test.go +++ b/ast/check_test.go @@ -344,6 +344,12 @@ func TestCheckInferenceRules(t *testing.T) { {`number_key`, `q[x] = y { a = ["a", "b"]; y = a[x] }`}, {`non_leaf`, `p[x] { data.prefix.i[x][_] }`}, } + ruleset2 := [][2]string{ + {`ref_rule_single`, `p.q.r { true }`}, + {`ref_rule_single_with_number_key`, `p.q[3] { true }`}, + {`ref_regression_array_key`, + `walker[[p, v]] = o { l = input; walk(l, k); [p, v] = k; o = {} }`}, + } tests := []struct { note string @@ -465,6 +471,37 @@ func TestCheckInferenceRules(t *testing.T) { {"non-leaf", ruleset1, "data.non_leaf.p", types.NewSet( types.S, )}, + + {"ref-rules single value, full ref", ruleset2, "data.ref_rule_single.p.q.r", types.B}, + {"ref-rules single value, prefix", ruleset2, "data.ref_rule_single.p", + types.NewObject( + []*types.StaticProperty{{ + Key: "q", Value: types.NewObject( + []*types.StaticProperty{{Key: "r", Value: types.B}}, + types.NewDynamicProperty(types.S, types.A), + ), + }}, + types.NewDynamicProperty(types.S, types.A), + )}, + + {"ref-rules single value, number key, full ref", ruleset2, "data.ref_rule_single_with_number_key.p.q[3]", types.B}, + {"ref-rules single value, number key, prefix", ruleset2, "data.ref_rule_single_with_number_key.p", + types.NewObject( + []*types.StaticProperty{{ + Key: "q", Value: types.NewObject( + []*types.StaticProperty{{Key: json.Number("3"), Value: types.B}}, + types.NewDynamicProperty(types.S, types.A), + ), + }}, + types.NewDynamicProperty(types.S, types.A), + )}, + + {"ref_regression_array_key", ruleset2, "data.ref_regression_array_key.walker", + types.NewObject( + nil, + types.NewDynamicProperty(types.NewArray([]types.Type{types.NewArray(types.A, types.A), types.A}, nil), + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))), + )}, } for _, tc := range tests { @@ -489,7 +526,7 @@ func TestCheckInferenceRules(t *testing.T) { ref := MustParseRef(tc.ref) checker := newTypeChecker() - env, err := checker.CheckTypes(nil, elems, nil) + env, err := checker.CheckTypes(newTypeChecker().Env(map[string]*Builtin{"walk": BuiltinMap["walk"]}), elems, nil) if err != nil { t.Fatalf("Unexpected error %v:", err) @@ -512,6 +549,87 @@ func TestCheckInferenceRules(t *testing.T) { } +func TestCheckInferenceOverlapWithRules(t *testing.T) { + ruleset1 := [][2]string{ + {`prefix.i.j.k`, `p = 1 { true }`}, + {`prefix.i.j.k`, `p = "foo" { true }`}, + } + tests := []struct { + note string + rules [][2]string + ref string + expected types.Type // ref's type + query string + extra map[Var]types.Type + }{ + { + note: "non-leaf, extra vars", + rules: ruleset1, + ref: "data.prefix.i.j[k]", + expected: types.A, + query: "data.prefix.i.j[k][b]", + extra: map[Var]types.Type{ + Var("k"): types.S, + Var("b"): types.S, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + var elems []util.T + + // Convert test rules into rule slice for "warmup" call. + for i := range tc.rules { + pkg := MustParsePackage(`package ` + tc.rules[i][0]) + rule := MustParseRule(tc.rules[i][1]) + module := &Module{ + Package: pkg, + Rules: []*Rule{rule}, + } + rule.Module = module + elems = append(elems, rule) + for next := rule.Else; next != nil; next = next.Else { + next.Module = module + elems = append(elems, next) + } + } + + ref := MustParseRef(tc.ref) + checker := newTypeChecker() + env, err := checker.CheckTypes(nil, elems, nil) + if err != nil { + t.Fatalf("Unexpected error %v:", err) + } + + result := env.Get(ref) + if tc.expected == nil { + if result != nil { + t.Errorf("Expected %v type to be unset but got: %v", ref, result) + } + } else { + if result == nil { + t.Errorf("Expected to infer %v => %v but got nil", ref, tc.expected) + } else if types.Compare(tc.expected, result) != 0 { + t.Errorf("Expected to infer %v => %v but got %v", ref, tc.expected, result) + } + } + + body := MustParseBody(tc.query) + env, err = checker.CheckBody(env, body) + if len(err) != 0 { + t.Fatalf("Unexpected error: %v", err) + } + for ex, exp := range tc.extra { + act := env.Get(ex) + if types.Compare(act, exp) != 0 { + t.Errorf("Expected to infer extra %v => %v but got %v", ex, exp, act) + } + } + }) + } +} + func TestCheckErrorSuppression(t *testing.T) { query := `arr = [1,2,3]; arr[0].deadbeef = 1` @@ -642,7 +760,7 @@ func TestCheckBuiltinErrors(t *testing.T) { {"objects-any", `fake_builtin_2({"a": a, "c": c})`}, {"objects-bad-input", `sum({"a": 1, "b": 2}, x)`}, {"sets-any", `sum({1,2,"3",4}, x)`}, - {"virtual-ref", `plus(data.test.p, data.deabeef, 0)`}, + {"virtual-ref", `plus(data.test.p, data.coffee, 0)`}, } env := newTestEnv([]string{ @@ -781,6 +899,7 @@ func TestCheckRefErrInvalid(t *testing.T) { env := newTestEnv([]string{ `p { true }`, `q = {"foo": 1, "bar": 2} { true }`, + `a.b.c[3] = x { x = {"x": {"y": 2}} }`, }) tests := []struct { @@ -799,7 +918,7 @@ func TestCheckRefErrInvalid(t *testing.T) { pos: 2, have: types.N, want: types.S, - oneOf: []Value{String("p"), String("q")}, + oneOf: []Value{String("a"), String("p"), String("q")}, }, { note: "bad non-leaf ref", @@ -808,7 +927,7 @@ func TestCheckRefErrInvalid(t *testing.T) { pos: 2, have: types.N, want: types.S, - oneOf: []Value{String("p"), String("q")}, + oneOf: []Value{String("a"), String("p"), String("q")}, }, { note: "bad leaf ref", @@ -819,6 +938,24 @@ func TestCheckRefErrInvalid(t *testing.T) { want: types.S, oneOf: []Value{String("bar"), String("foo")}, }, + { + note: "bad ref hitting last term", + query: `x = true; data.test.a.b.c[x][_]`, + ref: `data.test.a.b.c[x][_]`, + pos: 5, + have: types.B, + want: types.Any{types.N, types.S}, + oneOf: []Value{Number("3")}, + }, + { + note: "bad ref hitting dynamic part", + query: `s = true; data.test.a.b.c[3].x[s][_] = _`, + ref: `data.test.a.b.c[3].x[s][_]`, + pos: 7, + have: types.B, + want: types.S, + oneOf: []Value{String("y")}, + }, { note: "bad leaf var", query: `x = 1; data.test.q[x]`, @@ -851,12 +988,25 @@ func TestCheckRefErrInvalid(t *testing.T) { oneOf: []Value{String("a"), String("c")}, }, { + // NOTE(sr): Thins one and the next are special: it cannot work with ref heads, either, since we need at + // least ONE string term after data.test: a module needs a package line, and the shortest head ref + // possible is thus data.x.y. note: "bad non-leaf value", query: `data.test[1]`, ref: "data.test[1]", pos: 2, + have: types.N, + want: types.S, + oneOf: []Value{String("a"), String("p"), String("q")}, + }, + { + note: "bad non-leaf value (package)", // See note above ^^ + query: `data[1]`, + ref: "data[1]", + pos: 1, + have: types.N, want: types.S, - oneOf: []Value{String("p"), String("q")}, + oneOf: []Value{String("test")}, }, { note: "composite ref operand", diff --git a/ast/compile.go b/ast/compile.go index a650033b13..81b3d6110c 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -64,6 +64,7 @@ type Compiler struct { // p[1] { true } // p[2] { true } // q = true + // a.b.c = 3 // // root // | @@ -74,6 +75,12 @@ type Compiler struct { // +--- p (2 rules) // | // +--- q (1 rule) + // | + // +--- a + // | + // +--- b + // | + // +--- c (1 rule) RuleTree *TreeNode // Graph contains dependencies between rules. An edge (u,v) is added to the @@ -265,15 +272,16 @@ func NewCompiler() *Compiler { // load additional modules. If any stages run before resolution, they // need to be re-run after resolution. {"ResolveRefs", "compile_stage_resolve_refs", c.resolveAllRefs}, - {"CheckKeywordOverrides", "compile_stage_check_keyword_overrides", c.checkKeywordOverrides}, - {"CheckDuplicateImports", "compile_stage_check_duplicate_imports", c.checkDuplicateImports}, - {"RemoveImports", "compile_stage_remove_imports", c.removeImports}, - {"SetModuleTree", "compile_stage_set_module_tree", c.setModuleTree}, - {"SetRuleTree", "compile_stage_set_rule_tree", c.setRuleTree}, // The local variable generator must be initialized after references are // resolved and the dynamic module loader has run but before subsequent // stages that need to generate variables. {"InitLocalVarGen", "compile_stage_init_local_var_gen", c.initLocalVarGen}, + {"RewriteRuleHeadRefs", "compile_stage_rewrite_rule_head_refs", c.rewriteRuleHeadRefs}, + {"CheckKeywordOverrides", "compile_stage_check_keyword_overrides", c.checkKeywordOverrides}, + {"CheckDuplicateImports", "compile_stage_check_duplicate_imports", c.checkDuplicateImports}, + {"RemoveImports", "compile_stage_remove_imports", c.removeImports}, + {"SetModuleTree", "compile_stage_set_module_tree", c.setModuleTree}, + {"SetRuleTree", "compile_stage_set_rule_tree", c.setRuleTree}, // depends on RewriteRuleHeadRefs {"RewriteLocalVars", "compile_stage_rewrite_local_vars", c.rewriteLocalVars}, {"CheckVoidCalls", "compile_stage_check_void_calls", c.checkVoidCalls}, {"RewritePrintCalls", "compile_stage_rewrite_print_calls", c.rewritePrintCalls}, @@ -570,9 +578,10 @@ func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) { return rules } -func extractRules(s []util.T) (rules []*Rule) { - for _, r := range s { - rules = append(rules, r.(*Rule)) +func extractRules(s []util.T) []*Rule { + rules := make([]*Rule, len(s)) + for i := range s { + rules[i] = s[i].(*Rule) } return rules } @@ -768,13 +777,30 @@ func (c *Compiler) buildRuleIndices() { if len(node.Values) == 0 { return false } + rules := extractRules(node.Values) + hasNonGroundKey := false + for _, r := range rules { + if ref := r.Head.Ref(); len(ref) > 1 { + if !ref[len(ref)-1].IsGround() { + hasNonGroundKey = true + } + } + } + if hasNonGroundKey { + // collect children: as of now, this cannot go deeper than one level, + // so we grab those, and abort the DepthFirst processing for this branch + for _, n := range node.Children { + rules = append(rules, extractRules(n.Values)...) + } + } + index := newBaseDocEqIndex(func(ref Ref) bool { return isVirtual(c.RuleTree, ref.GroundPrefix()) }) - if rules := extractRules(node.Values); index.Build(rules) { - c.ruleIndices.Put(rules[0].Path(), index) + if index.Build(rules) { + c.ruleIndices.Put(rules[0].Ref().GroundPrefix(), index) } - return false + return hasNonGroundKey // currently, we don't allow those branches to go deeper }) } @@ -811,7 +837,7 @@ func (c *Compiler) checkRecursion() { func (c *Compiler) checkSelfPath(loc *Location, eq func(a, b util.T) bool, a, b util.T) { tr := NewGraphTraversal(c.Graph) if p := util.DFSPath(tr, eq, a, b); len(p) > 0 { - n := []string{} + n := make([]string, 0, len(p)) for _, x := range p { n = append(n, astNodeToString(x)) } @@ -820,40 +846,69 @@ func (c *Compiler) checkSelfPath(loc *Location, eq func(a, b util.T) bool, a, b } func astNodeToString(x interface{}) string { - switch x := x.(type) { - case *Rule: - return string(x.Head.Name) - default: - panic("not reached") - } + return x.(*Rule).Ref().String() } // checkRuleConflicts ensures that rules definitions are not in conflict. func (c *Compiler) checkRuleConflicts() { + rw := rewriteVarsInRef(c.RewrittenVars) + c.RuleTree.DepthFirst(func(node *TreeNode) bool { if len(node.Values) == 0 { - return false + return false // go deeper } - kinds := map[DocKind]struct{}{} + kinds := map[RuleKind]struct{}{} defaultRules := 0 arities := map[int]struct{}{} + name := "" + var singleValueConflicts []Ref for _, rule := range node.Values { r := rule.(*Rule) - kinds[r.Head.DocKind()] = struct{}{} + ref := r.Ref() + name = rw(ref.Copy()).String() // varRewriter operates in-place + kinds[r.Head.RuleKind()] = struct{}{} arities[len(r.Head.Args)] = struct{}{} if r.Default { defaultRules++ } + + // Single-value rules may not have any other rules in their extent: these pairs are invalid: + // + // data.p.q.r { true } # data.p.q is { "r": true } + // data.p.q.r.s { true } + // + // data.p.q[r] { r := input.r } # data.p.q could be { "r": true } + // data.p.q.r.s { true } + + // But this is allowed: + // data.p.q[r] = 1 { r := "r" } + // data.p.q.s = 2 + + if r.Head.RuleKind() == SingleValue && len(node.Children) > 0 { + if len(ref) > 1 && !ref[len(ref)-1].IsGround() { // p.q[x] and p.q.s.t => check grandchildren + for _, c := range node.Children { + if len(c.Children) > 0 { + singleValueConflicts = node.flattenChildren() + break + } + } + } else { // p.q.s and p.q.s.t => any children are in conflict + singleValueConflicts = node.flattenChildren() + } + } } - name := Var(node.Key.(String)) + switch { + case singleValueConflicts != nil: + c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "single-value rule %v conflicts with %v", name, singleValueConflicts)) - if len(kinds) > 1 || len(arities) > 1 { - c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules named %v found", name)) - } else if defaultRules > 1 { - c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multiple default rules named %s found", name)) + case len(kinds) > 1 || len(arities) > 1: + c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules %v found", name)) + + case defaultRules > 1: + c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multiple default rules %s found", name)) } return false @@ -865,13 +920,21 @@ func (c *Compiler) checkRuleConflicts() { } } + // NOTE(sr): depthfirst might better use sorted for stable errs? c.ModuleTree.DepthFirst(func(node *ModuleTreeNode) bool { for _, mod := range node.Modules { for _, rule := range mod.Rules { - if childNode, ok := node.Children[String(rule.Head.Name)]; ok { + ref := rule.Head.Ref().GroundPrefix() + childNode, tail := node.find(ref) + if childNode != nil { for _, childMod := range childNode.Modules { - msg := fmt.Sprintf("%v conflicts with rule defined at %v", childMod.Package, rule.Loc()) - c.err(NewError(TypeErr, mod.Package.Loc(), msg)) + if childMod.Equal(mod) { + continue // don't self-conflict + } + if len(tail) == 0 { + msg := fmt.Sprintf("%v conflicts with rule %v defined at %v", childMod.Package, rule.Head.Ref(), rule.Loc()) + c.err(NewError(TypeErr, mod.Package.Loc(), msg)) + } } } } @@ -1363,21 +1426,29 @@ func (c *Compiler) getExports() *util.HashMap { for _, name := range c.sorted { mod := c.Modules[name] - rv, ok := rules.Get(mod.Package.Path) - if !ok { - rv = []Var{} - } - rvs := rv.([]Var) for _, rule := range mod.Rules { - rvs = append(rvs, rule.Head.Name) + hashMapAdd(rules, mod.Package.Path, rule.Head.Ref().GroundPrefix()) } - rules.Put(mod.Package.Path, rvs) } return rules } +func hashMapAdd(rules *util.HashMap, pkg, rule Ref) { + prev, ok := rules.Get(pkg) + if !ok { + rules.Put(pkg, []Ref{rule}) + return + } + for _, p := range prev.([]Ref) { + if p.Equal(rule) { + return + } + } + rules.Put(pkg, append(prev.([]Ref), rule)) +} + func (c *Compiler) GetAnnotationSet() *AnnotationSet { return c.annotationSet } @@ -1450,6 +1521,15 @@ func checkKeywordOverrides(node interface{}, strict bool) Errors { // p[x] { bar[_] = x } // // The reference "bar[_]" would be resolved to "data.foo.bar[_]". +// +// Ref rules are resolved, too: +// +// package a.b +// q { c.d.e == 1 } +// c.d[e] := 1 if e := "e" +// +// The reference "c.d.e" would be resolved to "data.a.b.c.d.e". + func (c *Compiler) resolveAllRefs() { rules := c.getExports() @@ -1457,9 +1537,9 @@ func (c *Compiler) resolveAllRefs() { for _, name := range c.sorted { mod := c.Modules[name] - var ruleExports []Var + var ruleExports []Ref if x, ok := rules.Get(mod.Package.Path); ok { - ruleExports = x.([]Var) + ruleExports = x.([]Ref) } globals := getGlobals(mod.Package, ruleExports, mod.Imports) @@ -1542,6 +1622,52 @@ func (c *Compiler) rewriteExprTerms() { } } +func (c *Compiler) rewriteRuleHeadRefs() { + f := newEqualityFactory(c.localvargen) + for _, name := range c.sorted { + WalkRules(c.Modules[name], func(rule *Rule) bool { + + ref := rule.Head.Ref() + // NOTE(sr): We're backfilling Refs here -- all parser code paths would have them, but + // it's possible to construct Module{} instances from Golang code, so we need + // to accommodate for that, too. + if len(rule.Head.Reference) == 0 { + rule.Head.Reference = ref + } + + for i := 1; i < len(ref); i++ { + // NOTE(sr): In the first iteration, non-string values in the refs are forbidden + // except for the last position, e.g. + // OK: p.q.r[s] + // NOT OK: p[q].r.s + // TODO(sr): This is stricter than necessary. We could allow any non-var values there, + // but we'll also have to adjust the type tree, for example. + if i != len(ref)-1 { // last + if _, ok := ref[i].Value.(String); !ok { + c.err(NewError(TypeErr, rule.Loc(), "rule head must only contain string terms (except for last): %v", ref[i])) + continue + } + } + + // Rewrite so that any non-scalar elements that in the last position of + // the rule are vars: + // p.q.r[y.z] { ... } => p.q.r[__local0__] { __local0__ = y.z } + // because that's what the RuleTree knows how to deal with. + if _, ok := ref[i].Value.(Var); !ok && !IsScalar(ref[i].Value) { + expr := f.Generate(ref[i]) + if i == len(ref)-1 && rule.Head.Key.Equal(ref[i]) { + rule.Head.Key = expr.Operand(0) + } + rule.Head.Reference[i] = expr.Operand(0) + rule.Body.Append(expr) + } + } + + return true + }) + } +} + func (c *Compiler) checkVoidCalls() { for _, name := range c.sorted { mod := c.Modules[name] @@ -2044,6 +2170,9 @@ func (c *Compiler) rewriteLocalVars() { // Rewrite assignments in body. used := NewVarSet() + last := rule.Head.Ref()[len(rule.Head.Ref())-1] + used.Update(last.Vars()) + if rule.Head.Key != nil { used.Update(rule.Head.Key.Vars()) } @@ -2076,6 +2205,9 @@ func (c *Compiler) rewriteLocalVars() { rule.Head.Args[i], _ = transformTerm(localXform, rule.Head.Args[i]) } + for i := 1; i < len(rule.Head.Ref()); i++ { + rule.Head.Reference[i], _ = transformTerm(localXform, rule.Head.Ref()[i]) + } if rule.Head.Key != nil { rule.Head.Key, _ = transformTerm(localXform, rule.Head.Key) } @@ -2406,10 +2538,10 @@ func (qc *queryCompiler) resolveRefs(qctx *QueryContext, body Body) (Body, error pkg = &Package{Path: RefTerm(VarTerm("")).Value.(Ref)} } if pkg != nil { - var ruleExports []Var + var ruleExports []Ref rules := qc.compiler.getExports() if exist, ok := rules.Get(pkg.Path); ok { - ruleExports = exist.([]Var) + ruleExports = exist.([]Ref) } globals = getGlobals(qctx.Package, ruleExports, qctx.Imports) @@ -2792,6 +2924,16 @@ type ModuleTreeNode struct { Hide bool } +func (n *ModuleTreeNode) String() string { + var rules []string + for _, m := range n.Modules { + for _, r := range m.Rules { + rules = append(rules, r.Head.String()) + } + } + return fmt.Sprintf("", n.Key, n.Children, rules, n.Hide) +} + // NewModuleTree returns a new ModuleTreeNode that represents the root // of the module tree populated with the given modules. func NewModuleTree(mods map[string]*Module) *ModuleTreeNode { @@ -2836,13 +2978,43 @@ func (n *ModuleTreeNode) Size() int { return s } +// Child returns n's child with key k. +func (n *ModuleTreeNode) child(k Value) *ModuleTreeNode { + switch k.(type) { + case String, Var: + return n.Children[k] + } + return nil +} + +// Find dereferences ref along the tree. ref[0] is converted to a String +// for convenience. +func (n *ModuleTreeNode) find(ref Ref) (*ModuleTreeNode, Ref) { + if v, ok := ref[0].Value.(Var); ok { + ref = Ref{StringTerm(string(v))}.Concat(ref[1:]) + } + node := n + for i, r := range ref { + next := node.child(r.Value) + if next == nil { + tail := make(Ref, len(ref)-i) + tail[0] = VarTerm(string(ref[i].Value.(String))) + copy(tail[1:], ref[i+1:]) + return node, tail + } + node = next + } + return node, nil +} + // DepthFirst performs a depth-first traversal of the module tree rooted at n. // If f returns true, traversal will not continue to the children of n. -func (n *ModuleTreeNode) DepthFirst(f func(node *ModuleTreeNode) bool) { - if !f(n) { - for _, node := range n.Children { - node.DepthFirst(f) - } +func (n *ModuleTreeNode) DepthFirst(f func(*ModuleTreeNode) bool) { + if f(n) { + return + } + for _, node := range n.Children { + node.DepthFirst(f) } } @@ -2856,49 +3028,56 @@ type TreeNode struct { Hide bool } +func (n *TreeNode) String() string { + return fmt.Sprintf("", n.Key, n.Values, n.Sorted, n.Hide) +} + // NewRuleTree returns a new TreeNode that represents the root // of the rule tree populated with the given rules. func NewRuleTree(mtree *ModuleTreeNode) *TreeNode { - - ruleSets := map[String][]util.T{} - - // Build rule sets for this package. - for _, mod := range mtree.Modules { - for _, rule := range mod.Rules { - key := String(rule.Head.Name) - ruleSets[key] = append(ruleSets[key], rule) - } + root := TreeNode{ + Key: mtree.Key, } - // Each rule set becomes a leaf node. - children := map[Value]*TreeNode{} - sorted := make([]Value, 0, len(ruleSets)) - - for key, rules := range ruleSets { - sorted = append(sorted, key) - children[key] = &TreeNode{ - Key: key, - Children: nil, - Values: rules, + mtree.DepthFirst(func(m *ModuleTreeNode) bool { + for _, mod := range m.Modules { + if len(mod.Rules) == 0 { + root.add(mod.Package.Path, nil) + } + for _, rule := range mod.Rules { + root.add(rule.Ref().GroundPrefix(), rule) + } } - } + return false + }) - // Each module in subpackage becomes child node. - for key, child := range mtree.Children { - sorted = append(sorted, key) - children[child.Key] = NewRuleTree(child) + // ensure that data.system's TreeNode is hidden + node, tail := root.find(DefaultRootRef.Append(NewTerm(SystemDocumentKey))) + if len(tail) == 0 { // found + node.Hide = true } - sort.Slice(sorted, func(i, j int) bool { - return sorted[i].Compare(sorted[j]) < 0 + root.DepthFirst(func(x *TreeNode) bool { + x.sort() + return false }) - return &TreeNode{ - Key: mtree.Key, - Values: nil, - Children: children, - Sorted: sorted, - Hide: mtree.Hide, + return &root +} + +func (n *TreeNode) add(path Ref, rule *Rule) { + node, tail := n.find(path) + if len(tail) > 0 { + sub := treeNodeFromRef(tail, rule) + if node.Children == nil { + node.Children = make(map[Value]*TreeNode, 1) + } + node.Children[sub.Key] = sub + node.Sorted = append(node.Sorted, sub.Key) + } else { + if rule != nil { + node.Values = append(node.Values, rule) + } } } @@ -2914,33 +3093,95 @@ func (n *TreeNode) Size() int { // Child returns n's child with key k. func (n *TreeNode) Child(k Value) *TreeNode { switch k.(type) { - case String, Var: + case Ref: + return nil + default: return n.Children[k] } - return nil } // Find dereferences ref along the tree func (n *TreeNode) Find(ref Ref) *TreeNode { node := n for _, r := range ref { - child := node.Child(r.Value) - if child == nil { + node = node.Child(r.Value) + if node == nil { return nil } - node = child } return node } +func (n *TreeNode) find(ref Ref) (*TreeNode, Ref) { + node := n + for i := range ref { + next := node.Child(ref[i].Value) + if next == nil { + tail := make(Ref, len(ref)-i) + copy(tail, ref[i:]) + return node, tail + } + node = next + } + return node, nil +} + // DepthFirst performs a depth-first traversal of the rule tree rooted at n. If // f returns true, traversal will not continue to the children of n. -func (n *TreeNode) DepthFirst(f func(node *TreeNode) bool) { - if !f(n) { - for _, node := range n.Children { - node.DepthFirst(f) +func (n *TreeNode) DepthFirst(f func(*TreeNode) bool) { + if f(n) { + return + } + for _, node := range n.Children { + node.DepthFirst(f) + } +} + +func (n *TreeNode) sort() { + sort.Slice(n.Sorted, func(i, j int) bool { + return n.Sorted[i].Compare(n.Sorted[j]) < 0 + }) +} + +func treeNodeFromRef(ref Ref, rule *Rule) *TreeNode { + depth := len(ref) - 1 + key := ref[depth].Value + node := &TreeNode{ + Key: key, + Children: nil, + } + if rule != nil { + node.Values = []util.T{rule} + } + + for i := len(ref) - 2; i >= 0; i-- { + key := ref[i].Value + node = &TreeNode{ + Key: key, + Children: map[Value]*TreeNode{ref[i+1].Value: node}, + Sorted: []Value{ref[i+1].Value}, } } + return node +} + +// flattenChildren flattens all children's rule refs into a sorted array. +func (n *TreeNode) flattenChildren() []Ref { + ret := newRefSet() + for _, sub := range n.Children { // we only want the children, so don't use n.DepthFirst() right away + sub.DepthFirst(func(x *TreeNode) bool { + for _, r := range x.Values { + rule := r.(*Rule) + ret.AddPrefix(rule.Ref()) + } + return false + }) + } + + sort.Slice(ret.s, func(i, j int) bool { + return ret.s[i].Compare(ret.s[j]) < 0 + }) + return ret.s } // Graph represents the graph of dependencies between rules. @@ -3554,15 +3795,14 @@ func (l *localVarGenerator) Generate() Var { } } -func getGlobals(pkg *Package, rules []Var, imports []*Import) map[Var]*usedRef { +func getGlobals(pkg *Package, rules []Ref, imports []*Import) map[Var]*usedRef { - globals := map[Var]*usedRef{} + globals := make(map[Var]*usedRef, len(rules)) // NB: might grow bigger with imports // Populate globals with exports within the package. - for _, v := range rules { - global := append(Ref{}, pkg.Path...) - global = append(global, &Term{Value: String(v)}) - globals[v] = &usedRef{ref: global} + for _, ref := range rules { + v := ref[0].Value.(Var) + globals[v] = &usedRef{ref: pkg.Path.Append(StringTerm(string(v)))} } // Populate globals with imports. @@ -3670,6 +3910,10 @@ func resolveRefsInRule(globals map[Var]*usedRef, rule *Rule) error { ignore.Push(vars) ignore.Push(declaredVars(rule.Body)) + ref := rule.Head.Ref() + for i := 1; i < len(ref); i++ { + ref[i] = resolveRefsInTerm(globals, ignore, ref[i]) + } if rule.Head.Key != nil { rule.Head.Key = resolveRefsInTerm(globals, ignore, rule.Head.Key) } @@ -4985,7 +5229,7 @@ func isBuiltinRefOrVar(bs map[string]*Builtin, unsafeBuiltinsMap map[string]stru } func isVirtual(node *TreeNode, ref Ref) bool { - for i := 0; i < len(ref); i++ { + for i := range ref { child := node.Child(ref[i].Value) if child == nil { return false @@ -5095,3 +5339,57 @@ func rewriteVarsInRef(vars ...map[Var]Var) varRewriter { return i.(Ref) } } + +// NOTE(sr): This is duplicated with compile/compile.go; but moving it into another location +// would cause a circular dependency -- the refSet definition needs ast.Ref. If we make it +// public in the ast package, the compile package could take it from there, but it would also +// increase our public interface. Let's reconsider if we need it in a third place. +type refSet struct { + s []Ref +} + +func newRefSet(x ...Ref) *refSet { + result := &refSet{} + for i := range x { + result.AddPrefix(x[i]) + } + return result +} + +// ContainsPrefix returns true if r is prefixed by any of the existing refs in the set. +func (rs *refSet) ContainsPrefix(r Ref) bool { + for i := range rs.s { + if r.HasPrefix(rs.s[i]) { + return true + } + } + return false +} + +// AddPrefix inserts r into the set if r is not prefixed by any existing +// refs in the set. If any existing refs are prefixed by r, those existing +// refs are removed. +func (rs *refSet) AddPrefix(r Ref) { + if rs.ContainsPrefix(r) { + return + } + cpy := []Ref{r} + for i := range rs.s { + if !rs.s[i].HasPrefix(r) { + cpy = append(cpy, rs.s[i]) + } + } + rs.s = cpy +} + +// Sorted returns a sorted slice of terms for refs in the set. +func (rs *refSet) Sorted() []*Term { + terms := make([]*Term, len(rs.s)) + for i := range rs.s { + terms[i] = NewTerm(rs.s[i]) + } + sort.Slice(terms, func(i, j int) bool { + return terms[i].Value.Compare(terms[j].Value) < 0 + }) + return terms +} diff --git a/ast/compile_test.go b/ast/compile_test.go index 0de8baae5e..8b02e49609 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -263,7 +263,7 @@ func TestOutputVarsForNode(t *testing.T) { func TestModuleTree(t *testing.T) { - mods := getCompilerTestModules() + mods := getCompilerTestModules() // 7 modules mods["system-mod"] = MustParseModule(` package system.foo @@ -274,8 +274,14 @@ func TestModuleTree(t *testing.T) { p = 1 `) + mods["dots-in-heads"] = MustParseModule(` + package dots + + a.b.c = 12 + d.e.f.g = 34 + `) tree := NewModuleTree(mods) - expectedSize := 9 + expectedSize := 10 if tree.Size() != expectedSize { t.Fatalf("Expected %v but got %v modules", expectedSize, tree.Size()) @@ -294,6 +300,727 @@ func TestModuleTree(t *testing.T) { } } +func TestCompilerGetExports(t *testing.T) { + tests := []struct { + note string + modules []*Module + exports map[string][]string + }{ + { + note: "simple", + modules: modules(`package p + r = 1`), + exports: map[string][]string{"data.p": {"r"}}, + }, + { + note: "simple single-value ref rule", + modules: modules(`package p + q.r.s = 1`), + exports: map[string][]string{"data.p": {"q.r.s"}}, + }, + { + note: "var key single-value ref rule", + modules: modules(`package p + q.r[s] = 1 { s := "foo" }`), + exports: map[string][]string{"data.p": {"q.r"}}, + }, + { + note: "simple multi-value ref rule", + modules: modules(`package p + import future.keywords + + q.r.s contains 1 { true }`), + exports: map[string][]string{"data.p": {"q.r.s"}}, + }, + { + note: "two simple, multiple rules", + modules: modules(`package p + r = 1 + s = 11`, + `package q + x = 2 + y = 22`), + exports: map[string][]string{"data.p": {"r", "s"}, "data.q": {"x", "y"}}, + }, + { + note: "ref head + simple, multiple rules", + modules: modules(`package p.a.b.c + r = 1 + s = 11`, + `package q + a.b.x = 2 + a.b.c.y = 22`), + exports: map[string][]string{ + "data.p.a.b.c": {"r", "s"}, + "data.q": {"a.b.x", "a.b.c.y"}, + }, + }, + { + note: "two ref head, multiple rules", + modules: modules(`package p.a.b.c + r = 1 + s = 11`, + `package p + a.b.x = 2 + a.b.c.y = 22`), + exports: map[string][]string{ + "data.p.a.b.c": {"r", "s"}, + "data.p": {"a.b.x", "a.b.c.y"}, + }, + }, + { + note: "single-value rule with number key", + modules: modules(`package p + q[1] = 1 + q[2] = 2`), + exports: map[string][]string{ + "data.p": {"q[1]", "q[2]"}, // TODO(sr): is this really what we want? + }, + }, + { + note: "single-value (ref) rule with number key", + modules: modules(`package p + a.b.q[1] = 1 + a.b.q[2] = 2`), + exports: map[string][]string{ + "data.p": {"a.b.q[1]", "a.b.q[2]"}, + }, + }, + { + note: "single-value (ref) rule with var key", + modules: modules(`package p + a.b.q[x] = y { x := 1; y := true } + a.b.q[2] = 2`), + exports: map[string][]string{ + "data.p": {"a.b.q", "a.b.q[2]"}, // TODO(sr): GroundPrefix? right thing here? + }, + }, + { // NOTE(sr): An ast.Module can be constructed in various ways, this is to assert that + // our compilation process doesn't explode here if we're fed a Rule that has no Ref. + note: "synthetic", + modules: func() []*Module { + ms := modules(`package p + r = 1`) + ms[0].Rules[0].Head.Reference = nil + return ms + }(), + exports: map[string][]string{"data.p": {"r"}}, + }, + // TODO(sr): add multi-val rule, and ref-with-var single-value rule. + } + + hashMap := func(ms map[string][]string) *util.HashMap { + rules := util.NewHashMap(func(a, b util.T) bool { + switch a := a.(type) { + case Ref: + return a.Equal(b.(Ref)) + case []Ref: + b := b.([]Ref) + if len(b) != len(a) { + return false + } + for i := range a { + if !a[i].Equal(b[i]) { + return false + } + } + return true + default: + panic("unreachable") + } + }, func(v util.T) int { + return v.(Ref).Hash() + }) + for r, rs := range ms { + refs := make([]Ref, len(rs)) + for i := range rs { + refs[i] = toRef(rs[i]) + } + rules.Put(MustParseRef(r), refs) + } + return rules + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + c := NewCompiler() + for i, m := range tc.modules { + c.Modules[fmt.Sprint(i)] = m + c.sorted = append(c.sorted, fmt.Sprint(i)) + } + if exp, act := hashMap(tc.exports), c.getExports(); !exp.Equal(act) { + t.Errorf("expected %v, got %v", exp, act) + } + }) + } +} + +func toRef(s string) Ref { + switch t := MustParseTerm(s).Value.(type) { + case Var: + return Ref{NewTerm(t)} + case Ref: + return t + default: + panic("unreachable") + } +} + +func TestCompilerCheckRuleHeadRefs(t *testing.T) { + + tests := []struct { + note string + modules []*Module + expected *Rule + err string + }{ + { + note: "ref contains var", + modules: modules( + `package x + p.q[i].r = 1 { i := 10 }`, + ), + err: "rego_type_error: rule head must only contain string terms (except for last): i", + }, + { + note: "valid: ref is single-value rule with var key", + modules: modules( + `package x + p.q.r[i] { i := 10 }`, + ), + }, + { + note: "valid: ref is single-value rule with var key and value", + modules: modules( + `package x + p.q.r[i] = j { i := 10; j := 11 }`, + ), + }, + { + note: "valid: ref is single-value rule with var key and static value", + modules: modules( + `package x + p.q.r[i] = "ten" { i := 10 }`, + ), + }, + { + note: "valid: ref is single-value rule with number key", + modules: modules( + `package x + p.q.r[1] { true }`, + ), + }, + { + note: "valid: ref is single-value rule with boolean key", + modules: modules( + `package x + p.q.r[true] { true }`, + ), + }, + { + note: "valid: ref is single-value rule with null key", + modules: modules( + `package x + p.q.r[null] { true }`, + ), + }, + { + note: "valid: ref is single-value rule with set literal key", + modules: modules( + `package x + p.q.r[set()] { true }`, + ), + }, + { + note: "valid: ref is single-value rule with array literal key", + modules: modules( + `package x + p.q.r[[]] { true }`, + ), + }, + { + note: "valid: ref is single-value rule with object literal key", + modules: modules( + `package x + p.q.r[{}] { true }`, + ), + }, + { + note: "valid: ref is single-value rule with ref key", + modules: modules( + `package x + x := [1,2,3] + p.q.r[x[i]] { i := 0}`, + ), + }, + { + note: "invalid: ref in ref", + modules: modules( + `package x + p.q[arr[0]].r { i := 10 }`, + ), + err: "rego_type_error: rule head must only contain string terms (except for last): arr[0]", + }, + { + note: "invalid: non-string in ref (not last position)", + modules: modules( + `package x + p.q[10].r { true }`, + ), + err: "rego_type_error: rule head must only contain string terms (except for last): 10", + }, + { + note: "valid: multi-value with var key", + modules: modules( + `package x + p.q.r contains i if i := 10`, + ), + }, + { + note: "rewrite: single-value with non-var key (ref)", + modules: modules( + `package x + p.q.r[y.z] if y := {"z": "a"}`, + ), + expected: MustParseRule(`p.q.r[__local0__] { y := {"z": "a"}; __local0__ = y.z }`), + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + mods := make(map[string]*Module, len(tc.modules)) + for i, m := range tc.modules { + mods[fmt.Sprint(i)] = m + } + c := NewCompiler() + c.Modules = mods + compileStages(c, c.rewriteRuleHeadRefs) + if tc.err != "" { + assertCompilerErrorStrings(t, c, []string{tc.err}) + } else { + if len(c.Errors) > 0 { + t.Fatalf("expected no errors, got %v", c.Errors) + } + if tc.expected != nil { + assertRulesEqual(t, tc.expected, mods["0"].Rules[0]) + } + } + }) + } +} + +func TestRuleTreeWithDotsInHeads(t *testing.T) { + + // TODO(sr): multi-val with var key in ref + tests := []struct { + note string + modules []*Module + size int // expected tree size = number of leaves + depth int // expected tree depth + }{ + { + note: "two modules, same package, one rule each", + modules: modules( + `package x + p.q.r = 1`, + `package x + p.q.w = 2`, + ), + size: 2, + }, + { + note: "two modules, sub-package, one rule each", + modules: modules( + `package x + p.q.r = 1`, + `package x.p + q.w.z = 2`, + ), + size: 2, + }, + { + note: "three modules, sub-package, incl simple rule", + modules: modules( + `package x + p.q.r = 1`, + `package x.p + q.w.z = 2`, + `package x.p.q.w + y = 3`, + ), + size: 3, + }, + { + note: "simple: two modules", + modules: modules( + `package x + p.q.r = 1`, + `package y + p.q.w = 2`, + ), + size: 2, + }, + { + note: "conflict: one module", + modules: modules( + `package q + p[x] = 1 + p = 2`, + ), + size: 2, + }, + { + note: "conflict: two modules", + modules: modules( + `package q + p.r.s[x] = 1`, + `package q.p + r.s = 2 if true`, + ), + size: 2, + }, + { + note: "simple: two modules, one using ref head, one package path", + modules: modules( + `package x + p.q.r = 1 { input == 1 }`, + `package x.p.q + r = 2 { input == 2 }`, + ), + size: 2, + }, + { + note: "conflict: two modules, both using ref head, different package paths", + modules: modules( + `package x + p.q.r = 1 { input == 1 }`, // x.p.q.r = 1 + `package x.p + q.r.s = 2 { input == 2 }`, // x.p.q.r.s = 2 + ), + size: 2, + }, + { + note: "overlapping: one module, two ref head", + modules: modules( + `package x + p.q.r = 1 + p.q.w.v = 2`, + ), + size: 2, + depth: 6, + }, + { + note: "last ref term != string", + modules: modules( + `package x + p.q.w[1] = 2 + p.q.w[{"foo": "baz"}] = 20 + p.q.x[true] = false + p.q.x[y] = y { y := "y" }`, + ), + size: 4, + depth: 6, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + c := NewCompiler() + for i, m := range tc.modules { + c.Modules[fmt.Sprint(i)] = m + c.sorted = append(c.sorted, fmt.Sprint(i)) + } + compileStages(c, c.setRuleTree) + if len(c.Errors) > 0 { + t.Fatal(c.Errors) + } + tree := c.RuleTree + tree.DepthFirst(func(n *TreeNode) bool { + t.Log(n) + if !sort.SliceIsSorted(n.Sorted, func(i, j int) bool { + return n.Sorted[i].Compare(n.Sorted[j]) < 0 + }) { + t.Errorf("expected sorted to be sorted: %v", n.Sorted) + } + return false + }) + if tc.depth > 0 { + if exp, act := tc.depth, depth(tree); exp != act { + t.Errorf("expected tree depth %d, got %d", exp, act) + } + } + if exp, act := tc.size, tree.Size(); exp != act { + t.Errorf("expected tree size %d, got %d", exp, act) + } + }) + } +} + +func TestRuleTreeWithVars(t *testing.T) { + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + + t.Run("simple single-value rule", func(t *testing.T) { + mod0 := `package a.b +c.d.e = 1 if true` + + mods := map[string]*Module{"0.rego": MustParseModuleWithOpts(mod0, opts)} + tree := NewRuleTree(NewModuleTree(mods)) + + node := tree.Find(MustParseRef("data.a.b.c.d.e")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 1, len(node.Values); exp != act { + t.Errorf("expected %d values, found %d", exp, act) + } + if exp, act := 0, len(node.Children); exp != act { + t.Errorf("expected %d children, found %d", exp, act) + } + if exp, act := MustParseRef("c.d.e"), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + }) + + t.Run("two single-value rules", func(t *testing.T) { + mod0 := `package a.b +c.d.e = 1 if true` + mod1 := `package a.b.c +d.e = 2 if true` + + mods := map[string]*Module{ + "0.rego": MustParseModuleWithOpts(mod0, opts), + "1.rego": MustParseModuleWithOpts(mod1, opts), + } + tree := NewRuleTree(NewModuleTree(mods)) + + node := tree.Find(MustParseRef("data.a.b.c.d.e")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 2, len(node.Values); exp != act { + t.Errorf("expected %d values, found %d", exp, act) + } + if exp, act := 0, len(node.Children); exp != act { + t.Errorf("expected %d children, found %d", exp, act) + } + if exp, act := MustParseRef("c.d.e"), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if exp, act := MustParseRef("d.e"), node.Values[1].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + }) + + t.Run("one multi-value rule, one single-value, with var", func(t *testing.T) { + mod0 := `package a.b +c.d.e.g contains 1 if true` + mod1 := `package a.b.c +d.e.f = 2 if true` + + mods := map[string]*Module{ + "0.rego": MustParseModuleWithOpts(mod0, opts), + "1.rego": MustParseModuleWithOpts(mod1, opts), + } + tree := NewRuleTree(NewModuleTree(mods)) + + // var-key rules should be included in the results + node := tree.Find(MustParseRef("data.a.b.c.d.e.g")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 1, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d", exp, act) + } + if exp, act := 0, len(node.Children); exp != act { + t.Fatalf("expected %d children, found %d", exp, act) + } + node = tree.Find(MustParseRef("data.a.b.c.d.e.f")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 1, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d", exp, act) + } + if exp, act := MustParseRef("d.e.f"), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + }) + + t.Run("two multi-value rules, back compat", func(t *testing.T) { + mod0 := `package a +b[c] { c := "foo" }` + mod1 := `package a +b[d] { d := "bar" }` + + mods := map[string]*Module{ + "0.rego": MustParseModuleWithOpts(mod0, opts), + "1.rego": MustParseModuleWithOpts(mod1, opts), + } + tree := NewRuleTree(NewModuleTree(mods)) + + node := tree.Find(MustParseRef("data.a.b")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 2, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d: %v", exp, act, node.Values) + } + if exp, act := 0, len(node.Children); exp != act { + t.Errorf("expected %d children, found %d", exp, act) + } + if exp, act := (Ref{VarTerm("b")}), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if act := node.Values[0].(*Rule).Head.Value; act != nil { + t.Errorf("expected rule value nil, found %v", act) + } + if exp, act := VarTerm("c"), node.Values[0].(*Rule).Head.Key; !exp.Equal(act) { + t.Errorf("expected rule key %v, found %v", exp, act) + } + if exp, act := (Ref{VarTerm("b")}), node.Values[1].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if act := node.Values[1].(*Rule).Head.Value; act != nil { + t.Errorf("expected rule value nil, found %v", act) + } + if exp, act := VarTerm("d"), node.Values[1].(*Rule).Head.Key; !exp.Equal(act) { + t.Errorf("expected rule key %v, found %v", exp, act) + } + }) + + t.Run("two multi-value rules, back compat with short style", func(t *testing.T) { + mod0 := `package a +b[1]` + mod1 := `package a +b[2]` + mods := map[string]*Module{ + "0.rego": MustParseModuleWithOpts(mod0, opts), + "1.rego": MustParseModuleWithOpts(mod1, opts), + } + tree := NewRuleTree(NewModuleTree(mods)) + + node := tree.Find(MustParseRef("data.a.b")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 2, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d: %v", exp, act, node.Values) + } + if exp, act := 0, len(node.Children); exp != act { + t.Errorf("expected %d children, found %d", exp, act) + } + if exp, act := (Ref{VarTerm("b")}), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if act := node.Values[0].(*Rule).Head.Value; act != nil { + t.Errorf("expected rule value nil, found %v", act) + } + if exp, act := IntNumberTerm(1), node.Values[0].(*Rule).Head.Key; !exp.Equal(act) { + t.Errorf("expected rule key %v, found %v", exp, act) + } + if exp, act := (Ref{VarTerm("b")}), node.Values[1].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if act := node.Values[1].(*Rule).Head.Value; act != nil { + t.Errorf("expected rule value nil, found %v", act) + } + if exp, act := IntNumberTerm(2), node.Values[1].(*Rule).Head.Key; !exp.Equal(act) { + t.Errorf("expected rule key %v, found %v", exp, act) + } + }) + + t.Run("two single-value rules, back compat with short style", func(t *testing.T) { + mod0 := `package a +b[1] = 1` + mod1 := `package a +b[2] = 2` + mods := map[string]*Module{ + "0.rego": MustParseModuleWithOpts(mod0, opts), + "1.rego": MustParseModuleWithOpts(mod1, opts), + } + tree := NewRuleTree(NewModuleTree(mods)) + + // branch point + node := tree.Find(MustParseRef("data.a.b")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 0, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d: %v", exp, act, node.Values) + } + if exp, act := 2, len(node.Children); exp != act { + t.Fatalf("expected %d children, found %d", exp, act) + } + + // branch 1 + node = tree.Find(MustParseRef("data.a.b[1]")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 1, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d: %v", exp, act, node.Values) + } + if exp, act := MustParseRef("b[1]"), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if exp, act := IntNumberTerm(1), node.Values[0].(*Rule).Head.Value; !exp.Equal(act) { + t.Errorf("expected rule value %v, found %v", exp, act) + } + if exp, act := IntNumberTerm(1), node.Values[0].(*Rule).Head.Key; !exp.Equal(act) { + t.Errorf("expected rule key %v, found %v", exp, act) + } + + // branch 2 + node = tree.Find(MustParseRef("data.a.b[2]")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 1, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d: %v", exp, act, node.Values) + } + if exp, act := MustParseRef("b[2]"), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if exp, act := IntNumberTerm(2), node.Values[0].(*Rule).Head.Value; !exp.Equal(act) { + t.Errorf("expected rule value %v, found %v", exp, act) + } + if exp, act := IntNumberTerm(2), node.Values[0].(*Rule).Head.Key; !exp.Equal(act) { + t.Errorf("expected rule key %v, found %v", exp, act) + } + }) + + // NOTE(sr): Now this test seems obvious, but it's a bug that had snuck into the + // NewRuleTree code during development. + t.Run("root node and data node unhidden if there are no system nodes", func(t *testing.T) { + mod0 := `package a +p = 1` + mods := map[string]*Module{ + "0.rego": MustParseModuleWithOpts(mod0, opts), + } + tree := NewRuleTree(NewModuleTree(mods)) + + if exp, act := false, tree.Hide; act != exp { + t.Errorf("expected tree.Hide=%v, got %v", exp, act) + } + dataNode := tree.Child(Var("data")) + if dataNode == nil { + t.Fatal("expected data node") + } + if exp, act := false, dataNode.Hide; act != exp { + t.Errorf("expected dataNode.Hide=%v, got %v", exp, act) + } + }) +} + +func depth(n *TreeNode) int { + d := -1 + for _, m := range n.Children { + if d0 := depth(m); d0 > d { + d = d0 + } + } + return d + 1 +} func TestModuleTreeFilenameOrder(t *testing.T) { // NOTE(sr): It doesn't matter that these are conflicting; but that's where it @@ -328,16 +1055,23 @@ func TestRuleTree(t *testing.T) { mods["non-system-mod"] = MustParseModule(` package user.system - p = 1 - `) - mods["mod-incr"] = MustParseModule(`package a.b.c + p = 1`) + mods["mod-incr"] = MustParseModule(` + package a.b.c -s[1] { true } -s[2] { true }`, + s[1] { true } + s[2] { true }`, ) + mods["dots-in-heads"] = MustParseModule(` + package dots + + a.b.c = 12 + d.e.f.g = 34 + `) + tree := NewRuleTree(NewModuleTree(mods)) - expectedNumRules := 23 + expectedNumRules := 25 if tree.Size() != expectedNumRules { t.Errorf("Expected %v but got %v rules", expectedNumRules, tree.Size()) @@ -345,14 +1079,18 @@ s[2] { true }`, // Check that empty packages are represented as leaves with no rules. node := tree.Children[Var("data")].Children[String("a")].Children[String("b")].Children[String("empty")] - if node == nil || len(node.Children) != 0 || len(node.Values) != 0 { t.Fatalf("Unexpected nil value or non-empty leaf of non-leaf node: %v", node) } + // Check that root node is not hidden + if exp, act := false, tree.Hide; act != exp { + t.Errorf("expected tree.Hide=%v, got %v", exp, act) + } + system := tree.Child(Var("data")).Child(String("system")) if !system.Hide { - t.Fatalf("Expected system node to be hidden") + t.Fatalf("Expected system node to be hidden: %v", system) } if system.Child(String("foo")).Hide { @@ -692,16 +1430,19 @@ func TestCompilerErrorLimit(t *testing.T) { func TestCompilerCheckSafetyHead(t *testing.T) { c := NewCompiler() c.Modules = getCompilerTestModules() - c.Modules["newMod"] = MustParseModule(`package a.b - -unboundKey[x] = y { q[y] = {"foo": [1, 2, [{"bar": y}]]} } -unboundVal[y] = x { q[y] = {"foo": [1, 2, [{"bar": y}]]} } -unboundCompositeVal[y] = [{"foo": x, "bar": y}] { q[y] = {"foo": [1, 2, [{"bar": y}]]} } -unboundCompositeKey[[{"x": x}]] { q[y] } -unboundBuiltinOperator = eq { x = 1 } + popts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + c.Modules["newMod"] = MustParseModuleWithOpts(`package a.b + +unboundKey[x1] = y { q[y] = {"foo": [1, 2, [{"bar": y}]]} } +unboundVal[y] = x2 { q[y] = {"foo": [1, 2, [{"bar": y}]]} } +unboundCompositeVal[y] = [{"foo": x3, "bar": y}] { q[y] = {"foo": [1, 2, [{"bar": y}]]} } +unboundCompositeKey[[{"x": x4}]] { q[y] } +unboundBuiltinOperator = eq { 4 = 1 } unboundElse { false } else = else_var { true } -`, - ) +c.d.e[x5] if true +f.g.h[y] = x6 if y := "y" +i.j.k contains x7 if true +`, popts) compileStages(c, c.checkSafetyRuleHeads) makeErrMsg := func(v string) string { @@ -709,10 +1450,13 @@ unboundElse { false } else = else_var { true } } expected := []string{ - makeErrMsg("x"), - makeErrMsg("x"), - makeErrMsg("x"), - makeErrMsg("x"), + makeErrMsg("x1"), + makeErrMsg("x2"), + makeErrMsg("x3"), + makeErrMsg("x4"), + makeErrMsg("x5"), + makeErrMsg("x6"), + makeErrMsg("x7"), makeErrMsg("eq"), makeErrMsg("else_var"), } @@ -1018,20 +1762,37 @@ q[1] { true }`, default foo = 1 default foo = 2 -foo = 3 { true }`, - "mod4.rego": `package adrules.arity +foo = 3 { true } + +default p.q.bar = 1 +default p.q.bar = 2 +p.q.bar = 3 { true } +`, + "mod4.rego": `package badrules.arity f(1) { true } f { true } g(1) { true } -g(1,2) { true }`, +g(1,2) { true } + +p.q.h(1) { true } +p.q.h { true } + +p.q.i(1) { true } +p.q.i(1,2) { true }`, "mod5.rego": `package badrules.dataoverlap p { true }`, "mod6.rego": `package badrules.existserr -p { true }`}) +p { true }`, + + "mod7.rego": `package badrules.foo +import future.keywords + +bar.baz contains "quz" if true`, + }) c.WithPathConflictsCheck(func(path []string) (bool, error) { if reflect.DeepEqual(path, []string{"badrules", "dataoverlap", "p"}) { @@ -1047,18 +1808,155 @@ p { true }`}) expected := []string{ "rego_compile_error: conflict check for data path badrules/existserr/p: unexpected error", "rego_compile_error: conflicting rule for data path badrules/dataoverlap/p found", - "rego_type_error: conflicting rules named f found", - "rego_type_error: conflicting rules named g found", - "rego_type_error: conflicting rules named p found", - "rego_type_error: conflicting rules named q found", - "rego_type_error: multiple default rules named foo found", - "rego_type_error: package badrules.r conflicts with rule defined at mod1.rego:7", - "rego_type_error: package badrules.r conflicts with rule defined at mod1.rego:8", + "rego_type_error: conflicting rules data.badrules.arity.f found", + "rego_type_error: conflicting rules data.badrules.arity.g found", + "rego_type_error: conflicting rules data.badrules.arity.p.q.h found", + "rego_type_error: conflicting rules data.badrules.arity.p.q.i found", + "rego_type_error: conflicting rules data.badrules.p[x] found", + "rego_type_error: conflicting rules data.badrules.q found", + "rego_type_error: multiple default rules data.badrules.defkw.foo found", + "rego_type_error: multiple default rules data.badrules.defkw.p.q.bar found", + "rego_type_error: package badrules.r conflicts with rule r[x] defined at mod1.rego:7", + "rego_type_error: package badrules.r conflicts with rule r[x] defined at mod1.rego:8", } assertCompilerErrorStrings(t, c, expected) } +func TestCompilerCheckRuleConflictsDotsInRuleHeads(t *testing.T) { + + tests := []struct { + note string + modules []*Module + err string + }{ + { + note: "arity mismatch, ref and non-ref rule", + modules: modules( + `package pkg + p.q.r { true }`, + `package pkg.p.q + r(_) = 2`), + err: "rego_type_error: conflicting rules data.pkg.p.q.r found", + }, + { + note: "two default rules, ref and non-ref rule", + modules: modules( + `package pkg + default p.q.r = 3 + p.q.r { true }`, + `package pkg.p.q + default r = 4 + r = 2`), + err: "rego_type_error: multiple default rules data.pkg.p.q.r found", + }, + { + note: "arity mismatch, ref and ref rule", + modules: modules( + `package pkg.a.b + p.q.r { true }`, + `package pkg.a + b.p.q.r(_) = 2`), + err: "rego_type_error: conflicting rules data.pkg.a.b.p.q.r found", + }, + { + note: "two default rules, ref and ref rule", + modules: modules( + `package pkg + default p.q.w.r = 3 + p.q.w.r { true }`, + `package pkg.p + default q.w.r = 4 + q.w.r = 2`), + err: "rego_type_error: multiple default rules data.pkg.p.q.w.r found", + }, + { + note: "multi-value + single-value rules, both with same ref prefix", + modules: modules( + `package pkg + p.q.w[x] = 1 if x := "foo"`, + `package pkg + p.q.w contains "bar"`), + err: "rego_type_error: conflicting rules data.pkg.p.q.w found", + }, + { + note: "two multi-value rules, both with same ref", + modules: modules( + `package pkg + p.q.w contains "baz"`, + `package pkg + p.q.w contains "bar"`), + }, + { + note: "module conflict: non-ref rule", + modules: modules( + `package pkg.q + r { true }`, + `package pkg.q.r`), + err: "rego_type_error: package pkg.q.r conflicts with rule r defined at mod0.rego:2", + }, + { + note: "module conflict: ref rule", + modules: modules( + `package pkg + p.q.r { true }`, + `package pkg.p.q.r`), + err: "rego_type_error: package pkg.p.q.r conflicts with rule p.q.r defined at mod0.rego:2", + }, + { + note: "single-value with other rule overlap", + modules: modules( + `package pkg + p.q.r { true }`, + `package pkg + p.q.r.s { true }`), + err: "rego_type_error: single-value rule data.pkg.p.q.r conflicts with [data.pkg.p.q.r.s]", + }, + { + note: "single-value with other rule overlap", + modules: modules( + `package pkg + p.q.r { true } + p.q.r.s { true } + p.q.r.t { true }`), + err: "rego_type_error: single-value rule data.pkg.p.q.r conflicts with [data.pkg.p.q.r.s data.pkg.p.q.r.t]", + }, + { + note: "single-value with other rule overlap, unknown key", + modules: modules( + `package pkg + p.q[r] = x { r = input.key; x = input.foo } + p.q.r.s = x { true } + `), + err: "rego_type_error: single-value rule data.pkg.p.q[r] conflicts with [data.pkg.p.q.r.s]", + }, + { + note: "single-value rule with known and unknown key", + modules: modules( + `package pkg + p.q[r] = x { r = input.key; x = input.foo } + p.q.s = "x" { true } + `), + }, + } + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + mods := make(map[string]*Module, len(tc.modules)) + for i, m := range tc.modules { + mods[fmt.Sprint(i)] = m + } + c := NewCompiler() + c.Modules = mods + compileStages(c, c.checkRuleConflicts) + if tc.err != "" { + assertCompilerErrorStrings(t, c, []string{tc.err}) + } else { + assertCompilerErrorStrings(t, c, []string{}) + } + }) + } +} + func TestCompilerCheckUndefinedFuncs(t *testing.T) { module := ` @@ -2072,6 +2970,84 @@ func assertErrors(t *testing.T, actual Errors, expected Errors, assertLocation b } } +// NOTE(sr): the tests below this function are unwieldy, let's keep adding new ones to this one +func TestCompilerResolveAllRefsNewTests(t *testing.T) { + tests := []struct { + note string + mod string + exp string + extra string + }{ + { + note: "ref-rules referenced in body", + mod: `package test +a.b.c = 1 +q if a.b.c == 1 +`, + exp: `package test +a.b.c = 1 { true } +q if data.test.a.b.c = 1 +`, + }, + { + // NOTE(sr): This is a conservative extension of how it worked before: + // we will not automatically extend references to other parts of the rule tree, + // only to ref rules defined on the same level. + note: "ref-rules from other module referenced in body", + mod: `package test +q if a.b.c == 1 +`, + extra: `package test +a.b.c = 1 +`, + exp: `package test +q if data.test.a.b.c = 1 +`, + }, + { + note: "single-value rule in comprehension in call", // NOTE(sr): this is TestRego/partialiter/objects_conflict + mod: `package test +p := count([x | q[x]]) +q[1] = 1 +`, + exp: `package test +p := __local0__ { true; __local1__ = [x | data.test.q[x]]; count(__local1__, __local0__) } +q[1] = 1 +`, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + c := NewCompiler() + mod, err := ParseModuleWithOpts("test.rego", tc.mod, opts) + if err != nil { + t.Fatal(err) + } + exp, err := ParseModuleWithOpts("test.rego", tc.exp, opts) + if err != nil { + t.Fatal(err) + } + mods := map[string]*Module{"test": mod} + if tc.extra != "" { + extra, err := ParseModuleWithOpts("test.rego", tc.extra, opts) + if err != nil { + t.Fatal(err) + } + mods["extra"] = extra + } + c.Compile(mods) + if err := c.Errors; len(err) > 0 { + t.Errorf("compile module: %v", err) + } + if act := c.Modules["test"]; !exp.Equal(act) { + t.Errorf("compiled: expected %v, got %v", exp, act) + } + }) + } +} + func TestCompilerResolveAllRefs(t *testing.T) { c := NewCompiler() c.Modules = getCompilerTestModules() @@ -2180,6 +3156,12 @@ p[foo[bar[i]]] = {"baz": baz} { true }`) } }`, ParserOptions{unreleasedKeywords: true, FutureKeywords: []string{"every", "in"}}) + c.Modules["heads_with_dots"] = MustParseModule(`package heads_with_dots + + this_is_not = true + this.is.dotted { this_is_not } + `) + compileStages(c, c.resolveAllRefs) assertNotFailed(t, c) @@ -2317,6 +3299,15 @@ p[foo[bar[i]]] = {"baz": baz} { true }`) gt10 := MustParseExpr("x > 10") gt10.Index++ // TODO(sr): why? assertExprEqual(t, everyExpr.Body[1], gt10) + + // head refs are kept as-is, but their bodies are replaced. + mod := c.Modules["heads_with_dots"] + rule := mod.Rules[1] + body := rule.Body[0].Terms.(*Term) + assertTermEqual(t, body, MustParseTerm("data.heads_with_dots.this_is_not")) + if act, exp := rule.Head.Ref(), MustParseRef("this.is.dotted"); act.Compare(exp) != 0 { + t.Errorf("expected %v to match %v", act, exp) + } } func TestCompilerResolveErrors(t *testing.T) { @@ -2340,48 +3331,91 @@ func TestCompilerResolveErrors(t *testing.T) { } func TestCompilerRewriteTermsInHead(t *testing.T) { - c := NewCompiler() - c.Modules["head"] = MustParseModule(`package head + popts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + tests := []struct { + note string + mod *Module + exp *Rule + }{ + { + note: "imports", + mod: MustParseModule(`package head import data.doc1 as bar import data.doc2 as corge import input.x.y.foo import input.qux as baz p[foo[bar[i]]] = {"baz": baz, "corge": corge} { true } +`), + exp: MustParseRule(`p[__local0__] = __local1__ { true; __local0__ = input.x.y.foo[data.doc1[i]]; __local1__ = {"baz": input.qux, "corge": data.doc2} }`), + }, + { + note: "array comprehension value", + mod: MustParseModule(`package head q = [true | true] { true } +`), + exp: MustParseRule(`q = __local0__ { true; __local0__ = [true | true] }`), + }, + { + note: "object comprehension value", + mod: MustParseModule(`package head r = {"true": true | true} { true } +`), + exp: MustParseRule(`r = __local0__ { true; __local0__ = {"true": true | true} }`), + }, + { + note: "set comprehension value", + mod: MustParseModule(`package head s = {true | true} { true } - +`), + exp: MustParseRule(`s = __local0__ { true; __local0__ = {true | true} }`), + }, + { + note: "import in else value", + mod: MustParseModule(`package head +import input.qux as baz elsekw { false } else = baz { true } -`) - - compileStages(c, c.rewriteRefsInHead) - assertNotFailed(t, c) - - rule1 := c.Modules["head"].Rules[0] - expected1 := MustParseRule(`p[__local0__] = __local1__ { true; __local0__ = input.x.y.foo[data.doc1[i]]; __local1__ = {"baz": input.qux, "corge": data.doc2} }`) - assertRulesEqual(t, rule1, expected1) - - rule2 := c.Modules["head"].Rules[1] - expected2 := MustParseRule(`q = __local2__ { true; __local2__ = [true | true] }`) - assertRulesEqual(t, rule2, expected2) - - rule3 := c.Modules["head"].Rules[2] - expected3 := MustParseRule(`r = __local3__ { true; __local3__ = {"true": true | true} }`) - assertRulesEqual(t, rule3, expected3) - - rule4 := c.Modules["head"].Rules[3] - expected4 := MustParseRule(`s = __local4__ { true; __local4__ = {true | true} }`) - assertRulesEqual(t, rule4, expected4) +`), + exp: MustParseRule(`elsekw { false } else = __local0__ { true; __local0__ = input.qux }`), + }, + { + note: "import ref in last ref head term", + mod: MustParseModule(`package head +import data.doc1 as bar +x.y.z[bar[i]] = true +`), + exp: MustParseRule(`x.y.z[__local0__] = true { true; __local0__ = data.doc1[i] }`), + }, + { + note: "import ref in multi-value ref rule", + mod: MustParseModule(`package head +import future.keywords.if +import future.keywords.contains +import data.doc1 as bar +x.y.w contains bar[i] if true +`), + exp: func() *Rule { + exp, _ := ParseRuleWithOpts(`x.y.w contains __local0__ if {true; __local0__ = data.doc1[i] }`, popts) + return exp + }(), + }, + } - rule5 := c.Modules["head"].Rules[4] - expected5 := MustParseRule(`elsekw { false } else = __local5__ { true; __local5__ = input.qux }`) - assertRulesEqual(t, rule5, expected5) + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + c := NewCompiler() + c.Modules["head"] = tc.mod + compileStages(c, c.rewriteRefsInHead) + assertNotFailed(t, c) + act := c.Modules["head"].Rules[0] + assertRulesEqual(t, act, tc.exp) + }) + } } func TestCompilerRewriteRegoMetadataCalls(t *testing.T) { @@ -3384,6 +4418,25 @@ func TestRewriteDeclaredVars(t *testing.T) { p { __local1__ = data.test.y; data.test.q[[__local1__, __local0__]] } `, }, + { + note: "single-value rule with ref head", + module: ` + package test + + p.r.q[s] = t { + t := 1 + s := input.foo + } + `, + exp: ` + package test + + p.r.q[__local1__] = __local0__ { + __local0__ = 1 + __local1__ = input.foo + } + `, + }, { note: "rewrite some x in xs", module: ` @@ -3873,7 +4926,7 @@ func TestRewriteDeclaredVars(t *testing.T) { for _, tc := range tests { t.Run(tc.note, func(t *testing.T) { - opts := CompileOpts{ParserOptions: ParserOptions{FutureKeywords: []string{"in", "every"}, unreleasedKeywords: true}} + opts := CompileOpts{ParserOptions: ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true}} compiler, err := CompileModulesWithOpt(map[string]string{"test.rego": tc.module}, opts) if tc.wantErr != nil { if err == nil { @@ -5560,37 +6613,41 @@ dataref = true { data }`, compileStages(c, c.checkRecursion) - makeRuleErrMsg := func(rule string, loop ...string) string { - return fmt.Sprintf("rego_recursion_error: rule %v is recursive: %v", rule, strings.Join(loop, " -> ")) + makeRuleErrMsg := func(pkg, rule string, loop ...string) string { + l := make([]string, len(loop)) + for i, lo := range loop { + l[i] = "data." + pkg + "." + lo + } + return fmt.Sprintf("rego_recursion_error: rule data.%s.%s is recursive: %v", pkg, rule, strings.Join(l, " -> ")) } expected := []string{ - makeRuleErrMsg("s", "s", "t", "s"), - makeRuleErrMsg("t", "t", "s", "t"), - makeRuleErrMsg("a", "a", "b", "c", "e", "a"), - makeRuleErrMsg("b", "b", "c", "e", "a", "b"), - makeRuleErrMsg("c", "c", "e", "a", "b", "c"), - makeRuleErrMsg("e", "e", "a", "b", "c", "e"), - makeRuleErrMsg("p", "p", "q", "p"), - makeRuleErrMsg("q", "q", "p", "q"), - makeRuleErrMsg("acq", "acq", "acp", "acq"), - makeRuleErrMsg("acp", "acp", "acq", "acp"), - makeRuleErrMsg("np", "np", "nq", "np"), - makeRuleErrMsg("nq", "nq", "np", "nq"), - makeRuleErrMsg("prefix", "prefix", "prefix"), - makeRuleErrMsg("dataref", "dataref", "dataref"), - makeRuleErrMsg("else_self", "else_self", "else_self"), - makeRuleErrMsg("elsetop", "elsetop", "elsemid", "elsebottom", "elsetop"), - makeRuleErrMsg("elsemid", "elsemid", "elsebottom", "elsetop", "elsemid"), - makeRuleErrMsg("elsebottom", "elsebottom", "elsetop", "elsemid", "elsebottom"), - makeRuleErrMsg("fn", "fn", "fn"), - makeRuleErrMsg("foo", "foo", "bar", "foo"), - makeRuleErrMsg("bar", "bar", "foo", "bar"), - makeRuleErrMsg("bar", "bar", "p", "foo", "bar"), - makeRuleErrMsg("foo", "foo", "bar", "p", "foo"), - makeRuleErrMsg("p", "p", "foo", "bar", "p"), - makeRuleErrMsg("everyp", "everyp", "everyp"), - makeRuleErrMsg("everyq", "everyq", "everyq"), + makeRuleErrMsg("rec", "s", "s", "t", "s"), + makeRuleErrMsg("rec", "t", "t", "s", "t"), + makeRuleErrMsg("rec", "a", "a", "b", "c", "e", "a"), + makeRuleErrMsg("rec", "b", "b", "c", "e", "a", "b"), + makeRuleErrMsg("rec", "c", "c", "e", "a", "b", "c"), + makeRuleErrMsg("rec", "e", "e", "a", "b", "c", "e"), + `rego_recursion_error: rule data.rec3.p[x] is recursive: data.rec3.p[x] -> data.rec4.q[x] -> data.rec3.p[x]`, // NOTE(sr): these two are hardcoded: they are + `rego_recursion_error: rule data.rec4.q[x] is recursive: data.rec4.q[x] -> data.rec3.p[x] -> data.rec4.q[x]`, // the only ones not fitting the pattern. + makeRuleErrMsg("rec5", "acq", "acq", "acp", "acq"), + makeRuleErrMsg("rec5", "acp", "acp", "acq", "acp"), + makeRuleErrMsg("rec6", "np[x]", "np[x]", "nq[x]", "np[x]"), + makeRuleErrMsg("rec6", "nq[x]", "nq[x]", "np[x]", "nq[x]"), + makeRuleErrMsg("rec7", "prefix", "prefix", "prefix"), + makeRuleErrMsg("rec8", "dataref", "dataref", "dataref"), + makeRuleErrMsg("rec9", "else_self", "else_self", "else_self"), + makeRuleErrMsg("rec9", "elsetop", "elsetop", "elsemid", "elsebottom", "elsetop"), + makeRuleErrMsg("rec9", "elsemid", "elsemid", "elsebottom", "elsetop", "elsemid"), + makeRuleErrMsg("rec9", "elsebottom", "elsebottom", "elsetop", "elsemid", "elsebottom"), + makeRuleErrMsg("f0", "fn", "fn", "fn"), + makeRuleErrMsg("f1", "foo", "foo", "bar", "foo"), + makeRuleErrMsg("f1", "bar", "bar", "foo", "bar"), + makeRuleErrMsg("f2", "bar", "bar", "p[x]", "foo", "bar"), + makeRuleErrMsg("f2", "foo", "foo", "bar", "p[x]", "foo"), + makeRuleErrMsg("f2", "p[x]", "p[x]", "foo", "bar", "p[x]"), + makeRuleErrMsg("everymod", "everyp", "everyp", "everyp"), + makeRuleErrMsg("everymod", "everyq", "everyq", "everyq"), } result := compilerErrsToStringSlice(c.Errors) @@ -5612,28 +6669,38 @@ func TestCompilerCheckDynamicRecursion(t *testing.T) { // references. For more background info, see // . - for note, mod := range map[string]*Module{ - "recursion": MustParseModule(` + for _, tc := range []struct { + note, err string + mod *Module + }{ + { + note: "recursion", + mod: MustParseModule(` package recursion pkg = "recursion" foo[x] { data[pkg]["foo"][x] } `), - "system.main": MustParseModule(` + err: "rego_recursion_error: rule data.recursion.foo is recursive: data.recursion.foo -> data.recursion.foo", + }, + {note: "system.main", + mod: MustParseModule(` package system.main foo { - data[input] + data[input] } `), + err: "rego_recursion_error: rule data.system.main.foo is recursive: data.system.main.foo -> data.system.main.foo", + }, } { - t.Run(note, func(t *testing.T) { + t.Run(tc.note, func(t *testing.T) { c := NewCompiler() - c.Modules = map[string]*Module{note: mod} + c.Modules = map[string]*Module{tc.note: tc.mod} compileStages(c, c.checkRecursion) result := compilerErrsToStringSlice(c.Errors) - expected := "rego_recursion_error: rule foo is recursive: foo -> foo" + expected := tc.err if len(result) != 1 || result[0] != expected { t.Errorf("Expected %v but got: %v", expected, result) @@ -5909,6 +6976,7 @@ func TestCompilerGetRulesDynamic(t *testing.T) { "mod1": `package a.b.c.d r1 = 1`, "mod2": `package a.b.c.e +default r2 = false r2 = 2`, "mod3": `package a.b r3 = 3`, @@ -5919,7 +6987,8 @@ r4 = 4`, compileStages(compiler, nil) rule1 := compiler.Modules["mod1"].Rules[0] - rule2 := compiler.Modules["mod2"].Rules[0] + rule2d := compiler.Modules["mod2"].Rules[0] + rule2 := compiler.Modules["mod2"].Rules[1] rule3 := compiler.Modules["mod3"].Rules[0] rule4 := compiler.Modules["hidden"].Rules[0] @@ -5929,15 +6998,16 @@ r4 = 4`, excludeHidden bool }{ {input: "data.a.b.c.d.r1", expected: []*Rule{rule1}}, - {input: "data.a.b[x]", expected: []*Rule{rule1, rule2, rule3}}, + {input: "data.a.b[x]", expected: []*Rule{rule1, rule2d, rule2, rule3}}, {input: "data.a.b[x].d", expected: []*Rule{rule1, rule3}}, - {input: "data.a.b.c", expected: []*Rule{rule1, rule2}}, + {input: "data.a.b.c", expected: []*Rule{rule1, rule2d, rule2}}, {input: "data.a.b.d"}, - {input: "data[x]", expected: []*Rule{rule1, rule2, rule3, rule4}}, - {input: "data[data.complex_computation].b[y]", expected: []*Rule{rule1, rule2, rule3}}, - {input: "data[x][y].c.e", expected: []*Rule{rule2}}, + {input: "data", expected: []*Rule{rule1, rule2d, rule2, rule3, rule4}}, + {input: "data[x]", expected: []*Rule{rule1, rule2d, rule2, rule3, rule4}}, + {input: "data[data.complex_computation].b[y]", expected: []*Rule{rule1, rule2d, rule2, rule3}}, + {input: "data[x][y].c.e", expected: []*Rule{rule2d, rule2}}, {input: "data[x][y].r3", expected: []*Rule{rule3}}, - {input: "data[x][y]", expected: []*Rule{rule1, rule2, rule3}, excludeHidden: true}, // old behaviour of GetRulesDynamic + {input: "data[x][y]", expected: []*Rule{rule1, rule2d, rule2, rule3}, excludeHidden: true}, // old behaviour of GetRulesDynamic } for _, tc := range tests { @@ -7452,7 +8522,19 @@ deny { } else if !strings.HasPrefix(c.Errors.Error(), "1 error occurred: 7:2: rego_type_error: undefined ref: input.Something.Y.X.ThisDoesNotExist") { t.Errorf("unexpected error: %v", c.Errors.Error()) } +} +func modules(ms ...string) []*Module { + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + mods := make([]*Module, len(ms)) + for i, m := range ms { + var err error + mods[i], err = ParseModuleWithOpts(fmt.Sprintf("mod%d.rego", i), m, opts) + if err != nil { + panic(err) + } + } + return mods } func TestCompilerWithRecursiveSchemaAvoidRace(t *testing.T) { diff --git a/ast/conflicts.go b/ast/conflicts.go index d1013ccedd..c2713ad576 100644 --- a/ast/conflicts.go +++ b/ast/conflicts.go @@ -27,7 +27,12 @@ func CheckPathConflicts(c *Compiler, exists func([]string) (bool, error)) Errors func checkDocumentConflicts(node *TreeNode, exists func([]string) (bool, error), path []string) Errors { - path = append(path, string(node.Key.(String))) + switch key := node.Key.(type) { + case String: + path = append(path, string(key)) + default: // other key types cannot conflict with data + return nil + } if len(node.Values) > 0 { s := strings.Join(path, "/") diff --git a/ast/env.go b/ast/env.go index 60006baafd..5313a595b6 100644 --- a/ast/env.go +++ b/ast/env.go @@ -5,6 +5,8 @@ package ast import ( + "fmt" + "github.com/open-policy-agent/opa/types" "github.com/open-policy-agent/opa/util" ) @@ -195,9 +197,17 @@ func (env *TypeEnv) getRefRecExtent(node *typeTreeNode) types.Type { child := v.(*typeTreeNode) tpe := env.getRefRecExtent(child) - // TODO(tsandall): handle non-string keys? - if s, ok := key.(String); ok { - children = append(children, types.NewStaticProperty(string(s), tpe)) + + // NOTE(sr): Converting to Golang-native types here is an extension of what we did + // before -- only supporting strings. But since we cannot differentiate sets and arrays + // that way, we could reconsider. + switch key.(type) { + case String, Number, Boolean: // skip anything else + propKey, err := JSON(key) + if err != nil { + panic(fmt.Errorf("unreachable, ValueToInterface: %w", err)) + } + children = append(children, types.NewStaticProperty(propKey, tpe)) } return false }) diff --git a/ast/index.go b/ast/index.go index bcbb5c1765..0bd775b062 100644 --- a/ast/index.go +++ b/ast/index.go @@ -32,7 +32,7 @@ type RuleIndex interface { // IndexResult contains the result of an index lookup. type IndexResult struct { - Kind DocKind + Kind RuleKind Rules []*Rule Else map[*Rule][]*Rule Default *Rule @@ -40,7 +40,7 @@ type IndexResult struct { } // NewIndexResult returns a new IndexResult object. -func NewIndexResult(kind DocKind) *IndexResult { +func NewIndexResult(kind RuleKind) *IndexResult { return &IndexResult{ Kind: kind, Else: map[*Rule][]*Rule{}, @@ -57,7 +57,7 @@ type baseDocEqIndex struct { isVirtual func(Ref) bool root *trieNode defaultRule *Rule - kind DocKind + kind RuleKind } func newBaseDocEqIndex(isVirtual func(Ref) bool) *baseDocEqIndex { @@ -73,7 +73,7 @@ func (i *baseDocEqIndex) Build(rules []*Rule) bool { return false } - i.kind = rules[0].Head.DocKind() + i.kind = rules[0].Head.RuleKind() indices := newrefindices(i.isVirtual) // build indices for each rule. diff --git a/ast/index_test.go b/ast/index_test.go index 1c387e9c0a..dff4167f57 100644 --- a/ast/index_test.go +++ b/ast/index_test.go @@ -60,6 +60,20 @@ func TestBaseDocEqIndexing(t *testing.T) { input.b = 1 }`, opts) + refMod := MustParseModuleWithOpts(`package test + + ref.single.value.ground = x if x := input.x + + ref.single.value.key[k] = v if { k := input.k; v := input.v } + + ref.multi.value.ground contains x if x := input.x + + ref.multiple.single.value.ground = x if x := input.x + ref.multiple.single.value[y] = x if { x := input.x; y := index.y } + + # ref.multi.value.key[k] contains v if { k := input.k; v := input.v } # not supported yet + `, opts) + module := MustParseModule(` package test @@ -70,6 +84,7 @@ func TestBaseDocEqIndexing(t *testing.T) { input.x = 3 input.y = 4 } + scalars { input.x = 0 @@ -209,6 +224,7 @@ func TestBaseDocEqIndexing(t *testing.T) { note string module *Module ruleset string + ruleRef Ref input string unknowns []string args []Value @@ -632,6 +648,41 @@ func TestBaseDocEqIndexing(t *testing.T) { input: `{"a": [1]}`, expectedRS: RuleSet([]*Rule{everyModWithDomain.Rules[0]}), }, + { + note: "ref: single value, ground ref", + module: refMod, + ruleRef: MustParseRef("ref.single.value.ground"), + input: `{"x": 1}`, + expectedRS: RuleSet([]*Rule{refMod.Rules[0]}), + }, + { + note: "ref: single value, ground ref and non-ground ref", + module: refMod, + ruleRef: MustParseRef("ref.multiple.single.value"), + input: `{"x": 1, "y": "Y"}`, + expectedRS: RuleSet([]*Rule{refMod.Rules[3], refMod.Rules[4]}), + }, + { + note: "ref: single value, var in ref", + module: refMod, + ruleRef: MustParseRef("ref.single.value.key[k]"), + input: `{"k": 1, "v": 2}`, + expectedRS: RuleSet([]*Rule{refMod.Rules[1]}), + }, + { + note: "ref: multi value, ground ref", + module: refMod, + ruleRef: MustParseRef("ref.multi.value.ground"), + input: `{"x": 1}`, + expectedRS: RuleSet([]*Rule{refMod.Rules[2]}), + }, + // { + // note: "ref: multi value, var in ref", + // module: refMod, + // ruleRef: MustParseRef("ref.multi.value.key[k]"), + // input: `{"k": 1, "v": 2}`, + // expectedRS: RuleSet([]*Rule{refMod.Rules[3]}), + // }, } for _, tc := range tests { @@ -642,10 +693,19 @@ func TestBaseDocEqIndexing(t *testing.T) { } rules := []*Rule{} for _, rule := range module.Rules { - if rule.Head.Name == Var(tc.ruleset) { - rules = append(rules, rule) + if tc.ruleRef == nil { + if rule.Head.Name == Var(tc.ruleset) { + rules = append(rules, rule) + } + } else { + if rule.Head.Ref().HasPrefix(tc.ruleRef) { + rules = append(rules, rule) + } } } + if len(rules) == 0 { + t.Fatal("selected empty ruleset") + } var input *Term if tc.input != "" { diff --git a/ast/parser.go b/ast/parser.go index fd1b407246..9fe41ca368 100644 --- a/ast/parser.go +++ b/ast/parser.go @@ -567,14 +567,41 @@ func (p *Parser) parseRules() []*Rule { return []*Rule{&rule} } - hasIf := false - if p.s.tok == tokens.If { - hasIf = true + if usesContains && !rule.Head.Reference.IsGround() { + p.error(p.s.Loc(), "multi-value rules need ground refs") + return nil } - if hasIf && !usesContains && rule.Head.Key != nil && rule.Head.Value == nil { - p.illegal("invalid for partial set rule %s (use `contains`)", rule.Head.Name) - return nil + // back-compat with `p[x] { ... }`` + hasIf := p.s.tok == tokens.If + + // p[x] if ... becomes a single-value rule p[x] + if hasIf && !usesContains && len(rule.Head.Ref()) == 2 { + if rule.Head.Value == nil { + rule.Head.Value = BooleanTerm(true).SetLocation(rule.Head.Location) + } else { + // p[x] = y if becomes a single-value rule p[x] with value y, but needs name for compat + v, ok := rule.Head.Ref()[0].Value.(Var) + if !ok { + return nil + } + rule.Head.Name = v + } + } + + // p[x] becomes a multi-value rule p + if !hasIf && !usesContains && + len(rule.Head.Args) == 0 && // not a function + len(rule.Head.Ref()) == 2 { // ref like 'p[x]' + v, ok := rule.Head.Ref()[0].Value.(Var) + if !ok { + return nil + } + rule.Head.Name = v + rule.Head.Key = rule.Head.Ref()[1] + if rule.Head.Value == nil { + rule.Head.SetRef(rule.Head.Ref()[:len(rule.Head.Ref())-1]) + } } switch { @@ -743,67 +770,65 @@ func (p *Parser) parseElse(head *Head) *Rule { func (p *Parser) parseHead(defaultRule bool) (*Head, bool) { - var head Head - head.SetLoc(p.s.Loc()) - + head := &Head{} + loc := p.s.Loc() defer func() { - head.Location.Text = p.s.Text(head.Location.Offset, p.s.lastEnd) + if head != nil { + head.SetLoc(loc) + head.Location.Text = p.s.Text(head.Location.Offset, p.s.lastEnd) + } }() - if term := p.parseVar(); term != nil { - head.Name = term.Value.(Var) - } else { - p.illegal("expected rule head name") + term := p.parseVar() + if term == nil { + return nil, false } - p.scan() - - if p.s.tok == tokens.LParen { - p.scan() - if p.s.tok != tokens.RParen { - head.Args = p.parseTermList(tokens.RParen, nil) - if head.Args == nil { - return nil, false - } + ref := p.parseTermFinish(term, true) + if ref == nil { + p.illegal("expected rule head name") + return nil, false + } + + switch x := ref.Value.(type) { + case Var: + head = NewHead(x) + case Ref: + head = RefHead(x) + case Call: + op, args := x[0], x[1:] + var ref Ref + switch y := op.Value.(type) { + case Var: + ref = Ref{op} + case Ref: + ref = y } - p.scan() + head = RefHead(ref) + head.Args = append([]*Term{}, args...) - if p.s.tok == tokens.LBrack { - return nil, false - } + default: + return nil, false } - if p.s.tok == tokens.LBrack { - p.scan() - head.Key = p.parseTermInfixCall() - if head.Key == nil { - p.illegal("expected rule key term (e.g., %s[] { ... })", head.Name) - } - if p.s.tok != tokens.RBrack { - if _, ok := futureKeywords[head.Name.String()]; ok { - p.hint("`import future.keywords.%[1]s` for '%[1]s' keyword", head.Name.String()) - } - p.illegal("non-terminated rule key") - } - p.scan() - } + name := head.Ref().String() switch p.s.tok { - case tokens.Contains: + case tokens.Contains: // NOTE: no Value for `contains` heads, we return here p.scan() head.Key = p.parseTermInfixCall() if head.Key == nil { - p.illegal("expected rule key term (e.g., %s contains { ... })", head.Name) + p.illegal("expected rule key term (e.g., %s contains { ... })", name) } + return head, true - return &head, true case tokens.Unify: p.scan() head.Value = p.parseTermInfixCall() if head.Value == nil { - p.illegal("expected rule value term (e.g., %s[%s] = { ... })", head.Name, head.Key) + // FIX HEAD.String() + p.illegal("expected rule value term (e.g., %s[%s] = { ... })", name, head.Key) } - case tokens.Assign: s := p.save() p.scan() @@ -813,22 +838,23 @@ func (p *Parser) parseHead(defaultRule bool) (*Head, bool) { p.restore(s) switch { case len(head.Args) > 0: - p.illegal("expected function value term (e.g., %s(...) := { ... })", head.Name) + p.illegal("expected function value term (e.g., %s(...) := { ... })", name) case head.Key != nil: - p.illegal("expected partial rule value term (e.g., %s[...] := { ... })", head.Name) + p.illegal("expected partial rule value term (e.g., %s[...] := { ... })", name) case defaultRule: - p.illegal("expected default rule value term (e.g., default %s := )", head.Name) + p.illegal("expected default rule value term (e.g., default %s := )", name) default: - p.illegal("expected rule value term (e.g., %s := { ... })", head.Name) + p.illegal("expected rule value term (e.g., %s := { ... })", name) } } } if head.Value == nil && head.Key == nil { - head.Value = BooleanTerm(true).SetLocation(head.Location) + if len(head.Ref()) != 2 || len(head.Args) > 0 { + head.Value = BooleanTerm(true).SetLocation(head.Location) + } } - - return &head, false + return head, false } func (p *Parser) parseBody(end tokens.Token) Body { @@ -1348,17 +1374,18 @@ func (p *Parser) parseTerm() *Term { p.illegalToken() } - term = p.parseTermFinish(term) + term = p.parseTermFinish(term, false) p.parsedTermCachePush(term, s0) return term } -func (p *Parser) parseTermFinish(head *Term) *Term { +func (p *Parser) parseTermFinish(head *Term, skipws bool) *Term { if head == nil { return nil } offset := p.s.loc.Offset - p.scanWS() + p.doScan(skipws) + switch p.s.tok { case tokens.LParen, tokens.Dot, tokens.LBrack: return p.parseRef(head, offset) diff --git a/ast/parser_ext.go b/ast/parser_ext.go index 41eb4443ec..7f5a69424b 100644 --- a/ast/parser_ext.go +++ b/ast/parser_ext.go @@ -154,12 +154,15 @@ func ParseRuleFromExpr(module *Module, expr *Expr) (*Rule, error) { } if _, ok := expr.Terms.(*SomeDecl); ok { - return nil, errors.New("some declarations cannot be used for rule head") + return nil, errors.New("'some' declarations cannot be used for rule head") } if term, ok := expr.Terms.(*Term); ok { switch v := term.Value.(type) { case Ref: + if len(v) > 2 { // 2+ dots + return ParseCompleteDocRuleWithDotsFromTerm(module, term) + } return ParsePartialSetDocRuleFromTerm(module, term) default: return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(v)) @@ -207,18 +210,17 @@ func parseCompleteRuleFromEq(module *Module, expr *Expr) (rule *Rule, err error) return nil, errors.New("assignment requires two operands") } - rule, err = ParseCompleteDocRuleFromEqExpr(module, lhs, rhs) - + rule, err = ParseRuleFromCallEqExpr(module, lhs, rhs) if err == nil { return rule, nil } - rule, err = ParseRuleFromCallEqExpr(module, lhs, rhs) + rule, err = ParsePartialObjectDocRuleFromEqExpr(module, lhs, rhs) if err == nil { return rule, nil } - return ParsePartialObjectDocRuleFromEqExpr(module, lhs, rhs) + return ParseCompleteDocRuleFromEqExpr(module, lhs, rhs) } // ParseCompleteDocRuleFromAssignmentExpr returns a rule if the expression can @@ -239,39 +241,55 @@ func ParseCompleteDocRuleFromAssignmentExpr(module *Module, lhs, rhs *Term) (*Ru // ParseCompleteDocRuleFromEqExpr returns a rule if the expression can be // interpreted as a complete document definition. func ParseCompleteDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, error) { - - var name Var - - if RootDocumentRefs.Contains(lhs) { - name = lhs.Value.(Ref)[0].Value.(Var) - } else if v, ok := lhs.Value.(Var); ok { - name = v + var head *Head + + if v, ok := lhs.Value.(Var); ok { + head = NewHead(v) + } else if r, ok := lhs.Value.(Ref); ok { // groundness ? + head = RefHead(r) + if len(r) > 1 && !r[len(r)-1].IsGround() { + return nil, fmt.Errorf("ref not ground") + } } else { return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(lhs.Value)) } + head.Value = rhs + head.Location = lhs.Location - rule := &Rule{ + return &Rule{ Location: lhs.Location, - Head: &Head{ - Location: lhs.Location, - Name: name, - Value: rhs, - }, + Head: head, Body: NewBody( NewExpr(BooleanTerm(true).SetLocation(rhs.Location)).SetLocation(rhs.Location), ), Module: module, + }, nil +} + +func ParseCompleteDocRuleWithDotsFromTerm(module *Module, term *Term) (*Rule, error) { + ref, ok := term.Value.(Ref) + if !ok { + return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(term.Value)) } - return rule, nil + head := RefHead(ref, BooleanTerm(true).SetLocation(term.Location)) + head.Location = term.Location + + return &Rule{ + Location: term.Location, + Head: head, + Body: NewBody( + NewExpr(BooleanTerm(true).SetLocation(term.Location)).SetLocation(term.Location), + ), + Module: module, + }, nil } // ParsePartialObjectDocRuleFromEqExpr returns a rule if the expression can be // interpreted as a partial object document definition. func ParsePartialObjectDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, error) { - ref, ok := lhs.Value.(Ref) - if !ok || len(ref) != 2 { + if !ok { return nil, fmt.Errorf("%v cannot be used as rule name", TypeName(lhs.Value)) } @@ -279,17 +297,16 @@ func ParsePartialObjectDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, return nil, fmt.Errorf("%vs cannot be used as rule name", TypeName(ref[0].Value)) } - name := ref[0].Value.(Var) - key := ref[1] + head := RefHead(ref, rhs) + if len(ref) == 2 { // backcompat for naked `foo.bar = "baz"` statements + head.Name = ref[0].Value.(Var) + head.Key = ref[1] + } + head.Location = rhs.Location rule := &Rule{ Location: rhs.Location, - Head: &Head{ - Location: rhs.Location, - Name: name, - Key: key, - Value: rhs, - }, + Head: head, Body: NewBody( NewExpr(BooleanTerm(true).SetLocation(rhs.Location)).SetLocation(rhs.Location), ), @@ -304,26 +321,24 @@ func ParsePartialObjectDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, func ParsePartialSetDocRuleFromTerm(module *Module, term *Term) (*Rule, error) { ref, ok := term.Value.(Ref) - if !ok { + if !ok || len(ref) == 1 { return nil, fmt.Errorf("%vs cannot be used for rule head", TypeName(term.Value)) } - if len(ref) != 2 { - return nil, fmt.Errorf("refs cannot be used for rule") - } - - name, ok := ref[0].Value.(Var) - if !ok { - return nil, fmt.Errorf("%vs cannot be used as rule name", TypeName(ref[0].Value)) + head := RefHead(ref) + if len(ref) == 2 { + v, ok := ref[0].Value.(Var) + if !ok { + return nil, fmt.Errorf("%vs cannot be used for rule head", TypeName(term.Value)) + } + head = NewHead(v) + head.Key = ref[1] } + head.Location = term.Location rule := &Rule{ Location: term.Location, - Head: &Head{ - Location: term.Location, - Name: name, - Key: ref[1], - }, + Head: head, Body: NewBody( NewExpr(BooleanTerm(true).SetLocation(term.Location)).SetLocation(term.Location), ), @@ -347,21 +362,15 @@ func ParseRuleFromCallEqExpr(module *Module, lhs, rhs *Term) (*Rule, error) { return nil, fmt.Errorf("%vs cannot be used in function signature", TypeName(call[0].Value)) } - name, ok := ref[0].Value.(Var) - if !ok { - return nil, fmt.Errorf("%vs cannot be used in function signature", TypeName(ref[0].Value)) - } + head := RefHead(ref, rhs) + head.Location = lhs.Location + head.Args = Args(call[1:]) rule := &Rule{ Location: lhs.Location, - Head: &Head{ - Location: lhs.Location, - Name: name, - Args: Args(call[1:]), - Value: rhs, - }, - Body: NewBody(NewExpr(BooleanTerm(true).SetLocation(rhs.Location)).SetLocation(rhs.Location)), - Module: module, + Head: head, + Body: NewBody(NewExpr(BooleanTerm(true).SetLocation(rhs.Location)).SetLocation(rhs.Location)), + Module: module, } return rule, nil @@ -376,19 +385,15 @@ func ParseRuleFromCallExpr(module *Module, terms []*Term) (*Rule, error) { } loc := terms[0].Location - args := terms[1:] - value := BooleanTerm(true).SetLocation(loc) + head := RefHead(terms[0].Value.(Ref), BooleanTerm(true).SetLocation(loc)) + head.Location = loc + head.Args = terms[1:] rule := &Rule{ Location: loc, - Head: &Head{ - Location: loc, - Name: Var(terms[0].String()), - Args: args, - Value: value, - }, - Module: module, - Body: NewBody(NewExpr(BooleanTerm(true).SetLocation(loc)).SetLocation(loc)), + Head: head, + Module: module, + Body: NewBody(NewExpr(BooleanTerm(true).SetLocation(loc)).SetLocation(loc)), } return rule, nil } @@ -518,15 +523,15 @@ func ParseRef(input string) (Ref, error) { return ref, nil } -// ParseRule returns exactly one rule. +// ParseRuleWithOpts returns exactly one rule. // If multiple rules are parsed, an error is returned. -func ParseRule(input string) (*Rule, error) { - stmts, _, err := ParseStatements("", input) +func ParseRuleWithOpts(input string, opts ParserOptions) (*Rule, error) { + stmts, _, err := ParseStatementsWithOpts("", input, opts) if err != nil { return nil, err } if len(stmts) != 1 { - return nil, fmt.Errorf("expected exactly one statement (rule)") + return nil, fmt.Errorf("expected exactly one statement (rule), got %v = %T, %T", stmts, stmts[0], stmts[1]) } rule, ok := stmts[0].(*Rule) if !ok { @@ -535,6 +540,12 @@ func ParseRule(input string) (*Rule, error) { return rule, nil } +// ParseRule returns exactly one rule. +// If multiple rules are parsed, an error is returned. +func ParseRule(input string) (*Rule, error) { + return ParseRuleWithOpts(input, ParserOptions{}) +} + // ParseStatement returns exactly one statement. // A statement might be a term, expression, rule, etc. Regardless, // this function expects *exactly* one statement. If multiple @@ -611,14 +622,14 @@ func parseModule(filename string, stmts []Statement, comments []*Comment) (*Modu rule, err := ParseRuleFromBody(mod, stmt) if err != nil { errs = append(errs, NewError(ParseErr, stmt[0].Location, err.Error())) - } else { - mod.Rules = append(mod.Rules, rule) - - // NOTE(tsandall): the statement should now be interpreted as a - // rule so update the statement list. This is important for the - // logic below that associates annotations with statements. - stmts[i+1] = rule + continue } + mod.Rules = append(mod.Rules, rule) + + // NOTE(tsandall): the statement should now be interpreted as a + // rule so update the statement list. This is important for the + // logic below that associates annotations with statements. + stmts[i+1] = rule case *Package: errs = append(errs, NewError(ParseErr, stmt.Loc(), "unexpected package")) case *Annotations: diff --git a/ast/parser_test.go b/ast/parser_test.go index b9a0e8199d..4d610f9249 100644 --- a/ast/parser_test.go +++ b/ast/parser_test.go @@ -7,6 +7,7 @@ package ast import ( "bytes" "encoding/json" + "errors" "fmt" "reflect" "strings" @@ -1413,9 +1414,10 @@ func TestRule(t *testing.T) { assertParseRule(t, "default w/ assignment", `default allow := false`, &Rule{ Default: true, Head: &Head{ - Name: "allow", - Value: BooleanTerm(false), - Assign: true, + Name: "allow", + Reference: Ref{VarTerm("allow")}, + Value: BooleanTerm(false), + Assign: true, }, Body: NewBody(NewExpr(BooleanTerm(true))), }) @@ -1439,9 +1441,10 @@ func TestRule(t *testing.T) { }) fxy := &Head{ - Name: Var("f"), - Args: Args{VarTerm("x")}, - Value: VarTerm("y"), + Name: Var("f"), + Reference: Ref{VarTerm("f")}, + Args: Args{VarTerm("x")}, + Value: VarTerm("y"), } assertParseRule(t, "identity", `f(x) = y { y = x }`, &Rule{ @@ -1453,9 +1456,10 @@ func TestRule(t *testing.T) { assertParseRule(t, "composite arg", `f([x, y]) = z { split(x, y, z) }`, &Rule{ Head: &Head{ - Name: Var("f"), - Args: Args{ArrayTerm(VarTerm("x"), VarTerm("y"))}, - Value: VarTerm("z"), + Name: Var("f"), + Reference: Ref{VarTerm("f")}, + Args: Args{ArrayTerm(VarTerm("x"), VarTerm("y"))}, + Value: VarTerm("z"), }, Body: NewBody( Split.Expr(VarTerm("x"), VarTerm("y"), VarTerm("z")), @@ -1464,9 +1468,10 @@ func TestRule(t *testing.T) { assertParseRule(t, "composite result", `f(1) = [x, y] { split("foo.bar", x, y) }`, &Rule{ Head: &Head{ - Name: Var("f"), - Args: Args{IntNumberTerm(1)}, - Value: ArrayTerm(VarTerm("x"), VarTerm("y")), + Name: Var("f"), + Reference: Ref{VarTerm("f")}, + Args: Args{IntNumberTerm(1)}, + Value: ArrayTerm(VarTerm("x"), VarTerm("y")), }, Body: NewBody( Split.Expr(StringTerm("foo.bar"), VarTerm("x"), VarTerm("y")), @@ -1475,7 +1480,8 @@ func TestRule(t *testing.T) { assertParseRule(t, "expr terms: key", `p[f(x) + g(x)] { true }`, &Rule{ Head: &Head{ - Name: Var("p"), + Name: Var("p"), + Reference: Ref{VarTerm("p")}, Key: Plus.Call( CallTerm(RefTerm(VarTerm("f")), VarTerm("x")), CallTerm(RefTerm(VarTerm("g")), VarTerm("x")), @@ -1486,7 +1492,8 @@ func TestRule(t *testing.T) { assertParseRule(t, "expr terms: value", `p = f(x) + g(x) { true }`, &Rule{ Head: &Head{ - Name: Var("p"), + Name: Var("p"), + Reference: Ref{VarTerm("p")}, Value: Plus.Call( CallTerm(RefTerm(VarTerm("f")), VarTerm("x")), CallTerm(RefTerm(VarTerm("g")), VarTerm("x")), @@ -1497,7 +1504,8 @@ func TestRule(t *testing.T) { assertParseRule(t, "expr terms: args", `p(f(x) + g(x)) { true }`, &Rule{ Head: &Head{ - Name: Var("p"), + Name: Var("p"), + Reference: Ref{VarTerm("p")}, Args: Args{ Plus.Call( CallTerm(RefTerm(VarTerm("f")), VarTerm("x")), @@ -1511,25 +1519,28 @@ func TestRule(t *testing.T) { assertParseRule(t, "assignment operator", `x := 1 { true }`, &Rule{ Head: &Head{ - Name: Var("x"), - Value: IntNumberTerm(1), - Assign: true, + Name: Var("x"), + Reference: Ref{VarTerm("x")}, + Value: IntNumberTerm(1), + Assign: true, }, Body: NewBody(NewExpr(BooleanTerm(true))), }) assertParseRule(t, "else assignment", `x := 1 { false } else := 2`, &Rule{ Head: &Head{ - Name: "x", - Value: IntNumberTerm(1), - Assign: true, + Name: "x", // ha! clever! + Reference: Ref{VarTerm("x")}, + Value: IntNumberTerm(1), + Assign: true, }, Body: NewBody(NewExpr(BooleanTerm(false))), Else: &Rule{ Head: &Head{ - Name: "x", - Value: IntNumberTerm(2), - Assign: true, + Name: "x", + Reference: Ref{VarTerm("x")}, + Value: IntNumberTerm(2), + Assign: true, }, Body: NewBody(NewExpr(BooleanTerm(true))), }, @@ -1537,18 +1548,20 @@ func TestRule(t *testing.T) { assertParseRule(t, "partial assignment", `p[x] := y { true }`, &Rule{ Head: &Head{ - Name: "p", - Value: VarTerm("y"), - Key: VarTerm("x"), - Assign: true, + Name: "p", + Reference: MustParseRef("p[x]"), + Value: VarTerm("y"), + Key: VarTerm("x"), + Assign: true, }, Body: NewBody(NewExpr(BooleanTerm(true))), }) assertParseRule(t, "function assignment", `f(x) := y { true }`, &Rule{ Head: &Head{ - Name: "f", - Value: VarTerm("y"), + Name: "f", + Reference: Ref{VarTerm("f")}, + Value: VarTerm("y"), Args: Args{ VarTerm("x"), }, @@ -1562,11 +1575,10 @@ func TestRule(t *testing.T) { assertParseErrorContains(t, "empty rule body", "p {}", "rego_parse_error: found empty body") assertParseErrorContains(t, "unmatched braces", `f(x) = y { trim(x, ".", y) `, `rego_parse_error: unexpected eof token: expected \n or ; or }`) - // TODO: how to highlight that assignment is incorrect here? assertParseErrorContains(t, "no output", `f(_) = { "foo" = "bar" }`, "rego_parse_error: unexpected eq token: expected rule value term") assertParseErrorContains(t, "no output", `f(_) := { "foo" = "bar" }`, "rego_parse_error: unexpected assign token: expected function value term") assertParseErrorContains(t, "no output", `f := { "foo" = "bar" }`, "rego_parse_error: unexpected assign token: expected rule value term") - assertParseErrorContains(t, "no output", `f[_] := { "foo" = "bar" }`, "rego_parse_error: unexpected assign token: expected partial rule value term") + assertParseErrorContains(t, "no output", `f[_] := { "foo" = "bar" }`, "rego_parse_error: unexpected assign token: expected rule value term") assertParseErrorContains(t, "no output", `default f :=`, "rego_parse_error: unexpected assign token: expected default rule value term") // TODO(tsandall): improve error checking here. This is a common mistake @@ -1575,19 +1587,21 @@ func TestRule(t *testing.T) { assertParseError(t, "dangling semicolon", "p { true; false; }") assertParseErrorContains(t, "default invalid rule name", `default 0[0`, "unexpected default keyword") - assertParseErrorContains(t, "default invalid rule value", `default a[0`, "illegal default rule (must have a value)") + assertParseErrorContains(t, "default invalid rule value", `default a[0]`, "illegal default rule (must have a value)") assertParseRule(t, "default missing value", `default a`, &Rule{ Default: true, Head: &Head{ - Name: Var("a"), - Value: BooleanTerm(true), + Name: Var("a"), + Reference: Ref{VarTerm("a")}, + Value: BooleanTerm(true), }, Body: NewBody(NewExpr(BooleanTerm(true))), }) assertParseRule(t, "empty arguments", `f() { x := 1 }`, &Rule{ Head: &Head{ - Name: "f", - Value: BooleanTerm(true), + Name: "f", + Reference: Ref{VarTerm("f")}, + Value: BooleanTerm(true), }, Body: MustParseBody(`x := 1`), }) @@ -1598,23 +1612,23 @@ func TestRule(t *testing.T) { assertParseErrorContains(t, "default invalid rule head call", `default a = b`, "illegal default rule (value cannot contain var)") assertParseError(t, "extra braces", `{ a := 1 }`) - assertParseError(t, "invalid rule name dots", `a.b = x { x := 1 }`) - assertParseError(t, "invalid rule name dots and call", `a.b(x) { x := 1 }`) assertParseError(t, "invalid rule name hyphen", `a-b = x { x := 1 }`) assertParseRule(t, "wildcard name", `_ { x == 1 }`, &Rule{ Head: &Head{ - Name: "$0", - Value: BooleanTerm(true), + Name: "$0", + Reference: Ref{VarTerm("$0")}, + Value: BooleanTerm(true), }, Body: MustParseBody(`x == 1`), }) assertParseRule(t, "partial object array key", `p[[a, 1, 2]] = x { a := 1; x := "foo" }`, &Rule{ Head: &Head{ - Name: "p", - Key: ArrayTerm(VarTerm("a"), NumberTerm("1"), NumberTerm("2")), - Value: VarTerm("x"), + Name: "p", + Reference: MustParseRef("p[[a,1,2]]"), + Key: ArrayTerm(VarTerm("a"), NumberTerm("1"), NumberTerm("2")), + Value: VarTerm("x"), }, Body: MustParseBody(`a := 1; x := "foo"`), }) @@ -1623,7 +1637,7 @@ func TestRule(t *testing.T) { } func TestRuleContains(t *testing.T) { - opts := ParserOptions{FutureKeywords: []string{"contains"}} + opts := ParserOptions{FutureKeywords: []string{"contains", "if"}} tests := []struct { note string @@ -1646,6 +1660,28 @@ func TestRuleContains(t *testing.T) { Body: NewBody(NewExpr(BooleanTerm(true))), }, }, + { + note: "ref head, no body", + rule: `p.q contains "x"`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q"), + Key: StringTerm("x"), + }, + Body: NewBody(NewExpr(BooleanTerm(true))), + }, + }, + { + note: "ref head", + rule: `p.q contains "x" { true }`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q"), + Key: StringTerm("x"), + }, + Body: NewBody(NewExpr(BooleanTerm(true))), + }, + }, { note: "set with var element", rule: `deny contains msg { msg := "nonono" }`, @@ -1699,6 +1735,17 @@ func TestRuleIf(t *testing.T) { }, }, }, + { + note: "ref head, complete", + rule: `p.q if { true }`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q"), + Value: BooleanTerm(true), + }, + Body: NewBody(NewExpr(BooleanTerm(true))), + }, + }, { note: "complete, normal body", rule: `p if { x := 10; x > y }`, @@ -1712,16 +1759,18 @@ func TestRuleIf(t *testing.T) { rule: `p := "yes" if { 10 > y } else := "no" { 10 <= y }`, exp: &Rule{ Head: &Head{ - Name: Var("p"), - Value: StringTerm("yes"), - Assign: true, + Reference: Ref{VarTerm("p")}, + Name: Var("p"), + Value: StringTerm("yes"), + Assign: true, }, Body: MustParseBody(`10 > y`), Else: &Rule{ Head: &Head{ - Name: Var("p"), - Value: StringTerm("no"), - Assign: true, + Reference: Ref{VarTerm("p")}, + Name: Var("p"), + Value: StringTerm("no"), + Assign: true, }, Body: MustParseBody(`10 <= y`), }, @@ -1732,16 +1781,18 @@ func TestRuleIf(t *testing.T) { rule: `p := "yes" if { 10 > y } else := "no" if { 10 <= y }`, exp: &Rule{ Head: &Head{ - Name: Var("p"), - Value: StringTerm("yes"), - Assign: true, + Reference: Ref{VarTerm("p")}, + Name: Var("p"), + Value: StringTerm("yes"), + Assign: true, }, Body: MustParseBody(`10 > y`), Else: &Rule{ Head: &Head{ - Name: Var("p"), - Value: StringTerm("no"), - Assign: true, + Reference: Ref{VarTerm("p")}, + Name: Var("p"), + Value: StringTerm("no"), + Assign: true, }, Body: MustParseBody(`10 <= y`), }, @@ -1795,8 +1846,9 @@ func TestRuleIf(t *testing.T) { Body: MustParseBody(`1 > 2`), Else: &Rule{ Head: &Head{ - Name: Var("p"), - Value: NumberTerm("42"), + Reference: Ref{VarTerm("p")}, + Name: Var("p"), + Value: NumberTerm("42"), }, Body: MustParseBody(`2 > 1`), }, @@ -1815,9 +1867,10 @@ func TestRuleIf(t *testing.T) { rule: `f(x) = y if y := x + 1`, exp: &Rule{ Head: &Head{ - Name: Var("f"), - Args: []*Term{VarTerm("x")}, - Value: VarTerm("y"), + Reference: Ref{VarTerm("f")}, + Name: Var("f"), + Args: []*Term{VarTerm("x")}, + Value: VarTerm("y"), }, Body: MustParseBody(`y := x + 1`), }, @@ -1827,9 +1880,10 @@ func TestRuleIf(t *testing.T) { rule: `f(xs) if every x in xs { x != 0 }`, exp: &Rule{ Head: &Head{ - Name: Var("f"), - Args: []*Term{VarTerm("xs")}, - Value: BooleanTerm(true), + Reference: Ref{VarTerm("f")}, + Name: Var("f"), + Args: []*Term{VarTerm("xs")}, + Value: BooleanTerm(true), }, Body: MustParseBodyWithOpts(`every x in xs { x != 0 }`, opts), }, @@ -1838,7 +1892,11 @@ func TestRuleIf(t *testing.T) { note: "object", rule: `p["foo"] = "bar" if { true }`, exp: &Rule{ - Head: NewHead(Var("p"), StringTerm("foo"), StringTerm("bar")), + Head: &Head{ + Name: Var("p"), + Reference: MustParseRef("p.foo"), + Value: StringTerm("bar"), + }, Body: NewBody(NewExpr(BooleanTerm(true))), }, }, @@ -1846,7 +1904,11 @@ func TestRuleIf(t *testing.T) { note: "object, shorthand", rule: `p["foo"] = "bar" if true`, exp: &Rule{ - Head: NewHead(Var("p"), StringTerm("foo"), StringTerm("bar")), + Head: &Head{ + Name: Var("p"), + Reference: MustParseRef("p.foo"), + Value: StringTerm("bar"), + }, Body: NewBody(NewExpr(BooleanTerm(true))), }, }, @@ -1857,7 +1919,11 @@ func TestRuleIf(t *testing.T) { y := "bar" }`, exp: &Rule{ - Head: NewHead(Var("p"), VarTerm("x"), VarTerm("y")), + Head: &Head{ + Name: Var("p"), + Reference: MustParseRef("p[x]"), + Value: VarTerm("y"), + }, Body: MustParseBody(`x := "foo"; y := "bar"`), }, }, @@ -1885,6 +1951,28 @@ func TestRuleIf(t *testing.T) { Body: MustParseBody(`x := "foo"`), }, }, + { + note: "partial set+if, shorthand", // these are now Head.Ref rules, previously forbidden + rule: `p[x] if x := 1`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p[x]"), + Value: BooleanTerm(true), + }, + Body: MustParseBody(`x := 1`), + }, + }, + { + note: "partial set+if", // these are now Head.Ref rules, previously forbidden + rule: `p[x] if { x := 1 }`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p[x]"), + Value: BooleanTerm(true), + }, + Body: MustParseBody(`x := 1`), + }, + }, } for _, tc := range tests { @@ -1892,28 +1980,177 @@ func TestRuleIf(t *testing.T) { assertParseRule(t, tc.note, tc.rule, tc.exp, opts) }) } +} - errors := []struct { +func TestRuleRefHeads(t *testing.T) { + opts := ParserOptions{FutureKeywords: []string{"contains", "if", "every"}} + trueBody := NewBody(NewExpr(BooleanTerm(true))) + + tests := []struct { note string rule string - err string + exp *Rule }{ { - note: "partial set+if, shorthand", - rule: `p[x] if x := 1`, - err: "rego_parse_error: unexpected if keyword: invalid for partial set rule p (use `contains`)", + note: "single-value rule", + rule: "p.q.r = 1 if true", + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.r"), + Value: IntNumberTerm(1), + }, + Body: trueBody, + }, }, { - note: "partial set+if", - rule: `p[x] if { x := 1 }`, - err: "rego_parse_error: unexpected if keyword: invalid for partial set rule p (use `contains`)", + note: "single-value with brackets, string key", + rule: `p.q["r"] = 1 if true`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.r"), + Value: IntNumberTerm(1), + }, + Body: trueBody, + }, + }, + { + note: "single-value with brackets, number key", + rule: `p.q[2] = 1 if true`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q[2]"), + Value: IntNumberTerm(1), + }, + Body: trueBody, + }, + }, + { + note: "single-value with brackets, no value", + rule: `p.q[2] if true`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q[2]"), + Value: BooleanTerm(true), + }, + Body: trueBody, + }, + }, + { + note: "single-value with brackets, var key", + rule: `p.q[x] = 1 if x := 2`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q[x]"), + Value: IntNumberTerm(1), + }, + Body: MustParseBody("x := 2"), + }, + }, + { + note: "single-value with brackets, var key, no dot", + rule: `p[x] = 1 if x := 2`, + exp: &Rule{ + Head: &Head{ + Name: Var("p"), + Reference: MustParseRef("p[x]"), + Value: IntNumberTerm(1), + }, + Body: MustParseBody("x := 2"), + }, + }, + { + note: "multi-value, simple", + rule: `p.q.r contains x if x := 2`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.r"), + Key: VarTerm("x"), + }, + Body: MustParseBody("x := 2"), + }, + }, + { + note: "backcompat: multi-value, no dot", + rule: `p[x] { x := 2 }`, // no "if", which triggers ref-interpretation + exp: &Rule{ + Head: &Head{ + Name: "p", + Reference: Ref{VarTerm("p")}, // we're defining p as multi-val rule + Key: VarTerm("x"), + }, + Body: MustParseBody("x := 2"), + }, + }, + { + note: "backcompat: single-value, no dot", + rule: `p[x] = 3 { x := 2 }`, + exp: &Rule{ + Head: &Head{ + Name: "p", + Reference: MustParseRef("p[x]"), + Key: VarTerm("x"), // not used + Value: IntNumberTerm(3), + }, + Body: MustParseBody("x := 2"), + }, + }, + { + note: "backcompat: single-value, no dot, complex object", + rule: `partialobj[x] = {"foo": y} { y = "bar"; x = y }`, + exp: &Rule{ + Head: &Head{ + Name: "partialobj", + Reference: MustParseRef("partialobj[x]"), + Key: VarTerm("x"), // not used + Value: MustParseTerm(`{"foo": y}`), + }, + Body: MustParseBody(`y = "bar"; x = y`), + }, + }, + { + note: "function, simple", + rule: `p.q.f(x) = 1 if true`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.f"), + Args: Args([]*Term{VarTerm("x")}), + Value: IntNumberTerm(1), + }, + Body: trueBody, + }, + }, + { + note: "function, no value", + rule: `p.q.f(x) if true`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.f"), + Args: Args([]*Term{VarTerm("x")}), + Value: BooleanTerm(true), + }, + Body: trueBody, + }, + }, + { + note: "function, with value", + rule: `p.q.f(x) = x + 1 if true`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.f"), + Args: Args([]*Term{VarTerm("x")}), + Value: Plus.Call(VarTerm("x"), IntNumberTerm(1)), + }, + Body: trueBody, + }, }, } - for _, tc := range errors { + for _, tc := range tests { t.Run(tc.note, func(t *testing.T) { - assertParseErrorContains(t, tc.note, tc.rule, tc.err, opts) + assertParseRule(t, tc.note, tc.rule, tc.exp, opts) }) } + + // TODO(sr): error cases, non-ground terms anywhere but at the end of the ref } func TestRuleElseKeyword(t *testing.T) { @@ -1968,8 +2205,9 @@ func TestRuleElseKeyword(t *testing.T) { } name := Var("p") + ref := Ref{VarTerm("p")} tr := BooleanTerm(true) - head := &Head{Name: name, Value: tr} + head := &Head{Name: name, Reference: ref, Value: tr} expected := &Module{ Package: MustParsePackage(`package test`), @@ -1986,14 +2224,16 @@ func TestRuleElseKeyword(t *testing.T) { Body: MustParseBody(`"p1_e1"`), Else: &Rule{ Head: &Head{ - Name: Var("p"), - Value: ArrayTerm(NullTerm()), + Name: name, + Reference: ref, + Value: ArrayTerm(NullTerm()), }, Body: MustParseBody(`"p1_e2"`), Else: &Rule{ Head: &Head{ - Name: name, - Value: VarTerm("x"), + Name: name, + Reference: ref, + Value: VarTerm("x"), }, Body: MustParseBody(`x = "p1_e3"`), }, @@ -2006,23 +2246,26 @@ func TestRuleElseKeyword(t *testing.T) { }, { Head: &Head{ - Name: Var("f"), - Args: Args{VarTerm("x")}, - Value: BooleanTerm(true), + Name: Var("f"), + Reference: Ref{VarTerm("f")}, + Args: Args{VarTerm("x")}, + Value: BooleanTerm(true), }, Body: MustParseBody(`x < 100`), Else: &Rule{ Head: &Head{ - Name: Var("f"), - Args: Args{VarTerm("x")}, - Value: BooleanTerm(false), + Name: Var("f"), + Reference: Ref{VarTerm("f")}, + Args: Args{VarTerm("x")}, + Value: BooleanTerm(false), }, Body: MustParseBody(`x > 200`), Else: &Rule{ Head: &Head{ - Name: Var("f"), - Args: Args{VarTerm("x")}, - Value: BooleanTerm(true), + Name: Var("f"), + Reference: Ref{VarTerm("f")}, + Args: Args{VarTerm("x")}, + Value: BooleanTerm(true), }, Body: MustParseBody(`x != 150`), }, @@ -2031,20 +2274,23 @@ func TestRuleElseKeyword(t *testing.T) { { Head: &Head{ - Name: Var("$0"), - Value: BooleanTerm(true), + Name: Var("$0"), + Reference: Ref{VarTerm("$0")}, + Value: BooleanTerm(true), }, Body: MustParseBody(`x > 0`), Else: &Rule{ Head: &Head{ - Name: Var("$0"), - Value: BooleanTerm(true), + Name: Var("$0"), + Reference: Ref{VarTerm("$0")}, + Value: BooleanTerm(true), }, Body: MustParseBody(`x == -1`), Else: &Rule{ Head: &Head{ - Name: Var("$0"), - Value: BooleanTerm(true), + Name: Var("$0"), + Reference: Ref{VarTerm("$0")}, + Value: BooleanTerm(true), }, Body: MustParseBody(`x > -100`), }, @@ -2052,30 +2298,34 @@ func TestRuleElseKeyword(t *testing.T) { }, { Head: &Head{ - Name: Var("nobody"), - Value: IntNumberTerm(1), + Name: Var("nobody"), + Reference: Ref{VarTerm("nobody")}, + Value: IntNumberTerm(1), }, Body: MustParseBody("false"), Else: &Rule{ Head: &Head{ - Name: Var("nobody"), - Value: IntNumberTerm(7), + Name: Var("nobody"), + Reference: Ref{VarTerm("nobody")}, + Value: IntNumberTerm(7), }, Body: MustParseBody("true"), }, }, { Head: &Head{ - Name: Var("nobody_f"), - Args: Args{VarTerm("x")}, - Value: IntNumberTerm(1), + Name: Var("nobody_f"), + Reference: Ref{VarTerm("nobody_f")}, + Args: Args{VarTerm("x")}, + Value: IntNumberTerm(1), }, Body: MustParseBody("false"), Else: &Rule{ Head: &Head{ - Name: Var("nobody_f"), - Args: Args{VarTerm("x")}, - Value: IntNumberTerm(7), + Name: Var("nobody_f"), + Reference: Ref{VarTerm("nobody_f")}, + Args: Args{VarTerm("x")}, + Value: IntNumberTerm(7), }, Body: MustParseBody("true"), }, @@ -2102,14 +2352,16 @@ func TestRuleElseKeyword(t *testing.T) { Body: MustParseBody(`"p1_e1"`), Else: &Rule{ Head: &Head{ - Name: Var("p"), - Value: ArrayTerm(NullTerm()), + Name: Var("p"), + Reference: Ref{VarTerm("p")}, + Value: ArrayTerm(NullTerm()), }, Body: MustParseBody(`"p1_e2"`), Else: &Rule{ Head: &Head{ - Name: name, - Value: VarTerm("x"), + Name: name, + Reference: ref, + Value: VarTerm("x"), }, Body: MustParseBody(`x = "p1_e4"`), }, @@ -2209,7 +2461,6 @@ func TestEmptyModule(t *testing.T) { } func TestComments(t *testing.T) { - testModule := `package a.b.c import input.e.f as g # end of line @@ -2425,52 +2676,245 @@ func TestLocation(t *testing.T) { } } -func TestRuleFromBody(t *testing.T) { - testModule := `package a.b.c - -pi = 3.14159 -p[x] { x = 1 } -greeting = "hello" -cores = [{0: 1}, {1: 2}] -wrapper = cores[0][1] -pi = [3, 1, 4, x, y, z] -foo["bar"] = "buz" -foo["9"] = "10" -foo.buz = "bar" -bar[1] -bar[[{"foo":"baz"}]] -bar.qux -input = 1 -data = 2 -f(1) = 2 -f(1) -d1 := 1234 -` +func TestRuleFromBodyRefs(t *testing.T) { + opts := ParserOptions{FutureKeywords: []string{"if", "contains"}} - assertParseModule(t, "rules from bodies", testModule, &Module{ - Package: MustParseStatement(`package a.b.c`).(*Package), - Rules: []*Rule{ - MustParseRule(`pi = 3.14159 { true }`), - MustParseRule(`p[x] { x = 1 }`), - MustParseRule(`greeting = "hello" { true }`), - MustParseRule(`cores = [{0: 1}, {1: 2}] { true }`), - MustParseRule(`wrapper = cores[0][1] { true }`), - MustParseRule(`pi = [3, 1, 4, x, y, z] { true }`), - MustParseRule(`foo["bar"] = "buz" { true }`), - MustParseRule(`foo["9"] = "10" { true }`), - MustParseRule(`foo["buz"] = "bar" { true }`), - MustParseRule(`bar[1] { true }`), - MustParseRule(`bar[[{"foo":"baz"}]] { true }`), - MustParseRule(`bar["qux"] { true }`), - MustParseRule(`input = 1 { true }`), - MustParseRule(`data = 2 { true }`), - MustParseRule(`f(1) = 2 { true }`), - MustParseRule(`f(1) = true { true }`), - MustParseRule("d1 := 1234 { true }"), + // NOTE(sr): These tests assert that the other code path, parsing a module, and + // then interpreting naked expressions into (shortcut) rule definitions, works + // the same as parsing the string as a Rule directly. Without also passing + // TestRuleRefHeads, these tests are not to be trusted -- if changing something, + // start with getting TestRuleRefHeads to PASS. + tests := []struct { + note string + rule string + exp string + }{ + { + note: "no dots: single-value rule (complete doc)", + rule: `foo["bar"] = 12`, + exp: `foo["bar"] = 12 { true }`, + }, + { + note: "no dots: partial set of numbers", + rule: `foo[1]`, + exp: `foo[1] { true }`, + }, + { + note: "no dots: shorthand set of strings", // back compat + rule: `foo.one`, + exp: `foo["one"] { true }`, + }, + { + note: "no dots: partial set", + rule: `foo[x] { x = 1 }`, + exp: `foo[x] { x = 1 }`, + }, + { + note: "no dots + if: complete doc", + rule: `foo[x] if x := 1`, + exp: `foo[x] if x := 1`, + }, + { + note: "no dots: function", + rule: `foo(x)`, + exp: `foo(x) { true }`, + }, + { + note: "no dots: function with value", + rule: `foo(x) = y`, + exp: `foo(x) = y { true }`, + }, + { + note: "no dots: partial set, ref element", + rule: `test[arr[0]]`, + exp: `test[arr[0]] { true }`, + }, + { + note: "one dot: complete rule shorthand", + rule: `foo.bar = "buz"`, + exp: `foo.bar = "buz" { true }`, + }, + { + note: "one dot, bracket with var: partial object", + rule: `foo.bar[x] = "buz"`, + exp: `foo.bar[x] = "buz" { true }`, + }, + { + note: "one dot, bracket with var: partial set", + rule: `foo.bar[x] { x = 1 }`, + exp: `foo.bar[x] { x = 1 }`, + }, + { + note: "one dot, bracket with string: complete doc", + rule: `foo.bar["baz"] = "buz"`, + exp: `foo.bar.baz = "buz" { true }`, + }, + { + note: "one dot, bracket with var, rule body: partial object", + rule: `foo.bar[x] = "buz" { x = 1 }`, + exp: `foo.bar[x] = "buz" { x = 1 }`, + }, + { + note: "one dot: function", + rule: `foo.bar(x)`, + exp: `foo.bar(x) { true }`, + }, + { + note: "one dot: function with value", + rule: `foo.bar(x) = y`, + exp: `foo.bar(x) = y { true }`, + }, + { + note: "two dots, bracket with var: partial object", + rule: `foo.bar.baz[x] = "buz" { x = 1 }`, + exp: `foo.bar.baz[x] = "buz" { x = 1 }`, + }, + { + note: "two dots, bracket with var: partial set", + rule: `foo.bar.baz[x] { x = 1 }`, + exp: `foo.bar.baz[x] { x = 1 }`, + }, + { + note: "one dot, bracket with string, no key: complete doc", + rule: `foo.bar["baz"]`, + exp: `foo.bar.baz { true }`, + }, + { + note: "two dots: function", + rule: `foo.bar("baz")`, + exp: `foo.bar("baz") { true }`, + }, + { + note: "two dots: function with value", + rule: `foo.bar("baz") = y`, + exp: `foo.bar("baz") = y { true }`, }, + { + note: "non-ground ref: complete doc", + rule: `foo.bar[i].baz { i := 1 }`, + exp: `foo.bar[i].baz { i := 1 }`, + }, + { + note: "non-ground ref: partial set", + rule: `foo.bar[i].baz[x] { i := 1; x := 2 }`, + exp: `foo.bar[i].baz[x] { i := 1; x := 2 }`, + }, + { + note: "non-ground ref: partial object", + rule: `foo.bar[i].baz[x] = 3 { i := 1; x := 2 }`, + exp: `foo.bar[i].baz[x] = 3 { i := 1; x := 2 }`, + }, + { + note: "non-ground ref: function", + rule: `foo.bar[i].baz(x) = 3 { i := 1 }`, + exp: `foo.bar[i].baz(x) = 3 { i := 1 }`, + }, + { + note: "last term is number: partial set", + rule: `foo.bar.baz[3] { true }`, + exp: `foo.bar.baz[3] { true }`, + }, + } + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + r, err := ParseRuleWithOpts(tc.exp, opts) + if err != nil { + t.Fatal(err) + } + + testModule := "package a.b.c\n" + tc.rule + m, err := ParseModuleWithOpts("", testModule, opts) + if err != nil { + t.Fatal(err) + } + mr := m.Rules[0] + + if r.Head.Name.Compare(mr.Head.Name) != 0 { + t.Errorf("rule.Head.Name differs:\n exp = %#v\nrule = %#v", r.Head.Name, mr.Head.Name) + } + if r.Head.Ref().Compare(mr.Head.Ref()) != 0 { + t.Errorf("rule.Head.Ref() differs:\n exp = %v\nrule = %v", r.Head.Ref(), mr.Head.Ref()) + } + exp, err := ParseRuleWithOpts(tc.exp, opts) + if err != nil { + t.Fatal(err) + } + assertParseModule(t, tc.note, testModule, &Module{ + Package: MustParseStatement(`package a.b.c`).(*Package), + Rules: []*Rule{exp}, + }, opts) + }) + } + + // edge cases + t.Run("errors", func(t *testing.T) { + t.Run("naked 'data' ref", func(t *testing.T) { + _, err := ParseModuleWithOpts("", "package a.b.c\ndata", opts) + assertErrorWithMessage(t, err, "refs cannot be used for rule head") + }) + t.Run("naked 'input' ref", func(t *testing.T) { + _, err := ParseModuleWithOpts("", "package a.b.c\ninput", opts) + assertErrorWithMessage(t, err, "refs cannot be used for rule head") + }) }) +} + +func assertErrorWithMessage(t *testing.T, err error, msg string) { + t.Helper() + var errs Errors + if !errors.As(err, &errs) { + t.Fatalf("expected Errors, got %v %[1]T", err) + } + if exp, act := 1, len(errs); exp != act { + t.Fatalf("expected %d errors, got %d", exp, act) + } + e := errs[0] + if exp, act := msg, e.Message; exp != act { + t.Fatalf("expected error message %q, got %q", exp, act) + } +} + +func TestRuleFromBody(t *testing.T) { + tests := []struct { + input string + exp string + }{ + {`pi = 3.14159`, `pi = 3.14159 { true }`}, + {`p[x] { x = 1 }`, `p[x] { x = 1 }`}, + {`greeting = "hello"`, `greeting = "hello" { true }`}, + {`cores = [{0: 1}, {1: 2}]`, `cores = [{0: 1}, {1: 2}] { true }`}, + {`wrapper = cores[0][1]`, `wrapper = cores[0][1] { true }`}, + {`pi = [3, 1, 4, x, y, z]`, `pi = [3, 1, 4, x, y, z] { true }`}, + {`foo["bar"] = "buz"`, `foo["bar"] = "buz" { true }`}, + {`foo["9"] = "10"`, `foo["9"] = "10" { true }`}, + {`foo.buz = "bar"`, `foo["buz"] = "bar" { true }`}, + {`bar[1]`, `bar[1] { true }`}, + {`bar[[{"foo":"baz"}]]`, `bar[[{"foo":"baz"}]] { true }`}, + {`bar.qux`, `bar["qux"] { true }`}, + {`input = 1`, `input = 1 { true }`}, + {`data = 2`, `data = 2 { true }`}, + {`f(1) = 2`, `f(1) = 2 { true }`}, + {`f(1)`, `f(1) = true { true }`}, + {`d1 := 1234`, "d1 := 1234 { true }"}, + } + + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + testModule := "package a.b.c\n" + tc.input + assertParseModule(t, tc.input, testModule, &Module{ + Package: MustParseStatement(`package a.b.c`).(*Package), + Rules: []*Rule{ + MustParseRule(tc.exp), + }, + }) + }) + } // Verify the rule and rule and rule head col/loc values + testModule := "package a.b.c\n\n" + for _, tc := range tests { + testModule += tc.input + "\n" + } module, err := ParseModule("test.rego", testModule) if err != nil { t.Fatal(err) @@ -2479,19 +2923,19 @@ d1 := 1234 for i := range module.Rules { col := module.Rules[i].Location.Col if col != 1 { - t.Fatalf("expected rule %v column to be 1 but got %v", module.Rules[i].Head.Name, col) + t.Errorf("expected rule %v column to be 1 but got %v", module.Rules[i].Head.Name, col) } row := module.Rules[i].Location.Row - if row != 3+i { // 'pi' rule stats on row 3 - t.Fatalf("expected rule %v row to be %v but got %v", module.Rules[i].Head.Name, 3+i, row) + if row != 3+i { // 'pi' rule starts on row 3 + t.Errorf("expected rule %v row to be %v but got %v", module.Rules[i].Head.Name, 3+i, row) } col = module.Rules[i].Head.Location.Col if col != 1 { - t.Fatalf("expected rule head %v column to be 1 but got %v", module.Rules[i].Head.Name, col) + t.Errorf("expected rule head %v column to be 1 but got %v", module.Rules[i].Head.Name, col) } row = module.Rules[i].Head.Location.Row - if row != 3+i { // 'pi' rule stats on row 3 - t.Fatalf("expected rule head %v row to be %v but got %v", module.Rules[i].Head.Name, 3+i, row) + if row != 3+i { // 'pi' rule starts on row 3 + t.Errorf("expected rule head %v row to be %v but got %v", module.Rules[i].Head.Name, 3+i, row) } } @@ -2532,16 +2976,6 @@ data = {"bar": 2}` foo = input with input as 1 ` - badRefLen1 := ` - package a.b.c - - p["x"].y = 1` - - badRefLen2 := ` - package a.b.c - - p["x"].y` - negated := ` package a.b.c @@ -2600,8 +3034,6 @@ data = {"bar": 2}` assertParseModuleError(t, "non-equality", nonEquality) assertParseModuleError(t, "non-var name", nonVarName) assertParseModuleError(t, "with expr", withExpr) - assertParseModuleError(t, "bad ref (too long)", badRefLen1) - assertParseModuleError(t, "bad ref (too long)", badRefLen2) assertParseModuleError(t, "negated", negated) assertParseModuleError(t, "non ref term", nonRefTerm) assertParseModuleError(t, "zero args", zeroArgs) @@ -2665,7 +3097,8 @@ func TestWildcards(t *testing.T) { assertParseRule(t, "functions", `f(_) = y { true }`, &Rule{ Head: &Head{ - Name: Var("f"), + Name: Var("f"), + Reference: Ref{VarTerm("f")}, Args: Args{ VarTerm("$0"), }, @@ -4384,6 +4817,12 @@ func assertParseRule(t *testing.T, msg string, input string, correct *Rule, opts assertParseOne(t, msg, input, func(parsed interface{}) { t.Helper() rule := parsed.(*Rule) + if rule.Head.Name != correct.Head.Name { + t.Errorf("Error on test \"%s\": rule heads not equal: name = %v (parsed), name = %v (correct)", msg, rule.Head.Name, correct.Head.Name) + } + if !rule.Head.Ref().Equal(correct.Head.Ref()) { + t.Errorf("Error on test \"%s\": rule heads not equal: ref = %v (parsed), ref = %v (correct)", msg, rule.Head.Ref(), correct.Head.Ref()) + } if !rule.Equal(correct) { t.Errorf("Error on test \"%s\": rules not equal: %v (parsed), %v (correct)", msg, rule, correct) } diff --git a/ast/policy.go b/ast/policy.go index 04ebe56f10..4f56aa7edb 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -186,12 +186,13 @@ type ( // Head represents the head of a rule. Head struct { - Location *Location `json:"-"` - Name Var `json:"name"` - Args Args `json:"args,omitempty"` - Key *Term `json:"key,omitempty"` - Value *Term `json:"value,omitempty"` - Assign bool `json:"assign,omitempty"` + Location *Location `json:"-"` + Name Var `json:"name,omitempty"` + Reference Ref `json:"ref,omitempty"` + Args Args `json:"args,omitempty"` + Key *Term `json:"key,omitempty"` + Value *Term `json:"value,omitempty"` + Assign bool `json:"assign,omitempty"` } // Args represents zero or more arguments to a rule. @@ -598,11 +599,22 @@ func (rule *Rule) SetLoc(loc *Location) { // Path returns a ref referring to the document produced by this rule. If rule // is not contained in a module, this function panics. +// Deprecated: Poor handling of ref rules. Use `(*Rule).Ref()` instead. func (rule *Rule) Path() Ref { if rule.Module == nil { panic("assertion failed") } - return rule.Module.Package.Path.Append(StringTerm(string(rule.Head.Name))) + return rule.Module.Package.Path.Extend(rule.Head.Ref().GroundPrefix()) +} + +// Ref returns a ref referring to the document produced by this rule. If rule +// is not contained in a module, this function panics. The returned ref may +// contain variables in the last position. +func (rule *Rule) Ref() Ref { + if rule.Module == nil { + panic("assertion failed") + } + return rule.Module.Package.Path.Extend(rule.Head.Ref()) } func (rule *Rule) String() string { @@ -648,7 +660,8 @@ func (rule *Rule) elseString() string { // used for the key and the second will be used for the value. func NewHead(name Var, args ...*Term) *Head { head := &Head{ - Name: name, + Name: name, // backcompat + Reference: []*Term{NewTerm(name)}, } if len(args) == 0 { return head @@ -658,6 +671,23 @@ func NewHead(name Var, args ...*Term) *Head { return head } head.Value = args[1] + if head.Key != nil && head.Value != nil { + head.Reference = head.Reference.Append(args[0]) + } + return head +} + +// RefHead returns a new Head object with the passed Ref. If args are provided, +// the first will be used for the value. +func RefHead(ref Ref, args ...*Term) *Head { + head := &Head{} + head.SetRef(ref) + if len(ref) < 2 { + head.Name = ref[0].Value.(Var) + } + if len(args) >= 1 { + head.Value = args[0] + } return head } @@ -673,7 +703,7 @@ const ( // PartialObjectDoc represents an object document that is partially defined by the rule. PartialObjectDoc -) +) // TODO(sr): Deprecate? // DocKind returns the type of document produced by this rule. func (head *Head) DocKind() DocKind { @@ -686,6 +716,41 @@ func (head *Head) DocKind() DocKind { return CompleteDoc } +type RuleKind int + +const ( + SingleValue = iota + MultiValue +) + +// RuleKind returns the type of rule this is +func (head *Head) RuleKind() RuleKind { + // NOTE(sr): This is bit verbose, since the key is irrelevant for single vs + // multi value, but as good a spot as to assert the invariant. + switch { + case head.Value != nil: + return SingleValue + case head.Key != nil: + return MultiValue + default: + panic("unreachable") + } +} + +// Ref returns the Ref of the rule. If it doesn't have one, it's filled in +// via the Head's Name. +func (head *Head) Ref() Ref { + if len(head.Reference) > 0 { + return head.Reference + } + return Ref{&Term{Value: head.Name}} +} + +// SetRef can be used to set a rule head's Reference +func (head *Head) SetRef(r Ref) { + head.Reference = r +} + // Compare returns an integer indicating whether head is less than, equal to, // or greater than other. func (head *Head) Compare(other *Head) int { @@ -705,6 +770,9 @@ func (head *Head) Compare(other *Head) int { if cmp := Compare(head.Args, other.Args); cmp != 0 { return cmp } + if cmp := Compare(head.Reference, other.Reference); cmp != 0 { + return cmp + } if cmp := Compare(head.Name, other.Name); cmp != 0 { return cmp } @@ -717,6 +785,7 @@ func (head *Head) Compare(other *Head) int { // Copy returns a deep copy of head. func (head *Head) Copy() *Head { cpy := *head + cpy.Reference = head.Reference.Copy() cpy.Args = head.Args.Copy() cpy.Key = head.Key.Copy() cpy.Value = head.Value.Copy() @@ -729,23 +798,43 @@ func (head *Head) Equal(other *Head) bool { } func (head *Head) String() string { - var buf []string - if len(head.Args) != 0 { - buf = append(buf, head.Name.String()+head.Args.String()) - } else if head.Key != nil { - buf = append(buf, head.Name.String()+"["+head.Key.String()+"]") - } else { - buf = append(buf, head.Name.String()) + buf := strings.Builder{} + buf.WriteString(head.Ref().String()) + + switch { + case len(head.Args) != 0: + buf.WriteString(head.Args.String()) + case len(head.Reference) == 1 && head.Key != nil: + buf.WriteRune('[') + buf.WriteString(head.Key.String()) + buf.WriteRune(']') } if head.Value != nil { if head.Assign { - buf = append(buf, ":=") + buf.WriteString(" := ") } else { - buf = append(buf, "=") + buf.WriteString(" = ") } - buf = append(buf, head.Value.String()) - } - return strings.Join(buf, " ") + buf.WriteString(head.Value.String()) + } else if head.Name == "" && head.Key != nil { + buf.WriteString(" contains ") + buf.WriteString(head.Key.String()) + } + return buf.String() +} + +func (head *Head) MarshalJSON() ([]byte, error) { + // NOTE(sr): we do this to override the rendering of `head.Reference`. + // It's still what'll be used via the default means of encoding/json + // for unmarshaling a json object into a Head struct! + type h Head + return json.Marshal(struct { + h + Ref Ref `json:"ref"` + }{ + h: h(*head), + Ref: head.Ref(), + }) } // Vars returns a set of vars found in the head. @@ -761,6 +850,9 @@ func (head *Head) Vars() VarSet { if head.Value != nil { vis.Walk(head.Value) } + if len(head.Reference) > 0 { + vis.Walk(head.Reference[1:]) + } return vis.vars } diff --git a/ast/policy_test.go b/ast/policy_test.go index 684ece6ec3..2645a72d86 100644 --- a/ast/policy_test.go +++ b/ast/policy_test.go @@ -20,6 +20,7 @@ func TestModuleJSONRoundTrip(t *testing.T) { mod, err := ParseModuleWithOpts("test.rego", `package a.b.c +import future.keywords import data.x.y as z import data.u.i @@ -42,6 +43,7 @@ a = true { xs = {a: b | input.y[a] = "foo"; b = input.z["bar"]} } b = true { xs = {{"x": a[i].a} | a[i].n = "bob"; b[x]} } call_values { f(x) != g(x) } assigned := 1 +rule.having.ref.head[1] = x if x := 2 # METADATA # scope: rule @@ -392,17 +394,61 @@ func TestExprEveryCopy(t *testing.T) { } } +func TestRuleHeadJSON(t *testing.T) { + // NOTE(sr): we may get to see Rule objects that aren't the result of parsing, but + // fed as-is into the compiler. We need to be able to make sense of their refs, too. + head := Head{ + Name: Var("allow"), + } + + rule := Rule{ + Head: &head, + } + bs, err := json.Marshal(&rule) + if err != nil { + t.Fatal(err) + } + if exp, act := `{"head":{"name":"allow","ref":[{"type":"var","value":"allow"}]},"body":[]}`, string(bs); act != exp { + t.Errorf("expected %q, got %q", exp, act) + } + + var readRule Rule + if err := json.Unmarshal(bs, &readRule); err != nil { + t.Fatal(err) + } + if exp, act := 1, len(readRule.Head.Reference); act != exp { + t.Errorf("expected unmarshalled rule to have Reference, got %v", readRule.Head.Reference) + } + bs0, err := json.Marshal(&readRule) + if err != nil { + t.Fatal(err) + } + if exp, act := string(bs), string(bs0); exp != act { + t.Errorf("expected json repr to match %q, got %q", exp, act) + } + + var readAgainRule Rule + if err := json.Unmarshal(bs, &readAgainRule); err != nil { + t.Fatal(err) + } + if !readAgainRule.Equal(&readRule) { + t.Errorf("expected roundtripped rule reference to match %v, got %v", readRule.Head.Reference, readAgainRule.Head.Reference) + } +} + func TestRuleHeadEquals(t *testing.T) { assertHeadsEqual(t, &Head{}, &Head{}) - // Same name/key/value + // Same name/ref/key/value assertHeadsEqual(t, &Head{Name: Var("p")}, &Head{Name: Var("p")}) + assertHeadsEqual(t, &Head{Reference: Ref{VarTerm("p"), StringTerm("r")}}, &Head{Reference: Ref{VarTerm("p"), StringTerm("r")}}) // TODO: string for first section assertHeadsEqual(t, &Head{Key: VarTerm("x")}, &Head{Key: VarTerm("x")}) assertHeadsEqual(t, &Head{Value: VarTerm("x")}, &Head{Value: VarTerm("x")}) assertHeadsEqual(t, &Head{Args: []*Term{VarTerm("x"), VarTerm("y")}}, &Head{Args: []*Term{VarTerm("x"), VarTerm("y")}}) - // Different name/key/value + // Different name/ref/key/value assertHeadsNotEqual(t, &Head{Name: Var("p")}, &Head{Name: Var("q")}) + assertHeadsNotEqual(t, &Head{Reference: Ref{VarTerm("p")}}, &Head{Reference: Ref{VarTerm("q")}}) // TODO: string for first section assertHeadsNotEqual(t, &Head{Key: VarTerm("x")}, &Head{Key: VarTerm("y")}) assertHeadsNotEqual(t, &Head{Value: VarTerm("x")}, &Head{Value: VarTerm("y")}) assertHeadsNotEqual(t, &Head{Args: []*Term{VarTerm("x"), VarTerm("z")}}, &Head{Args: []*Term{VarTerm("x"), VarTerm("y")}}) @@ -438,48 +484,139 @@ func TestRuleBodyEquals(t *testing.T) { } func TestRuleString(t *testing.T) { + trueBody := NewBody(NewExpr(BooleanTerm(true))) - rule1 := &Rule{ - Head: NewHead(Var("p"), nil, BooleanTerm(true)), - Body: NewBody( - Equality.Expr(StringTerm("foo"), StringTerm("bar")), - ), - } - - rule2 := &Rule{ - Head: NewHead(Var("p"), VarTerm("x"), VarTerm("y")), - Body: NewBody( - Equality.Expr(StringTerm("foo"), VarTerm("x")), - &Expr{ - Negated: true, - Terms: RefTerm(VarTerm("a"), StringTerm("b"), VarTerm("x")), + tests := []struct { + rule *Rule + exp string + }{ + { + rule: &Rule{ + Head: NewHead(Var("p"), nil, BooleanTerm(true)), + Body: NewBody( + Equality.Expr(StringTerm("foo"), StringTerm("bar")), + ), + }, + exp: `p = true { "foo" = "bar" }`, + }, + { + rule: &Rule{ + Head: NewHead(Var("p"), VarTerm("x")), + Body: trueBody, + }, + exp: `p[x] { true }`, + }, + { + rule: &Rule{ + Head: RefHead(MustParseRef("p[x]"), BooleanTerm(true)), + Body: MustParseBody("x = 1"), + }, + exp: `p[x] = true { x = 1 }`, + }, + { + rule: &Rule{ + Head: RefHead(MustParseRef("p.q.r[x]"), BooleanTerm(true)), + Body: MustParseBody("x = 1"), }, - Equality.Expr(StringTerm("b"), VarTerm("y")), - ), + exp: `p.q.r[x] = true { x = 1 }`, + }, + { + rule: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.r"), + Key: VarTerm("1"), + }, + Body: MustParseBody("x = 1"), + }, + exp: `p.q.r contains 1 { x = 1 }`, + }, + { + rule: &Rule{ + Head: NewHead(Var("p"), VarTerm("x"), VarTerm("y")), + Body: NewBody( + Equality.Expr(StringTerm("foo"), VarTerm("x")), + &Expr{ + Negated: true, + Terms: RefTerm(VarTerm("a"), StringTerm("b"), VarTerm("x")), + }, + Equality.Expr(StringTerm("b"), VarTerm("y")), + ), + }, + exp: `p[x] = y { "foo" = x; not a.b[x]; "b" = y }`, + }, + { + rule: &Rule{ + Default: true, + Head: NewHead("p", nil, BooleanTerm(true)), + }, + exp: `default p = true`, + }, + { + rule: &Rule{ + Head: &Head{ + Name: Var("f"), + Args: Args{VarTerm("x"), VarTerm("y")}, + Value: VarTerm("z"), + }, + Body: NewBody(Plus.Expr(VarTerm("x"), VarTerm("y"), VarTerm("z"))), + }, + exp: "f(x, y) = z { plus(x, y, z) }", + }, + { + rule: &Rule{ + Head: &Head{ + Name: Var("p"), + Value: BooleanTerm(true), + Assign: true, + }, + Body: NewBody( + Equality.Expr(StringTerm("foo"), StringTerm("bar")), + ), + }, + exp: `p := true { "foo" = "bar" }`, + }, + { + rule: &Rule{ + Head: RefHead(MustParseRef("p.q.r")), + Body: trueBody, + }, + exp: `p.q.r { true }`, + }, + { + rule: &Rule{ + Head: RefHead(MustParseRef("p.q.r"), StringTerm("foo")), + Body: trueBody, + }, + exp: `p.q.r = "foo" { true }`, + }, + { + rule: &Rule{ + Head: RefHead(MustParseRef("p.q.r[x]"), StringTerm("foo")), + Body: MustParseBody(`x := 1`), + }, + exp: `p.q.r[x] = "foo" { assign(x, 1) }`, + }, } - rule3 := &Rule{ - Default: true, - Head: NewHead("p", nil, BooleanTerm(true)), + for _, tc := range tests { + t.Run(tc.exp, func(t *testing.T) { + assertRuleString(t, tc.rule, tc.exp) + }) } +} - rule4 := &Rule{ - Head: &Head{ - Name: Var("f"), - Args: Args{VarTerm("x"), VarTerm("y")}, - Value: VarTerm("z"), - }, - Body: NewBody(Plus.Expr(VarTerm("x"), VarTerm("y"), VarTerm("z"))), +func TestRulePath(t *testing.T) { + ruleWithMod := func(r string) Ref { + mod := MustParseModule("package pkg\n" + r) + return mod.Rules[0].Path() + } + if exp, act := MustParseRef("data.pkg.p.q.r"), ruleWithMod("p.q.r { true }"); !exp.Equal(act) { + t.Errorf("expected %v, got %v", exp, act) } - rule5 := rule1.Copy() - rule5.Head.Assign = true - - assertRuleString(t, rule1, `p = true { "foo" = "bar" }`) - assertRuleString(t, rule2, `p[x] = y { "foo" = x; not a.b[x]; "b" = y }`) - assertRuleString(t, rule3, `default p = true`) - assertRuleString(t, rule4, "f(x, y) = z { plus(x, y, z) }") - assertRuleString(t, rule5, `p := true { "foo" = "bar" }`) + if exp, act := MustParseRef("data.pkg.p"), ruleWithMod("p { true }"); !exp.Equal(act) { + t.Errorf("expected %v, got %v", exp, act) + } } func TestModuleString(t *testing.T) { @@ -791,7 +928,7 @@ func assertPackagesNotEqual(t *testing.T, a, b *Package) { func assertRulesEqual(t *testing.T, a, b *Rule) { t.Helper() if !a.Equal(b) { - t.Errorf("Rules are not equal (expected equal): a=%v b=%v", a, b) + t.Errorf("Rules are not equal (expected equal):\na=%v\nb=%v", a, b) } } diff --git a/ast/pretty_test.go b/ast/pretty_test.go index d59f7cb809..50fe97c142 100644 --- a/ast/pretty_test.go +++ b/ast/pretty_test.go @@ -41,8 +41,9 @@ func TestPretty(t *testing.T) { qux rule head - p - x + ref + p + x y body expr index=0 @@ -67,7 +68,8 @@ func TestPretty(t *testing.T) { true rule head - f + ref + f args x call diff --git a/ast/term.go b/ast/term.go index 73131b51e3..0aa800dc18 100644 --- a/ast/term.go +++ b/ast/term.go @@ -134,12 +134,12 @@ func As(v Value, x interface{}) error { // Resolver defines the interface for resolving references to native Go values. type Resolver interface { - Resolve(ref Ref) (interface{}, error) + Resolve(Ref) (interface{}, error) } // ValueResolver defines the interface for resolving references to AST values. type ValueResolver interface { - Resolve(ref Ref) (Value, error) + Resolve(Ref) (Value, error) } // UnknownValueErr indicates a ValueResolver was unable to resolve a reference diff --git a/ast/transform.go b/ast/transform.go index c25f7bc32d..391a164860 100644 --- a/ast/transform.go +++ b/ast/transform.go @@ -13,7 +13,7 @@ import ( // be set to nil and no transformations will be applied to children of the // element. type Transformer interface { - Transform(v interface{}) (interface{}, error) + Transform(interface{}) (interface{}, error) } // Transform iterates the AST and calls the Transform function on the @@ -116,6 +116,9 @@ func Transform(t Transformer, x interface{}) (interface{}, error) { } return y, nil case *Head: + if y.Reference, err = transformRef(t, y.Reference); err != nil { + return nil, err + } if y.Name, err = transformVar(t, y.Name); err != nil { return nil, err } @@ -327,7 +330,7 @@ func TransformComprehensions(x interface{}, f func(interface{}) (Value, error)) // GenericTransformer implements the Transformer interface to provide a utility // to transform AST nodes using a closure. type GenericTransformer struct { - f func(x interface{}) (interface{}, error) + f func(interface{}) (interface{}, error) } // NewGenericTransformer returns a new GenericTransformer that will transform @@ -414,3 +417,15 @@ func transformVar(t Transformer, v Var) (Var, error) { } return r, nil } + +func transformRef(t Transformer, r Ref) (Ref, error) { + r1, err := Transform(t, r) + if err != nil { + return nil, err + } + r2, ok := r1.(Ref) + if !ok { + return nil, fmt.Errorf("illegal transform: %T != %T", r, r2) + } + return r2, nil +} diff --git a/ast/transform_test.go b/ast/transform_test.go index c102e47f34..8c5a051aae 100644 --- a/ast/transform_test.go +++ b/ast/transform_test.go @@ -24,6 +24,7 @@ p = n { count({"this", "that"}, n) with input.foo.this as {"this": true} } p { false } else = "this" { "this" } else = ["this"] { true } foo(x) = y { split(x, "this", y) } p { every x in ["this"] { x == "this" } } +a.b.c.this["this"] = d { d := "this" } `) result, err := Transform(&GenericTransformer{ @@ -59,6 +60,7 @@ p = n { count({"that"}, n) with input.foo.that as {"that": true} } p { false } else = "that" { "that" } else = ["that"] { true } foo(x) = y { split(x, "that", y) } p { every x in ["that"] { x == "that" } } +a.b.c.that["that"] = d { d := "that" } `) if !expected.Equal(resultMod) { @@ -113,3 +115,23 @@ p := 7`, ParserOptions{ProcessAnnotation: true}) } } + +func TestTransformRefsAndRuleHeads(t *testing.T) { + module := MustParseModule(`package test +p.q.this.fo[x] = y { x := "x"; y := "y" }`) + + result, err := TransformRefs(module, func(r Ref) (Value, error) { + if r[0].Value.Compare(Var("p")) == 0 { + r[2] = StringTerm("that") + } + return r, nil + }) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + + resultMod := result.(*Module) + if exp, act := MustParseRef("p.q.that.fo[x]"), resultMod.Rules[0].Head.Reference; !act.Equal(exp) { + t.Errorf("expected %v, got %v", exp, act) + } +} diff --git a/ast/visit.go b/ast/visit.go index 02c393cdbc..4c60401e96 100644 --- a/ast/visit.go +++ b/ast/visit.go @@ -432,11 +432,15 @@ func (vis *BeforeAfterVisitor) Walk(x interface{}) { vis.Walk(x.Else) } case *Head: - vis.Walk(x.Name) - vis.Walk(x.Args) - if x.Key != nil { - vis.Walk(x.Key) + if len(x.Reference) > 0 { + vis.Walk(x.Reference) + } else { + vis.Walk(x.Name) + if x.Key != nil { + vis.Walk(x.Key) + } } + vis.Walk(x.Args) if x.Value != nil { vis.Walk(x.Value) } @@ -662,11 +666,16 @@ func (vis *VarVisitor) Walk(x interface{}) { vis.Walk(x.Else) } case *Head: - vis.Walk(x.Name) - vis.Walk(x.Args) - if x.Key != nil { - vis.Walk(x.Key) + if len(x.Reference) > 0 { + vis.Walk(x.Reference) + } else { + vis.Walk(x.Name) + if x.Key != nil { + vis.Walk(x.Key) + } } + vis.Walk(x.Args) + if x.Value != nil { vis.Walk(x.Value) } diff --git a/ast/visit_test.go b/ast/visit_test.go index 1b3694508b..0f2b663275 100644 --- a/ast/visit_test.go +++ b/ast/visit_test.go @@ -393,12 +393,12 @@ fn([x, y]) = z { json.unmarshal(x, z); z > y } }) vis.Walk(rule) - if len(before) != 246 { - t.Errorf("Expected exactly 246 before elements in AST but got %d: %v", len(before), before) + if exp, act := 256, len(before); exp != act { + t.Errorf("Expected exactly %d before elements in AST but got %d: %v", exp, act, before) } - if len(after) != 246 { - t.Errorf("Expected exactly 246 after elements in AST but got %d: %v", len(after), after) + if exp, act := 256, len(before); exp != act { + t.Errorf("Expected exactly %d after elements in AST but got %d: %v", exp, act, after) } } diff --git a/cmd/build_test.go b/cmd/build_test.go index b708417366..cc188cbf5e 100644 --- a/cmd/build_test.go +++ b/cmd/build_test.go @@ -3,6 +3,7 @@ package cmd import ( "archive/tar" "compress/gzip" + "fmt" "io" "os" "path" @@ -162,12 +163,14 @@ func TestBuildErrorDoesNotWriteFile(t *testing.T) { params.outputFile = path.Join(root, "bundle.tar.gz") err := dobuild(params, []string{root}) - if err == nil || !strings.Contains(err.Error(), "rule p is recursive") { - t.Fatal("expected recursion error but got:", err) + exp := fmt.Sprintf("1 error occurred: %s/test.rego:3: rego_recursion_error: rule data.test.p is recursive: data.test.p -> data.test.p", + root) + if err == nil || err.Error() != exp { + t.Fatalf("expected recursion error %q but got: %q", exp, err) } - if _, err := os.Stat(params.outputFile); err == nil { - t.Fatal("expected stat error") + if _, err := os.Stat(params.outputFile); !os.IsNotExist(err) { + t.Fatalf("expected stat \"not found\" error, got %v", err) } }) } diff --git a/compile/compile.go b/compile/compile.go index 8114ada6e3..ea9dfadc32 100644 --- a/compile/compile.go +++ b/compile/compile.go @@ -887,7 +887,7 @@ func (o *optimizer) getSupportForEntrypoint(queries []ast.Body, e *ast.Term, res o.debug.Printf("optimizer: entrypoint: %v: discard due to self-reference", e) return nil } - module.Rules = append(module.Rules, &ast.Rule{ + module.Rules = append(module.Rules, &ast.Rule{ // TODO(sr): use RefHead instead? Head: ast.NewHead(name, nil, resultsym), Body: query, Module: module, @@ -900,6 +900,8 @@ func (o *optimizer) getSupportForEntrypoint(queries []ast.Body, e *ast.Term, res // merge combines two sets of modules and returns the result. The rules from modules // in 'b' override rules from modules in 'a'. If all rules in a module in 'a' are overridden // by rules in modules in 'b' then the module from 'a' is discarded. +// NOTE(sr): This function assumes that `b` is the result of partial eval, and thus does NOT +// contain any rules that genuinely need their ref heads. func (o *optimizer) merge(a, b []bundle.ModuleFile) []bundle.ModuleFile { prefixes := ast.NewSet() @@ -911,12 +913,19 @@ func (o *optimizer) merge(a, b []bundle.ModuleFile) []bundle.ModuleFile { // of rules.) seen := ast.NewVarSet() for _, rule := range b[i].Parsed.Rules { - if _, ok := seen[rule.Head.Name]; !ok { - prefixes.Add(ast.NewTerm(rule.Path())) - seen.Add(rule.Head.Name) + // NOTE(sr): we're relying on the fact that PE never emits ref rules (so far)! + // The rule + // p.a = 1 { ... } + // will be recorded in prefixes as `data.test.p`, and that'll be checked later on against `data.test.p[k]` + if len(rule.Head.Ref()) > 2 { + panic("expected a module without ref rules") + } + name := rule.Head.Name + if !seen.Contains(name) { + prefixes.Add(ast.NewTerm(b[i].Parsed.Package.Path.Append(ast.StringTerm(string(name))))) + seen.Add(name) } } - } for i := range a { @@ -925,30 +934,28 @@ func (o *optimizer) merge(a, b []bundle.ModuleFile) []bundle.ModuleFile { // NOTE(tsandall): same as above--memoize keep/discard decision. If multiple // entrypoints are provided the dst module may contain a large number of rules. - seen := ast.NewVarSet() - discard := ast.NewVarSet() - + seen, discarded := ast.NewSet(), ast.NewSet() for _, rule := range a[i].Parsed.Rules { - - if _, ok := discard[rule.Head.Name]; ok { - continue - } else if _, ok := seen[rule.Head.Name]; ok { + refT := ast.NewTerm(rule.Ref()) + switch { + case seen.Contains(refT): keep = append(keep, rule) continue + case discarded.Contains(refT): + continue } - path := rule.Path() + path := rule.Ref() overlap := prefixes.Until(func(x *ast.Term) bool { - ref := x.Value.(ast.Ref) - return path.HasPrefix(ref) + r := x.Value.(ast.Ref) + return path.HasPrefix(r) }) - if overlap { - discard.Add(rule.Head.Name) - } else { - seen.Add(rule.Head.Name) - keep = append(keep, rule) + discarded.Add(refT) + continue } + seen.Add(refT) + keep = append(keep, rule) } if len(keep) > 0 { @@ -956,7 +963,6 @@ func (o *optimizer) merge(a, b []bundle.ModuleFile) []bundle.ModuleFile { a[i].Raw = nil b = append(b, a[i]) } - } return b diff --git a/compile/compile_test.go b/compile/compile_test.go index 221a4b6997..376313b7bf 100644 --- a/compile/compile_test.go +++ b/compile/compile_test.go @@ -968,7 +968,7 @@ func TestOptimizerErrors(t *testing.T) { p { data.test.p } `, }, - wantErr: fmt.Errorf("1 error occurred: test.rego:3: rego_recursion_error: rule p is recursive: p -> p"), + wantErr: fmt.Errorf("1 error occurred: test.rego:3: rego_recursion_error: rule data.test.p is recursive: data.test.p -> data.test.p"), }, { note: "partial eval error", @@ -1072,6 +1072,36 @@ func TestOptimizerOutput(t *testing.T) { `, }, }, + { + note: "support rules, ref heads", + entrypoints: []string{"data.test.p.q.r"}, + modules: map[string]string{ + "test.rego": ` + package test + + default p.q.r = false + p.q.r { q[input.x] } + + q[1] + q[2]`, + }, + wantModules: map[string]string{ + "optimized/test/p/q.rego": ` + package test.p.q + + default r = false + r = true { 1 = input.x } + r = true { 2 = input.x } + + `, + "test.rego": ` + package test + + q[1] + q[2] + `, + }, + }, { note: "multiple entrypoints", entrypoints: []string{"data.test.p", "data.test.r", "data.test.s"}, @@ -1526,7 +1556,6 @@ func getOptimizer(modules map[string]string, data string, entries []string, root func getModuleFiles(src map[string]string, includeRaw bool) []bundle.ModuleFile { keys := make([]string, 0, len(src)) - for k := range src { keys = append(keys, k) } diff --git a/docs/content/policy-language.md b/docs/content/policy-language.md index eaa0b517e3..87615dd5fd 100644 --- a/docs/content/policy-language.md +++ b/docs/content/policy-language.md @@ -909,6 +909,24 @@ max_memory := 32 { power_users[user] } max_memory := 4 { restricted_users[user] } ``` +### Rule Heads containing References + +As a shorthand for defining nested rule structures, it's valid to use references as rule heads: + +```live:eg/ref_heads:module +fruit.apple.seeds = 12 + +fruit.orange.color = "orange" +``` + +This module defines _two complete rules_, `data.example.fruit.apple.seeds` and `data.example.fruit.orange.color`: + +```live:eg/ref_heads:query:merge_down +data.example +``` +```live:eg/ref_heads:output +``` + ### Functions Rego supports user-defined functions that can be called with the same semantics as [Built-in Functions](#built-in-functions). They have access to both the [the data Document](../philosophy/#the-opa-document-model) and [the input Document](../philosophy/#the-opa-document-model). diff --git a/docs/content/policy-reference.md b/docs/content/policy-reference.md index cc8eb3ba72..2bab586251 100644 --- a/docs/content/policy-reference.md +++ b/docs/content/policy-reference.md @@ -282,6 +282,43 @@ f(x) := "B" if { x >= 80; x < 90 } f(x) := "C" if { x >= 70; x < 80 } ``` +### Reference Heads + +```live:rules/ref_heads:module:read_only +fruit.apple.seeds = 12 if input == "apple" # complete document (single value rule) + +fruit.pineapple.colors contains x if x := "yellow" # multi-value rule + +fruit.banana.phone[x] = "bananular" if x := "cellular" # single value rule +fruit.banana.phone.cellular = "bananular" if true # equivalent single value rule + +fruit.orange.color(x) = true if x == "orange" # function +``` + +For reasons of backwards-compatibility, partial sets need to use `contains` in +their rule hesas, i.e. + +```live:rules/ref_heads/set:module:read_only +fruit.box contains "apples" if true +``` + +whereas + +```live:rules/ref_heads/complete:module:read_only +fruit.box[x] if { x := "apples" } +``` + +defines a _complete document rule_ `fruit.box.apples` with value `true`. +The same is the case of rules with brackets that don't contain dots, like + +```live:rules/ref_heads/simple:module:read_only +box[x] if { x := "apples" } # => {"box": {"apples": true }} +box2[x] { x := "apples" } # => {"box": ["apples"]} +``` + +For backwards-compatibility, rules _without_ if and without _dots_ will be interpreted +as defining partial sets, like `box2`. + ## Tests ```live:tests:module:read_only @@ -1133,7 +1170,7 @@ package = "package" ref import = "import" ref [ "as" var ] policy = { rule } rule = [ "default" ] rule-head { rule-body } -rule-head = var ( rule-head-set | rule-head-obj | rule-head-func | rule-head-comp | "if" ) +rule-head = ( ref | var ) ( rule-head-set | rule-head-obj | rule-head-func | rule-head-comp ) rule-head-comp = [ assign-operator term ] [ "if" ] rule-head-obj = "[" term "]" [ assign-operator term ] [ "if" ] rule-head-func = "(" rule-args ")" [ assign-operator term ] [ "if" ] diff --git a/format/format.go b/format/format.go index f549f4ec5c..745efa05c2 100644 --- a/format/format.go +++ b/format/format.go @@ -406,7 +406,8 @@ func (w *writer) writeElse(rule *ast.Rule, useContainsKW, useIf bool, comments [ w.startLine() } - rule.Else.Head.Name = "else" + rule.Else.Head.Name = "else" // NOTE(sr): whaaat + rule.Else.Head.Reference = ast.Ref{ast.VarTerm("else")} rule.Else.Head.Args = nil comments = w.insertComments(comments, rule.Else.Head.Location) @@ -427,7 +428,12 @@ func (w *writer) writeElse(rule *ast.Rule, useContainsKW, useIf bool, comments [ } func (w *writer) writeHead(head *ast.Head, isDefault, isExpandedConst, useContainsKW, useIf bool, comments []*ast.Comment) []*ast.Comment { - w.write(head.Name.String()) + ref := head.Ref() + if head.Key != nil && head.Value == nil { + ref = ref.GroundPrefix() + } + w.write(ref.String()) + if len(head.Args) > 0 { w.write("(") var args []interface{} @@ -441,7 +447,7 @@ func (w *writer) writeHead(head *ast.Head, isDefault, isExpandedConst, useContai if useContainsKW && head.Value == nil { w.write(" contains ") comments = w.writeTerm(head.Key, comments) - } else { // no `if` for p[x] notation + } else if head.Value == nil { // no `if` for p[x] notation w.write("[") comments = w.writeTerm(head.Key, comments) w.write("]") diff --git a/format/testfiles/test.rego.formatted b/format/testfiles/test.rego.formatted index d78027e32b..9197d0bb45 100644 --- a/format/testfiles/test.rego.formatted +++ b/format/testfiles/test.rego.formatted @@ -24,11 +24,11 @@ globals = { "fizz": "buzz", } -partial_obj["x"] = 1 +partial_obj.x = 1 -partial_obj["y"] = 2 +partial_obj.y = 2 -partial_obj["z"] = 3 +partial_obj.z = 3 partial_set["x"] @@ -198,7 +198,7 @@ nested_infix { expanded_const = true -partial_obj["why"] = true { +partial_obj.why = true { false } diff --git a/format/testfiles/test_ref_heads.rego b/format/testfiles/test_ref_heads.rego new file mode 100644 index 0000000000..a55386ab59 --- /dev/null +++ b/format/testfiles/test_ref_heads.rego @@ -0,0 +1,13 @@ +package test + +import future.keywords + +a.b.c = "d" if true +a.b.e = "f" if true +a.b.g contains x if some x in numbers.range(1, 3) +a.b.h[x] = 1 if x := "one" + +q[1] = y if true +r[x] if x := 10 +p.q.r[x] if x := 10 +p.q.r[2] if true diff --git a/format/testfiles/test_ref_heads.rego.formatted b/format/testfiles/test_ref_heads.rego.formatted new file mode 100644 index 0000000000..5f23e3681b --- /dev/null +++ b/format/testfiles/test_ref_heads.rego.formatted @@ -0,0 +1,19 @@ +package test + +import future.keywords + +a.b.c = "d" + +a.b.e = "f" + +a.b.g contains x if some x in numbers.range(1, 3) + +a.b.h[x] = 1 if x := "one" + +q[1] = y + +r[x] if x := 10 + +p.q.r[x] if x := 10 + +p.q.r[2] = true diff --git a/internal/oracle/oracle.go b/internal/oracle/oracle.go index af2fca93ed..e3eb3afa1a 100644 --- a/internal/oracle/oracle.go +++ b/internal/oracle/oracle.go @@ -52,6 +52,8 @@ func (o *Oracle) FindDefinition(q DefinitionQuery) (*DefinitionQueryResult, erro // TODO(tsandall): how can we cache the results of compilation and parsing so that // multiple queries can be executed without having to re-compute the same values? // Ditto for caching across runs. Avoid repeating the same work. + + // NOTE(sr): "SetRuleTree" because it's needed for compiler.GetRulesExact() below compiler, parsed, err := compileUpto("SetRuleTree", q.Modules, q.Buffer, q.Filename) if err != nil { return nil, err @@ -220,7 +222,8 @@ func getLocMinMax(x ast.Node) (int, int) { // Special case bodies because location text is only for the first expr. if body, ok := x.(ast.Body); ok { - extraLoc := body[len(body)-1].Loc() + last := findLastExpr(body) + extraLoc := last.Loc() if extraLoc == nil { return -1, -1 } @@ -229,3 +232,20 @@ func getLocMinMax(x ast.Node) (int, int) { return min, min + len(loc.Text) } + +// findLastExpr returns the last expression in an ast.Body that has not been generated +// by the compiler. It's used to cope with the fact that a compiler stage before SetRuleTree +// has rewritten the rule bodies slightly. By ignoring appended generated body expressions, +// we can still use the "circling in on the variable" logic based on node locations. +func findLastExpr(body ast.Body) *ast.Expr { + for i := len(body) - 1; i >= 0; i-- { + if !body[i].Generated { + return body[i] + } + } + // NOTE(sr): I believe this shouldn't happen -- we only ever start circling in on a node + // inside a body if there's something in that body. A body that only consists of generated + // expressions should not appear here. Either way, the caller deals with `nil` returned by + // this helper. + return nil +} diff --git a/internal/oracle/oracle_test.go b/internal/oracle/oracle_test.go index 55636734a0..2822c807d2 100644 --- a/internal/oracle/oracle_test.go +++ b/internal/oracle/oracle_test.go @@ -158,6 +158,13 @@ y[1] s = input bar = 7` + // NOTE(sr): Early ref rewriting adds an expression to the rule body for `x.y` + const varInRuleRefModule = `package foo +q[x.y] = 10 { + x := input + some z + z = 1 +}` cases := []struct { note string modules map[string]string @@ -350,6 +357,20 @@ bar = 7` Text: []byte("i"), }, }, + { + note: "intra-rule: ref head", + modules: map[string]string{ + "buffer.rego": varInRuleRefModule, + }, + pos: 47, // "z" in "z = 1" + exp: &ast.Location{ + File: "buffer.rego", + Row: 4, + Col: 7, + Offset: 44, + Text: []byte("z"), + }, + }, } for _, tc := range cases { @@ -362,10 +383,20 @@ bar = 7` t.Fatal(err) } } + buffer := tc.modules["buffer.rego"] + before := tc.pos - 4 + if before < 0 { + before = 0 + } + after := tc.pos + 5 + if after > len(buffer) { + after = len(buffer) + } + t.Logf("pos is %d: \"%s<%s>%s\"", tc.pos, string(buffer[before:tc.pos]), string(buffer[tc.pos]), string(buffer[tc.pos+1:after])) o := New() result, err := o.FindDefinition(DefinitionQuery{ Modules: modules, - Buffer: []byte(tc.modules["buffer.rego"]), + Buffer: []byte(buffer), Filename: "buffer.rego", Pos: tc.pos, }) diff --git a/internal/planner/planner.go b/internal/planner/planner.go index 62e2b4c37b..070b1a6c1e 100644 --- a/internal/planner/planner.go +++ b/internal/planner/planner.go @@ -147,7 +147,7 @@ func (p *Planner) buildFunctrie() error { } for _, rule := range module.Rules { - val := p.rules.LookupOrInsert(rule.Path()) + val := p.rules.LookupOrInsert(rule.Ref()) val.rules = append(val.rules, rule) } } @@ -155,14 +155,29 @@ func (p *Planner) buildFunctrie() error { return nil } -func (p *Planner) planRules(rules []*ast.Rule) (string, error) { - - pathRef := rules[0].Path() +func (p *Planner) planRules(rules []*ast.Rule, cut bool) (string, error) { + pathRef := rules[0].Ref() // NOTE(sr): no longer the same for all those rules, respect `cut`? path := pathRef.String() var pathPieces []string - for i := 1; /* skip `data` */ i < len(pathRef); i++ { - pathPieces = append(pathPieces, string(pathRef[i].Value.(ast.String))) + // TODO(sr): this has to change when allowing `p[v].q.r[w]` ref rules + // including the mapping lookup structure and lookup functions + + // if we're planning both p.q.r and p.q[s], we'll name the function p.q (for the mapping table) + pieces := len(pathRef) + if cut { + pieces-- + } + for i := 1; /* skip `data` */ i < pieces; i++ { + switch q := pathRef[i].Value.(type) { + case ast.String: + pathPieces = append(pathPieces, string(q)) + case ast.Var: + pathPieces = append(pathPieces, fmt.Sprintf("[%s]", q)) + default: + // Needs to be fixed if we allow non-string ref pieces, like `p.q[3][4].r = x` + pathPieces = append(pathPieces, q.String()) + } } if funcName, ok := p.funcs.Get(path); ok { @@ -204,12 +219,27 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) { params := fn.Params[2:] + // control if p.a = 1 is to return 1 directly; or insert 1 under key "a" into an object + buildObject := false + // Initialize return value for partial set/object rules. Complete document // rules assign directly to `fn.Return`. - switch rules[0].Head.DocKind() { - case ast.PartialObjectDoc: - fn.Blocks = append(fn.Blocks, p.blockWithStmt(&ir.MakeObjectStmt{Target: fn.Return})) - case ast.PartialSetDoc: + switch rules[0].Head.RuleKind() { + case ast.SingleValue: + // if any rule has a non-ground last key, create an object, insert into it + any := false + for _, rule := range rules { + ref := rule.Head.Ref() + if last := ref[len(ref)-1]; len(ref) > 1 && !last.IsGround() { + any = true + break + } + } + if any { + buildObject = true + fn.Blocks = append(fn.Blocks, p.blockWithStmt(&ir.MakeObjectStmt{Target: fn.Return})) + } + case ast.MultiValue: fn.Blocks = append(fn.Blocks, p.blockWithStmt(&ir.MakeSetStmt{Target: fn.Return})) } @@ -279,6 +309,7 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) { p.appendStmt(&ir.IsUndefinedStmt{Source: lresult}) } else { // The first rule body resets the local, so it can be reused. + // TODO(sr): I don't think we need this anymore. Double-check? Perhaps multi-value rules need it. p.appendStmt(&ir.ResetLocalStmt{Target: lresult}) } @@ -287,11 +318,27 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) { err := p.planFuncParams(params, rule.Head.Args, 0, func() error { // Run planner on the rule body. - err := p.planQuery(rule.Body, 0, func() error { + return p.planQuery(rule.Body, 0, func() error { // Run planner on the result. - switch rule.Head.DocKind() { - case ast.CompleteDoc: + switch rule.Head.RuleKind() { + case ast.SingleValue: + if buildObject { + ref := rule.Head.Ref() + last := ref[len(ref)-1] + return p.planTerm(last, func() error { + key := p.ltarget + return p.planTerm(rule.Head.Value, func() error { + value := p.ltarget + p.appendStmt(&ir.ObjectInsertOnceStmt{ + Object: fn.Return, + Key: key, + Value: value, + }) + return nil + }) + }) + } return p.planTerm(rule.Head.Value, func() error { p.appendStmt(&ir.AssignVarOnceStmt{ Target: lresult, @@ -299,7 +346,7 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) { }) return nil }) - case ast.PartialSetDoc: + case ast.MultiValue: return p.planTerm(rule.Head.Key, func() error { p.appendStmt(&ir.SetAddStmt{ Set: fn.Return, @@ -307,29 +354,10 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) { }) return nil }) - case ast.PartialObjectDoc: - return p.planTerm(rule.Head.Key, func() error { - key := p.ltarget - return p.planTerm(rule.Head.Value, func() error { - value := p.ltarget - p.appendStmt(&ir.ObjectInsertOnceStmt{ - Object: fn.Return, - Key: key, - Value: value, - }) - return nil - }) - }) default: return fmt.Errorf("illegal rule kind") } }) - - if err != nil { - return err - } - - return nil }) if err != nil { @@ -338,7 +366,7 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) { } // rule[i] and its else-rule(s), if present, are done - if rules[i].Head.DocKind() == ast.CompleteDoc { + if rules[i].Head.RuleKind() == ast.SingleValue && !buildObject { end := &ir.Block{} p.appendStmtToBlock(&ir.IsDefinedStmt{Source: lresult}, end) p.appendStmtToBlock( @@ -841,7 +869,7 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error { // replacement is a function (rule) if node := p.rules.Lookup(r); node != nil { p.mocks.Push() // new scope - name, err = p.planRules(node.Rules()) + name, err = p.planRules(node.Rules(), false) if err != nil { return err } @@ -862,7 +890,7 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error { } if node := p.rules.Lookup(op); node != nil { - name, err = p.planRules(node.Rules()) + name, err = p.planRules(node.Rules(), false) if err != nil { return err } @@ -1604,7 +1632,7 @@ func (p *Planner) planRefData(virtual *ruletrie, base *baseptr, ref ast.Ref, ind // NOTE(sr): we do it on the first index because later on, the recursion // on subtrees of virtual already lost parts of the path we've taken. if index == 1 && virtual != nil { - rulesets, path, index, optimize := p.optimizeLookup(virtual, ref) + rulesets, path, index, optimize := p.optimizeLookup(virtual, ref.GroundPrefix()) if optimize { // If there are no rulesets in a situation that otherwise would // allow for a call_indirect optimization, then there's nothing @@ -1614,7 +1642,7 @@ func (p *Planner) planRefData(virtual *ruletrie, base *baseptr, ref ast.Ref, ind } // plan rules for _, rules := range rulesets { - if _, err := p.planRules(rules); err != nil { + if _, err := p.planRules(rules, false); err != nil { return err } } @@ -1701,17 +1729,36 @@ func (p *Planner) planRefData(virtual *ruletrie, base *baseptr, ref ast.Ref, ind if ref[index].IsGround() { var vchild *ruletrie + var rules []*ast.Rule + + // If there's any non-ground key among the vchild.Children, like + // p[x] and p.a (x being non-ground), we'll collect all 'p' rules, + // plan them. + anyKeyNonGround := false if virtual != nil { + vchild = virtual.Get(ref[index].Value) - } - rules := vchild.Rules() + for _, key := range vchild.Children() { + if !key.IsGround() { + anyKeyNonGround = true + break + } + } + if anyKeyNonGround { + for _, key := range vchild.Children() { + rules = append(rules, vchild.Get(key).Rules()...) + } + } else { + rules = vchild.Rules() // hit or miss + } + } if len(rules) > 0 { p.ltarget = p.newOperand() - funcName, err := p.planRules(rules) + funcName, err := p.planRules(rules, anyKeyNonGround) if err != nil { return err } @@ -1727,7 +1774,6 @@ func (p *Planner) planRefData(virtual *ruletrie, base *baseptr, ref ast.Ref, ind bchild := *base bchild.path = append(bchild.path, ref[index]) - return p.planRefData(vchild, &bchild, ref, index+1, iter) } @@ -1836,38 +1882,21 @@ func (p *Planner) planRefDataExtent(virtual *ruletrie, base *baseptr, iter plani Target: vtarget, }) + anyKeyNonGround := false for _, key := range virtual.Children() { - child := virtual.Get(key) - - // Skip functions. - if child.Arity() > 0 { - continue + if !key.IsGround() { + anyKeyNonGround = true + break } - - lkey := ir.StringIndex(p.getStringConst(string(key.(ast.String)))) - - rules := child.Rules() - - // Build object hierarchy depth-first. - if len(rules) == 0 { - err := p.planRefDataExtent(child, nil, func() error { - p.appendStmt(&ir.ObjectInsertStmt{ - Object: vtarget, - Key: op(lkey), - Value: p.ltarget, - }) - return nil - }) - if err != nil { - return err - } - continue + } + if anyKeyNonGround { + var rules []*ast.Rule + for _, key := range virtual.Children() { + // TODO(sr): skip functions + rules = append(rules, virtual.Get(key).Rules()...) } - // Generate virtual document for leaf. - lvalue := p.newLocal() - - funcName, err := p.planRules(rules) + funcName, err := p.planRules(rules, true) if err != nil { return err } @@ -1877,14 +1906,59 @@ func (p *Planner) planRefDataExtent(virtual *ruletrie, base *baseptr, iter plani p.appendStmtToBlock(&ir.CallStmt{ Func: funcName, Args: p.defaultOperands(), - Result: lvalue, - }, b) - p.appendStmtToBlock(&ir.ObjectInsertStmt{ - Object: vtarget, - Key: op(lkey), - Value: op(lvalue), + Result: vtarget, }, b) p.appendStmt(&ir.BlockStmt{Blocks: []*ir.Block{b}}) + } else { + for _, key := range virtual.Children() { + child := virtual.Get(key) + + // Skip functions. + if child.Arity() > 0 { + continue + } + + lkey := ir.StringIndex(p.getStringConst(string(key.(ast.String)))) + rules := child.Rules() + + // Build object hierarchy depth-first. + if len(rules) == 0 { + err := p.planRefDataExtent(child, nil, func() error { + p.appendStmt(&ir.ObjectInsertStmt{ + Object: vtarget, + Key: op(lkey), + Value: p.ltarget, + }) + return nil + }) + if err != nil { + return err + } + continue + } + + // Generate virtual document for leaf. + lvalue := p.newLocal() + + funcName, err := p.planRules(rules, false) + if err != nil { + return err + } + + // Add leaf to object if defined. + b := &ir.Block{} + p.appendStmtToBlock(&ir.CallStmt{ + Func: funcName, + Args: p.defaultOperands(), + Result: lvalue, + }, b) + p.appendStmtToBlock(&ir.ObjectInsertStmt{ + Object: vtarget, + Key: op(lkey), + Value: op(lvalue), + }, b) + p.appendStmt(&ir.BlockStmt{Blocks: []*ir.Block{b}}) + } } // At this point vtarget refers to the full extent of the virtual diff --git a/internal/planner/planner_test.go b/internal/planner/planner_test.go index bd488972cf..2f589fa697 100644 --- a/internal/planner/planner_test.go +++ b/internal/planner/planner_test.go @@ -86,7 +86,7 @@ func TestPlannerHelloWorld(t *testing.T) { }, { note: "complete rules", - queries: []string{"true"}, + queries: []string{"data.test.p = x"}, modules: []string{` package test p = x { x = 1 } @@ -141,6 +141,15 @@ func TestPlannerHelloWorld(t *testing.T) { p["b"] = 2 `}, }, + { // NOTE(sr): these are handled differently with ref-heads + note: "partial object with var", + queries: []string{`data.test.p = x`}, + modules: []string{` + package test + p["a"] = 1 + p[v] = 2 { v := "b" } + `}, + }, { note: "every", queries: []string{`data.test.p`}, diff --git a/internal/planner/rules.go b/internal/planner/rules.go index 5e94fb1c5e..6ff78f4c69 100644 --- a/internal/planner/rules.go +++ b/internal/planner/rules.go @@ -144,6 +144,9 @@ func (t *ruletrie) LookupOrInsert(key ast.Ref) *ruletrie { } func (t *ruletrie) Children() []ast.Value { + if t == nil { + return nil + } sorted := make([]ast.Value, 0, len(t.children)) for key := range t.children { if t.Get(key) != nil { diff --git a/refactor/refactor_test.go b/refactor/refactor_test.go index ac55d89b05..6a06318142 100644 --- a/refactor/refactor_test.go +++ b/refactor/refactor_test.go @@ -337,7 +337,7 @@ p = 7`) t.Fatal("Expected error but got nil") } - errMsg := "rego_type_error: conflicting rules named p found" + errMsg := "rego_type_error: conflicting rules data.b.p found" if !strings.Contains(err.Error(), errMsg) { t.Fatalf("Expected error message %v but got %v", errMsg, err.Error()) } diff --git a/rego/rego_test.go b/rego/rego_test.go index 3f408409b0..d0c1b1c950 100644 --- a/rego/rego_test.go +++ b/rego/rego_test.go @@ -1867,7 +1867,6 @@ func TestRegoCustomBuiltinPartialPropagate(t *testing.T) { } func TestRegoPartialResultRecursiveRefs(t *testing.T) { - r := New(Query("data"), Module("test.rego", `package foo.bar default p = false diff --git a/repl/repl.go b/repl/repl.go index 48a2f669f5..6adb33a652 100644 --- a/repl/repl.go +++ b/repl/repl.go @@ -811,11 +811,9 @@ func (r *REPL) compileRule(ctx context.Context, rule *ast.Rule) error { switch r.outputFormat { case "json": default: - var msg string + msg := "defined" if unset { msg = "re-defined" - } else { - msg = "defined" } fmt.Fprintf(r.output, "Rule '%v' %v in %v. Type 'show' to see rules.\n", rule.Head.Name, msg, mod.Package) } @@ -1177,10 +1175,16 @@ func (r *REPL) interpretAsRule(ctx context.Context, compiler *ast.Compiler, body if expr.IsAssignment() { rule, err := ast.ParseCompleteDocRuleFromAssignmentExpr(r.getCurrentOrDefaultModule(), expr.Operand(0), expr.Operand(1)) - if err == nil { - if err := r.compileRule(ctx, rule); err != nil { - return false, err - } + if err != nil { + return false, nil + } + // TODO(sr): support interactive ref head rule definitions + if len(rule.Head.Ref()) > 1 { + return false, nil + } + + if err := r.compileRule(ctx, rule); err != nil { + return false, err } return rule != nil, nil } @@ -1194,12 +1198,17 @@ func (r *REPL) interpretAsRule(ctx context.Context, compiler *ast.Compiler, body } rule, err := ast.ParseCompleteDocRuleFromEqExpr(r.getCurrentOrDefaultModule(), expr.Operand(0), expr.Operand(1)) - if err == nil { - if err := r.compileRule(ctx, rule); err != nil { - return false, err - } + if err != nil { + return false, nil + } + // TODO(sr): support interactive ref head rule definitions + if len(rule.Head.Ref()) > 1 { + return false, nil } + if err := r.compileRule(ctx, rule); err != nil { + return false, err + } return rule != nil, nil } diff --git a/server/server.go b/server/server.go index 2800beaafa..213beafa00 100644 --- a/server/server.go +++ b/server/server.go @@ -2638,13 +2638,7 @@ func stringPathToRef(s string) (r ast.Ref) { } func validateQuery(query string) (ast.Body, error) { - - var body ast.Body - body, err := ast.ParseBody(query) - if err != nil { - return nil, err - } - return body, nil + return ast.ParseBody(query) } func getBoolParam(url *url.URL, name string, ifEmpty bool) bool { diff --git a/test/cases/testdata/jsonpatch/json-patch-tests.yaml b/test/cases/testdata/jsonpatch/json-patch-tests.yaml index 7692887d9a..9d58bc2853 100644 --- a/test/cases/testdata/jsonpatch/json-patch-tests.yaml +++ b/test/cases/testdata/jsonpatch/json-patch-tests.yaml @@ -16,7 +16,7 @@ cases: # Grab all cases from the modules inside `data.json_patch_cases` and # construct an object with readable index names. all_cases[k] = t { - t := data.json_patch_cases[p]["cases"][i] + t := data.json_patch_cases[p].cases[i] k := sprintf("%s/%d", [p, i]) } diff --git a/test/cases/testdata/partialobjectdoc/test-wasm-cases.yaml b/test/cases/testdata/partialobjectdoc/test-wasm-cases.yaml new file mode 100644 index 0000000000..d1da298df4 --- /dev/null +++ b/test/cases/testdata/partialobjectdoc/test-wasm-cases.yaml @@ -0,0 +1,106 @@ + +# NOTE(sr): These test cases stem from cases we run against the wasm +# module but since the ref-heads change made them fail -- because of +# changes to the typing of partial object rules (which are now single- +# valued rules), they're added to the topdown tests, too. +cases: + - note: wasm/additive + query: 'data.x.q = x' + modules: + - | + package x + p["a"] = 1 + p["b"] = 2 + + q { + p == {"a": 1, "b": 2} + } + want_result: + - x: true + - note: wasm/additive (negative) + query: 'data.x.p = input' + input: {"a": 1, "b": 2} + modules: + - | + package x + p["a"] = 1 + p["b"] = 2 + p["c"] = 3 + want_result: [] + - note: wasm/input + query: 'data.x.q = x' + modules: + - | + package x + p["a"] = 1 { input.x = 1 } + p["b"] = 2 { input.y = 2 } + + q { + p == {"a": 1, "b": 2} + } + input: {"x": 1, "y": 2} + want_result: + - x: true + - note: wasm/input (negative) + query: 'data.x.q = x' + data: + z: + a: 1 + b: 2 + modules: + - | + package x + p["a"] = 1 { input.x = 1 } + p["b"] = 2 { input.y = 2 } + p["c"] = 3 { input.z = 3 } + + q { + p == data.z + } + want_result: + - x: true + input: {"x": 1, "y": 2} + - note: wasm/composites + query: 'data.x.q = x' + modules: + - | + package x + p[x] = [y] { x = "a"; y = 1 } + p[x] = [y] { x = "b"; y = 2 } + + q { p == {"a": [1], "b": [2]} } + want_result: + - x : true + - note: wasm/conflict error + query: 'data.x.q = x' + data: + z: + a: 1 + modules: + - | + package x + p["x"] = 1 + p["x"] = 2 + q { + p == data.z + } + want_error: "test-0.rego:3: eval_conflict_error: complete rules must not produce multiple outputs" + want_error_code: eval_conflict_error + - note: wasm/object dereference + query: 'data.x.q = x' + modules: + - | + package x + p["a"] = {"b": 1} + q { + p.a.b = 1 + } + want_result: + - x: true + - note: wasm/object dereference (negative) + query: data.x.p.a.b = 1 + modules: + - | + package x + p["a"] = {"b": 2} + want_result: [] diff --git a/test/cases/testdata/refheads/test-refs-as-rule-heads.yaml b/test/cases/testdata/refheads/test-refs-as-rule-heads.yaml new file mode 100644 index 0000000000..f8d00b5a7e --- /dev/null +++ b/test/cases/testdata/refheads/test-refs-as-rule-heads.yaml @@ -0,0 +1,313 @@ +cases: +- modules: + - | + package test + + p.q.r = 1 + p.q.s = 2 + note: 'refheads/single-value' + query: data.test.p = x + want_result: + - x: + q: + r: 1 + s: 2 +- modules: + - | + package test + + p.q.r = 1 + p.q[s] = 2 { s := "s" } + note: 'refheads/single-value, with var' + query: data.test.p = x + want_result: + - x: + q: + r: 1 + s: 2 +- modules: + - | + package test + + a.b.c.p { true } + note: 'refheads/complete: direct query' + query: data.test.a.b.c.p = x + want_result: + - x: true +- modules: + - | + package test + + q = 0 + note: 'refheads/complete: direct query q' + query: data.test.q = x + want_result: + - x: 0 +- modules: + - | + package test + + a.b.c.p { true } + note: 'refheads/complete: full package extent' + query: data.test = x + want_result: + - x: + a: + b: + c: + p: true +- modules: + - | + package test + + a.b.c.p = 1 + q = 0 + a.b.d = 3 + + p { + q == 0 + a.b.c.p == 1 + a.b.d == 3 + } + note: refheads/complete+mixed + query: data.test.p = x + want_result: + - x: true +- modules: + - | + package test + + a.b[x] = y { x := "c"; y := "d" } + note: refheads/single-value rule + query: data.test.a = x + want_result: + - x: + b: + c: d +- modules: + - | + package test + import future.keywords + + a.b contains x if some x in [1,2,3] + note: refheads/multi-value + query: data.test.a = x + want_result: + - x: + b: [1, 2, 3] +# NOTE(sr): This isn't supported yet +# - modules: +# - | +# package test +# import future.keywords + +# a.b[c] contains x if { c := "c"; some x in [1,2,3] } +# note: refheads/multi-value with var in ref +# query: data.test.a = x +# want_result: +# - x: +# b: +# c: [1, 2, 3] +- modules: + - | + package test + import future.keywords + + a.b[x] = i if some i, x in [1, 2, 3] + note: 'refheads/single-value: previously partial object' + query: data.test.a.b = x + want_result: + - x: + 1: 0 + 2: 1 + 3: 2 +- modules: + - | + package test + import future.keywords + + a.b.c.d contains 1 if true + - | + package test.a + import future.keywords + + b.c.d contains 2 if true + note: 'refheads/multi-value: same rule' + query: data.test.a = x + want_result: + - x: + b: + c: + d: [1, 2] +- modules: + - | + package test + + default a.b.c := "d" + note: refheads/single-value default rule + query: data.test.a = x + want_result: + - x: + b: + c: d +- modules: + - | + package test + import future.keywords + + q[7] = 8 if true + a[x] if q[x] + note: refheads/single-value example + query: data.test.a = x + want_result: + - x: + 7: true +- modules: + - | + package test + import future.keywords + + q[7] = 8 if false + a[x] if q[x] + note: refheads/single-value example, false + query: data.test.a = x + want_result: + - x: {} +- modules: + - | + package test + import future.keywords + + a.b.c = "d" if true + a.b.e = "f" if true + a.b.g contains x if some x in numbers.range(1, 3) + a.b.h[x] = 1 if x := "one" + note: refheads/mixed example, multiple rules + query: data.test.a.b = x + want_result: + - x: + c: d + e: f + g: + - 1 + - 2 + - 3 + h: + one: 1 +- modules: + - | + package example + import future.keywords + + apps_by_hostname[hostname] := app if { + some i + server := sites[_].servers[_] + hostname := server.hostname + apps[i].servers[_] == server.name + app := apps[i].name + } + + sites := [ + { + "region": "east", + "name": "prod", + "servers": [ + { + "name": "web-0", + "hostname": "hydrogen" + }, + { + "name": "web-1", + "hostname": "helium" + }, + { + "name": "db-0", + "hostname": "lithium" + } + ] + }, + { + "region": "west", + "name": "smoke", + "servers": [ + { + "name": "web-1000", + "hostname": "beryllium" + }, + { + "name": "web-1001", + "hostname": "boron" + }, + { + "name": "db-1000", + "hostname": "carbon" + } + ] + }, + { + "region": "west", + "name": "dev", + "servers": [ + { + "name": "web-dev", + "hostname": "nitrogen" + }, + { + "name": "db-dev", + "hostname": "oxygen" + } + ] + } + ] + + apps := [ + { + "name": "web", + "servers": ["web-0", "web-1", "web-1000", "web-1001", "web-dev"] + }, + { + "name": "mysql", + "servers": ["db-0", "db-1000"] + }, + { + "name": "mongodb", + "servers": ["db-dev"] + } + ] + + containers := [ + { + "image": "redis", + "ipaddress": "10.0.0.1", + "name": "big_stallman" + }, + { + "image": "nginx", + "ipaddress": "10.0.0.2", + "name": "cranky_euclid" + } + ] + note: refheads/website-example/partial-obj + query: data.example.apps_by_hostname.helium = x + want_result: + - x: web +- modules: + - | + package example + import future.keywords + + public_network contains net.id if { + some net in input.networks + net.public + } + note: refheads/website-example/partial-set + query: data.example.public_network = x + input: + networks: + - id: n1 + public: true + - id: n2 + public: false + want_result: + - x: + - n1 \ No newline at end of file diff --git a/test/cases/testdata/refheads/test-regressions.yaml b/test/cases/testdata/refheads/test-regressions.yaml new file mode 100644 index 0000000000..4bce1e3c7a --- /dev/null +++ b/test/cases/testdata/refheads/test-regressions.yaml @@ -0,0 +1,116 @@ +# NOTE(sr): These tests are not really related to ref heads, but collection +# regressions found when introducing ref heads. +cases: +- note: regression/ref-not-hashable + modules: + - | + package test + + ms[m.z] = m { + m := input.xs[y] + } + input: + xs: + something: + z: a + query: data.test = x + want_result: + - x: + ms: + a: + z: a +- note: regression/function refs and package extent + modules: + - | + package test + import future.keywords.if + + foo.bar(x) = x+1 + x := foo.bar(2) + query: data.test = x + want_result: + - x: + foo: {} + x: 3 + +- note: regression/rule refs and package extent + modules: + - | + package test + import future.keywords.if + + buz.quz = 3 if input == 3 + x := y { + y := buz.quz with input as 3 + } + query: data.test = x + want_result: + - x: + buz: {} + x: 3 +- note: regression/rule refs and package extents, multiple modules + modules: + - | + package test + import future.keywords.if + + x := y { + y := data.test.buz.quz with input as 3 + } + - | + package test.buz + import future.keywords.if + + quz = 3 if input == 3 + query: data.test = x + want_result: + - x: + buz: {} + x: 3 +- note: regression/type checking with ref rules + modules: + - | + package test + all[0] = [2] + level := 1 + + p := y { y := all[level-1][_] } + query: data.test.p = x + want_result: + - x: 2 +- note: regression/type checking with ref rules, number + modules: + - | + package test + p[0] = 1 + query: 'data.test.p[0] = x' + want_result: + - x: 1 +- note: regression/type checking with ref rules, bool + modules: + - | + package test + p[true] = 1 + query: 'data.test.p[true] = x' + want_result: + - x: 1 +- note: regression/full extent with partial object rule with empty indexer lookup result + modules: + - | + package test + p[x] = 2 { + x := input # we'll get 0 rules for data.test.p + false + } + query: data.test = x + want_result: + - x: + p: {} +- note: regression/obj in ref head query + modules: + - | + package test + p[{"a": "b"}] = true + query: data.test.p[{"a":"b"}] = x + want_result: + - x: true \ No newline at end of file diff --git a/test/wasm/assets/012_partialobjects.yaml b/test/wasm/assets/012_partialobjects.yaml index 621ae6be96..39c0f752c7 100644 --- a/test/wasm/assets/012_partialobjects.yaml +++ b/test/wasm/assets/012_partialobjects.yaml @@ -8,7 +8,8 @@ cases: p["b"] = 2 want_defined: true - note: additive (negative) - query: 'data.x.p = {"a": 1, "b": 2}' + query: 'data.x.p = input' + input: {"a": 1, "b": 2} modules: - | package x @@ -18,15 +19,19 @@ cases: want_defined: false - note: input query: 'data.x.p = {"b": 2, "a": 1}' + input: {"x": 1, "y": 2} modules: - | package x p["a"] = 1 { input.x = 1 } p["b"] = 2 { input.y = 2 } want_defined: true - input: {"x": 1, "y": 2} - note: input (negative) - query: 'data.x.p = {"a": 1, "b": 2}' + query: 'data.x.p = data.z' + data: + z: + a: 1 + b: 2 modules: - | package x @@ -44,13 +49,16 @@ cases: p[x] = [y] { x = "b"; y = 2 } want_defined: true - note: conflict error - query: 'data.x.p = {"a": 1}' + data: + z: + a: 1 + query: 'data.x.p = data.z' modules: - | package x p["x"] = 1 p["x"] = 2 - want_error: "module0.rego:3:1: object insert conflict" + want_error: "module0.rego:3:1: var assignment conflict" - note: object dereference query: data.x.p.a.b = 1 modules: diff --git a/tester/reporter_test.go b/tester/reporter_test.go index afaf527e82..56d98b2d57 100644 --- a/tester/reporter_test.go +++ b/tester/reporter_test.go @@ -75,6 +75,14 @@ func TestPrettyReporterVerbose(t *testing.T) { File: "policy3.rego", }, }, + { + Package: "data.foo.baz", + Name: "p.q.r.test_quz", + Trace: getFakeTraceEvents(), + Location: &ast.Location{ + File: "policy4.rego", + }, + }, } r := PrettyReporter{ @@ -109,11 +117,14 @@ data.foo.bar.test_contains_print: PASS (0s) fake print output + +policy4.rego: +data.foo.baz.p.q.r.test_quz: PASS (0s) -------------------------------------------------------------------------------- -PASS: 2/5 -FAIL: 1/5 -SKIPPED: 1/5 -ERROR: 1/5 +PASS: 3/6 +FAIL: 1/6 +SKIPPED: 1/6 +ERROR: 1/6 ` str := buf.String() @@ -181,6 +192,15 @@ func TestPrettyReporter(t *testing.T) { File: "policy2.rego", }, }, + { + Package: "data.foo.baz", + Name: "p.q.r.test_quz", + Fail: true, + Trace: getFakeTraceEvents(), + Location: &ast.Location{ + File: "policy3.rego", + }, + }, } r := PrettyReporter{ @@ -203,11 +223,14 @@ data.foo.bar.test_contains_print_fail: FAIL (0s) fake print output2 + +policy3.rego: +data.foo.baz.p.q.r.test_quz: FAIL (0s) -------------------------------------------------------------------------------- -PASS: 2/6 -FAIL: 2/6 -SKIPPED: 1/6 -ERROR: 1/6 +PASS: 2/7 +FAIL: 3/7 +SKIPPED: 1/7 +ERROR: 1/7 ` if exp != buf.String() { @@ -246,6 +269,10 @@ func TestJSONReporter(t *testing.T) { Name: "test_contains_print", Output: []byte("fake print output\n"), }, + { + Package: "data.foo.baz", + Name: "p.q.r.test_quz", + }, } r := JSONReporter{ @@ -406,7 +433,13 @@ func TestJSONReporter(t *testing.T) { "name": "test_contains_print", "output": "ZmFrZSBwcmludCBvdXRwdXQK", "duration": 0 - } + }, + { + "location": null, + "package": "data.foo.baz", + "name": "p.q.r.test_quz", + "duration": 0 +} ] `)) diff --git a/tester/runner.go b/tester/runner.go index 5b08ee9a6b..9852882f8c 100644 --- a/tester/runner.go +++ b/tester/runner.go @@ -299,7 +299,7 @@ func (r *Runner) runTests(ctx context.Context, txn storage.Transaction, enablePr } // rewrite duplicate test_* rule names as we compile modules - r.compiler.WithStageAfter("ResolveRefs", ast.CompilerStageDefinition{ + r.compiler.WithStageAfter("RewriteRuleHeadRefs", ast.CompilerStageDefinition{ Name: "RewriteDuplicateTestNames", MetricName: "rewrite_duplicate_test_names", Stage: rewriteDuplicateTestNames, @@ -340,7 +340,7 @@ func (r *Runner) runTests(ctx context.Context, txn storage.Transaction, enablePr } } - if r.modules != nil && len(r.modules) > 0 { + if len(r.modules) > 0 { if r.compiler.Compile(r.modules); r.compiler.Failed() { return nil, r.compiler.Errors } @@ -380,7 +380,7 @@ func (r *Runner) runTests(ctx context.Context, txn storage.Transaction, enablePr } func (r *Runner) shouldRun(rule *ast.Rule, testRegex *regexp.Regexp) bool { - ruleName := string(rule.Head.Name) + ruleName := ruleName(rule.Head) // All tests must have the right prefix if !strings.HasPrefix(ruleName, TestPrefix) && !strings.HasPrefix(ruleName, SkipTestPrefix) { @@ -388,7 +388,7 @@ func (r *Runner) shouldRun(rule *ast.Rule, testRegex *regexp.Regexp) bool { } // Even with the prefix it needs to pass the regex (if applicable) - fullName := fmt.Sprintf("%s.%s", rule.Module.Package.Path.String(), ruleName) + fullName := rule.Ref().String() if testRegex != nil && !testRegex.MatchString(fullName) { return false } @@ -403,13 +403,20 @@ func rewriteDuplicateTestNames(compiler *ast.Compiler) *ast.Error { count := map[string]int{} for _, mod := range compiler.Modules { for _, rule := range mod.Rules { - name := rule.Head.Name.String() + name := ruleName(rule.Head) if !strings.HasPrefix(name, TestPrefix) { continue } - key := rule.Path().String() + key := rule.Ref().String() if k, ok := count[key]; ok { - rule.Head.Name = ast.Var(fmt.Sprintf("%s#%02d", name, k)) + ref := rule.Head.Ref() + newName := fmt.Sprintf("%s#%02d", name, k) + if len(ref) == 1 { + ref[0] = ast.VarTerm(newName) + } else { + ref[len(ref)-1] = ast.StringTerm(newName) + } + rule.Head.SetRef(ref) } count[key]++ } @@ -417,6 +424,23 @@ func rewriteDuplicateTestNames(compiler *ast.Compiler) *ast.Error { return nil } +// ruleName is a helper to be used when checking if a function +// (a) is a test, or +// (b) needs to be skipped +// -- it'll resolve `p.q.r` to `r`. For representing results, we'll +// use rule.Head.Ref() +func ruleName(h *ast.Head) string { + ref := h.Ref() + switch last := ref[len(ref)-1].Value.(type) { + case ast.Var: + return string(last) + case ast.String: + return string(last) + default: + panic("unreachable") + } +} + func (r *Runner) runTest(ctx context.Context, txn storage.Transaction, mod *ast.Module, rule *ast.Rule) (*Result, bool) { var bufferTracer *topdown.BufferTracer var bufFailureLineTracer *topdown.BufferTracer @@ -429,10 +453,9 @@ func (r *Runner) runTest(ctx context.Context, txn storage.Transaction, mod *ast. tracer = bufferTracer } - ruleName := string(rule.Head.Name) - - if strings.HasPrefix(ruleName, SkipTestPrefix) { - tr := newResult(rule.Loc(), mod.Package.Path.String(), ruleName, 0*time.Second, nil, nil) + ruleName := ruleName(rule.Head) + if strings.HasPrefix(ruleName, SkipTestPrefix) { // TODO(sr): add test + tr := newResult(rule.Loc(), mod.Package.Path.String(), rule.Head.Ref().String(), 0*time.Second, nil, nil) tr.Skip = true return tr, false } @@ -465,7 +488,7 @@ func (r *Runner) runTest(ctx context.Context, txn storage.Transaction, mod *ast. trace = *bufferTracer } - tr := newResult(rule.Loc(), mod.Package.Path.String(), ruleName, dt, trace, printbuf.Bytes()) + tr := newResult(rule.Loc(), mod.Package.Path.String(), rule.Head.Ref().String(), dt, trace, printbuf.Bytes()) tr.Error = err var stop bool @@ -489,7 +512,7 @@ func (r *Runner) runBenchmark(ctx context.Context, txn storage.Transaction, mod tr := &Result{ Location: rule.Loc(), Package: mod.Package.Path.String(), - Name: string(rule.Head.Name), + Name: rule.Head.Ref().String(), // TODO(sr): test } var stop bool diff --git a/tester/runner_test.go b/tester/runner_test.go index c3679c7cb2..33241fc2be 100644 --- a/tester/runner_test.go +++ b/tester/runner_test.go @@ -75,19 +75,27 @@ func testRun(t *testing.T, conf testRunConfig) map[string]*ast.Module { `, "/b_test.rego": `package bar - test_duplicate { true }`, + test_duplicate { true }`, + "/c_test.rego": `package baz + + a.b.test_duplicate { false } + a.b.test_duplicate { true } + a.b.test_duplicate { true }`, } tests := expectedTestResults{ - {"data.foo", "test_pass"}: {false, false, false}, - {"data.foo", "test_fail"}: {false, true, false}, - {"data.foo", "test_fail_non_bool"}: {false, true, false}, - {"data.foo", "test_duplicate"}: {false, true, false}, - {"data.foo", "test_duplicate#01"}: {false, false, false}, - {"data.foo", "test_duplicate#02"}: {false, false, false}, - {"data.foo", "test_err"}: {true, false, false}, - {"data.foo", "todo_test_skip"}: {false, false, true}, - {"data.bar", "test_duplicate"}: {false, false, false}, + {"data.foo", "test_pass"}: {false, false, false}, + {"data.foo", "test_fail"}: {false, true, false}, + {"data.foo", "test_fail_non_bool"}: {false, true, false}, + {"data.foo", "test_duplicate"}: {false, true, false}, + {"data.foo", "test_duplicate#01"}: {false, false, false}, + {"data.foo", "test_duplicate#02"}: {false, false, false}, + {"data.foo", "test_err"}: {true, false, false}, + {"data.foo", "todo_test_skip"}: {false, false, true}, + {"data.bar", "test_duplicate"}: {false, false, false}, + {"data.baz", "a.b.test_duplicate"}: {false, true, false}, + {"data.baz", "a.b[\"test_duplicate#01\"]"}: {false, false, false}, + {"data.baz", "a.b[\"test_duplicate#02\"]"}: {false, false, false}, } var modules map[string]*ast.Module @@ -186,6 +194,11 @@ func TestRunWithFilterRegex(t *testing.T) { "/b_test.rego": `package bar test_duplicate { true }`, + "/c_test.rego": `package baz + + a.b.test_duplicate { false } + a.b.test_duplicate { true } + a.b.test_duplicate { true }`, } cases := []struct { @@ -197,32 +210,38 @@ func TestRunWithFilterRegex(t *testing.T) { note: "all tests match", regex: ".*", tests: expectedTestResults{ - {"data.foo", "test_pass"}: {false, false, false}, - {"data.foo", "test_fail"}: {false, true, false}, - {"data.foo", "test_fail_non_bool"}: {false, true, false}, - {"data.foo", "test_duplicate"}: {false, true, false}, - {"data.foo", "test_duplicate#01"}: {false, false, false}, - {"data.foo", "test_duplicate#02"}: {false, false, false}, - {"data.foo", "test_err"}: {true, false, false}, - {"data.foo", "todo_test_skip"}: {false, false, true}, - {"data.foo", "todo_test_skip_too"}: {false, false, true}, - {"data.bar", "test_duplicate"}: {false, false, false}, + {"data.foo", "test_pass"}: {false, false, false}, + {"data.foo", "test_fail"}: {false, true, false}, + {"data.foo", "test_fail_non_bool"}: {false, true, false}, + {"data.foo", "test_duplicate"}: {false, true, false}, + {"data.foo", "test_duplicate#01"}: {false, false, false}, + {"data.foo", "test_duplicate#02"}: {false, false, false}, + {"data.foo", "test_err"}: {true, false, false}, + {"data.foo", "todo_test_skip"}: {false, false, true}, + {"data.foo", "todo_test_skip_too"}: {false, false, true}, + {"data.bar", "test_duplicate"}: {false, false, false}, + {"data.baz", "a.b.test_duplicate"}: {false, true, false}, + {"data.baz", "a.b[\"test_duplicate#01\"]"}: {false, false, false}, + {"data.baz", "a.b[\"test_duplicate#02\"]"}: {false, false, false}, }, }, { note: "no filter", regex: "", tests: expectedTestResults{ - {"data.foo", "test_pass"}: {false, false, false}, - {"data.foo", "test_fail"}: {false, true, false}, - {"data.foo", "test_fail_non_bool"}: {false, true, false}, - {"data.foo", "test_duplicate"}: {false, true, false}, - {"data.foo", "test_duplicate#01"}: {false, false, false}, - {"data.foo", "test_duplicate#02"}: {false, false, false}, - {"data.foo", "test_err"}: {true, false, false}, - {"data.foo", "todo_test_skip"}: {false, false, true}, - {"data.foo", "todo_test_skip_too"}: {false, false, true}, - {"data.bar", "test_duplicate"}: {false, false, false}, + {"data.foo", "test_pass"}: {false, false, false}, + {"data.foo", "test_fail"}: {false, true, false}, + {"data.foo", "test_fail_non_bool"}: {false, true, false}, + {"data.foo", "test_duplicate"}: {false, true, false}, + {"data.foo", "test_duplicate#01"}: {false, false, false}, + {"data.foo", "test_duplicate#02"}: {false, false, false}, + {"data.foo", "test_err"}: {true, false, false}, + {"data.foo", "todo_test_skip"}: {false, false, true}, + {"data.foo", "todo_test_skip_too"}: {false, false, true}, + {"data.bar", "test_duplicate"}: {false, false, false}, + {"data.baz", "a.b.test_duplicate"}: {false, true, false}, + {"data.baz", "a.b[\"test_duplicate#01\"]"}: {false, false, false}, + {"data.baz", "a.b[\"test_duplicate#02\"]"}: {false, false, false}, }, }, { @@ -288,6 +307,15 @@ func TestRunWithFilterRegex(t *testing.T) { {"data.bar", "test_duplicate"}: {false, false, false}, }, }, + { + note: "matching ref rule halfways", + regex: "data.baz.a", + tests: expectedTestResults{ + {"data.baz", "a.b.test_duplicate"}: {false, true, false}, + {"data.baz", "a.b[\"test_duplicate#01\"]"}: {false, false, false}, + {"data.baz", "a.b[\"test_duplicate#02\"]"}: {false, false, false}, + }, + }, } test.WithTempFS(files, func(d string) { @@ -436,7 +464,8 @@ func TestRunnerPrintOutput(t *testing.T) { test_a { print("A") } test_b { false; print("B") } - test_c { print("C"); false }`, + test_c { print("C"); false } + p.q.r.test_d { print("D") }`, } ctx := context.Background() @@ -461,9 +490,10 @@ func TestRunnerPrintOutput(t *testing.T) { } exp := map[string]string{ - "test_a": "A\n", - "test_b": "", - "test_c": "C\n", + "test_a": "A\n", + "test_b": "", + "test_c": "C\n", + "p.q.r.test_d": "D\n", } got := map[string]string{} diff --git a/topdown/eval.go b/topdown/eval.go index 55e2c34696..3f824e6e34 100644 --- a/topdown/eval.go +++ b/topdown/eval.go @@ -17,6 +17,7 @@ import ( "github.com/open-policy-agent/opa/topdown/print" "github.com/open-policy-agent/opa/tracing" "github.com/open-policy-agent/opa/types" + "github.com/open-policy-agent/opa/util" ) type evalIterator func(*eval) error @@ -2008,9 +2009,10 @@ func (e evalFunc) partialEvalSupportRule(rule *ast.Rule, path ast.Ref) error { // Type-checking failure means the rule body will never succeed. if e.e.compiler.PassesTypeCheck(plugged) { head := &ast.Head{ - Name: rule.Head.Name, - Value: child.bindings.PlugNamespaced(rule.Head.Value, e.e.caller.bindings), - Args: make([]*ast.Term, len(rule.Head.Args)), + Name: rule.Head.Name, + Reference: rule.Head.Reference, + Value: child.bindings.PlugNamespaced(rule.Head.Value, e.e.caller.bindings), + Args: make([]*ast.Term, len(rule.Head.Args)), } for i, a := range rule.Head.Args { head.Args[i] = child.bindings.PlugNamespaced(a, e.e.caller.bindings) @@ -2089,13 +2091,14 @@ func (e evalTree) next(iter unifyIterator, plugged *ast.Term) error { node = e.node.Child(plugged.Value) if node != nil && len(node.Values) > 0 { r := evalVirtual{ - e: e.e, - ref: e.ref, - plugged: e.plugged, - pos: e.pos, - bindings: e.bindings, - rterm: e.rterm, - rbindings: e.rbindings, + onlyGroundRefs: onlyGroundRefs(node.Values), + e: e.e, + ref: e.ref, + plugged: e.plugged, + pos: e.pos, + bindings: e.bindings, + rterm: e.rterm, + rbindings: e.rbindings, } r.plugged[e.pos] = plugged return r.eval(iter) @@ -2107,6 +2110,16 @@ func (e evalTree) next(iter unifyIterator, plugged *ast.Term) error { return cpy.eval(iter) } +func onlyGroundRefs(values []util.T) bool { + for _, v := range values { + rule := v.(*ast.Rule) + if !rule.Head.Reference.IsGround() { + return false + } + } + return true +} + func (e evalTree) enumerate(iter unifyIterator) error { if e.e.inliningControl.Disabled(e.plugged[:e.pos], true) { @@ -2196,6 +2209,8 @@ func (e evalTree) extent() (*ast.Term, error) { return ast.NewTerm(virtual), nil } +// leaves builds a tree from evaluating the full rule tree extent, by recursing into all +// branches, and building up objects as it goes. func (e evalTree) leaves(plugged ast.Ref, node *ast.TreeNode) (ast.Object, error) { if e.node == nil { @@ -2243,13 +2258,14 @@ func (e evalTree) leaves(plugged ast.Ref, node *ast.TreeNode) (ast.Object, error } type evalVirtual struct { - e *eval - ref ast.Ref - plugged ast.Ref - pos int - bindings *bindings - rterm *ast.Term - rbindings *bindings + onlyGroundRefs bool + e *eval + ref ast.Ref + plugged ast.Ref + pos int + bindings *bindings + rterm *ast.Term + rbindings *bindings } func (e evalVirtual) eval(iter unifyIterator) error { @@ -2266,7 +2282,7 @@ func (e evalVirtual) eval(iter unifyIterator) error { } switch ir.Kind { - case ast.PartialSetDoc: + case ast.MultiValue: eval := evalVirtualPartial{ e: e.e, ref: e.ref, @@ -2279,7 +2295,22 @@ func (e evalVirtual) eval(iter unifyIterator) error { empty: ast.SetTerm(), } return eval.eval(iter) - case ast.PartialObjectDoc: + case ast.SingleValue: + // NOTE(sr): If we allow vars in others than the last position of a ref, we need + // to start reworking things here + if e.onlyGroundRefs { + eval := evalVirtualComplete{ + e: e.e, + ref: e.ref, + plugged: e.plugged, + pos: e.pos, + ir: ir, + bindings: e.bindings, + rterm: e.rterm, + rbindings: e.rbindings, + } + return eval.eval(iter) + } eval := evalVirtualPartial{ e: e.e, ref: e.ref, @@ -2293,17 +2324,7 @@ func (e evalVirtual) eval(iter unifyIterator) error { } return eval.eval(iter) default: - eval := evalVirtualComplete{ - e: e.e, - ref: e.ref, - plugged: e.plugged, - pos: e.pos, - ir: ir, - bindings: e.bindings, - rterm: e.rterm, - rbindings: e.rbindings, - } - return eval.eval(iter) + panic("unreachable") } } @@ -2437,13 +2458,17 @@ func (e evalVirtualPartial) evalOneRulePreUnify(iter unifyIterator, rule *ast.Ru child.traceEnter(rule) var defined bool - err := child.biunify(rule.Head.Key, key, child.bindings, e.bindings, func() error { + headKey := rule.Head.Key + if headKey == nil { + headKey = rule.Head.Reference[len(rule.Head.Reference)-1] + } + err := child.biunify(headKey, key, child.bindings, e.bindings, func() error { defined = true return child.eval(func(child *eval) error { term := rule.Head.Value if term == nil { - term = rule.Head.Key + term = headKey } if hint.key != nil { @@ -2550,7 +2575,8 @@ func (e evalVirtualPartial) partialEvalSupport(iter unifyIterator) error { ok, err := e.partialEvalSupportRule(e.ir.Rules[i], path) if err != nil { return err - } else if ok { + } + if ok { defined = true } } @@ -2669,13 +2695,15 @@ func (e evalVirtualPartial) evalCache(iter unifyIterator) (evalVirtualPartialCac func (e evalVirtualPartial) reduce(head *ast.Head, b *bindings, result *ast.Term) (*ast.Term, bool, error) { var exists bool - key := b.Plug(head.Key) switch v := result.Value.(type) { - case ast.Set: + case ast.Set: // MultiValue + key := b.Plug(head.Key) exists = v.Contains(key) v.Add(key) - case ast.Object: + case ast.Object: // SingleValue + key := head.Reference[len(head.Reference)-1] // NOTE(sr): multiple vars in ref heads need to deal with this better + key = b.Plug(key) value := b.Plug(head.Value) if curr := v.Get(key); curr != nil { if !curr.Equal(value) { @@ -2782,6 +2810,7 @@ func (e evalVirtualComplete) evalValueRule(iter unifyIterator, rule *ast.Rule, p var result *ast.Term err := child.eval(func(child *eval) error { child.traceExit(rule) + result = child.bindings.Plug(rule.Head.Value) if prev != nil { @@ -2794,8 +2823,8 @@ func (e evalVirtualComplete) evalValueRule(iter unifyIterator, rule *ast.Rule, p prev = result e.e.virtualCache.Put(e.plugged[:e.pos+1], result) - term, termbindings := child.bindings.apply(rule.Head.Value) + term, termbindings := child.bindings.apply(rule.Head.Value) err := e.evalTerm(iter, term, termbindings) if err != nil { return err @@ -2840,42 +2869,65 @@ func (e evalVirtualComplete) partialEvalSupport(iter unifyIterator) error { path := e.e.namespaceRef(e.plugged[:e.pos+1]) term := ast.NewTerm(e.e.namespaceRef(e.ref)) - if !e.e.saveSupport.Exists(path) { + var defined bool + if e.e.saveSupport.Exists(path) { + defined = true + } else { for i := range e.ir.Rules { - err := e.partialEvalSupportRule(e.ir.Rules[i], path) + ok, err := e.partialEvalSupportRule(e.ir.Rules[i], path) if err != nil { return err } + if ok { + defined = true + } } if e.ir.Default != nil { - err := e.partialEvalSupportRule(e.ir.Default, path) + ok, err := e.partialEvalSupportRule(e.ir.Default, path) if err != nil { return err } + if ok { + defined = true + } } } + if !defined { + return nil + } + return e.e.saveUnify(term, e.rterm, e.bindings, e.rbindings, iter) } -func (e evalVirtualComplete) partialEvalSupportRule(rule *ast.Rule, path ast.Ref) error { +func (e evalVirtualComplete) partialEvalSupportRule(rule *ast.Rule, path ast.Ref) (bool, error) { child := e.e.child(rule.Body) child.traceEnter(rule) e.e.saveStack.PushQuery(nil) + var defined bool err := child.eval(func(child *eval) error { child.traceExit(rule) + defined = true current := e.e.saveStack.PopQuery() plugged := current.Plug(e.e.caller.bindings) // Skip this rule body if it fails to type-check. // Type-checking failure means the rule body will never succeed. if e.e.compiler.PassesTypeCheck(plugged) { - head := ast.NewHead(rule.Head.Name, nil, child.bindings.PlugNamespaced(rule.Head.Value, e.e.caller.bindings)) + var name ast.Var + switch ref := rule.Head.Ref().GroundPrefix(); len(ref) { + case 1: + name = ref[0].Value.(ast.Var) + default: + s := ref[len(ref)-1].Value.(ast.String) + name = ast.Var(s) + } + head := ast.NewHead(name, nil, child.bindings.PlugNamespaced(rule.Head.Value, e.e.caller.bindings)) if !e.e.inliningControl.shallow { cp := copypropagation.New(head.Vars()). @@ -2895,7 +2947,7 @@ func (e evalVirtualComplete) partialEvalSupportRule(rule *ast.Rule, path ast.Ref return nil }) e.e.saveStack.PopQuery() - return err + return defined, err } func (e evalVirtualComplete) evalTerm(iter unifyIterator, term *ast.Term, termbindings *bindings) error { diff --git a/wasm/src/value.c b/wasm/src/value.c index 07f3c30e9d..d9d03d4035 100644 --- a/wasm/src/value.c +++ b/wasm/src/value.c @@ -1629,8 +1629,13 @@ opa_errc opa_value_remove_path(opa_value *data, opa_value *path) // be found, or of there's no function index leaf when we've run out // of path pieces. int opa_lookup(opa_value *mapping, opa_value *path) { - int path_len = _validate_json_path(path); - if (path_len < 1) + if (path == NULL || opa_value_type(path) != OPA_ARRAY) + { + return 0; + } + + int path_len = opa_value_length(path); + if (path_len == 0) { return 0; }