diff --git a/CHANGELOG.md b/CHANGELOG.md index a67ee8f226..0633393050 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,117 @@ project adheres to [Semantic Versioning](http://semver.org/). ## Unreleased +### Refs in Rule Heads + +With this version of OPA, we can use a shorthand for defining deeply-nested structures +in Rego: + +Before, we had to use multiple packages, and hence multiple files to define a structure +like this: +```json +{ + "method": { + "get": { + "allowed": true + } + "post": { + "allowed": true + } + } +} +``` + +```rego +package method.get +default allowed := false +allowed { ... } +``` + + +```rego +package method.post +default allowed := false +allowed { ... } +``` + +Now, we can define those rules in single package (and file): + +```rego +package method +import future.keywords.if +default get.allowed := false +get.allowed if { ... } + +default post.allowed := false +post.allowed if { ... } +``` + +Note that in this example, the use of the future keyword `if` is mandatory +for backwards-compatibility: without it, `get.allowed` would be interpreted +as `get["allowed"]`, a definition of a partial set rule. + +Currently, variables may only appear in the last part of the rule head: + +```rego +package method +import future.keywords.if + +endpoints[ep].allowed if ep := "/v1/data" # invalid +repos.get.endpoint[x] if x := "/v1/data" # valid +``` + +The valid rule defines this structure: +```json +{ + "method": { + "repos": { + "get": { + "endpoint": { + "/v1/data": true + } + } + } + } +} +``` + +To define a nested key-value pair, we would use + +```rego +package method +import future.keywords.if + +repos.get.endpoint[x] = y if { + x := "/v1/data" + y := "example" +} +``` + +Multi-value rules (previously referred to as "partial set rules") that are +nested like this need to use `contains` future keyword, to differentiate them +from the "last part is a variable" case mentioned just above: + +```rego +package method +import future.keywords.contains + +repos.get.endpoint contains x if x := "/v1/data" +``` + +This rule defines the same structure, but with multiple values instead of a key: +```json +{ + "method": { + "repos": { + "get": { + "endpoint": ["/v1/data"] + } + } + } + } +} +``` + ## 0.45.0 This release contains a mix of bugfixes, optimizations, and new features. @@ -319,7 +430,6 @@ This is a security release fixing the following vulnerabilities: Note that CVE-2022-32190 is most likely not relevant for OPA's usage of net/url. But since these CVEs tend to come up in security assessment tooling regardless, it's better to get it out of the way. - ## 0.43.0 This release contains a number of fixes, enhancements, and performance improvements. diff --git a/ast/annotations_test.go b/ast/annotations_test.go index 322cfc00fd..249d0daf8a 100644 --- a/ast/annotations_test.go +++ b/ast/annotations_test.go @@ -421,6 +421,37 @@ p[v] {v = 2}`, }, }, }, + { + note: "overlapping rule paths (different modules, rule head refs)", + modules: map[string]string{ + "mod1": `package test.a +# METADATA +# title: P1 +b.c.p[v] {v = 1}`, + "mod2": `package test +# METADATA +# title: P2 +a.b.c.p[v] {v = 2}`, + }, + expected: []AnnotationsRef{ + { + Path: MustParseRef("data.test.a.b.c.p"), + Location: &Location{File: "mod1", Row: 4}, + Annotations: &Annotations{ + Scope: "rule", + Title: "P1", + }, + }, + { + Path: MustParseRef("data.test.a.b.c.p"), + Location: &Location{File: "mod2", Row: 4}, + Annotations: &Annotations{ + Scope: "rule", + Title: "P2", + }, + }, + }, + }, } for _, tc := range tests { @@ -737,6 +768,78 @@ p = 1`, }, }, }, + { + note: "multiple subpackages, refs in rule heads", // NOTE(sr): same as above, but last module's rule is `foo.bar.p` in package `root` + modules: map[string]string{ + "root": `# METADATA +# scope: subpackages +# title: ROOT +package root`, + "root.foo": `# METADATA +# title: FOO +# scope: subpackages +package root.foo`, + "root.foo.bar": `# METADATA +# scope: subpackages +# description: subpackages scope applied to rule in other module +# title: BAR-sub + +# METADATA +# title: BAR-other +# description: This metadata is on the path of the queried rule, but shouldn't show up in the result as it's in a different module. +package root.foo.bar + +# METADATA +# scope: document +# description: document scope applied to rule in other module +# title: P-doc +p = 1`, + "rule": `# METADATA +# title: BAR +package root + +# METADATA +# title: P +foo.bar.p = 1`, + }, + moduleToAnalyze: "rule", + ruleOnLineToAnalyze: 7, + expected: []AnnotationsRef{ + { + Path: MustParseRef("data.root.foo.bar.p"), + Location: &Location{File: "rule", Row: 7}, + Annotations: &Annotations{ + Scope: "rule", + Title: "P", + }, + }, + { + Path: MustParseRef("data.root.foo.bar.p"), + Location: &Location{File: "root.foo.bar", Row: 15}, + Annotations: &Annotations{ + Scope: "document", + Title: "P-doc", + Description: "document scope applied to rule in other module", + }, + }, + { + Path: MustParseRef("data.root"), + Location: &Location{File: "rule", Row: 3}, + Annotations: &Annotations{ + Scope: "package", + Title: "BAR", + }, + }, + { + Path: MustParseRef("data.root"), + Location: &Location{File: "root", Row: 4}, + Annotations: &Annotations{ + Scope: "subpackages", + Title: "ROOT", + }, + }, + }, + }, { note: "multiple metadata blocks for single rule (order)", modules: map[string]string{ @@ -824,6 +927,7 @@ p = true`, chain := as.Chain(rule) if len(chain) != len(tc.expected) { + t.Errorf("expected %d elements, got %d:", len(tc.expected), len(chain)) t.Fatalf("chained AnnotationSet\n%v\n\ndoesn't match expected\n\n%v", toJSON(chain), toJSON(tc.expected)) } @@ -1022,7 +1126,7 @@ func TestAnnotations_toObject(t *testing.T) { } func toJSON(v interface{}) string { - b, _ := json.Marshal(v) + b, _ := json.MarshalIndent(v, "", " ") return string(b) } diff --git a/ast/check.go b/ast/check.go index fd35d017aa..b671e82999 100644 --- a/ast/check.go +++ b/ast/check.go @@ -200,7 +200,7 @@ func (tc *typeChecker) checkRule(env *TypeEnv, as *AnnotationSet, rule *Rule) { cpy, err := tc.CheckBody(env, rule.Body) env = env.next - path := rule.Path() + path := rule.Ref() if len(err) > 0 { // if the rule/function contains an error, add it to the type env so @@ -235,23 +235,28 @@ func (tc *typeChecker) checkRule(env *TypeEnv, as *AnnotationSet, rule *Rule) { tpe = types.Or(exist, f) } else { - switch rule.Head.DocKind() { - case CompleteDoc: + switch rule.Head.RuleKind() { + case SingleValue: typeV := cpy.Get(rule.Head.Value) - if typeV != nil { - exist := env.tree.Get(path) - tpe = types.Or(typeV, exist) - } - case PartialObjectDoc: - typeK := cpy.Get(rule.Head.Key) - typeV := cpy.Get(rule.Head.Value) - if typeK != nil && typeV != nil { - exist := env.tree.Get(path) - typeV = types.Or(types.Values(exist), typeV) - typeK = types.Or(types.Keys(exist), typeK) - tpe = types.NewObject(nil, types.NewDynamicProperty(typeK, typeV)) + if last := path[len(path)-1]; !last.IsGround() { + + // e.g. store object[string: whatever] at data.p.q.r, not data.p.q.r[x] + path = path.GroundPrefix() + + typeK := cpy.Get(last) + if typeK != nil && typeV != nil { + exist := env.tree.Get(path) + typeV = types.Or(types.Values(exist), typeV) + typeK = types.Or(types.Keys(exist), typeK) + tpe = types.NewObject(nil, types.NewDynamicProperty(typeK, typeV)) + } + } else { + if typeV != nil { + exist := env.tree.Get(path) + tpe = types.Or(typeV, exist) + } } - case PartialSetDoc: + case MultiValue: typeK := cpy.Get(rule.Head.Key) if typeK != nil { exist := env.tree.Get(path) @@ -652,72 +657,57 @@ func (rc *refChecker) checkRef(curr *TypeEnv, node *typeTreeNode, ref Ref, idx i head := ref[idx] - // Handle constant ref operands, i.e., strings or the ref head. - if _, ok := head.Value.(String); ok || idx == 0 { - - child := node.Child(head.Value) - if child == nil { - - if curr.next != nil { - next := curr.next - return rc.checkRef(next, next.tree, ref, 0) - } - - if RootDocumentNames.Contains(ref[0]) { - return rc.checkRefLeaf(types.A, ref, 1) - } - - return rc.checkRefLeaf(types.A, ref, 0) - } - - if child.Leaf() { - return rc.checkRefLeaf(child.Value(), ref, idx+1) + // NOTE(sr): as long as package statements are required, this isn't possible: + // the shortest possible rule ref is data.a.b (b is idx 2), idx 1 and 2 need to + // be strings or vars. + if idx == 1 || idx == 2 { + switch head.Value.(type) { + case Var, String: // OK + default: + have := rc.env.Get(head.Value) + return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, have, types.S, getOneOfForNode(node)) } - - return rc.checkRef(curr, child, ref, idx+1) } - // Handle dynamic ref operands. - switch value := head.Value.(type) { - - case Var: - - if exist := rc.env.Get(value); exist != nil { - if !unifies(types.S, exist) { - return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, types.S, getOneOfForNode(node)) + if v, ok := head.Value.(Var); ok && idx != 0 { + tpe := types.Keys(rc.env.getRefRecExtent(node)) + if exist := rc.env.Get(v); exist != nil { + if !unifies(tpe, exist) { + return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, tpe, getOneOfForNode(node)) } } else { - rc.env.tree.PutOne(value, types.S) + rc.env.tree.PutOne(v, tpe) } + } - case Ref: + child := node.Child(head.Value) + if child == nil { + // NOTE(sr): idx is reset on purpose: we start over + switch { + case curr.next != nil: + next := curr.next + return rc.checkRef(next, next.tree, ref, 0) - exist := rc.env.Get(value) - if exist == nil { - // If ref type is unknown, an error will already be reported so - // stop here. - return nil - } + case RootDocumentNames.Contains(ref[0]): + if idx != 0 { + node.Children().Iter(func(_, child util.T) bool { + _ = rc.checkRef(curr, child.(*typeTreeNode), ref, idx+1) // ignore error + return false + }) + return nil + } + return rc.checkRefLeaf(types.A, ref, 1) - if !unifies(types.S, exist) { - return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, exist, types.S, getOneOfForNode(node)) + default: + return rc.checkRefLeaf(types.A, ref, 0) } - - // Catch other ref operand types here. Non-leaf nodes must be referred to - // with string values. - default: - return newRefErrInvalid(ref[0].Location, rc.varRewriter(ref), idx, nil, types.S, getOneOfForNode(node)) } - // Run checking on remaining portion of the ref. Note, since the ref - // potentially refers to data for which no type information exists, - // checking should never fail. - node.Children().Iter(func(_, child util.T) bool { - _ = rc.checkRef(curr, child.(*typeTreeNode), ref, idx+1) // ignore error - return false - }) + if child.Leaf() { + return rc.checkRefLeaf(child.Value(), ref, idx+1) + } - return nil + return rc.checkRef(curr, child, ref, idx+1) } func (rc *refChecker) checkRefLeaf(tpe types.Type, ref Ref, idx int) *Error { diff --git a/ast/check_test.go b/ast/check_test.go index 33a51ca063..56162c5a97 100644 --- a/ast/check_test.go +++ b/ast/check_test.go @@ -344,6 +344,12 @@ func TestCheckInferenceRules(t *testing.T) { {`number_key`, `q[x] = y { a = ["a", "b"]; y = a[x] }`}, {`non_leaf`, `p[x] { data.prefix.i[x][_] }`}, } + ruleset2 := [][2]string{ + {`ref_rule_single`, `p.q.r { true }`}, + {`ref_rule_single_with_number_key`, `p.q[3] { true }`}, + {`ref_regression_array_key`, + `walker[[p, v]] = o { l = input; walk(l, k); [p, v] = k; o = {} }`}, + } tests := []struct { note string @@ -465,6 +471,37 @@ func TestCheckInferenceRules(t *testing.T) { {"non-leaf", ruleset1, "data.non_leaf.p", types.NewSet( types.S, )}, + + {"ref-rules single value, full ref", ruleset2, "data.ref_rule_single.p.q.r", types.B}, + {"ref-rules single value, prefix", ruleset2, "data.ref_rule_single.p", + types.NewObject( + []*types.StaticProperty{{ + Key: "q", Value: types.NewObject( + []*types.StaticProperty{{Key: "r", Value: types.B}}, + types.NewDynamicProperty(types.S, types.A), + ), + }}, + types.NewDynamicProperty(types.S, types.A), + )}, + + {"ref-rules single value, number key, full ref", ruleset2, "data.ref_rule_single_with_number_key.p.q[3]", types.B}, + {"ref-rules single value, number key, prefix", ruleset2, "data.ref_rule_single_with_number_key.p", + types.NewObject( + []*types.StaticProperty{{ + Key: "q", Value: types.NewObject( + []*types.StaticProperty{{Key: json.Number("3"), Value: types.B}}, + types.NewDynamicProperty(types.S, types.A), + ), + }}, + types.NewDynamicProperty(types.S, types.A), + )}, + + {"ref_regression_array_key", ruleset2, "data.ref_regression_array_key.walker", + types.NewObject( + nil, + types.NewDynamicProperty(types.NewArray([]types.Type{types.NewArray(types.A, types.A), types.A}, nil), + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A))), + )}, } for _, tc := range tests { @@ -489,7 +526,7 @@ func TestCheckInferenceRules(t *testing.T) { ref := MustParseRef(tc.ref) checker := newTypeChecker() - env, err := checker.CheckTypes(nil, elems, nil) + env, err := checker.CheckTypes(newTypeChecker().Env(map[string]*Builtin{"walk": BuiltinMap["walk"]}), elems, nil) if err != nil { t.Fatalf("Unexpected error %v:", err) @@ -512,6 +549,87 @@ func TestCheckInferenceRules(t *testing.T) { } +func TestCheckInferenceOverlapWithRules(t *testing.T) { + ruleset1 := [][2]string{ + {`prefix.i.j.k`, `p = 1 { true }`}, + {`prefix.i.j.k`, `p = "foo" { true }`}, + } + tests := []struct { + note string + rules [][2]string + ref string + expected types.Type // ref's type + query string + extra map[Var]types.Type + }{ + { + note: "non-leaf, extra vars", + rules: ruleset1, + ref: "data.prefix.i.j[k]", + expected: types.A, + query: "data.prefix.i.j[k][b]", + extra: map[Var]types.Type{ + Var("k"): types.S, + Var("b"): types.S, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + var elems []util.T + + // Convert test rules into rule slice for "warmup" call. + for i := range tc.rules { + pkg := MustParsePackage(`package ` + tc.rules[i][0]) + rule := MustParseRule(tc.rules[i][1]) + module := &Module{ + Package: pkg, + Rules: []*Rule{rule}, + } + rule.Module = module + elems = append(elems, rule) + for next := rule.Else; next != nil; next = next.Else { + next.Module = module + elems = append(elems, next) + } + } + + ref := MustParseRef(tc.ref) + checker := newTypeChecker() + env, err := checker.CheckTypes(nil, elems, nil) + if err != nil { + t.Fatalf("Unexpected error %v:", err) + } + + result := env.Get(ref) + if tc.expected == nil { + if result != nil { + t.Errorf("Expected %v type to be unset but got: %v", ref, result) + } + } else { + if result == nil { + t.Errorf("Expected to infer %v => %v but got nil", ref, tc.expected) + } else if types.Compare(tc.expected, result) != 0 { + t.Errorf("Expected to infer %v => %v but got %v", ref, tc.expected, result) + } + } + + body := MustParseBody(tc.query) + env, err = checker.CheckBody(env, body) + if len(err) != 0 { + t.Fatalf("Unexpected error: %v", err) + } + for ex, exp := range tc.extra { + act := env.Get(ex) + if types.Compare(act, exp) != 0 { + t.Errorf("Expected to infer extra %v => %v but got %v", ex, exp, act) + } + } + }) + } +} + func TestCheckErrorSuppression(t *testing.T) { query := `arr = [1,2,3]; arr[0].deadbeef = 1` @@ -642,7 +760,7 @@ func TestCheckBuiltinErrors(t *testing.T) { {"objects-any", `fake_builtin_2({"a": a, "c": c})`}, {"objects-bad-input", `sum({"a": 1, "b": 2}, x)`}, {"sets-any", `sum({1,2,"3",4}, x)`}, - {"virtual-ref", `plus(data.test.p, data.deabeef, 0)`}, + {"virtual-ref", `plus(data.test.p, data.coffee, 0)`}, } env := newTestEnv([]string{ @@ -781,6 +899,7 @@ func TestCheckRefErrInvalid(t *testing.T) { env := newTestEnv([]string{ `p { true }`, `q = {"foo": 1, "bar": 2} { true }`, + `a.b.c[3] = x { x = {"x": {"y": 2}} }`, }) tests := []struct { @@ -799,7 +918,7 @@ func TestCheckRefErrInvalid(t *testing.T) { pos: 2, have: types.N, want: types.S, - oneOf: []Value{String("p"), String("q")}, + oneOf: []Value{String("a"), String("p"), String("q")}, }, { note: "bad non-leaf ref", @@ -808,7 +927,7 @@ func TestCheckRefErrInvalid(t *testing.T) { pos: 2, have: types.N, want: types.S, - oneOf: []Value{String("p"), String("q")}, + oneOf: []Value{String("a"), String("p"), String("q")}, }, { note: "bad leaf ref", @@ -819,6 +938,24 @@ func TestCheckRefErrInvalid(t *testing.T) { want: types.S, oneOf: []Value{String("bar"), String("foo")}, }, + { + note: "bad ref hitting last term", + query: `x = true; data.test.a.b.c[x][_]`, + ref: `data.test.a.b.c[x][_]`, + pos: 5, + have: types.B, + want: types.Any{types.N, types.S}, + oneOf: []Value{Number("3")}, + }, + { + note: "bad ref hitting dynamic part", + query: `s = true; data.test.a.b.c[3].x[s][_] = _`, + ref: `data.test.a.b.c[3].x[s][_]`, + pos: 7, + have: types.B, + want: types.S, + oneOf: []Value{String("y")}, + }, { note: "bad leaf var", query: `x = 1; data.test.q[x]`, @@ -851,12 +988,25 @@ func TestCheckRefErrInvalid(t *testing.T) { oneOf: []Value{String("a"), String("c")}, }, { + // NOTE(sr): Thins one and the next are special: it cannot work with ref heads, either, since we need at + // least ONE string term after data.test: a module needs a package line, and the shortest head ref + // possible is thus data.x.y. note: "bad non-leaf value", query: `data.test[1]`, ref: "data.test[1]", pos: 2, + have: types.N, + want: types.S, + oneOf: []Value{String("a"), String("p"), String("q")}, + }, + { + note: "bad non-leaf value (package)", // See note above ^^ + query: `data[1]`, + ref: "data[1]", + pos: 1, + have: types.N, want: types.S, - oneOf: []Value{String("p"), String("q")}, + oneOf: []Value{String("test")}, }, { note: "composite ref operand", diff --git a/ast/compile.go b/ast/compile.go index a650033b13..81b3d6110c 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -64,6 +64,7 @@ type Compiler struct { // p[1] { true } // p[2] { true } // q = true + // a.b.c = 3 // // root // | @@ -74,6 +75,12 @@ type Compiler struct { // +--- p (2 rules) // | // +--- q (1 rule) + // | + // +--- a + // | + // +--- b + // | + // +--- c (1 rule) RuleTree *TreeNode // Graph contains dependencies between rules. An edge (u,v) is added to the @@ -265,15 +272,16 @@ func NewCompiler() *Compiler { // load additional modules. If any stages run before resolution, they // need to be re-run after resolution. {"ResolveRefs", "compile_stage_resolve_refs", c.resolveAllRefs}, - {"CheckKeywordOverrides", "compile_stage_check_keyword_overrides", c.checkKeywordOverrides}, - {"CheckDuplicateImports", "compile_stage_check_duplicate_imports", c.checkDuplicateImports}, - {"RemoveImports", "compile_stage_remove_imports", c.removeImports}, - {"SetModuleTree", "compile_stage_set_module_tree", c.setModuleTree}, - {"SetRuleTree", "compile_stage_set_rule_tree", c.setRuleTree}, // The local variable generator must be initialized after references are // resolved and the dynamic module loader has run but before subsequent // stages that need to generate variables. {"InitLocalVarGen", "compile_stage_init_local_var_gen", c.initLocalVarGen}, + {"RewriteRuleHeadRefs", "compile_stage_rewrite_rule_head_refs", c.rewriteRuleHeadRefs}, + {"CheckKeywordOverrides", "compile_stage_check_keyword_overrides", c.checkKeywordOverrides}, + {"CheckDuplicateImports", "compile_stage_check_duplicate_imports", c.checkDuplicateImports}, + {"RemoveImports", "compile_stage_remove_imports", c.removeImports}, + {"SetModuleTree", "compile_stage_set_module_tree", c.setModuleTree}, + {"SetRuleTree", "compile_stage_set_rule_tree", c.setRuleTree}, // depends on RewriteRuleHeadRefs {"RewriteLocalVars", "compile_stage_rewrite_local_vars", c.rewriteLocalVars}, {"CheckVoidCalls", "compile_stage_check_void_calls", c.checkVoidCalls}, {"RewritePrintCalls", "compile_stage_rewrite_print_calls", c.rewritePrintCalls}, @@ -570,9 +578,10 @@ func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) { return rules } -func extractRules(s []util.T) (rules []*Rule) { - for _, r := range s { - rules = append(rules, r.(*Rule)) +func extractRules(s []util.T) []*Rule { + rules := make([]*Rule, len(s)) + for i := range s { + rules[i] = s[i].(*Rule) } return rules } @@ -768,13 +777,30 @@ func (c *Compiler) buildRuleIndices() { if len(node.Values) == 0 { return false } + rules := extractRules(node.Values) + hasNonGroundKey := false + for _, r := range rules { + if ref := r.Head.Ref(); len(ref) > 1 { + if !ref[len(ref)-1].IsGround() { + hasNonGroundKey = true + } + } + } + if hasNonGroundKey { + // collect children: as of now, this cannot go deeper than one level, + // so we grab those, and abort the DepthFirst processing for this branch + for _, n := range node.Children { + rules = append(rules, extractRules(n.Values)...) + } + } + index := newBaseDocEqIndex(func(ref Ref) bool { return isVirtual(c.RuleTree, ref.GroundPrefix()) }) - if rules := extractRules(node.Values); index.Build(rules) { - c.ruleIndices.Put(rules[0].Path(), index) + if index.Build(rules) { + c.ruleIndices.Put(rules[0].Ref().GroundPrefix(), index) } - return false + return hasNonGroundKey // currently, we don't allow those branches to go deeper }) } @@ -811,7 +837,7 @@ func (c *Compiler) checkRecursion() { func (c *Compiler) checkSelfPath(loc *Location, eq func(a, b util.T) bool, a, b util.T) { tr := NewGraphTraversal(c.Graph) if p := util.DFSPath(tr, eq, a, b); len(p) > 0 { - n := []string{} + n := make([]string, 0, len(p)) for _, x := range p { n = append(n, astNodeToString(x)) } @@ -820,40 +846,69 @@ func (c *Compiler) checkSelfPath(loc *Location, eq func(a, b util.T) bool, a, b } func astNodeToString(x interface{}) string { - switch x := x.(type) { - case *Rule: - return string(x.Head.Name) - default: - panic("not reached") - } + return x.(*Rule).Ref().String() } // checkRuleConflicts ensures that rules definitions are not in conflict. func (c *Compiler) checkRuleConflicts() { + rw := rewriteVarsInRef(c.RewrittenVars) + c.RuleTree.DepthFirst(func(node *TreeNode) bool { if len(node.Values) == 0 { - return false + return false // go deeper } - kinds := map[DocKind]struct{}{} + kinds := map[RuleKind]struct{}{} defaultRules := 0 arities := map[int]struct{}{} + name := "" + var singleValueConflicts []Ref for _, rule := range node.Values { r := rule.(*Rule) - kinds[r.Head.DocKind()] = struct{}{} + ref := r.Ref() + name = rw(ref.Copy()).String() // varRewriter operates in-place + kinds[r.Head.RuleKind()] = struct{}{} arities[len(r.Head.Args)] = struct{}{} if r.Default { defaultRules++ } + + // Single-value rules may not have any other rules in their extent: these pairs are invalid: + // + // data.p.q.r { true } # data.p.q is { "r": true } + // data.p.q.r.s { true } + // + // data.p.q[r] { r := input.r } # data.p.q could be { "r": true } + // data.p.q.r.s { true } + + // But this is allowed: + // data.p.q[r] = 1 { r := "r" } + // data.p.q.s = 2 + + if r.Head.RuleKind() == SingleValue && len(node.Children) > 0 { + if len(ref) > 1 && !ref[len(ref)-1].IsGround() { // p.q[x] and p.q.s.t => check grandchildren + for _, c := range node.Children { + if len(c.Children) > 0 { + singleValueConflicts = node.flattenChildren() + break + } + } + } else { // p.q.s and p.q.s.t => any children are in conflict + singleValueConflicts = node.flattenChildren() + } + } } - name := Var(node.Key.(String)) + switch { + case singleValueConflicts != nil: + c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "single-value rule %v conflicts with %v", name, singleValueConflicts)) - if len(kinds) > 1 || len(arities) > 1 { - c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules named %v found", name)) - } else if defaultRules > 1 { - c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multiple default rules named %s found", name)) + case len(kinds) > 1 || len(arities) > 1: + c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules %v found", name)) + + case defaultRules > 1: + c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multiple default rules %s found", name)) } return false @@ -865,13 +920,21 @@ func (c *Compiler) checkRuleConflicts() { } } + // NOTE(sr): depthfirst might better use sorted for stable errs? c.ModuleTree.DepthFirst(func(node *ModuleTreeNode) bool { for _, mod := range node.Modules { for _, rule := range mod.Rules { - if childNode, ok := node.Children[String(rule.Head.Name)]; ok { + ref := rule.Head.Ref().GroundPrefix() + childNode, tail := node.find(ref) + if childNode != nil { for _, childMod := range childNode.Modules { - msg := fmt.Sprintf("%v conflicts with rule defined at %v", childMod.Package, rule.Loc()) - c.err(NewError(TypeErr, mod.Package.Loc(), msg)) + if childMod.Equal(mod) { + continue // don't self-conflict + } + if len(tail) == 0 { + msg := fmt.Sprintf("%v conflicts with rule %v defined at %v", childMod.Package, rule.Head.Ref(), rule.Loc()) + c.err(NewError(TypeErr, mod.Package.Loc(), msg)) + } } } } @@ -1363,21 +1426,29 @@ func (c *Compiler) getExports() *util.HashMap { for _, name := range c.sorted { mod := c.Modules[name] - rv, ok := rules.Get(mod.Package.Path) - if !ok { - rv = []Var{} - } - rvs := rv.([]Var) for _, rule := range mod.Rules { - rvs = append(rvs, rule.Head.Name) + hashMapAdd(rules, mod.Package.Path, rule.Head.Ref().GroundPrefix()) } - rules.Put(mod.Package.Path, rvs) } return rules } +func hashMapAdd(rules *util.HashMap, pkg, rule Ref) { + prev, ok := rules.Get(pkg) + if !ok { + rules.Put(pkg, []Ref{rule}) + return + } + for _, p := range prev.([]Ref) { + if p.Equal(rule) { + return + } + } + rules.Put(pkg, append(prev.([]Ref), rule)) +} + func (c *Compiler) GetAnnotationSet() *AnnotationSet { return c.annotationSet } @@ -1450,6 +1521,15 @@ func checkKeywordOverrides(node interface{}, strict bool) Errors { // p[x] { bar[_] = x } // // The reference "bar[_]" would be resolved to "data.foo.bar[_]". +// +// Ref rules are resolved, too: +// +// package a.b +// q { c.d.e == 1 } +// c.d[e] := 1 if e := "e" +// +// The reference "c.d.e" would be resolved to "data.a.b.c.d.e". + func (c *Compiler) resolveAllRefs() { rules := c.getExports() @@ -1457,9 +1537,9 @@ func (c *Compiler) resolveAllRefs() { for _, name := range c.sorted { mod := c.Modules[name] - var ruleExports []Var + var ruleExports []Ref if x, ok := rules.Get(mod.Package.Path); ok { - ruleExports = x.([]Var) + ruleExports = x.([]Ref) } globals := getGlobals(mod.Package, ruleExports, mod.Imports) @@ -1542,6 +1622,52 @@ func (c *Compiler) rewriteExprTerms() { } } +func (c *Compiler) rewriteRuleHeadRefs() { + f := newEqualityFactory(c.localvargen) + for _, name := range c.sorted { + WalkRules(c.Modules[name], func(rule *Rule) bool { + + ref := rule.Head.Ref() + // NOTE(sr): We're backfilling Refs here -- all parser code paths would have them, but + // it's possible to construct Module{} instances from Golang code, so we need + // to accommodate for that, too. + if len(rule.Head.Reference) == 0 { + rule.Head.Reference = ref + } + + for i := 1; i < len(ref); i++ { + // NOTE(sr): In the first iteration, non-string values in the refs are forbidden + // except for the last position, e.g. + // OK: p.q.r[s] + // NOT OK: p[q].r.s + // TODO(sr): This is stricter than necessary. We could allow any non-var values there, + // but we'll also have to adjust the type tree, for example. + if i != len(ref)-1 { // last + if _, ok := ref[i].Value.(String); !ok { + c.err(NewError(TypeErr, rule.Loc(), "rule head must only contain string terms (except for last): %v", ref[i])) + continue + } + } + + // Rewrite so that any non-scalar elements that in the last position of + // the rule are vars: + // p.q.r[y.z] { ... } => p.q.r[__local0__] { __local0__ = y.z } + // because that's what the RuleTree knows how to deal with. + if _, ok := ref[i].Value.(Var); !ok && !IsScalar(ref[i].Value) { + expr := f.Generate(ref[i]) + if i == len(ref)-1 && rule.Head.Key.Equal(ref[i]) { + rule.Head.Key = expr.Operand(0) + } + rule.Head.Reference[i] = expr.Operand(0) + rule.Body.Append(expr) + } + } + + return true + }) + } +} + func (c *Compiler) checkVoidCalls() { for _, name := range c.sorted { mod := c.Modules[name] @@ -2044,6 +2170,9 @@ func (c *Compiler) rewriteLocalVars() { // Rewrite assignments in body. used := NewVarSet() + last := rule.Head.Ref()[len(rule.Head.Ref())-1] + used.Update(last.Vars()) + if rule.Head.Key != nil { used.Update(rule.Head.Key.Vars()) } @@ -2076,6 +2205,9 @@ func (c *Compiler) rewriteLocalVars() { rule.Head.Args[i], _ = transformTerm(localXform, rule.Head.Args[i]) } + for i := 1; i < len(rule.Head.Ref()); i++ { + rule.Head.Reference[i], _ = transformTerm(localXform, rule.Head.Ref()[i]) + } if rule.Head.Key != nil { rule.Head.Key, _ = transformTerm(localXform, rule.Head.Key) } @@ -2406,10 +2538,10 @@ func (qc *queryCompiler) resolveRefs(qctx *QueryContext, body Body) (Body, error pkg = &Package{Path: RefTerm(VarTerm("")).Value.(Ref)} } if pkg != nil { - var ruleExports []Var + var ruleExports []Ref rules := qc.compiler.getExports() if exist, ok := rules.Get(pkg.Path); ok { - ruleExports = exist.([]Var) + ruleExports = exist.([]Ref) } globals = getGlobals(qctx.Package, ruleExports, qctx.Imports) @@ -2792,6 +2924,16 @@ type ModuleTreeNode struct { Hide bool } +func (n *ModuleTreeNode) String() string { + var rules []string + for _, m := range n.Modules { + for _, r := range m.Rules { + rules = append(rules, r.Head.String()) + } + } + return fmt.Sprintf("", n.Key, n.Children, rules, n.Hide) +} + // NewModuleTree returns a new ModuleTreeNode that represents the root // of the module tree populated with the given modules. func NewModuleTree(mods map[string]*Module) *ModuleTreeNode { @@ -2836,13 +2978,43 @@ func (n *ModuleTreeNode) Size() int { return s } +// Child returns n's child with key k. +func (n *ModuleTreeNode) child(k Value) *ModuleTreeNode { + switch k.(type) { + case String, Var: + return n.Children[k] + } + return nil +} + +// Find dereferences ref along the tree. ref[0] is converted to a String +// for convenience. +func (n *ModuleTreeNode) find(ref Ref) (*ModuleTreeNode, Ref) { + if v, ok := ref[0].Value.(Var); ok { + ref = Ref{StringTerm(string(v))}.Concat(ref[1:]) + } + node := n + for i, r := range ref { + next := node.child(r.Value) + if next == nil { + tail := make(Ref, len(ref)-i) + tail[0] = VarTerm(string(ref[i].Value.(String))) + copy(tail[1:], ref[i+1:]) + return node, tail + } + node = next + } + return node, nil +} + // DepthFirst performs a depth-first traversal of the module tree rooted at n. // If f returns true, traversal will not continue to the children of n. -func (n *ModuleTreeNode) DepthFirst(f func(node *ModuleTreeNode) bool) { - if !f(n) { - for _, node := range n.Children { - node.DepthFirst(f) - } +func (n *ModuleTreeNode) DepthFirst(f func(*ModuleTreeNode) bool) { + if f(n) { + return + } + for _, node := range n.Children { + node.DepthFirst(f) } } @@ -2856,49 +3028,56 @@ type TreeNode struct { Hide bool } +func (n *TreeNode) String() string { + return fmt.Sprintf("", n.Key, n.Values, n.Sorted, n.Hide) +} + // NewRuleTree returns a new TreeNode that represents the root // of the rule tree populated with the given rules. func NewRuleTree(mtree *ModuleTreeNode) *TreeNode { - - ruleSets := map[String][]util.T{} - - // Build rule sets for this package. - for _, mod := range mtree.Modules { - for _, rule := range mod.Rules { - key := String(rule.Head.Name) - ruleSets[key] = append(ruleSets[key], rule) - } + root := TreeNode{ + Key: mtree.Key, } - // Each rule set becomes a leaf node. - children := map[Value]*TreeNode{} - sorted := make([]Value, 0, len(ruleSets)) - - for key, rules := range ruleSets { - sorted = append(sorted, key) - children[key] = &TreeNode{ - Key: key, - Children: nil, - Values: rules, + mtree.DepthFirst(func(m *ModuleTreeNode) bool { + for _, mod := range m.Modules { + if len(mod.Rules) == 0 { + root.add(mod.Package.Path, nil) + } + for _, rule := range mod.Rules { + root.add(rule.Ref().GroundPrefix(), rule) + } } - } + return false + }) - // Each module in subpackage becomes child node. - for key, child := range mtree.Children { - sorted = append(sorted, key) - children[child.Key] = NewRuleTree(child) + // ensure that data.system's TreeNode is hidden + node, tail := root.find(DefaultRootRef.Append(NewTerm(SystemDocumentKey))) + if len(tail) == 0 { // found + node.Hide = true } - sort.Slice(sorted, func(i, j int) bool { - return sorted[i].Compare(sorted[j]) < 0 + root.DepthFirst(func(x *TreeNode) bool { + x.sort() + return false }) - return &TreeNode{ - Key: mtree.Key, - Values: nil, - Children: children, - Sorted: sorted, - Hide: mtree.Hide, + return &root +} + +func (n *TreeNode) add(path Ref, rule *Rule) { + node, tail := n.find(path) + if len(tail) > 0 { + sub := treeNodeFromRef(tail, rule) + if node.Children == nil { + node.Children = make(map[Value]*TreeNode, 1) + } + node.Children[sub.Key] = sub + node.Sorted = append(node.Sorted, sub.Key) + } else { + if rule != nil { + node.Values = append(node.Values, rule) + } } } @@ -2914,33 +3093,95 @@ func (n *TreeNode) Size() int { // Child returns n's child with key k. func (n *TreeNode) Child(k Value) *TreeNode { switch k.(type) { - case String, Var: + case Ref: + return nil + default: return n.Children[k] } - return nil } // Find dereferences ref along the tree func (n *TreeNode) Find(ref Ref) *TreeNode { node := n for _, r := range ref { - child := node.Child(r.Value) - if child == nil { + node = node.Child(r.Value) + if node == nil { return nil } - node = child } return node } +func (n *TreeNode) find(ref Ref) (*TreeNode, Ref) { + node := n + for i := range ref { + next := node.Child(ref[i].Value) + if next == nil { + tail := make(Ref, len(ref)-i) + copy(tail, ref[i:]) + return node, tail + } + node = next + } + return node, nil +} + // DepthFirst performs a depth-first traversal of the rule tree rooted at n. If // f returns true, traversal will not continue to the children of n. -func (n *TreeNode) DepthFirst(f func(node *TreeNode) bool) { - if !f(n) { - for _, node := range n.Children { - node.DepthFirst(f) +func (n *TreeNode) DepthFirst(f func(*TreeNode) bool) { + if f(n) { + return + } + for _, node := range n.Children { + node.DepthFirst(f) + } +} + +func (n *TreeNode) sort() { + sort.Slice(n.Sorted, func(i, j int) bool { + return n.Sorted[i].Compare(n.Sorted[j]) < 0 + }) +} + +func treeNodeFromRef(ref Ref, rule *Rule) *TreeNode { + depth := len(ref) - 1 + key := ref[depth].Value + node := &TreeNode{ + Key: key, + Children: nil, + } + if rule != nil { + node.Values = []util.T{rule} + } + + for i := len(ref) - 2; i >= 0; i-- { + key := ref[i].Value + node = &TreeNode{ + Key: key, + Children: map[Value]*TreeNode{ref[i+1].Value: node}, + Sorted: []Value{ref[i+1].Value}, } } + return node +} + +// flattenChildren flattens all children's rule refs into a sorted array. +func (n *TreeNode) flattenChildren() []Ref { + ret := newRefSet() + for _, sub := range n.Children { // we only want the children, so don't use n.DepthFirst() right away + sub.DepthFirst(func(x *TreeNode) bool { + for _, r := range x.Values { + rule := r.(*Rule) + ret.AddPrefix(rule.Ref()) + } + return false + }) + } + + sort.Slice(ret.s, func(i, j int) bool { + return ret.s[i].Compare(ret.s[j]) < 0 + }) + return ret.s } // Graph represents the graph of dependencies between rules. @@ -3554,15 +3795,14 @@ func (l *localVarGenerator) Generate() Var { } } -func getGlobals(pkg *Package, rules []Var, imports []*Import) map[Var]*usedRef { +func getGlobals(pkg *Package, rules []Ref, imports []*Import) map[Var]*usedRef { - globals := map[Var]*usedRef{} + globals := make(map[Var]*usedRef, len(rules)) // NB: might grow bigger with imports // Populate globals with exports within the package. - for _, v := range rules { - global := append(Ref{}, pkg.Path...) - global = append(global, &Term{Value: String(v)}) - globals[v] = &usedRef{ref: global} + for _, ref := range rules { + v := ref[0].Value.(Var) + globals[v] = &usedRef{ref: pkg.Path.Append(StringTerm(string(v)))} } // Populate globals with imports. @@ -3670,6 +3910,10 @@ func resolveRefsInRule(globals map[Var]*usedRef, rule *Rule) error { ignore.Push(vars) ignore.Push(declaredVars(rule.Body)) + ref := rule.Head.Ref() + for i := 1; i < len(ref); i++ { + ref[i] = resolveRefsInTerm(globals, ignore, ref[i]) + } if rule.Head.Key != nil { rule.Head.Key = resolveRefsInTerm(globals, ignore, rule.Head.Key) } @@ -4985,7 +5229,7 @@ func isBuiltinRefOrVar(bs map[string]*Builtin, unsafeBuiltinsMap map[string]stru } func isVirtual(node *TreeNode, ref Ref) bool { - for i := 0; i < len(ref); i++ { + for i := range ref { child := node.Child(ref[i].Value) if child == nil { return false @@ -5095,3 +5339,57 @@ func rewriteVarsInRef(vars ...map[Var]Var) varRewriter { return i.(Ref) } } + +// NOTE(sr): This is duplicated with compile/compile.go; but moving it into another location +// would cause a circular dependency -- the refSet definition needs ast.Ref. If we make it +// public in the ast package, the compile package could take it from there, but it would also +// increase our public interface. Let's reconsider if we need it in a third place. +type refSet struct { + s []Ref +} + +func newRefSet(x ...Ref) *refSet { + result := &refSet{} + for i := range x { + result.AddPrefix(x[i]) + } + return result +} + +// ContainsPrefix returns true if r is prefixed by any of the existing refs in the set. +func (rs *refSet) ContainsPrefix(r Ref) bool { + for i := range rs.s { + if r.HasPrefix(rs.s[i]) { + return true + } + } + return false +} + +// AddPrefix inserts r into the set if r is not prefixed by any existing +// refs in the set. If any existing refs are prefixed by r, those existing +// refs are removed. +func (rs *refSet) AddPrefix(r Ref) { + if rs.ContainsPrefix(r) { + return + } + cpy := []Ref{r} + for i := range rs.s { + if !rs.s[i].HasPrefix(r) { + cpy = append(cpy, rs.s[i]) + } + } + rs.s = cpy +} + +// Sorted returns a sorted slice of terms for refs in the set. +func (rs *refSet) Sorted() []*Term { + terms := make([]*Term, len(rs.s)) + for i := range rs.s { + terms[i] = NewTerm(rs.s[i]) + } + sort.Slice(terms, func(i, j int) bool { + return terms[i].Value.Compare(terms[j].Value) < 0 + }) + return terms +} diff --git a/ast/compile_test.go b/ast/compile_test.go index 0de8baae5e..8b02e49609 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -263,7 +263,7 @@ func TestOutputVarsForNode(t *testing.T) { func TestModuleTree(t *testing.T) { - mods := getCompilerTestModules() + mods := getCompilerTestModules() // 7 modules mods["system-mod"] = MustParseModule(` package system.foo @@ -274,8 +274,14 @@ func TestModuleTree(t *testing.T) { p = 1 `) + mods["dots-in-heads"] = MustParseModule(` + package dots + + a.b.c = 12 + d.e.f.g = 34 + `) tree := NewModuleTree(mods) - expectedSize := 9 + expectedSize := 10 if tree.Size() != expectedSize { t.Fatalf("Expected %v but got %v modules", expectedSize, tree.Size()) @@ -294,6 +300,727 @@ func TestModuleTree(t *testing.T) { } } +func TestCompilerGetExports(t *testing.T) { + tests := []struct { + note string + modules []*Module + exports map[string][]string + }{ + { + note: "simple", + modules: modules(`package p + r = 1`), + exports: map[string][]string{"data.p": {"r"}}, + }, + { + note: "simple single-value ref rule", + modules: modules(`package p + q.r.s = 1`), + exports: map[string][]string{"data.p": {"q.r.s"}}, + }, + { + note: "var key single-value ref rule", + modules: modules(`package p + q.r[s] = 1 { s := "foo" }`), + exports: map[string][]string{"data.p": {"q.r"}}, + }, + { + note: "simple multi-value ref rule", + modules: modules(`package p + import future.keywords + + q.r.s contains 1 { true }`), + exports: map[string][]string{"data.p": {"q.r.s"}}, + }, + { + note: "two simple, multiple rules", + modules: modules(`package p + r = 1 + s = 11`, + `package q + x = 2 + y = 22`), + exports: map[string][]string{"data.p": {"r", "s"}, "data.q": {"x", "y"}}, + }, + { + note: "ref head + simple, multiple rules", + modules: modules(`package p.a.b.c + r = 1 + s = 11`, + `package q + a.b.x = 2 + a.b.c.y = 22`), + exports: map[string][]string{ + "data.p.a.b.c": {"r", "s"}, + "data.q": {"a.b.x", "a.b.c.y"}, + }, + }, + { + note: "two ref head, multiple rules", + modules: modules(`package p.a.b.c + r = 1 + s = 11`, + `package p + a.b.x = 2 + a.b.c.y = 22`), + exports: map[string][]string{ + "data.p.a.b.c": {"r", "s"}, + "data.p": {"a.b.x", "a.b.c.y"}, + }, + }, + { + note: "single-value rule with number key", + modules: modules(`package p + q[1] = 1 + q[2] = 2`), + exports: map[string][]string{ + "data.p": {"q[1]", "q[2]"}, // TODO(sr): is this really what we want? + }, + }, + { + note: "single-value (ref) rule with number key", + modules: modules(`package p + a.b.q[1] = 1 + a.b.q[2] = 2`), + exports: map[string][]string{ + "data.p": {"a.b.q[1]", "a.b.q[2]"}, + }, + }, + { + note: "single-value (ref) rule with var key", + modules: modules(`package p + a.b.q[x] = y { x := 1; y := true } + a.b.q[2] = 2`), + exports: map[string][]string{ + "data.p": {"a.b.q", "a.b.q[2]"}, // TODO(sr): GroundPrefix? right thing here? + }, + }, + { // NOTE(sr): An ast.Module can be constructed in various ways, this is to assert that + // our compilation process doesn't explode here if we're fed a Rule that has no Ref. + note: "synthetic", + modules: func() []*Module { + ms := modules(`package p + r = 1`) + ms[0].Rules[0].Head.Reference = nil + return ms + }(), + exports: map[string][]string{"data.p": {"r"}}, + }, + // TODO(sr): add multi-val rule, and ref-with-var single-value rule. + } + + hashMap := func(ms map[string][]string) *util.HashMap { + rules := util.NewHashMap(func(a, b util.T) bool { + switch a := a.(type) { + case Ref: + return a.Equal(b.(Ref)) + case []Ref: + b := b.([]Ref) + if len(b) != len(a) { + return false + } + for i := range a { + if !a[i].Equal(b[i]) { + return false + } + } + return true + default: + panic("unreachable") + } + }, func(v util.T) int { + return v.(Ref).Hash() + }) + for r, rs := range ms { + refs := make([]Ref, len(rs)) + for i := range rs { + refs[i] = toRef(rs[i]) + } + rules.Put(MustParseRef(r), refs) + } + return rules + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + c := NewCompiler() + for i, m := range tc.modules { + c.Modules[fmt.Sprint(i)] = m + c.sorted = append(c.sorted, fmt.Sprint(i)) + } + if exp, act := hashMap(tc.exports), c.getExports(); !exp.Equal(act) { + t.Errorf("expected %v, got %v", exp, act) + } + }) + } +} + +func toRef(s string) Ref { + switch t := MustParseTerm(s).Value.(type) { + case Var: + return Ref{NewTerm(t)} + case Ref: + return t + default: + panic("unreachable") + } +} + +func TestCompilerCheckRuleHeadRefs(t *testing.T) { + + tests := []struct { + note string + modules []*Module + expected *Rule + err string + }{ + { + note: "ref contains var", + modules: modules( + `package x + p.q[i].r = 1 { i := 10 }`, + ), + err: "rego_type_error: rule head must only contain string terms (except for last): i", + }, + { + note: "valid: ref is single-value rule with var key", + modules: modules( + `package x + p.q.r[i] { i := 10 }`, + ), + }, + { + note: "valid: ref is single-value rule with var key and value", + modules: modules( + `package x + p.q.r[i] = j { i := 10; j := 11 }`, + ), + }, + { + note: "valid: ref is single-value rule with var key and static value", + modules: modules( + `package x + p.q.r[i] = "ten" { i := 10 }`, + ), + }, + { + note: "valid: ref is single-value rule with number key", + modules: modules( + `package x + p.q.r[1] { true }`, + ), + }, + { + note: "valid: ref is single-value rule with boolean key", + modules: modules( + `package x + p.q.r[true] { true }`, + ), + }, + { + note: "valid: ref is single-value rule with null key", + modules: modules( + `package x + p.q.r[null] { true }`, + ), + }, + { + note: "valid: ref is single-value rule with set literal key", + modules: modules( + `package x + p.q.r[set()] { true }`, + ), + }, + { + note: "valid: ref is single-value rule with array literal key", + modules: modules( + `package x + p.q.r[[]] { true }`, + ), + }, + { + note: "valid: ref is single-value rule with object literal key", + modules: modules( + `package x + p.q.r[{}] { true }`, + ), + }, + { + note: "valid: ref is single-value rule with ref key", + modules: modules( + `package x + x := [1,2,3] + p.q.r[x[i]] { i := 0}`, + ), + }, + { + note: "invalid: ref in ref", + modules: modules( + `package x + p.q[arr[0]].r { i := 10 }`, + ), + err: "rego_type_error: rule head must only contain string terms (except for last): arr[0]", + }, + { + note: "invalid: non-string in ref (not last position)", + modules: modules( + `package x + p.q[10].r { true }`, + ), + err: "rego_type_error: rule head must only contain string terms (except for last): 10", + }, + { + note: "valid: multi-value with var key", + modules: modules( + `package x + p.q.r contains i if i := 10`, + ), + }, + { + note: "rewrite: single-value with non-var key (ref)", + modules: modules( + `package x + p.q.r[y.z] if y := {"z": "a"}`, + ), + expected: MustParseRule(`p.q.r[__local0__] { y := {"z": "a"}; __local0__ = y.z }`), + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + mods := make(map[string]*Module, len(tc.modules)) + for i, m := range tc.modules { + mods[fmt.Sprint(i)] = m + } + c := NewCompiler() + c.Modules = mods + compileStages(c, c.rewriteRuleHeadRefs) + if tc.err != "" { + assertCompilerErrorStrings(t, c, []string{tc.err}) + } else { + if len(c.Errors) > 0 { + t.Fatalf("expected no errors, got %v", c.Errors) + } + if tc.expected != nil { + assertRulesEqual(t, tc.expected, mods["0"].Rules[0]) + } + } + }) + } +} + +func TestRuleTreeWithDotsInHeads(t *testing.T) { + + // TODO(sr): multi-val with var key in ref + tests := []struct { + note string + modules []*Module + size int // expected tree size = number of leaves + depth int // expected tree depth + }{ + { + note: "two modules, same package, one rule each", + modules: modules( + `package x + p.q.r = 1`, + `package x + p.q.w = 2`, + ), + size: 2, + }, + { + note: "two modules, sub-package, one rule each", + modules: modules( + `package x + p.q.r = 1`, + `package x.p + q.w.z = 2`, + ), + size: 2, + }, + { + note: "three modules, sub-package, incl simple rule", + modules: modules( + `package x + p.q.r = 1`, + `package x.p + q.w.z = 2`, + `package x.p.q.w + y = 3`, + ), + size: 3, + }, + { + note: "simple: two modules", + modules: modules( + `package x + p.q.r = 1`, + `package y + p.q.w = 2`, + ), + size: 2, + }, + { + note: "conflict: one module", + modules: modules( + `package q + p[x] = 1 + p = 2`, + ), + size: 2, + }, + { + note: "conflict: two modules", + modules: modules( + `package q + p.r.s[x] = 1`, + `package q.p + r.s = 2 if true`, + ), + size: 2, + }, + { + note: "simple: two modules, one using ref head, one package path", + modules: modules( + `package x + p.q.r = 1 { input == 1 }`, + `package x.p.q + r = 2 { input == 2 }`, + ), + size: 2, + }, + { + note: "conflict: two modules, both using ref head, different package paths", + modules: modules( + `package x + p.q.r = 1 { input == 1 }`, // x.p.q.r = 1 + `package x.p + q.r.s = 2 { input == 2 }`, // x.p.q.r.s = 2 + ), + size: 2, + }, + { + note: "overlapping: one module, two ref head", + modules: modules( + `package x + p.q.r = 1 + p.q.w.v = 2`, + ), + size: 2, + depth: 6, + }, + { + note: "last ref term != string", + modules: modules( + `package x + p.q.w[1] = 2 + p.q.w[{"foo": "baz"}] = 20 + p.q.x[true] = false + p.q.x[y] = y { y := "y" }`, + ), + size: 4, + depth: 6, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + c := NewCompiler() + for i, m := range tc.modules { + c.Modules[fmt.Sprint(i)] = m + c.sorted = append(c.sorted, fmt.Sprint(i)) + } + compileStages(c, c.setRuleTree) + if len(c.Errors) > 0 { + t.Fatal(c.Errors) + } + tree := c.RuleTree + tree.DepthFirst(func(n *TreeNode) bool { + t.Log(n) + if !sort.SliceIsSorted(n.Sorted, func(i, j int) bool { + return n.Sorted[i].Compare(n.Sorted[j]) < 0 + }) { + t.Errorf("expected sorted to be sorted: %v", n.Sorted) + } + return false + }) + if tc.depth > 0 { + if exp, act := tc.depth, depth(tree); exp != act { + t.Errorf("expected tree depth %d, got %d", exp, act) + } + } + if exp, act := tc.size, tree.Size(); exp != act { + t.Errorf("expected tree size %d, got %d", exp, act) + } + }) + } +} + +func TestRuleTreeWithVars(t *testing.T) { + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + + t.Run("simple single-value rule", func(t *testing.T) { + mod0 := `package a.b +c.d.e = 1 if true` + + mods := map[string]*Module{"0.rego": MustParseModuleWithOpts(mod0, opts)} + tree := NewRuleTree(NewModuleTree(mods)) + + node := tree.Find(MustParseRef("data.a.b.c.d.e")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 1, len(node.Values); exp != act { + t.Errorf("expected %d values, found %d", exp, act) + } + if exp, act := 0, len(node.Children); exp != act { + t.Errorf("expected %d children, found %d", exp, act) + } + if exp, act := MustParseRef("c.d.e"), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + }) + + t.Run("two single-value rules", func(t *testing.T) { + mod0 := `package a.b +c.d.e = 1 if true` + mod1 := `package a.b.c +d.e = 2 if true` + + mods := map[string]*Module{ + "0.rego": MustParseModuleWithOpts(mod0, opts), + "1.rego": MustParseModuleWithOpts(mod1, opts), + } + tree := NewRuleTree(NewModuleTree(mods)) + + node := tree.Find(MustParseRef("data.a.b.c.d.e")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 2, len(node.Values); exp != act { + t.Errorf("expected %d values, found %d", exp, act) + } + if exp, act := 0, len(node.Children); exp != act { + t.Errorf("expected %d children, found %d", exp, act) + } + if exp, act := MustParseRef("c.d.e"), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if exp, act := MustParseRef("d.e"), node.Values[1].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + }) + + t.Run("one multi-value rule, one single-value, with var", func(t *testing.T) { + mod0 := `package a.b +c.d.e.g contains 1 if true` + mod1 := `package a.b.c +d.e.f = 2 if true` + + mods := map[string]*Module{ + "0.rego": MustParseModuleWithOpts(mod0, opts), + "1.rego": MustParseModuleWithOpts(mod1, opts), + } + tree := NewRuleTree(NewModuleTree(mods)) + + // var-key rules should be included in the results + node := tree.Find(MustParseRef("data.a.b.c.d.e.g")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 1, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d", exp, act) + } + if exp, act := 0, len(node.Children); exp != act { + t.Fatalf("expected %d children, found %d", exp, act) + } + node = tree.Find(MustParseRef("data.a.b.c.d.e.f")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 1, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d", exp, act) + } + if exp, act := MustParseRef("d.e.f"), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + }) + + t.Run("two multi-value rules, back compat", func(t *testing.T) { + mod0 := `package a +b[c] { c := "foo" }` + mod1 := `package a +b[d] { d := "bar" }` + + mods := map[string]*Module{ + "0.rego": MustParseModuleWithOpts(mod0, opts), + "1.rego": MustParseModuleWithOpts(mod1, opts), + } + tree := NewRuleTree(NewModuleTree(mods)) + + node := tree.Find(MustParseRef("data.a.b")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 2, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d: %v", exp, act, node.Values) + } + if exp, act := 0, len(node.Children); exp != act { + t.Errorf("expected %d children, found %d", exp, act) + } + if exp, act := (Ref{VarTerm("b")}), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if act := node.Values[0].(*Rule).Head.Value; act != nil { + t.Errorf("expected rule value nil, found %v", act) + } + if exp, act := VarTerm("c"), node.Values[0].(*Rule).Head.Key; !exp.Equal(act) { + t.Errorf("expected rule key %v, found %v", exp, act) + } + if exp, act := (Ref{VarTerm("b")}), node.Values[1].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if act := node.Values[1].(*Rule).Head.Value; act != nil { + t.Errorf("expected rule value nil, found %v", act) + } + if exp, act := VarTerm("d"), node.Values[1].(*Rule).Head.Key; !exp.Equal(act) { + t.Errorf("expected rule key %v, found %v", exp, act) + } + }) + + t.Run("two multi-value rules, back compat with short style", func(t *testing.T) { + mod0 := `package a +b[1]` + mod1 := `package a +b[2]` + mods := map[string]*Module{ + "0.rego": MustParseModuleWithOpts(mod0, opts), + "1.rego": MustParseModuleWithOpts(mod1, opts), + } + tree := NewRuleTree(NewModuleTree(mods)) + + node := tree.Find(MustParseRef("data.a.b")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 2, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d: %v", exp, act, node.Values) + } + if exp, act := 0, len(node.Children); exp != act { + t.Errorf("expected %d children, found %d", exp, act) + } + if exp, act := (Ref{VarTerm("b")}), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if act := node.Values[0].(*Rule).Head.Value; act != nil { + t.Errorf("expected rule value nil, found %v", act) + } + if exp, act := IntNumberTerm(1), node.Values[0].(*Rule).Head.Key; !exp.Equal(act) { + t.Errorf("expected rule key %v, found %v", exp, act) + } + if exp, act := (Ref{VarTerm("b")}), node.Values[1].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if act := node.Values[1].(*Rule).Head.Value; act != nil { + t.Errorf("expected rule value nil, found %v", act) + } + if exp, act := IntNumberTerm(2), node.Values[1].(*Rule).Head.Key; !exp.Equal(act) { + t.Errorf("expected rule key %v, found %v", exp, act) + } + }) + + t.Run("two single-value rules, back compat with short style", func(t *testing.T) { + mod0 := `package a +b[1] = 1` + mod1 := `package a +b[2] = 2` + mods := map[string]*Module{ + "0.rego": MustParseModuleWithOpts(mod0, opts), + "1.rego": MustParseModuleWithOpts(mod1, opts), + } + tree := NewRuleTree(NewModuleTree(mods)) + + // branch point + node := tree.Find(MustParseRef("data.a.b")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 0, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d: %v", exp, act, node.Values) + } + if exp, act := 2, len(node.Children); exp != act { + t.Fatalf("expected %d children, found %d", exp, act) + } + + // branch 1 + node = tree.Find(MustParseRef("data.a.b[1]")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 1, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d: %v", exp, act, node.Values) + } + if exp, act := MustParseRef("b[1]"), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if exp, act := IntNumberTerm(1), node.Values[0].(*Rule).Head.Value; !exp.Equal(act) { + t.Errorf("expected rule value %v, found %v", exp, act) + } + if exp, act := IntNumberTerm(1), node.Values[0].(*Rule).Head.Key; !exp.Equal(act) { + t.Errorf("expected rule key %v, found %v", exp, act) + } + + // branch 2 + node = tree.Find(MustParseRef("data.a.b[2]")) + if node == nil { + t.Fatal("expected non-nil leaf node") + } + if exp, act := 1, len(node.Values); exp != act { + t.Fatalf("expected %d values, found %d: %v", exp, act, node.Values) + } + if exp, act := MustParseRef("b[2]"), node.Values[0].(*Rule).Head.Ref(); !exp.Equal(act) { + t.Errorf("expected rule ref %v, found %v", exp, act) + } + if exp, act := IntNumberTerm(2), node.Values[0].(*Rule).Head.Value; !exp.Equal(act) { + t.Errorf("expected rule value %v, found %v", exp, act) + } + if exp, act := IntNumberTerm(2), node.Values[0].(*Rule).Head.Key; !exp.Equal(act) { + t.Errorf("expected rule key %v, found %v", exp, act) + } + }) + + // NOTE(sr): Now this test seems obvious, but it's a bug that had snuck into the + // NewRuleTree code during development. + t.Run("root node and data node unhidden if there are no system nodes", func(t *testing.T) { + mod0 := `package a +p = 1` + mods := map[string]*Module{ + "0.rego": MustParseModuleWithOpts(mod0, opts), + } + tree := NewRuleTree(NewModuleTree(mods)) + + if exp, act := false, tree.Hide; act != exp { + t.Errorf("expected tree.Hide=%v, got %v", exp, act) + } + dataNode := tree.Child(Var("data")) + if dataNode == nil { + t.Fatal("expected data node") + } + if exp, act := false, dataNode.Hide; act != exp { + t.Errorf("expected dataNode.Hide=%v, got %v", exp, act) + } + }) +} + +func depth(n *TreeNode) int { + d := -1 + for _, m := range n.Children { + if d0 := depth(m); d0 > d { + d = d0 + } + } + return d + 1 +} func TestModuleTreeFilenameOrder(t *testing.T) { // NOTE(sr): It doesn't matter that these are conflicting; but that's where it @@ -328,16 +1055,23 @@ func TestRuleTree(t *testing.T) { mods["non-system-mod"] = MustParseModule(` package user.system - p = 1 - `) - mods["mod-incr"] = MustParseModule(`package a.b.c + p = 1`) + mods["mod-incr"] = MustParseModule(` + package a.b.c -s[1] { true } -s[2] { true }`, + s[1] { true } + s[2] { true }`, ) + mods["dots-in-heads"] = MustParseModule(` + package dots + + a.b.c = 12 + d.e.f.g = 34 + `) + tree := NewRuleTree(NewModuleTree(mods)) - expectedNumRules := 23 + expectedNumRules := 25 if tree.Size() != expectedNumRules { t.Errorf("Expected %v but got %v rules", expectedNumRules, tree.Size()) @@ -345,14 +1079,18 @@ s[2] { true }`, // Check that empty packages are represented as leaves with no rules. node := tree.Children[Var("data")].Children[String("a")].Children[String("b")].Children[String("empty")] - if node == nil || len(node.Children) != 0 || len(node.Values) != 0 { t.Fatalf("Unexpected nil value or non-empty leaf of non-leaf node: %v", node) } + // Check that root node is not hidden + if exp, act := false, tree.Hide; act != exp { + t.Errorf("expected tree.Hide=%v, got %v", exp, act) + } + system := tree.Child(Var("data")).Child(String("system")) if !system.Hide { - t.Fatalf("Expected system node to be hidden") + t.Fatalf("Expected system node to be hidden: %v", system) } if system.Child(String("foo")).Hide { @@ -692,16 +1430,19 @@ func TestCompilerErrorLimit(t *testing.T) { func TestCompilerCheckSafetyHead(t *testing.T) { c := NewCompiler() c.Modules = getCompilerTestModules() - c.Modules["newMod"] = MustParseModule(`package a.b - -unboundKey[x] = y { q[y] = {"foo": [1, 2, [{"bar": y}]]} } -unboundVal[y] = x { q[y] = {"foo": [1, 2, [{"bar": y}]]} } -unboundCompositeVal[y] = [{"foo": x, "bar": y}] { q[y] = {"foo": [1, 2, [{"bar": y}]]} } -unboundCompositeKey[[{"x": x}]] { q[y] } -unboundBuiltinOperator = eq { x = 1 } + popts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + c.Modules["newMod"] = MustParseModuleWithOpts(`package a.b + +unboundKey[x1] = y { q[y] = {"foo": [1, 2, [{"bar": y}]]} } +unboundVal[y] = x2 { q[y] = {"foo": [1, 2, [{"bar": y}]]} } +unboundCompositeVal[y] = [{"foo": x3, "bar": y}] { q[y] = {"foo": [1, 2, [{"bar": y}]]} } +unboundCompositeKey[[{"x": x4}]] { q[y] } +unboundBuiltinOperator = eq { 4 = 1 } unboundElse { false } else = else_var { true } -`, - ) +c.d.e[x5] if true +f.g.h[y] = x6 if y := "y" +i.j.k contains x7 if true +`, popts) compileStages(c, c.checkSafetyRuleHeads) makeErrMsg := func(v string) string { @@ -709,10 +1450,13 @@ unboundElse { false } else = else_var { true } } expected := []string{ - makeErrMsg("x"), - makeErrMsg("x"), - makeErrMsg("x"), - makeErrMsg("x"), + makeErrMsg("x1"), + makeErrMsg("x2"), + makeErrMsg("x3"), + makeErrMsg("x4"), + makeErrMsg("x5"), + makeErrMsg("x6"), + makeErrMsg("x7"), makeErrMsg("eq"), makeErrMsg("else_var"), } @@ -1018,20 +1762,37 @@ q[1] { true }`, default foo = 1 default foo = 2 -foo = 3 { true }`, - "mod4.rego": `package adrules.arity +foo = 3 { true } + +default p.q.bar = 1 +default p.q.bar = 2 +p.q.bar = 3 { true } +`, + "mod4.rego": `package badrules.arity f(1) { true } f { true } g(1) { true } -g(1,2) { true }`, +g(1,2) { true } + +p.q.h(1) { true } +p.q.h { true } + +p.q.i(1) { true } +p.q.i(1,2) { true }`, "mod5.rego": `package badrules.dataoverlap p { true }`, "mod6.rego": `package badrules.existserr -p { true }`}) +p { true }`, + + "mod7.rego": `package badrules.foo +import future.keywords + +bar.baz contains "quz" if true`, + }) c.WithPathConflictsCheck(func(path []string) (bool, error) { if reflect.DeepEqual(path, []string{"badrules", "dataoverlap", "p"}) { @@ -1047,18 +1808,155 @@ p { true }`}) expected := []string{ "rego_compile_error: conflict check for data path badrules/existserr/p: unexpected error", "rego_compile_error: conflicting rule for data path badrules/dataoverlap/p found", - "rego_type_error: conflicting rules named f found", - "rego_type_error: conflicting rules named g found", - "rego_type_error: conflicting rules named p found", - "rego_type_error: conflicting rules named q found", - "rego_type_error: multiple default rules named foo found", - "rego_type_error: package badrules.r conflicts with rule defined at mod1.rego:7", - "rego_type_error: package badrules.r conflicts with rule defined at mod1.rego:8", + "rego_type_error: conflicting rules data.badrules.arity.f found", + "rego_type_error: conflicting rules data.badrules.arity.g found", + "rego_type_error: conflicting rules data.badrules.arity.p.q.h found", + "rego_type_error: conflicting rules data.badrules.arity.p.q.i found", + "rego_type_error: conflicting rules data.badrules.p[x] found", + "rego_type_error: conflicting rules data.badrules.q found", + "rego_type_error: multiple default rules data.badrules.defkw.foo found", + "rego_type_error: multiple default rules data.badrules.defkw.p.q.bar found", + "rego_type_error: package badrules.r conflicts with rule r[x] defined at mod1.rego:7", + "rego_type_error: package badrules.r conflicts with rule r[x] defined at mod1.rego:8", } assertCompilerErrorStrings(t, c, expected) } +func TestCompilerCheckRuleConflictsDotsInRuleHeads(t *testing.T) { + + tests := []struct { + note string + modules []*Module + err string + }{ + { + note: "arity mismatch, ref and non-ref rule", + modules: modules( + `package pkg + p.q.r { true }`, + `package pkg.p.q + r(_) = 2`), + err: "rego_type_error: conflicting rules data.pkg.p.q.r found", + }, + { + note: "two default rules, ref and non-ref rule", + modules: modules( + `package pkg + default p.q.r = 3 + p.q.r { true }`, + `package pkg.p.q + default r = 4 + r = 2`), + err: "rego_type_error: multiple default rules data.pkg.p.q.r found", + }, + { + note: "arity mismatch, ref and ref rule", + modules: modules( + `package pkg.a.b + p.q.r { true }`, + `package pkg.a + b.p.q.r(_) = 2`), + err: "rego_type_error: conflicting rules data.pkg.a.b.p.q.r found", + }, + { + note: "two default rules, ref and ref rule", + modules: modules( + `package pkg + default p.q.w.r = 3 + p.q.w.r { true }`, + `package pkg.p + default q.w.r = 4 + q.w.r = 2`), + err: "rego_type_error: multiple default rules data.pkg.p.q.w.r found", + }, + { + note: "multi-value + single-value rules, both with same ref prefix", + modules: modules( + `package pkg + p.q.w[x] = 1 if x := "foo"`, + `package pkg + p.q.w contains "bar"`), + err: "rego_type_error: conflicting rules data.pkg.p.q.w found", + }, + { + note: "two multi-value rules, both with same ref", + modules: modules( + `package pkg + p.q.w contains "baz"`, + `package pkg + p.q.w contains "bar"`), + }, + { + note: "module conflict: non-ref rule", + modules: modules( + `package pkg.q + r { true }`, + `package pkg.q.r`), + err: "rego_type_error: package pkg.q.r conflicts with rule r defined at mod0.rego:2", + }, + { + note: "module conflict: ref rule", + modules: modules( + `package pkg + p.q.r { true }`, + `package pkg.p.q.r`), + err: "rego_type_error: package pkg.p.q.r conflicts with rule p.q.r defined at mod0.rego:2", + }, + { + note: "single-value with other rule overlap", + modules: modules( + `package pkg + p.q.r { true }`, + `package pkg + p.q.r.s { true }`), + err: "rego_type_error: single-value rule data.pkg.p.q.r conflicts with [data.pkg.p.q.r.s]", + }, + { + note: "single-value with other rule overlap", + modules: modules( + `package pkg + p.q.r { true } + p.q.r.s { true } + p.q.r.t { true }`), + err: "rego_type_error: single-value rule data.pkg.p.q.r conflicts with [data.pkg.p.q.r.s data.pkg.p.q.r.t]", + }, + { + note: "single-value with other rule overlap, unknown key", + modules: modules( + `package pkg + p.q[r] = x { r = input.key; x = input.foo } + p.q.r.s = x { true } + `), + err: "rego_type_error: single-value rule data.pkg.p.q[r] conflicts with [data.pkg.p.q.r.s]", + }, + { + note: "single-value rule with known and unknown key", + modules: modules( + `package pkg + p.q[r] = x { r = input.key; x = input.foo } + p.q.s = "x" { true } + `), + }, + } + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + mods := make(map[string]*Module, len(tc.modules)) + for i, m := range tc.modules { + mods[fmt.Sprint(i)] = m + } + c := NewCompiler() + c.Modules = mods + compileStages(c, c.checkRuleConflicts) + if tc.err != "" { + assertCompilerErrorStrings(t, c, []string{tc.err}) + } else { + assertCompilerErrorStrings(t, c, []string{}) + } + }) + } +} + func TestCompilerCheckUndefinedFuncs(t *testing.T) { module := ` @@ -2072,6 +2970,84 @@ func assertErrors(t *testing.T, actual Errors, expected Errors, assertLocation b } } +// NOTE(sr): the tests below this function are unwieldy, let's keep adding new ones to this one +func TestCompilerResolveAllRefsNewTests(t *testing.T) { + tests := []struct { + note string + mod string + exp string + extra string + }{ + { + note: "ref-rules referenced in body", + mod: `package test +a.b.c = 1 +q if a.b.c == 1 +`, + exp: `package test +a.b.c = 1 { true } +q if data.test.a.b.c = 1 +`, + }, + { + // NOTE(sr): This is a conservative extension of how it worked before: + // we will not automatically extend references to other parts of the rule tree, + // only to ref rules defined on the same level. + note: "ref-rules from other module referenced in body", + mod: `package test +q if a.b.c == 1 +`, + extra: `package test +a.b.c = 1 +`, + exp: `package test +q if data.test.a.b.c = 1 +`, + }, + { + note: "single-value rule in comprehension in call", // NOTE(sr): this is TestRego/partialiter/objects_conflict + mod: `package test +p := count([x | q[x]]) +q[1] = 1 +`, + exp: `package test +p := __local0__ { true; __local1__ = [x | data.test.q[x]]; count(__local1__, __local0__) } +q[1] = 1 +`, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + c := NewCompiler() + mod, err := ParseModuleWithOpts("test.rego", tc.mod, opts) + if err != nil { + t.Fatal(err) + } + exp, err := ParseModuleWithOpts("test.rego", tc.exp, opts) + if err != nil { + t.Fatal(err) + } + mods := map[string]*Module{"test": mod} + if tc.extra != "" { + extra, err := ParseModuleWithOpts("test.rego", tc.extra, opts) + if err != nil { + t.Fatal(err) + } + mods["extra"] = extra + } + c.Compile(mods) + if err := c.Errors; len(err) > 0 { + t.Errorf("compile module: %v", err) + } + if act := c.Modules["test"]; !exp.Equal(act) { + t.Errorf("compiled: expected %v, got %v", exp, act) + } + }) + } +} + func TestCompilerResolveAllRefs(t *testing.T) { c := NewCompiler() c.Modules = getCompilerTestModules() @@ -2180,6 +3156,12 @@ p[foo[bar[i]]] = {"baz": baz} { true }`) } }`, ParserOptions{unreleasedKeywords: true, FutureKeywords: []string{"every", "in"}}) + c.Modules["heads_with_dots"] = MustParseModule(`package heads_with_dots + + this_is_not = true + this.is.dotted { this_is_not } + `) + compileStages(c, c.resolveAllRefs) assertNotFailed(t, c) @@ -2317,6 +3299,15 @@ p[foo[bar[i]]] = {"baz": baz} { true }`) gt10 := MustParseExpr("x > 10") gt10.Index++ // TODO(sr): why? assertExprEqual(t, everyExpr.Body[1], gt10) + + // head refs are kept as-is, but their bodies are replaced. + mod := c.Modules["heads_with_dots"] + rule := mod.Rules[1] + body := rule.Body[0].Terms.(*Term) + assertTermEqual(t, body, MustParseTerm("data.heads_with_dots.this_is_not")) + if act, exp := rule.Head.Ref(), MustParseRef("this.is.dotted"); act.Compare(exp) != 0 { + t.Errorf("expected %v to match %v", act, exp) + } } func TestCompilerResolveErrors(t *testing.T) { @@ -2340,48 +3331,91 @@ func TestCompilerResolveErrors(t *testing.T) { } func TestCompilerRewriteTermsInHead(t *testing.T) { - c := NewCompiler() - c.Modules["head"] = MustParseModule(`package head + popts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + tests := []struct { + note string + mod *Module + exp *Rule + }{ + { + note: "imports", + mod: MustParseModule(`package head import data.doc1 as bar import data.doc2 as corge import input.x.y.foo import input.qux as baz p[foo[bar[i]]] = {"baz": baz, "corge": corge} { true } +`), + exp: MustParseRule(`p[__local0__] = __local1__ { true; __local0__ = input.x.y.foo[data.doc1[i]]; __local1__ = {"baz": input.qux, "corge": data.doc2} }`), + }, + { + note: "array comprehension value", + mod: MustParseModule(`package head q = [true | true] { true } +`), + exp: MustParseRule(`q = __local0__ { true; __local0__ = [true | true] }`), + }, + { + note: "object comprehension value", + mod: MustParseModule(`package head r = {"true": true | true} { true } +`), + exp: MustParseRule(`r = __local0__ { true; __local0__ = {"true": true | true} }`), + }, + { + note: "set comprehension value", + mod: MustParseModule(`package head s = {true | true} { true } - +`), + exp: MustParseRule(`s = __local0__ { true; __local0__ = {true | true} }`), + }, + { + note: "import in else value", + mod: MustParseModule(`package head +import input.qux as baz elsekw { false } else = baz { true } -`) - - compileStages(c, c.rewriteRefsInHead) - assertNotFailed(t, c) - - rule1 := c.Modules["head"].Rules[0] - expected1 := MustParseRule(`p[__local0__] = __local1__ { true; __local0__ = input.x.y.foo[data.doc1[i]]; __local1__ = {"baz": input.qux, "corge": data.doc2} }`) - assertRulesEqual(t, rule1, expected1) - - rule2 := c.Modules["head"].Rules[1] - expected2 := MustParseRule(`q = __local2__ { true; __local2__ = [true | true] }`) - assertRulesEqual(t, rule2, expected2) - - rule3 := c.Modules["head"].Rules[2] - expected3 := MustParseRule(`r = __local3__ { true; __local3__ = {"true": true | true} }`) - assertRulesEqual(t, rule3, expected3) - - rule4 := c.Modules["head"].Rules[3] - expected4 := MustParseRule(`s = __local4__ { true; __local4__ = {true | true} }`) - assertRulesEqual(t, rule4, expected4) +`), + exp: MustParseRule(`elsekw { false } else = __local0__ { true; __local0__ = input.qux }`), + }, + { + note: "import ref in last ref head term", + mod: MustParseModule(`package head +import data.doc1 as bar +x.y.z[bar[i]] = true +`), + exp: MustParseRule(`x.y.z[__local0__] = true { true; __local0__ = data.doc1[i] }`), + }, + { + note: "import ref in multi-value ref rule", + mod: MustParseModule(`package head +import future.keywords.if +import future.keywords.contains +import data.doc1 as bar +x.y.w contains bar[i] if true +`), + exp: func() *Rule { + exp, _ := ParseRuleWithOpts(`x.y.w contains __local0__ if {true; __local0__ = data.doc1[i] }`, popts) + return exp + }(), + }, + } - rule5 := c.Modules["head"].Rules[4] - expected5 := MustParseRule(`elsekw { false } else = __local5__ { true; __local5__ = input.qux }`) - assertRulesEqual(t, rule5, expected5) + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + c := NewCompiler() + c.Modules["head"] = tc.mod + compileStages(c, c.rewriteRefsInHead) + assertNotFailed(t, c) + act := c.Modules["head"].Rules[0] + assertRulesEqual(t, act, tc.exp) + }) + } } func TestCompilerRewriteRegoMetadataCalls(t *testing.T) { @@ -3384,6 +4418,25 @@ func TestRewriteDeclaredVars(t *testing.T) { p { __local1__ = data.test.y; data.test.q[[__local1__, __local0__]] } `, }, + { + note: "single-value rule with ref head", + module: ` + package test + + p.r.q[s] = t { + t := 1 + s := input.foo + } + `, + exp: ` + package test + + p.r.q[__local1__] = __local0__ { + __local0__ = 1 + __local1__ = input.foo + } + `, + }, { note: "rewrite some x in xs", module: ` @@ -3873,7 +4926,7 @@ func TestRewriteDeclaredVars(t *testing.T) { for _, tc := range tests { t.Run(tc.note, func(t *testing.T) { - opts := CompileOpts{ParserOptions: ParserOptions{FutureKeywords: []string{"in", "every"}, unreleasedKeywords: true}} + opts := CompileOpts{ParserOptions: ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true}} compiler, err := CompileModulesWithOpt(map[string]string{"test.rego": tc.module}, opts) if tc.wantErr != nil { if err == nil { @@ -5560,37 +6613,41 @@ dataref = true { data }`, compileStages(c, c.checkRecursion) - makeRuleErrMsg := func(rule string, loop ...string) string { - return fmt.Sprintf("rego_recursion_error: rule %v is recursive: %v", rule, strings.Join(loop, " -> ")) + makeRuleErrMsg := func(pkg, rule string, loop ...string) string { + l := make([]string, len(loop)) + for i, lo := range loop { + l[i] = "data." + pkg + "." + lo + } + return fmt.Sprintf("rego_recursion_error: rule data.%s.%s is recursive: %v", pkg, rule, strings.Join(l, " -> ")) } expected := []string{ - makeRuleErrMsg("s", "s", "t", "s"), - makeRuleErrMsg("t", "t", "s", "t"), - makeRuleErrMsg("a", "a", "b", "c", "e", "a"), - makeRuleErrMsg("b", "b", "c", "e", "a", "b"), - makeRuleErrMsg("c", "c", "e", "a", "b", "c"), - makeRuleErrMsg("e", "e", "a", "b", "c", "e"), - makeRuleErrMsg("p", "p", "q", "p"), - makeRuleErrMsg("q", "q", "p", "q"), - makeRuleErrMsg("acq", "acq", "acp", "acq"), - makeRuleErrMsg("acp", "acp", "acq", "acp"), - makeRuleErrMsg("np", "np", "nq", "np"), - makeRuleErrMsg("nq", "nq", "np", "nq"), - makeRuleErrMsg("prefix", "prefix", "prefix"), - makeRuleErrMsg("dataref", "dataref", "dataref"), - makeRuleErrMsg("else_self", "else_self", "else_self"), - makeRuleErrMsg("elsetop", "elsetop", "elsemid", "elsebottom", "elsetop"), - makeRuleErrMsg("elsemid", "elsemid", "elsebottom", "elsetop", "elsemid"), - makeRuleErrMsg("elsebottom", "elsebottom", "elsetop", "elsemid", "elsebottom"), - makeRuleErrMsg("fn", "fn", "fn"), - makeRuleErrMsg("foo", "foo", "bar", "foo"), - makeRuleErrMsg("bar", "bar", "foo", "bar"), - makeRuleErrMsg("bar", "bar", "p", "foo", "bar"), - makeRuleErrMsg("foo", "foo", "bar", "p", "foo"), - makeRuleErrMsg("p", "p", "foo", "bar", "p"), - makeRuleErrMsg("everyp", "everyp", "everyp"), - makeRuleErrMsg("everyq", "everyq", "everyq"), + makeRuleErrMsg("rec", "s", "s", "t", "s"), + makeRuleErrMsg("rec", "t", "t", "s", "t"), + makeRuleErrMsg("rec", "a", "a", "b", "c", "e", "a"), + makeRuleErrMsg("rec", "b", "b", "c", "e", "a", "b"), + makeRuleErrMsg("rec", "c", "c", "e", "a", "b", "c"), + makeRuleErrMsg("rec", "e", "e", "a", "b", "c", "e"), + `rego_recursion_error: rule data.rec3.p[x] is recursive: data.rec3.p[x] -> data.rec4.q[x] -> data.rec3.p[x]`, // NOTE(sr): these two are hardcoded: they are + `rego_recursion_error: rule data.rec4.q[x] is recursive: data.rec4.q[x] -> data.rec3.p[x] -> data.rec4.q[x]`, // the only ones not fitting the pattern. + makeRuleErrMsg("rec5", "acq", "acq", "acp", "acq"), + makeRuleErrMsg("rec5", "acp", "acp", "acq", "acp"), + makeRuleErrMsg("rec6", "np[x]", "np[x]", "nq[x]", "np[x]"), + makeRuleErrMsg("rec6", "nq[x]", "nq[x]", "np[x]", "nq[x]"), + makeRuleErrMsg("rec7", "prefix", "prefix", "prefix"), + makeRuleErrMsg("rec8", "dataref", "dataref", "dataref"), + makeRuleErrMsg("rec9", "else_self", "else_self", "else_self"), + makeRuleErrMsg("rec9", "elsetop", "elsetop", "elsemid", "elsebottom", "elsetop"), + makeRuleErrMsg("rec9", "elsemid", "elsemid", "elsebottom", "elsetop", "elsemid"), + makeRuleErrMsg("rec9", "elsebottom", "elsebottom", "elsetop", "elsemid", "elsebottom"), + makeRuleErrMsg("f0", "fn", "fn", "fn"), + makeRuleErrMsg("f1", "foo", "foo", "bar", "foo"), + makeRuleErrMsg("f1", "bar", "bar", "foo", "bar"), + makeRuleErrMsg("f2", "bar", "bar", "p[x]", "foo", "bar"), + makeRuleErrMsg("f2", "foo", "foo", "bar", "p[x]", "foo"), + makeRuleErrMsg("f2", "p[x]", "p[x]", "foo", "bar", "p[x]"), + makeRuleErrMsg("everymod", "everyp", "everyp", "everyp"), + makeRuleErrMsg("everymod", "everyq", "everyq", "everyq"), } result := compilerErrsToStringSlice(c.Errors) @@ -5612,28 +6669,38 @@ func TestCompilerCheckDynamicRecursion(t *testing.T) { // references. For more background info, see // . - for note, mod := range map[string]*Module{ - "recursion": MustParseModule(` + for _, tc := range []struct { + note, err string + mod *Module + }{ + { + note: "recursion", + mod: MustParseModule(` package recursion pkg = "recursion" foo[x] { data[pkg]["foo"][x] } `), - "system.main": MustParseModule(` + err: "rego_recursion_error: rule data.recursion.foo is recursive: data.recursion.foo -> data.recursion.foo", + }, + {note: "system.main", + mod: MustParseModule(` package system.main foo { - data[input] + data[input] } `), + err: "rego_recursion_error: rule data.system.main.foo is recursive: data.system.main.foo -> data.system.main.foo", + }, } { - t.Run(note, func(t *testing.T) { + t.Run(tc.note, func(t *testing.T) { c := NewCompiler() - c.Modules = map[string]*Module{note: mod} + c.Modules = map[string]*Module{tc.note: tc.mod} compileStages(c, c.checkRecursion) result := compilerErrsToStringSlice(c.Errors) - expected := "rego_recursion_error: rule foo is recursive: foo -> foo" + expected := tc.err if len(result) != 1 || result[0] != expected { t.Errorf("Expected %v but got: %v", expected, result) @@ -5909,6 +6976,7 @@ func TestCompilerGetRulesDynamic(t *testing.T) { "mod1": `package a.b.c.d r1 = 1`, "mod2": `package a.b.c.e +default r2 = false r2 = 2`, "mod3": `package a.b r3 = 3`, @@ -5919,7 +6987,8 @@ r4 = 4`, compileStages(compiler, nil) rule1 := compiler.Modules["mod1"].Rules[0] - rule2 := compiler.Modules["mod2"].Rules[0] + rule2d := compiler.Modules["mod2"].Rules[0] + rule2 := compiler.Modules["mod2"].Rules[1] rule3 := compiler.Modules["mod3"].Rules[0] rule4 := compiler.Modules["hidden"].Rules[0] @@ -5929,15 +6998,16 @@ r4 = 4`, excludeHidden bool }{ {input: "data.a.b.c.d.r1", expected: []*Rule{rule1}}, - {input: "data.a.b[x]", expected: []*Rule{rule1, rule2, rule3}}, + {input: "data.a.b[x]", expected: []*Rule{rule1, rule2d, rule2, rule3}}, {input: "data.a.b[x].d", expected: []*Rule{rule1, rule3}}, - {input: "data.a.b.c", expected: []*Rule{rule1, rule2}}, + {input: "data.a.b.c", expected: []*Rule{rule1, rule2d, rule2}}, {input: "data.a.b.d"}, - {input: "data[x]", expected: []*Rule{rule1, rule2, rule3, rule4}}, - {input: "data[data.complex_computation].b[y]", expected: []*Rule{rule1, rule2, rule3}}, - {input: "data[x][y].c.e", expected: []*Rule{rule2}}, + {input: "data", expected: []*Rule{rule1, rule2d, rule2, rule3, rule4}}, + {input: "data[x]", expected: []*Rule{rule1, rule2d, rule2, rule3, rule4}}, + {input: "data[data.complex_computation].b[y]", expected: []*Rule{rule1, rule2d, rule2, rule3}}, + {input: "data[x][y].c.e", expected: []*Rule{rule2d, rule2}}, {input: "data[x][y].r3", expected: []*Rule{rule3}}, - {input: "data[x][y]", expected: []*Rule{rule1, rule2, rule3}, excludeHidden: true}, // old behaviour of GetRulesDynamic + {input: "data[x][y]", expected: []*Rule{rule1, rule2d, rule2, rule3}, excludeHidden: true}, // old behaviour of GetRulesDynamic } for _, tc := range tests { @@ -7452,7 +8522,19 @@ deny { } else if !strings.HasPrefix(c.Errors.Error(), "1 error occurred: 7:2: rego_type_error: undefined ref: input.Something.Y.X.ThisDoesNotExist") { t.Errorf("unexpected error: %v", c.Errors.Error()) } +} +func modules(ms ...string) []*Module { + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + mods := make([]*Module, len(ms)) + for i, m := range ms { + var err error + mods[i], err = ParseModuleWithOpts(fmt.Sprintf("mod%d.rego", i), m, opts) + if err != nil { + panic(err) + } + } + return mods } func TestCompilerWithRecursiveSchemaAvoidRace(t *testing.T) { diff --git a/ast/conflicts.go b/ast/conflicts.go index d1013ccedd..c2713ad576 100644 --- a/ast/conflicts.go +++ b/ast/conflicts.go @@ -27,7 +27,12 @@ func CheckPathConflicts(c *Compiler, exists func([]string) (bool, error)) Errors func checkDocumentConflicts(node *TreeNode, exists func([]string) (bool, error), path []string) Errors { - path = append(path, string(node.Key.(String))) + switch key := node.Key.(type) { + case String: + path = append(path, string(key)) + default: // other key types cannot conflict with data + return nil + } if len(node.Values) > 0 { s := strings.Join(path, "/") diff --git a/ast/env.go b/ast/env.go index 60006baafd..5313a595b6 100644 --- a/ast/env.go +++ b/ast/env.go @@ -5,6 +5,8 @@ package ast import ( + "fmt" + "github.com/open-policy-agent/opa/types" "github.com/open-policy-agent/opa/util" ) @@ -195,9 +197,17 @@ func (env *TypeEnv) getRefRecExtent(node *typeTreeNode) types.Type { child := v.(*typeTreeNode) tpe := env.getRefRecExtent(child) - // TODO(tsandall): handle non-string keys? - if s, ok := key.(String); ok { - children = append(children, types.NewStaticProperty(string(s), tpe)) + + // NOTE(sr): Converting to Golang-native types here is an extension of what we did + // before -- only supporting strings. But since we cannot differentiate sets and arrays + // that way, we could reconsider. + switch key.(type) { + case String, Number, Boolean: // skip anything else + propKey, err := JSON(key) + if err != nil { + panic(fmt.Errorf("unreachable, ValueToInterface: %w", err)) + } + children = append(children, types.NewStaticProperty(propKey, tpe)) } return false }) diff --git a/ast/index.go b/ast/index.go index bcbb5c1765..0bd775b062 100644 --- a/ast/index.go +++ b/ast/index.go @@ -32,7 +32,7 @@ type RuleIndex interface { // IndexResult contains the result of an index lookup. type IndexResult struct { - Kind DocKind + Kind RuleKind Rules []*Rule Else map[*Rule][]*Rule Default *Rule @@ -40,7 +40,7 @@ type IndexResult struct { } // NewIndexResult returns a new IndexResult object. -func NewIndexResult(kind DocKind) *IndexResult { +func NewIndexResult(kind RuleKind) *IndexResult { return &IndexResult{ Kind: kind, Else: map[*Rule][]*Rule{}, @@ -57,7 +57,7 @@ type baseDocEqIndex struct { isVirtual func(Ref) bool root *trieNode defaultRule *Rule - kind DocKind + kind RuleKind } func newBaseDocEqIndex(isVirtual func(Ref) bool) *baseDocEqIndex { @@ -73,7 +73,7 @@ func (i *baseDocEqIndex) Build(rules []*Rule) bool { return false } - i.kind = rules[0].Head.DocKind() + i.kind = rules[0].Head.RuleKind() indices := newrefindices(i.isVirtual) // build indices for each rule. diff --git a/ast/index_test.go b/ast/index_test.go index 1c387e9c0a..dff4167f57 100644 --- a/ast/index_test.go +++ b/ast/index_test.go @@ -60,6 +60,20 @@ func TestBaseDocEqIndexing(t *testing.T) { input.b = 1 }`, opts) + refMod := MustParseModuleWithOpts(`package test + + ref.single.value.ground = x if x := input.x + + ref.single.value.key[k] = v if { k := input.k; v := input.v } + + ref.multi.value.ground contains x if x := input.x + + ref.multiple.single.value.ground = x if x := input.x + ref.multiple.single.value[y] = x if { x := input.x; y := index.y } + + # ref.multi.value.key[k] contains v if { k := input.k; v := input.v } # not supported yet + `, opts) + module := MustParseModule(` package test @@ -70,6 +84,7 @@ func TestBaseDocEqIndexing(t *testing.T) { input.x = 3 input.y = 4 } + scalars { input.x = 0 @@ -209,6 +224,7 @@ func TestBaseDocEqIndexing(t *testing.T) { note string module *Module ruleset string + ruleRef Ref input string unknowns []string args []Value @@ -632,6 +648,41 @@ func TestBaseDocEqIndexing(t *testing.T) { input: `{"a": [1]}`, expectedRS: RuleSet([]*Rule{everyModWithDomain.Rules[0]}), }, + { + note: "ref: single value, ground ref", + module: refMod, + ruleRef: MustParseRef("ref.single.value.ground"), + input: `{"x": 1}`, + expectedRS: RuleSet([]*Rule{refMod.Rules[0]}), + }, + { + note: "ref: single value, ground ref and non-ground ref", + module: refMod, + ruleRef: MustParseRef("ref.multiple.single.value"), + input: `{"x": 1, "y": "Y"}`, + expectedRS: RuleSet([]*Rule{refMod.Rules[3], refMod.Rules[4]}), + }, + { + note: "ref: single value, var in ref", + module: refMod, + ruleRef: MustParseRef("ref.single.value.key[k]"), + input: `{"k": 1, "v": 2}`, + expectedRS: RuleSet([]*Rule{refMod.Rules[1]}), + }, + { + note: "ref: multi value, ground ref", + module: refMod, + ruleRef: MustParseRef("ref.multi.value.ground"), + input: `{"x": 1}`, + expectedRS: RuleSet([]*Rule{refMod.Rules[2]}), + }, + // { + // note: "ref: multi value, var in ref", + // module: refMod, + // ruleRef: MustParseRef("ref.multi.value.key[k]"), + // input: `{"k": 1, "v": 2}`, + // expectedRS: RuleSet([]*Rule{refMod.Rules[3]}), + // }, } for _, tc := range tests { @@ -642,10 +693,19 @@ func TestBaseDocEqIndexing(t *testing.T) { } rules := []*Rule{} for _, rule := range module.Rules { - if rule.Head.Name == Var(tc.ruleset) { - rules = append(rules, rule) + if tc.ruleRef == nil { + if rule.Head.Name == Var(tc.ruleset) { + rules = append(rules, rule) + } + } else { + if rule.Head.Ref().HasPrefix(tc.ruleRef) { + rules = append(rules, rule) + } } } + if len(rules) == 0 { + t.Fatal("selected empty ruleset") + } var input *Term if tc.input != "" { diff --git a/ast/parser.go b/ast/parser.go index fd1b407246..9fe41ca368 100644 --- a/ast/parser.go +++ b/ast/parser.go @@ -567,14 +567,41 @@ func (p *Parser) parseRules() []*Rule { return []*Rule{&rule} } - hasIf := false - if p.s.tok == tokens.If { - hasIf = true + if usesContains && !rule.Head.Reference.IsGround() { + p.error(p.s.Loc(), "multi-value rules need ground refs") + return nil } - if hasIf && !usesContains && rule.Head.Key != nil && rule.Head.Value == nil { - p.illegal("invalid for partial set rule %s (use `contains`)", rule.Head.Name) - return nil + // back-compat with `p[x] { ... }`` + hasIf := p.s.tok == tokens.If + + // p[x] if ... becomes a single-value rule p[x] + if hasIf && !usesContains && len(rule.Head.Ref()) == 2 { + if rule.Head.Value == nil { + rule.Head.Value = BooleanTerm(true).SetLocation(rule.Head.Location) + } else { + // p[x] = y if becomes a single-value rule p[x] with value y, but needs name for compat + v, ok := rule.Head.Ref()[0].Value.(Var) + if !ok { + return nil + } + rule.Head.Name = v + } + } + + // p[x] becomes a multi-value rule p + if !hasIf && !usesContains && + len(rule.Head.Args) == 0 && // not a function + len(rule.Head.Ref()) == 2 { // ref like 'p[x]' + v, ok := rule.Head.Ref()[0].Value.(Var) + if !ok { + return nil + } + rule.Head.Name = v + rule.Head.Key = rule.Head.Ref()[1] + if rule.Head.Value == nil { + rule.Head.SetRef(rule.Head.Ref()[:len(rule.Head.Ref())-1]) + } } switch { @@ -743,67 +770,65 @@ func (p *Parser) parseElse(head *Head) *Rule { func (p *Parser) parseHead(defaultRule bool) (*Head, bool) { - var head Head - head.SetLoc(p.s.Loc()) - + head := &Head{} + loc := p.s.Loc() defer func() { - head.Location.Text = p.s.Text(head.Location.Offset, p.s.lastEnd) + if head != nil { + head.SetLoc(loc) + head.Location.Text = p.s.Text(head.Location.Offset, p.s.lastEnd) + } }() - if term := p.parseVar(); term != nil { - head.Name = term.Value.(Var) - } else { - p.illegal("expected rule head name") + term := p.parseVar() + if term == nil { + return nil, false } - p.scan() - - if p.s.tok == tokens.LParen { - p.scan() - if p.s.tok != tokens.RParen { - head.Args = p.parseTermList(tokens.RParen, nil) - if head.Args == nil { - return nil, false - } + ref := p.parseTermFinish(term, true) + if ref == nil { + p.illegal("expected rule head name") + return nil, false + } + + switch x := ref.Value.(type) { + case Var: + head = NewHead(x) + case Ref: + head = RefHead(x) + case Call: + op, args := x[0], x[1:] + var ref Ref + switch y := op.Value.(type) { + case Var: + ref = Ref{op} + case Ref: + ref = y } - p.scan() + head = RefHead(ref) + head.Args = append([]*Term{}, args...) - if p.s.tok == tokens.LBrack { - return nil, false - } + default: + return nil, false } - if p.s.tok == tokens.LBrack { - p.scan() - head.Key = p.parseTermInfixCall() - if head.Key == nil { - p.illegal("expected rule key term (e.g., %s[] { ... })", head.Name) - } - if p.s.tok != tokens.RBrack { - if _, ok := futureKeywords[head.Name.String()]; ok { - p.hint("`import future.keywords.%[1]s` for '%[1]s' keyword", head.Name.String()) - } - p.illegal("non-terminated rule key") - } - p.scan() - } + name := head.Ref().String() switch p.s.tok { - case tokens.Contains: + case tokens.Contains: // NOTE: no Value for `contains` heads, we return here p.scan() head.Key = p.parseTermInfixCall() if head.Key == nil { - p.illegal("expected rule key term (e.g., %s contains { ... })", head.Name) + p.illegal("expected rule key term (e.g., %s contains { ... })", name) } + return head, true - return &head, true case tokens.Unify: p.scan() head.Value = p.parseTermInfixCall() if head.Value == nil { - p.illegal("expected rule value term (e.g., %s[%s] = { ... })", head.Name, head.Key) + // FIX HEAD.String() + p.illegal("expected rule value term (e.g., %s[%s] = { ... })", name, head.Key) } - case tokens.Assign: s := p.save() p.scan() @@ -813,22 +838,23 @@ func (p *Parser) parseHead(defaultRule bool) (*Head, bool) { p.restore(s) switch { case len(head.Args) > 0: - p.illegal("expected function value term (e.g., %s(...) := { ... })", head.Name) + p.illegal("expected function value term (e.g., %s(...) := { ... })", name) case head.Key != nil: - p.illegal("expected partial rule value term (e.g., %s[...] := { ... })", head.Name) + p.illegal("expected partial rule value term (e.g., %s[...] := { ... })", name) case defaultRule: - p.illegal("expected default rule value term (e.g., default %s := )", head.Name) + p.illegal("expected default rule value term (e.g., default %s := )", name) default: - p.illegal("expected rule value term (e.g., %s := { ... })", head.Name) + p.illegal("expected rule value term (e.g., %s := { ... })", name) } } } if head.Value == nil && head.Key == nil { - head.Value = BooleanTerm(true).SetLocation(head.Location) + if len(head.Ref()) != 2 || len(head.Args) > 0 { + head.Value = BooleanTerm(true).SetLocation(head.Location) + } } - - return &head, false + return head, false } func (p *Parser) parseBody(end tokens.Token) Body { @@ -1348,17 +1374,18 @@ func (p *Parser) parseTerm() *Term { p.illegalToken() } - term = p.parseTermFinish(term) + term = p.parseTermFinish(term, false) p.parsedTermCachePush(term, s0) return term } -func (p *Parser) parseTermFinish(head *Term) *Term { +func (p *Parser) parseTermFinish(head *Term, skipws bool) *Term { if head == nil { return nil } offset := p.s.loc.Offset - p.scanWS() + p.doScan(skipws) + switch p.s.tok { case tokens.LParen, tokens.Dot, tokens.LBrack: return p.parseRef(head, offset) diff --git a/ast/parser_ext.go b/ast/parser_ext.go index 41eb4443ec..7f5a69424b 100644 --- a/ast/parser_ext.go +++ b/ast/parser_ext.go @@ -154,12 +154,15 @@ func ParseRuleFromExpr(module *Module, expr *Expr) (*Rule, error) { } if _, ok := expr.Terms.(*SomeDecl); ok { - return nil, errors.New("some declarations cannot be used for rule head") + return nil, errors.New("'some' declarations cannot be used for rule head") } if term, ok := expr.Terms.(*Term); ok { switch v := term.Value.(type) { case Ref: + if len(v) > 2 { // 2+ dots + return ParseCompleteDocRuleWithDotsFromTerm(module, term) + } return ParsePartialSetDocRuleFromTerm(module, term) default: return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(v)) @@ -207,18 +210,17 @@ func parseCompleteRuleFromEq(module *Module, expr *Expr) (rule *Rule, err error) return nil, errors.New("assignment requires two operands") } - rule, err = ParseCompleteDocRuleFromEqExpr(module, lhs, rhs) - + rule, err = ParseRuleFromCallEqExpr(module, lhs, rhs) if err == nil { return rule, nil } - rule, err = ParseRuleFromCallEqExpr(module, lhs, rhs) + rule, err = ParsePartialObjectDocRuleFromEqExpr(module, lhs, rhs) if err == nil { return rule, nil } - return ParsePartialObjectDocRuleFromEqExpr(module, lhs, rhs) + return ParseCompleteDocRuleFromEqExpr(module, lhs, rhs) } // ParseCompleteDocRuleFromAssignmentExpr returns a rule if the expression can @@ -239,39 +241,55 @@ func ParseCompleteDocRuleFromAssignmentExpr(module *Module, lhs, rhs *Term) (*Ru // ParseCompleteDocRuleFromEqExpr returns a rule if the expression can be // interpreted as a complete document definition. func ParseCompleteDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, error) { - - var name Var - - if RootDocumentRefs.Contains(lhs) { - name = lhs.Value.(Ref)[0].Value.(Var) - } else if v, ok := lhs.Value.(Var); ok { - name = v + var head *Head + + if v, ok := lhs.Value.(Var); ok { + head = NewHead(v) + } else if r, ok := lhs.Value.(Ref); ok { // groundness ? + head = RefHead(r) + if len(r) > 1 && !r[len(r)-1].IsGround() { + return nil, fmt.Errorf("ref not ground") + } } else { return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(lhs.Value)) } + head.Value = rhs + head.Location = lhs.Location - rule := &Rule{ + return &Rule{ Location: lhs.Location, - Head: &Head{ - Location: lhs.Location, - Name: name, - Value: rhs, - }, + Head: head, Body: NewBody( NewExpr(BooleanTerm(true).SetLocation(rhs.Location)).SetLocation(rhs.Location), ), Module: module, + }, nil +} + +func ParseCompleteDocRuleWithDotsFromTerm(module *Module, term *Term) (*Rule, error) { + ref, ok := term.Value.(Ref) + if !ok { + return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(term.Value)) } - return rule, nil + head := RefHead(ref, BooleanTerm(true).SetLocation(term.Location)) + head.Location = term.Location + + return &Rule{ + Location: term.Location, + Head: head, + Body: NewBody( + NewExpr(BooleanTerm(true).SetLocation(term.Location)).SetLocation(term.Location), + ), + Module: module, + }, nil } // ParsePartialObjectDocRuleFromEqExpr returns a rule if the expression can be // interpreted as a partial object document definition. func ParsePartialObjectDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, error) { - ref, ok := lhs.Value.(Ref) - if !ok || len(ref) != 2 { + if !ok { return nil, fmt.Errorf("%v cannot be used as rule name", TypeName(lhs.Value)) } @@ -279,17 +297,16 @@ func ParsePartialObjectDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, return nil, fmt.Errorf("%vs cannot be used as rule name", TypeName(ref[0].Value)) } - name := ref[0].Value.(Var) - key := ref[1] + head := RefHead(ref, rhs) + if len(ref) == 2 { // backcompat for naked `foo.bar = "baz"` statements + head.Name = ref[0].Value.(Var) + head.Key = ref[1] + } + head.Location = rhs.Location rule := &Rule{ Location: rhs.Location, - Head: &Head{ - Location: rhs.Location, - Name: name, - Key: key, - Value: rhs, - }, + Head: head, Body: NewBody( NewExpr(BooleanTerm(true).SetLocation(rhs.Location)).SetLocation(rhs.Location), ), @@ -304,26 +321,24 @@ func ParsePartialObjectDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, func ParsePartialSetDocRuleFromTerm(module *Module, term *Term) (*Rule, error) { ref, ok := term.Value.(Ref) - if !ok { + if !ok || len(ref) == 1 { return nil, fmt.Errorf("%vs cannot be used for rule head", TypeName(term.Value)) } - if len(ref) != 2 { - return nil, fmt.Errorf("refs cannot be used for rule") - } - - name, ok := ref[0].Value.(Var) - if !ok { - return nil, fmt.Errorf("%vs cannot be used as rule name", TypeName(ref[0].Value)) + head := RefHead(ref) + if len(ref) == 2 { + v, ok := ref[0].Value.(Var) + if !ok { + return nil, fmt.Errorf("%vs cannot be used for rule head", TypeName(term.Value)) + } + head = NewHead(v) + head.Key = ref[1] } + head.Location = term.Location rule := &Rule{ Location: term.Location, - Head: &Head{ - Location: term.Location, - Name: name, - Key: ref[1], - }, + Head: head, Body: NewBody( NewExpr(BooleanTerm(true).SetLocation(term.Location)).SetLocation(term.Location), ), @@ -347,21 +362,15 @@ func ParseRuleFromCallEqExpr(module *Module, lhs, rhs *Term) (*Rule, error) { return nil, fmt.Errorf("%vs cannot be used in function signature", TypeName(call[0].Value)) } - name, ok := ref[0].Value.(Var) - if !ok { - return nil, fmt.Errorf("%vs cannot be used in function signature", TypeName(ref[0].Value)) - } + head := RefHead(ref, rhs) + head.Location = lhs.Location + head.Args = Args(call[1:]) rule := &Rule{ Location: lhs.Location, - Head: &Head{ - Location: lhs.Location, - Name: name, - Args: Args(call[1:]), - Value: rhs, - }, - Body: NewBody(NewExpr(BooleanTerm(true).SetLocation(rhs.Location)).SetLocation(rhs.Location)), - Module: module, + Head: head, + Body: NewBody(NewExpr(BooleanTerm(true).SetLocation(rhs.Location)).SetLocation(rhs.Location)), + Module: module, } return rule, nil @@ -376,19 +385,15 @@ func ParseRuleFromCallExpr(module *Module, terms []*Term) (*Rule, error) { } loc := terms[0].Location - args := terms[1:] - value := BooleanTerm(true).SetLocation(loc) + head := RefHead(terms[0].Value.(Ref), BooleanTerm(true).SetLocation(loc)) + head.Location = loc + head.Args = terms[1:] rule := &Rule{ Location: loc, - Head: &Head{ - Location: loc, - Name: Var(terms[0].String()), - Args: args, - Value: value, - }, - Module: module, - Body: NewBody(NewExpr(BooleanTerm(true).SetLocation(loc)).SetLocation(loc)), + Head: head, + Module: module, + Body: NewBody(NewExpr(BooleanTerm(true).SetLocation(loc)).SetLocation(loc)), } return rule, nil } @@ -518,15 +523,15 @@ func ParseRef(input string) (Ref, error) { return ref, nil } -// ParseRule returns exactly one rule. +// ParseRuleWithOpts returns exactly one rule. // If multiple rules are parsed, an error is returned. -func ParseRule(input string) (*Rule, error) { - stmts, _, err := ParseStatements("", input) +func ParseRuleWithOpts(input string, opts ParserOptions) (*Rule, error) { + stmts, _, err := ParseStatementsWithOpts("", input, opts) if err != nil { return nil, err } if len(stmts) != 1 { - return nil, fmt.Errorf("expected exactly one statement (rule)") + return nil, fmt.Errorf("expected exactly one statement (rule), got %v = %T, %T", stmts, stmts[0], stmts[1]) } rule, ok := stmts[0].(*Rule) if !ok { @@ -535,6 +540,12 @@ func ParseRule(input string) (*Rule, error) { return rule, nil } +// ParseRule returns exactly one rule. +// If multiple rules are parsed, an error is returned. +func ParseRule(input string) (*Rule, error) { + return ParseRuleWithOpts(input, ParserOptions{}) +} + // ParseStatement returns exactly one statement. // A statement might be a term, expression, rule, etc. Regardless, // this function expects *exactly* one statement. If multiple @@ -611,14 +622,14 @@ func parseModule(filename string, stmts []Statement, comments []*Comment) (*Modu rule, err := ParseRuleFromBody(mod, stmt) if err != nil { errs = append(errs, NewError(ParseErr, stmt[0].Location, err.Error())) - } else { - mod.Rules = append(mod.Rules, rule) - - // NOTE(tsandall): the statement should now be interpreted as a - // rule so update the statement list. This is important for the - // logic below that associates annotations with statements. - stmts[i+1] = rule + continue } + mod.Rules = append(mod.Rules, rule) + + // NOTE(tsandall): the statement should now be interpreted as a + // rule so update the statement list. This is important for the + // logic below that associates annotations with statements. + stmts[i+1] = rule case *Package: errs = append(errs, NewError(ParseErr, stmt.Loc(), "unexpected package")) case *Annotations: diff --git a/ast/parser_test.go b/ast/parser_test.go index b9a0e8199d..4d610f9249 100644 --- a/ast/parser_test.go +++ b/ast/parser_test.go @@ -7,6 +7,7 @@ package ast import ( "bytes" "encoding/json" + "errors" "fmt" "reflect" "strings" @@ -1413,9 +1414,10 @@ func TestRule(t *testing.T) { assertParseRule(t, "default w/ assignment", `default allow := false`, &Rule{ Default: true, Head: &Head{ - Name: "allow", - Value: BooleanTerm(false), - Assign: true, + Name: "allow", + Reference: Ref{VarTerm("allow")}, + Value: BooleanTerm(false), + Assign: true, }, Body: NewBody(NewExpr(BooleanTerm(true))), }) @@ -1439,9 +1441,10 @@ func TestRule(t *testing.T) { }) fxy := &Head{ - Name: Var("f"), - Args: Args{VarTerm("x")}, - Value: VarTerm("y"), + Name: Var("f"), + Reference: Ref{VarTerm("f")}, + Args: Args{VarTerm("x")}, + Value: VarTerm("y"), } assertParseRule(t, "identity", `f(x) = y { y = x }`, &Rule{ @@ -1453,9 +1456,10 @@ func TestRule(t *testing.T) { assertParseRule(t, "composite arg", `f([x, y]) = z { split(x, y, z) }`, &Rule{ Head: &Head{ - Name: Var("f"), - Args: Args{ArrayTerm(VarTerm("x"), VarTerm("y"))}, - Value: VarTerm("z"), + Name: Var("f"), + Reference: Ref{VarTerm("f")}, + Args: Args{ArrayTerm(VarTerm("x"), VarTerm("y"))}, + Value: VarTerm("z"), }, Body: NewBody( Split.Expr(VarTerm("x"), VarTerm("y"), VarTerm("z")), @@ -1464,9 +1468,10 @@ func TestRule(t *testing.T) { assertParseRule(t, "composite result", `f(1) = [x, y] { split("foo.bar", x, y) }`, &Rule{ Head: &Head{ - Name: Var("f"), - Args: Args{IntNumberTerm(1)}, - Value: ArrayTerm(VarTerm("x"), VarTerm("y")), + Name: Var("f"), + Reference: Ref{VarTerm("f")}, + Args: Args{IntNumberTerm(1)}, + Value: ArrayTerm(VarTerm("x"), VarTerm("y")), }, Body: NewBody( Split.Expr(StringTerm("foo.bar"), VarTerm("x"), VarTerm("y")), @@ -1475,7 +1480,8 @@ func TestRule(t *testing.T) { assertParseRule(t, "expr terms: key", `p[f(x) + g(x)] { true }`, &Rule{ Head: &Head{ - Name: Var("p"), + Name: Var("p"), + Reference: Ref{VarTerm("p")}, Key: Plus.Call( CallTerm(RefTerm(VarTerm("f")), VarTerm("x")), CallTerm(RefTerm(VarTerm("g")), VarTerm("x")), @@ -1486,7 +1492,8 @@ func TestRule(t *testing.T) { assertParseRule(t, "expr terms: value", `p = f(x) + g(x) { true }`, &Rule{ Head: &Head{ - Name: Var("p"), + Name: Var("p"), + Reference: Ref{VarTerm("p")}, Value: Plus.Call( CallTerm(RefTerm(VarTerm("f")), VarTerm("x")), CallTerm(RefTerm(VarTerm("g")), VarTerm("x")), @@ -1497,7 +1504,8 @@ func TestRule(t *testing.T) { assertParseRule(t, "expr terms: args", `p(f(x) + g(x)) { true }`, &Rule{ Head: &Head{ - Name: Var("p"), + Name: Var("p"), + Reference: Ref{VarTerm("p")}, Args: Args{ Plus.Call( CallTerm(RefTerm(VarTerm("f")), VarTerm("x")), @@ -1511,25 +1519,28 @@ func TestRule(t *testing.T) { assertParseRule(t, "assignment operator", `x := 1 { true }`, &Rule{ Head: &Head{ - Name: Var("x"), - Value: IntNumberTerm(1), - Assign: true, + Name: Var("x"), + Reference: Ref{VarTerm("x")}, + Value: IntNumberTerm(1), + Assign: true, }, Body: NewBody(NewExpr(BooleanTerm(true))), }) assertParseRule(t, "else assignment", `x := 1 { false } else := 2`, &Rule{ Head: &Head{ - Name: "x", - Value: IntNumberTerm(1), - Assign: true, + Name: "x", // ha! clever! + Reference: Ref{VarTerm("x")}, + Value: IntNumberTerm(1), + Assign: true, }, Body: NewBody(NewExpr(BooleanTerm(false))), Else: &Rule{ Head: &Head{ - Name: "x", - Value: IntNumberTerm(2), - Assign: true, + Name: "x", + Reference: Ref{VarTerm("x")}, + Value: IntNumberTerm(2), + Assign: true, }, Body: NewBody(NewExpr(BooleanTerm(true))), }, @@ -1537,18 +1548,20 @@ func TestRule(t *testing.T) { assertParseRule(t, "partial assignment", `p[x] := y { true }`, &Rule{ Head: &Head{ - Name: "p", - Value: VarTerm("y"), - Key: VarTerm("x"), - Assign: true, + Name: "p", + Reference: MustParseRef("p[x]"), + Value: VarTerm("y"), + Key: VarTerm("x"), + Assign: true, }, Body: NewBody(NewExpr(BooleanTerm(true))), }) assertParseRule(t, "function assignment", `f(x) := y { true }`, &Rule{ Head: &Head{ - Name: "f", - Value: VarTerm("y"), + Name: "f", + Reference: Ref{VarTerm("f")}, + Value: VarTerm("y"), Args: Args{ VarTerm("x"), }, @@ -1562,11 +1575,10 @@ func TestRule(t *testing.T) { assertParseErrorContains(t, "empty rule body", "p {}", "rego_parse_error: found empty body") assertParseErrorContains(t, "unmatched braces", `f(x) = y { trim(x, ".", y) `, `rego_parse_error: unexpected eof token: expected \n or ; or }`) - // TODO: how to highlight that assignment is incorrect here? assertParseErrorContains(t, "no output", `f(_) = { "foo" = "bar" }`, "rego_parse_error: unexpected eq token: expected rule value term") assertParseErrorContains(t, "no output", `f(_) := { "foo" = "bar" }`, "rego_parse_error: unexpected assign token: expected function value term") assertParseErrorContains(t, "no output", `f := { "foo" = "bar" }`, "rego_parse_error: unexpected assign token: expected rule value term") - assertParseErrorContains(t, "no output", `f[_] := { "foo" = "bar" }`, "rego_parse_error: unexpected assign token: expected partial rule value term") + assertParseErrorContains(t, "no output", `f[_] := { "foo" = "bar" }`, "rego_parse_error: unexpected assign token: expected rule value term") assertParseErrorContains(t, "no output", `default f :=`, "rego_parse_error: unexpected assign token: expected default rule value term") // TODO(tsandall): improve error checking here. This is a common mistake @@ -1575,19 +1587,21 @@ func TestRule(t *testing.T) { assertParseError(t, "dangling semicolon", "p { true; false; }") assertParseErrorContains(t, "default invalid rule name", `default 0[0`, "unexpected default keyword") - assertParseErrorContains(t, "default invalid rule value", `default a[0`, "illegal default rule (must have a value)") + assertParseErrorContains(t, "default invalid rule value", `default a[0]`, "illegal default rule (must have a value)") assertParseRule(t, "default missing value", `default a`, &Rule{ Default: true, Head: &Head{ - Name: Var("a"), - Value: BooleanTerm(true), + Name: Var("a"), + Reference: Ref{VarTerm("a")}, + Value: BooleanTerm(true), }, Body: NewBody(NewExpr(BooleanTerm(true))), }) assertParseRule(t, "empty arguments", `f() { x := 1 }`, &Rule{ Head: &Head{ - Name: "f", - Value: BooleanTerm(true), + Name: "f", + Reference: Ref{VarTerm("f")}, + Value: BooleanTerm(true), }, Body: MustParseBody(`x := 1`), }) @@ -1598,23 +1612,23 @@ func TestRule(t *testing.T) { assertParseErrorContains(t, "default invalid rule head call", `default a = b`, "illegal default rule (value cannot contain var)") assertParseError(t, "extra braces", `{ a := 1 }`) - assertParseError(t, "invalid rule name dots", `a.b = x { x := 1 }`) - assertParseError(t, "invalid rule name dots and call", `a.b(x) { x := 1 }`) assertParseError(t, "invalid rule name hyphen", `a-b = x { x := 1 }`) assertParseRule(t, "wildcard name", `_ { x == 1 }`, &Rule{ Head: &Head{ - Name: "$0", - Value: BooleanTerm(true), + Name: "$0", + Reference: Ref{VarTerm("$0")}, + Value: BooleanTerm(true), }, Body: MustParseBody(`x == 1`), }) assertParseRule(t, "partial object array key", `p[[a, 1, 2]] = x { a := 1; x := "foo" }`, &Rule{ Head: &Head{ - Name: "p", - Key: ArrayTerm(VarTerm("a"), NumberTerm("1"), NumberTerm("2")), - Value: VarTerm("x"), + Name: "p", + Reference: MustParseRef("p[[a,1,2]]"), + Key: ArrayTerm(VarTerm("a"), NumberTerm("1"), NumberTerm("2")), + Value: VarTerm("x"), }, Body: MustParseBody(`a := 1; x := "foo"`), }) @@ -1623,7 +1637,7 @@ func TestRule(t *testing.T) { } func TestRuleContains(t *testing.T) { - opts := ParserOptions{FutureKeywords: []string{"contains"}} + opts := ParserOptions{FutureKeywords: []string{"contains", "if"}} tests := []struct { note string @@ -1646,6 +1660,28 @@ func TestRuleContains(t *testing.T) { Body: NewBody(NewExpr(BooleanTerm(true))), }, }, + { + note: "ref head, no body", + rule: `p.q contains "x"`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q"), + Key: StringTerm("x"), + }, + Body: NewBody(NewExpr(BooleanTerm(true))), + }, + }, + { + note: "ref head", + rule: `p.q contains "x" { true }`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q"), + Key: StringTerm("x"), + }, + Body: NewBody(NewExpr(BooleanTerm(true))), + }, + }, { note: "set with var element", rule: `deny contains msg { msg := "nonono" }`, @@ -1699,6 +1735,17 @@ func TestRuleIf(t *testing.T) { }, }, }, + { + note: "ref head, complete", + rule: `p.q if { true }`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q"), + Value: BooleanTerm(true), + }, + Body: NewBody(NewExpr(BooleanTerm(true))), + }, + }, { note: "complete, normal body", rule: `p if { x := 10; x > y }`, @@ -1712,16 +1759,18 @@ func TestRuleIf(t *testing.T) { rule: `p := "yes" if { 10 > y } else := "no" { 10 <= y }`, exp: &Rule{ Head: &Head{ - Name: Var("p"), - Value: StringTerm("yes"), - Assign: true, + Reference: Ref{VarTerm("p")}, + Name: Var("p"), + Value: StringTerm("yes"), + Assign: true, }, Body: MustParseBody(`10 > y`), Else: &Rule{ Head: &Head{ - Name: Var("p"), - Value: StringTerm("no"), - Assign: true, + Reference: Ref{VarTerm("p")}, + Name: Var("p"), + Value: StringTerm("no"), + Assign: true, }, Body: MustParseBody(`10 <= y`), }, @@ -1732,16 +1781,18 @@ func TestRuleIf(t *testing.T) { rule: `p := "yes" if { 10 > y } else := "no" if { 10 <= y }`, exp: &Rule{ Head: &Head{ - Name: Var("p"), - Value: StringTerm("yes"), - Assign: true, + Reference: Ref{VarTerm("p")}, + Name: Var("p"), + Value: StringTerm("yes"), + Assign: true, }, Body: MustParseBody(`10 > y`), Else: &Rule{ Head: &Head{ - Name: Var("p"), - Value: StringTerm("no"), - Assign: true, + Reference: Ref{VarTerm("p")}, + Name: Var("p"), + Value: StringTerm("no"), + Assign: true, }, Body: MustParseBody(`10 <= y`), }, @@ -1795,8 +1846,9 @@ func TestRuleIf(t *testing.T) { Body: MustParseBody(`1 > 2`), Else: &Rule{ Head: &Head{ - Name: Var("p"), - Value: NumberTerm("42"), + Reference: Ref{VarTerm("p")}, + Name: Var("p"), + Value: NumberTerm("42"), }, Body: MustParseBody(`2 > 1`), }, @@ -1815,9 +1867,10 @@ func TestRuleIf(t *testing.T) { rule: `f(x) = y if y := x + 1`, exp: &Rule{ Head: &Head{ - Name: Var("f"), - Args: []*Term{VarTerm("x")}, - Value: VarTerm("y"), + Reference: Ref{VarTerm("f")}, + Name: Var("f"), + Args: []*Term{VarTerm("x")}, + Value: VarTerm("y"), }, Body: MustParseBody(`y := x + 1`), }, @@ -1827,9 +1880,10 @@ func TestRuleIf(t *testing.T) { rule: `f(xs) if every x in xs { x != 0 }`, exp: &Rule{ Head: &Head{ - Name: Var("f"), - Args: []*Term{VarTerm("xs")}, - Value: BooleanTerm(true), + Reference: Ref{VarTerm("f")}, + Name: Var("f"), + Args: []*Term{VarTerm("xs")}, + Value: BooleanTerm(true), }, Body: MustParseBodyWithOpts(`every x in xs { x != 0 }`, opts), }, @@ -1838,7 +1892,11 @@ func TestRuleIf(t *testing.T) { note: "object", rule: `p["foo"] = "bar" if { true }`, exp: &Rule{ - Head: NewHead(Var("p"), StringTerm("foo"), StringTerm("bar")), + Head: &Head{ + Name: Var("p"), + Reference: MustParseRef("p.foo"), + Value: StringTerm("bar"), + }, Body: NewBody(NewExpr(BooleanTerm(true))), }, }, @@ -1846,7 +1904,11 @@ func TestRuleIf(t *testing.T) { note: "object, shorthand", rule: `p["foo"] = "bar" if true`, exp: &Rule{ - Head: NewHead(Var("p"), StringTerm("foo"), StringTerm("bar")), + Head: &Head{ + Name: Var("p"), + Reference: MustParseRef("p.foo"), + Value: StringTerm("bar"), + }, Body: NewBody(NewExpr(BooleanTerm(true))), }, }, @@ -1857,7 +1919,11 @@ func TestRuleIf(t *testing.T) { y := "bar" }`, exp: &Rule{ - Head: NewHead(Var("p"), VarTerm("x"), VarTerm("y")), + Head: &Head{ + Name: Var("p"), + Reference: MustParseRef("p[x]"), + Value: VarTerm("y"), + }, Body: MustParseBody(`x := "foo"; y := "bar"`), }, }, @@ -1885,6 +1951,28 @@ func TestRuleIf(t *testing.T) { Body: MustParseBody(`x := "foo"`), }, }, + { + note: "partial set+if, shorthand", // these are now Head.Ref rules, previously forbidden + rule: `p[x] if x := 1`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p[x]"), + Value: BooleanTerm(true), + }, + Body: MustParseBody(`x := 1`), + }, + }, + { + note: "partial set+if", // these are now Head.Ref rules, previously forbidden + rule: `p[x] if { x := 1 }`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p[x]"), + Value: BooleanTerm(true), + }, + Body: MustParseBody(`x := 1`), + }, + }, } for _, tc := range tests { @@ -1892,28 +1980,177 @@ func TestRuleIf(t *testing.T) { assertParseRule(t, tc.note, tc.rule, tc.exp, opts) }) } +} - errors := []struct { +func TestRuleRefHeads(t *testing.T) { + opts := ParserOptions{FutureKeywords: []string{"contains", "if", "every"}} + trueBody := NewBody(NewExpr(BooleanTerm(true))) + + tests := []struct { note string rule string - err string + exp *Rule }{ { - note: "partial set+if, shorthand", - rule: `p[x] if x := 1`, - err: "rego_parse_error: unexpected if keyword: invalid for partial set rule p (use `contains`)", + note: "single-value rule", + rule: "p.q.r = 1 if true", + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.r"), + Value: IntNumberTerm(1), + }, + Body: trueBody, + }, }, { - note: "partial set+if", - rule: `p[x] if { x := 1 }`, - err: "rego_parse_error: unexpected if keyword: invalid for partial set rule p (use `contains`)", + note: "single-value with brackets, string key", + rule: `p.q["r"] = 1 if true`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.r"), + Value: IntNumberTerm(1), + }, + Body: trueBody, + }, + }, + { + note: "single-value with brackets, number key", + rule: `p.q[2] = 1 if true`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q[2]"), + Value: IntNumberTerm(1), + }, + Body: trueBody, + }, + }, + { + note: "single-value with brackets, no value", + rule: `p.q[2] if true`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q[2]"), + Value: BooleanTerm(true), + }, + Body: trueBody, + }, + }, + { + note: "single-value with brackets, var key", + rule: `p.q[x] = 1 if x := 2`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q[x]"), + Value: IntNumberTerm(1), + }, + Body: MustParseBody("x := 2"), + }, + }, + { + note: "single-value with brackets, var key, no dot", + rule: `p[x] = 1 if x := 2`, + exp: &Rule{ + Head: &Head{ + Name: Var("p"), + Reference: MustParseRef("p[x]"), + Value: IntNumberTerm(1), + }, + Body: MustParseBody("x := 2"), + }, + }, + { + note: "multi-value, simple", + rule: `p.q.r contains x if x := 2`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.r"), + Key: VarTerm("x"), + }, + Body: MustParseBody("x := 2"), + }, + }, + { + note: "backcompat: multi-value, no dot", + rule: `p[x] { x := 2 }`, // no "if", which triggers ref-interpretation + exp: &Rule{ + Head: &Head{ + Name: "p", + Reference: Ref{VarTerm("p")}, // we're defining p as multi-val rule + Key: VarTerm("x"), + }, + Body: MustParseBody("x := 2"), + }, + }, + { + note: "backcompat: single-value, no dot", + rule: `p[x] = 3 { x := 2 }`, + exp: &Rule{ + Head: &Head{ + Name: "p", + Reference: MustParseRef("p[x]"), + Key: VarTerm("x"), // not used + Value: IntNumberTerm(3), + }, + Body: MustParseBody("x := 2"), + }, + }, + { + note: "backcompat: single-value, no dot, complex object", + rule: `partialobj[x] = {"foo": y} { y = "bar"; x = y }`, + exp: &Rule{ + Head: &Head{ + Name: "partialobj", + Reference: MustParseRef("partialobj[x]"), + Key: VarTerm("x"), // not used + Value: MustParseTerm(`{"foo": y}`), + }, + Body: MustParseBody(`y = "bar"; x = y`), + }, + }, + { + note: "function, simple", + rule: `p.q.f(x) = 1 if true`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.f"), + Args: Args([]*Term{VarTerm("x")}), + Value: IntNumberTerm(1), + }, + Body: trueBody, + }, + }, + { + note: "function, no value", + rule: `p.q.f(x) if true`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.f"), + Args: Args([]*Term{VarTerm("x")}), + Value: BooleanTerm(true), + }, + Body: trueBody, + }, + }, + { + note: "function, with value", + rule: `p.q.f(x) = x + 1 if true`, + exp: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.f"), + Args: Args([]*Term{VarTerm("x")}), + Value: Plus.Call(VarTerm("x"), IntNumberTerm(1)), + }, + Body: trueBody, + }, }, } - for _, tc := range errors { + for _, tc := range tests { t.Run(tc.note, func(t *testing.T) { - assertParseErrorContains(t, tc.note, tc.rule, tc.err, opts) + assertParseRule(t, tc.note, tc.rule, tc.exp, opts) }) } + + // TODO(sr): error cases, non-ground terms anywhere but at the end of the ref } func TestRuleElseKeyword(t *testing.T) { @@ -1968,8 +2205,9 @@ func TestRuleElseKeyword(t *testing.T) { } name := Var("p") + ref := Ref{VarTerm("p")} tr := BooleanTerm(true) - head := &Head{Name: name, Value: tr} + head := &Head{Name: name, Reference: ref, Value: tr} expected := &Module{ Package: MustParsePackage(`package test`), @@ -1986,14 +2224,16 @@ func TestRuleElseKeyword(t *testing.T) { Body: MustParseBody(`"p1_e1"`), Else: &Rule{ Head: &Head{ - Name: Var("p"), - Value: ArrayTerm(NullTerm()), + Name: name, + Reference: ref, + Value: ArrayTerm(NullTerm()), }, Body: MustParseBody(`"p1_e2"`), Else: &Rule{ Head: &Head{ - Name: name, - Value: VarTerm("x"), + Name: name, + Reference: ref, + Value: VarTerm("x"), }, Body: MustParseBody(`x = "p1_e3"`), }, @@ -2006,23 +2246,26 @@ func TestRuleElseKeyword(t *testing.T) { }, { Head: &Head{ - Name: Var("f"), - Args: Args{VarTerm("x")}, - Value: BooleanTerm(true), + Name: Var("f"), + Reference: Ref{VarTerm("f")}, + Args: Args{VarTerm("x")}, + Value: BooleanTerm(true), }, Body: MustParseBody(`x < 100`), Else: &Rule{ Head: &Head{ - Name: Var("f"), - Args: Args{VarTerm("x")}, - Value: BooleanTerm(false), + Name: Var("f"), + Reference: Ref{VarTerm("f")}, + Args: Args{VarTerm("x")}, + Value: BooleanTerm(false), }, Body: MustParseBody(`x > 200`), Else: &Rule{ Head: &Head{ - Name: Var("f"), - Args: Args{VarTerm("x")}, - Value: BooleanTerm(true), + Name: Var("f"), + Reference: Ref{VarTerm("f")}, + Args: Args{VarTerm("x")}, + Value: BooleanTerm(true), }, Body: MustParseBody(`x != 150`), }, @@ -2031,20 +2274,23 @@ func TestRuleElseKeyword(t *testing.T) { { Head: &Head{ - Name: Var("$0"), - Value: BooleanTerm(true), + Name: Var("$0"), + Reference: Ref{VarTerm("$0")}, + Value: BooleanTerm(true), }, Body: MustParseBody(`x > 0`), Else: &Rule{ Head: &Head{ - Name: Var("$0"), - Value: BooleanTerm(true), + Name: Var("$0"), + Reference: Ref{VarTerm("$0")}, + Value: BooleanTerm(true), }, Body: MustParseBody(`x == -1`), Else: &Rule{ Head: &Head{ - Name: Var("$0"), - Value: BooleanTerm(true), + Name: Var("$0"), + Reference: Ref{VarTerm("$0")}, + Value: BooleanTerm(true), }, Body: MustParseBody(`x > -100`), }, @@ -2052,30 +2298,34 @@ func TestRuleElseKeyword(t *testing.T) { }, { Head: &Head{ - Name: Var("nobody"), - Value: IntNumberTerm(1), + Name: Var("nobody"), + Reference: Ref{VarTerm("nobody")}, + Value: IntNumberTerm(1), }, Body: MustParseBody("false"), Else: &Rule{ Head: &Head{ - Name: Var("nobody"), - Value: IntNumberTerm(7), + Name: Var("nobody"), + Reference: Ref{VarTerm("nobody")}, + Value: IntNumberTerm(7), }, Body: MustParseBody("true"), }, }, { Head: &Head{ - Name: Var("nobody_f"), - Args: Args{VarTerm("x")}, - Value: IntNumberTerm(1), + Name: Var("nobody_f"), + Reference: Ref{VarTerm("nobody_f")}, + Args: Args{VarTerm("x")}, + Value: IntNumberTerm(1), }, Body: MustParseBody("false"), Else: &Rule{ Head: &Head{ - Name: Var("nobody_f"), - Args: Args{VarTerm("x")}, - Value: IntNumberTerm(7), + Name: Var("nobody_f"), + Reference: Ref{VarTerm("nobody_f")}, + Args: Args{VarTerm("x")}, + Value: IntNumberTerm(7), }, Body: MustParseBody("true"), }, @@ -2102,14 +2352,16 @@ func TestRuleElseKeyword(t *testing.T) { Body: MustParseBody(`"p1_e1"`), Else: &Rule{ Head: &Head{ - Name: Var("p"), - Value: ArrayTerm(NullTerm()), + Name: Var("p"), + Reference: Ref{VarTerm("p")}, + Value: ArrayTerm(NullTerm()), }, Body: MustParseBody(`"p1_e2"`), Else: &Rule{ Head: &Head{ - Name: name, - Value: VarTerm("x"), + Name: name, + Reference: ref, + Value: VarTerm("x"), }, Body: MustParseBody(`x = "p1_e4"`), }, @@ -2209,7 +2461,6 @@ func TestEmptyModule(t *testing.T) { } func TestComments(t *testing.T) { - testModule := `package a.b.c import input.e.f as g # end of line @@ -2425,52 +2676,245 @@ func TestLocation(t *testing.T) { } } -func TestRuleFromBody(t *testing.T) { - testModule := `package a.b.c - -pi = 3.14159 -p[x] { x = 1 } -greeting = "hello" -cores = [{0: 1}, {1: 2}] -wrapper = cores[0][1] -pi = [3, 1, 4, x, y, z] -foo["bar"] = "buz" -foo["9"] = "10" -foo.buz = "bar" -bar[1] -bar[[{"foo":"baz"}]] -bar.qux -input = 1 -data = 2 -f(1) = 2 -f(1) -d1 := 1234 -` +func TestRuleFromBodyRefs(t *testing.T) { + opts := ParserOptions{FutureKeywords: []string{"if", "contains"}} - assertParseModule(t, "rules from bodies", testModule, &Module{ - Package: MustParseStatement(`package a.b.c`).(*Package), - Rules: []*Rule{ - MustParseRule(`pi = 3.14159 { true }`), - MustParseRule(`p[x] { x = 1 }`), - MustParseRule(`greeting = "hello" { true }`), - MustParseRule(`cores = [{0: 1}, {1: 2}] { true }`), - MustParseRule(`wrapper = cores[0][1] { true }`), - MustParseRule(`pi = [3, 1, 4, x, y, z] { true }`), - MustParseRule(`foo["bar"] = "buz" { true }`), - MustParseRule(`foo["9"] = "10" { true }`), - MustParseRule(`foo["buz"] = "bar" { true }`), - MustParseRule(`bar[1] { true }`), - MustParseRule(`bar[[{"foo":"baz"}]] { true }`), - MustParseRule(`bar["qux"] { true }`), - MustParseRule(`input = 1 { true }`), - MustParseRule(`data = 2 { true }`), - MustParseRule(`f(1) = 2 { true }`), - MustParseRule(`f(1) = true { true }`), - MustParseRule("d1 := 1234 { true }"), + // NOTE(sr): These tests assert that the other code path, parsing a module, and + // then interpreting naked expressions into (shortcut) rule definitions, works + // the same as parsing the string as a Rule directly. Without also passing + // TestRuleRefHeads, these tests are not to be trusted -- if changing something, + // start with getting TestRuleRefHeads to PASS. + tests := []struct { + note string + rule string + exp string + }{ + { + note: "no dots: single-value rule (complete doc)", + rule: `foo["bar"] = 12`, + exp: `foo["bar"] = 12 { true }`, + }, + { + note: "no dots: partial set of numbers", + rule: `foo[1]`, + exp: `foo[1] { true }`, + }, + { + note: "no dots: shorthand set of strings", // back compat + rule: `foo.one`, + exp: `foo["one"] { true }`, + }, + { + note: "no dots: partial set", + rule: `foo[x] { x = 1 }`, + exp: `foo[x] { x = 1 }`, + }, + { + note: "no dots + if: complete doc", + rule: `foo[x] if x := 1`, + exp: `foo[x] if x := 1`, + }, + { + note: "no dots: function", + rule: `foo(x)`, + exp: `foo(x) { true }`, + }, + { + note: "no dots: function with value", + rule: `foo(x) = y`, + exp: `foo(x) = y { true }`, + }, + { + note: "no dots: partial set, ref element", + rule: `test[arr[0]]`, + exp: `test[arr[0]] { true }`, + }, + { + note: "one dot: complete rule shorthand", + rule: `foo.bar = "buz"`, + exp: `foo.bar = "buz" { true }`, + }, + { + note: "one dot, bracket with var: partial object", + rule: `foo.bar[x] = "buz"`, + exp: `foo.bar[x] = "buz" { true }`, + }, + { + note: "one dot, bracket with var: partial set", + rule: `foo.bar[x] { x = 1 }`, + exp: `foo.bar[x] { x = 1 }`, + }, + { + note: "one dot, bracket with string: complete doc", + rule: `foo.bar["baz"] = "buz"`, + exp: `foo.bar.baz = "buz" { true }`, + }, + { + note: "one dot, bracket with var, rule body: partial object", + rule: `foo.bar[x] = "buz" { x = 1 }`, + exp: `foo.bar[x] = "buz" { x = 1 }`, + }, + { + note: "one dot: function", + rule: `foo.bar(x)`, + exp: `foo.bar(x) { true }`, + }, + { + note: "one dot: function with value", + rule: `foo.bar(x) = y`, + exp: `foo.bar(x) = y { true }`, + }, + { + note: "two dots, bracket with var: partial object", + rule: `foo.bar.baz[x] = "buz" { x = 1 }`, + exp: `foo.bar.baz[x] = "buz" { x = 1 }`, + }, + { + note: "two dots, bracket with var: partial set", + rule: `foo.bar.baz[x] { x = 1 }`, + exp: `foo.bar.baz[x] { x = 1 }`, + }, + { + note: "one dot, bracket with string, no key: complete doc", + rule: `foo.bar["baz"]`, + exp: `foo.bar.baz { true }`, + }, + { + note: "two dots: function", + rule: `foo.bar("baz")`, + exp: `foo.bar("baz") { true }`, + }, + { + note: "two dots: function with value", + rule: `foo.bar("baz") = y`, + exp: `foo.bar("baz") = y { true }`, }, + { + note: "non-ground ref: complete doc", + rule: `foo.bar[i].baz { i := 1 }`, + exp: `foo.bar[i].baz { i := 1 }`, + }, + { + note: "non-ground ref: partial set", + rule: `foo.bar[i].baz[x] { i := 1; x := 2 }`, + exp: `foo.bar[i].baz[x] { i := 1; x := 2 }`, + }, + { + note: "non-ground ref: partial object", + rule: `foo.bar[i].baz[x] = 3 { i := 1; x := 2 }`, + exp: `foo.bar[i].baz[x] = 3 { i := 1; x := 2 }`, + }, + { + note: "non-ground ref: function", + rule: `foo.bar[i].baz(x) = 3 { i := 1 }`, + exp: `foo.bar[i].baz(x) = 3 { i := 1 }`, + }, + { + note: "last term is number: partial set", + rule: `foo.bar.baz[3] { true }`, + exp: `foo.bar.baz[3] { true }`, + }, + } + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + r, err := ParseRuleWithOpts(tc.exp, opts) + if err != nil { + t.Fatal(err) + } + + testModule := "package a.b.c\n" + tc.rule + m, err := ParseModuleWithOpts("", testModule, opts) + if err != nil { + t.Fatal(err) + } + mr := m.Rules[0] + + if r.Head.Name.Compare(mr.Head.Name) != 0 { + t.Errorf("rule.Head.Name differs:\n exp = %#v\nrule = %#v", r.Head.Name, mr.Head.Name) + } + if r.Head.Ref().Compare(mr.Head.Ref()) != 0 { + t.Errorf("rule.Head.Ref() differs:\n exp = %v\nrule = %v", r.Head.Ref(), mr.Head.Ref()) + } + exp, err := ParseRuleWithOpts(tc.exp, opts) + if err != nil { + t.Fatal(err) + } + assertParseModule(t, tc.note, testModule, &Module{ + Package: MustParseStatement(`package a.b.c`).(*Package), + Rules: []*Rule{exp}, + }, opts) + }) + } + + // edge cases + t.Run("errors", func(t *testing.T) { + t.Run("naked 'data' ref", func(t *testing.T) { + _, err := ParseModuleWithOpts("", "package a.b.c\ndata", opts) + assertErrorWithMessage(t, err, "refs cannot be used for rule head") + }) + t.Run("naked 'input' ref", func(t *testing.T) { + _, err := ParseModuleWithOpts("", "package a.b.c\ninput", opts) + assertErrorWithMessage(t, err, "refs cannot be used for rule head") + }) }) +} + +func assertErrorWithMessage(t *testing.T, err error, msg string) { + t.Helper() + var errs Errors + if !errors.As(err, &errs) { + t.Fatalf("expected Errors, got %v %[1]T", err) + } + if exp, act := 1, len(errs); exp != act { + t.Fatalf("expected %d errors, got %d", exp, act) + } + e := errs[0] + if exp, act := msg, e.Message; exp != act { + t.Fatalf("expected error message %q, got %q", exp, act) + } +} + +func TestRuleFromBody(t *testing.T) { + tests := []struct { + input string + exp string + }{ + {`pi = 3.14159`, `pi = 3.14159 { true }`}, + {`p[x] { x = 1 }`, `p[x] { x = 1 }`}, + {`greeting = "hello"`, `greeting = "hello" { true }`}, + {`cores = [{0: 1}, {1: 2}]`, `cores = [{0: 1}, {1: 2}] { true }`}, + {`wrapper = cores[0][1]`, `wrapper = cores[0][1] { true }`}, + {`pi = [3, 1, 4, x, y, z]`, `pi = [3, 1, 4, x, y, z] { true }`}, + {`foo["bar"] = "buz"`, `foo["bar"] = "buz" { true }`}, + {`foo["9"] = "10"`, `foo["9"] = "10" { true }`}, + {`foo.buz = "bar"`, `foo["buz"] = "bar" { true }`}, + {`bar[1]`, `bar[1] { true }`}, + {`bar[[{"foo":"baz"}]]`, `bar[[{"foo":"baz"}]] { true }`}, + {`bar.qux`, `bar["qux"] { true }`}, + {`input = 1`, `input = 1 { true }`}, + {`data = 2`, `data = 2 { true }`}, + {`f(1) = 2`, `f(1) = 2 { true }`}, + {`f(1)`, `f(1) = true { true }`}, + {`d1 := 1234`, "d1 := 1234 { true }"}, + } + + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + testModule := "package a.b.c\n" + tc.input + assertParseModule(t, tc.input, testModule, &Module{ + Package: MustParseStatement(`package a.b.c`).(*Package), + Rules: []*Rule{ + MustParseRule(tc.exp), + }, + }) + }) + } // Verify the rule and rule and rule head col/loc values + testModule := "package a.b.c\n\n" + for _, tc := range tests { + testModule += tc.input + "\n" + } module, err := ParseModule("test.rego", testModule) if err != nil { t.Fatal(err) @@ -2479,19 +2923,19 @@ d1 := 1234 for i := range module.Rules { col := module.Rules[i].Location.Col if col != 1 { - t.Fatalf("expected rule %v column to be 1 but got %v", module.Rules[i].Head.Name, col) + t.Errorf("expected rule %v column to be 1 but got %v", module.Rules[i].Head.Name, col) } row := module.Rules[i].Location.Row - if row != 3+i { // 'pi' rule stats on row 3 - t.Fatalf("expected rule %v row to be %v but got %v", module.Rules[i].Head.Name, 3+i, row) + if row != 3+i { // 'pi' rule starts on row 3 + t.Errorf("expected rule %v row to be %v but got %v", module.Rules[i].Head.Name, 3+i, row) } col = module.Rules[i].Head.Location.Col if col != 1 { - t.Fatalf("expected rule head %v column to be 1 but got %v", module.Rules[i].Head.Name, col) + t.Errorf("expected rule head %v column to be 1 but got %v", module.Rules[i].Head.Name, col) } row = module.Rules[i].Head.Location.Row - if row != 3+i { // 'pi' rule stats on row 3 - t.Fatalf("expected rule head %v row to be %v but got %v", module.Rules[i].Head.Name, 3+i, row) + if row != 3+i { // 'pi' rule starts on row 3 + t.Errorf("expected rule head %v row to be %v but got %v", module.Rules[i].Head.Name, 3+i, row) } } @@ -2532,16 +2976,6 @@ data = {"bar": 2}` foo = input with input as 1 ` - badRefLen1 := ` - package a.b.c - - p["x"].y = 1` - - badRefLen2 := ` - package a.b.c - - p["x"].y` - negated := ` package a.b.c @@ -2600,8 +3034,6 @@ data = {"bar": 2}` assertParseModuleError(t, "non-equality", nonEquality) assertParseModuleError(t, "non-var name", nonVarName) assertParseModuleError(t, "with expr", withExpr) - assertParseModuleError(t, "bad ref (too long)", badRefLen1) - assertParseModuleError(t, "bad ref (too long)", badRefLen2) assertParseModuleError(t, "negated", negated) assertParseModuleError(t, "non ref term", nonRefTerm) assertParseModuleError(t, "zero args", zeroArgs) @@ -2665,7 +3097,8 @@ func TestWildcards(t *testing.T) { assertParseRule(t, "functions", `f(_) = y { true }`, &Rule{ Head: &Head{ - Name: Var("f"), + Name: Var("f"), + Reference: Ref{VarTerm("f")}, Args: Args{ VarTerm("$0"), }, @@ -4384,6 +4817,12 @@ func assertParseRule(t *testing.T, msg string, input string, correct *Rule, opts assertParseOne(t, msg, input, func(parsed interface{}) { t.Helper() rule := parsed.(*Rule) + if rule.Head.Name != correct.Head.Name { + t.Errorf("Error on test \"%s\": rule heads not equal: name = %v (parsed), name = %v (correct)", msg, rule.Head.Name, correct.Head.Name) + } + if !rule.Head.Ref().Equal(correct.Head.Ref()) { + t.Errorf("Error on test \"%s\": rule heads not equal: ref = %v (parsed), ref = %v (correct)", msg, rule.Head.Ref(), correct.Head.Ref()) + } if !rule.Equal(correct) { t.Errorf("Error on test \"%s\": rules not equal: %v (parsed), %v (correct)", msg, rule, correct) } diff --git a/ast/policy.go b/ast/policy.go index 04ebe56f10..4f56aa7edb 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -186,12 +186,13 @@ type ( // Head represents the head of a rule. Head struct { - Location *Location `json:"-"` - Name Var `json:"name"` - Args Args `json:"args,omitempty"` - Key *Term `json:"key,omitempty"` - Value *Term `json:"value,omitempty"` - Assign bool `json:"assign,omitempty"` + Location *Location `json:"-"` + Name Var `json:"name,omitempty"` + Reference Ref `json:"ref,omitempty"` + Args Args `json:"args,omitempty"` + Key *Term `json:"key,omitempty"` + Value *Term `json:"value,omitempty"` + Assign bool `json:"assign,omitempty"` } // Args represents zero or more arguments to a rule. @@ -598,11 +599,22 @@ func (rule *Rule) SetLoc(loc *Location) { // Path returns a ref referring to the document produced by this rule. If rule // is not contained in a module, this function panics. +// Deprecated: Poor handling of ref rules. Use `(*Rule).Ref()` instead. func (rule *Rule) Path() Ref { if rule.Module == nil { panic("assertion failed") } - return rule.Module.Package.Path.Append(StringTerm(string(rule.Head.Name))) + return rule.Module.Package.Path.Extend(rule.Head.Ref().GroundPrefix()) +} + +// Ref returns a ref referring to the document produced by this rule. If rule +// is not contained in a module, this function panics. The returned ref may +// contain variables in the last position. +func (rule *Rule) Ref() Ref { + if rule.Module == nil { + panic("assertion failed") + } + return rule.Module.Package.Path.Extend(rule.Head.Ref()) } func (rule *Rule) String() string { @@ -648,7 +660,8 @@ func (rule *Rule) elseString() string { // used for the key and the second will be used for the value. func NewHead(name Var, args ...*Term) *Head { head := &Head{ - Name: name, + Name: name, // backcompat + Reference: []*Term{NewTerm(name)}, } if len(args) == 0 { return head @@ -658,6 +671,23 @@ func NewHead(name Var, args ...*Term) *Head { return head } head.Value = args[1] + if head.Key != nil && head.Value != nil { + head.Reference = head.Reference.Append(args[0]) + } + return head +} + +// RefHead returns a new Head object with the passed Ref. If args are provided, +// the first will be used for the value. +func RefHead(ref Ref, args ...*Term) *Head { + head := &Head{} + head.SetRef(ref) + if len(ref) < 2 { + head.Name = ref[0].Value.(Var) + } + if len(args) >= 1 { + head.Value = args[0] + } return head } @@ -673,7 +703,7 @@ const ( // PartialObjectDoc represents an object document that is partially defined by the rule. PartialObjectDoc -) +) // TODO(sr): Deprecate? // DocKind returns the type of document produced by this rule. func (head *Head) DocKind() DocKind { @@ -686,6 +716,41 @@ func (head *Head) DocKind() DocKind { return CompleteDoc } +type RuleKind int + +const ( + SingleValue = iota + MultiValue +) + +// RuleKind returns the type of rule this is +func (head *Head) RuleKind() RuleKind { + // NOTE(sr): This is bit verbose, since the key is irrelevant for single vs + // multi value, but as good a spot as to assert the invariant. + switch { + case head.Value != nil: + return SingleValue + case head.Key != nil: + return MultiValue + default: + panic("unreachable") + } +} + +// Ref returns the Ref of the rule. If it doesn't have one, it's filled in +// via the Head's Name. +func (head *Head) Ref() Ref { + if len(head.Reference) > 0 { + return head.Reference + } + return Ref{&Term{Value: head.Name}} +} + +// SetRef can be used to set a rule head's Reference +func (head *Head) SetRef(r Ref) { + head.Reference = r +} + // Compare returns an integer indicating whether head is less than, equal to, // or greater than other. func (head *Head) Compare(other *Head) int { @@ -705,6 +770,9 @@ func (head *Head) Compare(other *Head) int { if cmp := Compare(head.Args, other.Args); cmp != 0 { return cmp } + if cmp := Compare(head.Reference, other.Reference); cmp != 0 { + return cmp + } if cmp := Compare(head.Name, other.Name); cmp != 0 { return cmp } @@ -717,6 +785,7 @@ func (head *Head) Compare(other *Head) int { // Copy returns a deep copy of head. func (head *Head) Copy() *Head { cpy := *head + cpy.Reference = head.Reference.Copy() cpy.Args = head.Args.Copy() cpy.Key = head.Key.Copy() cpy.Value = head.Value.Copy() @@ -729,23 +798,43 @@ func (head *Head) Equal(other *Head) bool { } func (head *Head) String() string { - var buf []string - if len(head.Args) != 0 { - buf = append(buf, head.Name.String()+head.Args.String()) - } else if head.Key != nil { - buf = append(buf, head.Name.String()+"["+head.Key.String()+"]") - } else { - buf = append(buf, head.Name.String()) + buf := strings.Builder{} + buf.WriteString(head.Ref().String()) + + switch { + case len(head.Args) != 0: + buf.WriteString(head.Args.String()) + case len(head.Reference) == 1 && head.Key != nil: + buf.WriteRune('[') + buf.WriteString(head.Key.String()) + buf.WriteRune(']') } if head.Value != nil { if head.Assign { - buf = append(buf, ":=") + buf.WriteString(" := ") } else { - buf = append(buf, "=") + buf.WriteString(" = ") } - buf = append(buf, head.Value.String()) - } - return strings.Join(buf, " ") + buf.WriteString(head.Value.String()) + } else if head.Name == "" && head.Key != nil { + buf.WriteString(" contains ") + buf.WriteString(head.Key.String()) + } + return buf.String() +} + +func (head *Head) MarshalJSON() ([]byte, error) { + // NOTE(sr): we do this to override the rendering of `head.Reference`. + // It's still what'll be used via the default means of encoding/json + // for unmarshaling a json object into a Head struct! + type h Head + return json.Marshal(struct { + h + Ref Ref `json:"ref"` + }{ + h: h(*head), + Ref: head.Ref(), + }) } // Vars returns a set of vars found in the head. @@ -761,6 +850,9 @@ func (head *Head) Vars() VarSet { if head.Value != nil { vis.Walk(head.Value) } + if len(head.Reference) > 0 { + vis.Walk(head.Reference[1:]) + } return vis.vars } diff --git a/ast/policy_test.go b/ast/policy_test.go index 684ece6ec3..2645a72d86 100644 --- a/ast/policy_test.go +++ b/ast/policy_test.go @@ -20,6 +20,7 @@ func TestModuleJSONRoundTrip(t *testing.T) { mod, err := ParseModuleWithOpts("test.rego", `package a.b.c +import future.keywords import data.x.y as z import data.u.i @@ -42,6 +43,7 @@ a = true { xs = {a: b | input.y[a] = "foo"; b = input.z["bar"]} } b = true { xs = {{"x": a[i].a} | a[i].n = "bob"; b[x]} } call_values { f(x) != g(x) } assigned := 1 +rule.having.ref.head[1] = x if x := 2 # METADATA # scope: rule @@ -392,17 +394,61 @@ func TestExprEveryCopy(t *testing.T) { } } +func TestRuleHeadJSON(t *testing.T) { + // NOTE(sr): we may get to see Rule objects that aren't the result of parsing, but + // fed as-is into the compiler. We need to be able to make sense of their refs, too. + head := Head{ + Name: Var("allow"), + } + + rule := Rule{ + Head: &head, + } + bs, err := json.Marshal(&rule) + if err != nil { + t.Fatal(err) + } + if exp, act := `{"head":{"name":"allow","ref":[{"type":"var","value":"allow"}]},"body":[]}`, string(bs); act != exp { + t.Errorf("expected %q, got %q", exp, act) + } + + var readRule Rule + if err := json.Unmarshal(bs, &readRule); err != nil { + t.Fatal(err) + } + if exp, act := 1, len(readRule.Head.Reference); act != exp { + t.Errorf("expected unmarshalled rule to have Reference, got %v", readRule.Head.Reference) + } + bs0, err := json.Marshal(&readRule) + if err != nil { + t.Fatal(err) + } + if exp, act := string(bs), string(bs0); exp != act { + t.Errorf("expected json repr to match %q, got %q", exp, act) + } + + var readAgainRule Rule + if err := json.Unmarshal(bs, &readAgainRule); err != nil { + t.Fatal(err) + } + if !readAgainRule.Equal(&readRule) { + t.Errorf("expected roundtripped rule reference to match %v, got %v", readRule.Head.Reference, readAgainRule.Head.Reference) + } +} + func TestRuleHeadEquals(t *testing.T) { assertHeadsEqual(t, &Head{}, &Head{}) - // Same name/key/value + // Same name/ref/key/value assertHeadsEqual(t, &Head{Name: Var("p")}, &Head{Name: Var("p")}) + assertHeadsEqual(t, &Head{Reference: Ref{VarTerm("p"), StringTerm("r")}}, &Head{Reference: Ref{VarTerm("p"), StringTerm("r")}}) // TODO: string for first section assertHeadsEqual(t, &Head{Key: VarTerm("x")}, &Head{Key: VarTerm("x")}) assertHeadsEqual(t, &Head{Value: VarTerm("x")}, &Head{Value: VarTerm("x")}) assertHeadsEqual(t, &Head{Args: []*Term{VarTerm("x"), VarTerm("y")}}, &Head{Args: []*Term{VarTerm("x"), VarTerm("y")}}) - // Different name/key/value + // Different name/ref/key/value assertHeadsNotEqual(t, &Head{Name: Var("p")}, &Head{Name: Var("q")}) + assertHeadsNotEqual(t, &Head{Reference: Ref{VarTerm("p")}}, &Head{Reference: Ref{VarTerm("q")}}) // TODO: string for first section assertHeadsNotEqual(t, &Head{Key: VarTerm("x")}, &Head{Key: VarTerm("y")}) assertHeadsNotEqual(t, &Head{Value: VarTerm("x")}, &Head{Value: VarTerm("y")}) assertHeadsNotEqual(t, &Head{Args: []*Term{VarTerm("x"), VarTerm("z")}}, &Head{Args: []*Term{VarTerm("x"), VarTerm("y")}}) @@ -438,48 +484,139 @@ func TestRuleBodyEquals(t *testing.T) { } func TestRuleString(t *testing.T) { + trueBody := NewBody(NewExpr(BooleanTerm(true))) - rule1 := &Rule{ - Head: NewHead(Var("p"), nil, BooleanTerm(true)), - Body: NewBody( - Equality.Expr(StringTerm("foo"), StringTerm("bar")), - ), - } - - rule2 := &Rule{ - Head: NewHead(Var("p"), VarTerm("x"), VarTerm("y")), - Body: NewBody( - Equality.Expr(StringTerm("foo"), VarTerm("x")), - &Expr{ - Negated: true, - Terms: RefTerm(VarTerm("a"), StringTerm("b"), VarTerm("x")), + tests := []struct { + rule *Rule + exp string + }{ + { + rule: &Rule{ + Head: NewHead(Var("p"), nil, BooleanTerm(true)), + Body: NewBody( + Equality.Expr(StringTerm("foo"), StringTerm("bar")), + ), + }, + exp: `p = true { "foo" = "bar" }`, + }, + { + rule: &Rule{ + Head: NewHead(Var("p"), VarTerm("x")), + Body: trueBody, + }, + exp: `p[x] { true }`, + }, + { + rule: &Rule{ + Head: RefHead(MustParseRef("p[x]"), BooleanTerm(true)), + Body: MustParseBody("x = 1"), + }, + exp: `p[x] = true { x = 1 }`, + }, + { + rule: &Rule{ + Head: RefHead(MustParseRef("p.q.r[x]"), BooleanTerm(true)), + Body: MustParseBody("x = 1"), }, - Equality.Expr(StringTerm("b"), VarTerm("y")), - ), + exp: `p.q.r[x] = true { x = 1 }`, + }, + { + rule: &Rule{ + Head: &Head{ + Reference: MustParseRef("p.q.r"), + Key: VarTerm("1"), + }, + Body: MustParseBody("x = 1"), + }, + exp: `p.q.r contains 1 { x = 1 }`, + }, + { + rule: &Rule{ + Head: NewHead(Var("p"), VarTerm("x"), VarTerm("y")), + Body: NewBody( + Equality.Expr(StringTerm("foo"), VarTerm("x")), + &Expr{ + Negated: true, + Terms: RefTerm(VarTerm("a"), StringTerm("b"), VarTerm("x")), + }, + Equality.Expr(StringTerm("b"), VarTerm("y")), + ), + }, + exp: `p[x] = y { "foo" = x; not a.b[x]; "b" = y }`, + }, + { + rule: &Rule{ + Default: true, + Head: NewHead("p", nil, BooleanTerm(true)), + }, + exp: `default p = true`, + }, + { + rule: &Rule{ + Head: &Head{ + Name: Var("f"), + Args: Args{VarTerm("x"), VarTerm("y")}, + Value: VarTerm("z"), + }, + Body: NewBody(Plus.Expr(VarTerm("x"), VarTerm("y"), VarTerm("z"))), + }, + exp: "f(x, y) = z { plus(x, y, z) }", + }, + { + rule: &Rule{ + Head: &Head{ + Name: Var("p"), + Value: BooleanTerm(true), + Assign: true, + }, + Body: NewBody( + Equality.Expr(StringTerm("foo"), StringTerm("bar")), + ), + }, + exp: `p := true { "foo" = "bar" }`, + }, + { + rule: &Rule{ + Head: RefHead(MustParseRef("p.q.r")), + Body: trueBody, + }, + exp: `p.q.r { true }`, + }, + { + rule: &Rule{ + Head: RefHead(MustParseRef("p.q.r"), StringTerm("foo")), + Body: trueBody, + }, + exp: `p.q.r = "foo" { true }`, + }, + { + rule: &Rule{ + Head: RefHead(MustParseRef("p.q.r[x]"), StringTerm("foo")), + Body: MustParseBody(`x := 1`), + }, + exp: `p.q.r[x] = "foo" { assign(x, 1) }`, + }, } - rule3 := &Rule{ - Default: true, - Head: NewHead("p", nil, BooleanTerm(true)), + for _, tc := range tests { + t.Run(tc.exp, func(t *testing.T) { + assertRuleString(t, tc.rule, tc.exp) + }) } +} - rule4 := &Rule{ - Head: &Head{ - Name: Var("f"), - Args: Args{VarTerm("x"), VarTerm("y")}, - Value: VarTerm("z"), - }, - Body: NewBody(Plus.Expr(VarTerm("x"), VarTerm("y"), VarTerm("z"))), +func TestRulePath(t *testing.T) { + ruleWithMod := func(r string) Ref { + mod := MustParseModule("package pkg\n" + r) + return mod.Rules[0].Path() + } + if exp, act := MustParseRef("data.pkg.p.q.r"), ruleWithMod("p.q.r { true }"); !exp.Equal(act) { + t.Errorf("expected %v, got %v", exp, act) } - rule5 := rule1.Copy() - rule5.Head.Assign = true - - assertRuleString(t, rule1, `p = true { "foo" = "bar" }`) - assertRuleString(t, rule2, `p[x] = y { "foo" = x; not a.b[x]; "b" = y }`) - assertRuleString(t, rule3, `default p = true`) - assertRuleString(t, rule4, "f(x, y) = z { plus(x, y, z) }") - assertRuleString(t, rule5, `p := true { "foo" = "bar" }`) + if exp, act := MustParseRef("data.pkg.p"), ruleWithMod("p { true }"); !exp.Equal(act) { + t.Errorf("expected %v, got %v", exp, act) + } } func TestModuleString(t *testing.T) { @@ -791,7 +928,7 @@ func assertPackagesNotEqual(t *testing.T, a, b *Package) { func assertRulesEqual(t *testing.T, a, b *Rule) { t.Helper() if !a.Equal(b) { - t.Errorf("Rules are not equal (expected equal): a=%v b=%v", a, b) + t.Errorf("Rules are not equal (expected equal):\na=%v\nb=%v", a, b) } } diff --git a/ast/pretty_test.go b/ast/pretty_test.go index d59f7cb809..50fe97c142 100644 --- a/ast/pretty_test.go +++ b/ast/pretty_test.go @@ -41,8 +41,9 @@ func TestPretty(t *testing.T) { qux rule head - p - x + ref + p + x y body expr index=0 @@ -67,7 +68,8 @@ func TestPretty(t *testing.T) { true rule head - f + ref + f args x call diff --git a/ast/term.go b/ast/term.go index 73131b51e3..0aa800dc18 100644 --- a/ast/term.go +++ b/ast/term.go @@ -134,12 +134,12 @@ func As(v Value, x interface{}) error { // Resolver defines the interface for resolving references to native Go values. type Resolver interface { - Resolve(ref Ref) (interface{}, error) + Resolve(Ref) (interface{}, error) } // ValueResolver defines the interface for resolving references to AST values. type ValueResolver interface { - Resolve(ref Ref) (Value, error) + Resolve(Ref) (Value, error) } // UnknownValueErr indicates a ValueResolver was unable to resolve a reference diff --git a/ast/transform.go b/ast/transform.go index c25f7bc32d..391a164860 100644 --- a/ast/transform.go +++ b/ast/transform.go @@ -13,7 +13,7 @@ import ( // be set to nil and no transformations will be applied to children of the // element. type Transformer interface { - Transform(v interface{}) (interface{}, error) + Transform(interface{}) (interface{}, error) } // Transform iterates the AST and calls the Transform function on the @@ -116,6 +116,9 @@ func Transform(t Transformer, x interface{}) (interface{}, error) { } return y, nil case *Head: + if y.Reference, err = transformRef(t, y.Reference); err != nil { + return nil, err + } if y.Name, err = transformVar(t, y.Name); err != nil { return nil, err } @@ -327,7 +330,7 @@ func TransformComprehensions(x interface{}, f func(interface{}) (Value, error)) // GenericTransformer implements the Transformer interface to provide a utility // to transform AST nodes using a closure. type GenericTransformer struct { - f func(x interface{}) (interface{}, error) + f func(interface{}) (interface{}, error) } // NewGenericTransformer returns a new GenericTransformer that will transform @@ -414,3 +417,15 @@ func transformVar(t Transformer, v Var) (Var, error) { } return r, nil } + +func transformRef(t Transformer, r Ref) (Ref, error) { + r1, err := Transform(t, r) + if err != nil { + return nil, err + } + r2, ok := r1.(Ref) + if !ok { + return nil, fmt.Errorf("illegal transform: %T != %T", r, r2) + } + return r2, nil +} diff --git a/ast/transform_test.go b/ast/transform_test.go index c102e47f34..8c5a051aae 100644 --- a/ast/transform_test.go +++ b/ast/transform_test.go @@ -24,6 +24,7 @@ p = n { count({"this", "that"}, n) with input.foo.this as {"this": true} } p { false } else = "this" { "this" } else = ["this"] { true } foo(x) = y { split(x, "this", y) } p { every x in ["this"] { x == "this" } } +a.b.c.this["this"] = d { d := "this" } `) result, err := Transform(&GenericTransformer{ @@ -59,6 +60,7 @@ p = n { count({"that"}, n) with input.foo.that as {"that": true} } p { false } else = "that" { "that" } else = ["that"] { true } foo(x) = y { split(x, "that", y) } p { every x in ["that"] { x == "that" } } +a.b.c.that["that"] = d { d := "that" } `) if !expected.Equal(resultMod) { @@ -113,3 +115,23 @@ p := 7`, ParserOptions{ProcessAnnotation: true}) } } + +func TestTransformRefsAndRuleHeads(t *testing.T) { + module := MustParseModule(`package test +p.q.this.fo[x] = y { x := "x"; y := "y" }`) + + result, err := TransformRefs(module, func(r Ref) (Value, error) { + if r[0].Value.Compare(Var("p")) == 0 { + r[2] = StringTerm("that") + } + return r, nil + }) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + + resultMod := result.(*Module) + if exp, act := MustParseRef("p.q.that.fo[x]"), resultMod.Rules[0].Head.Reference; !act.Equal(exp) { + t.Errorf("expected %v, got %v", exp, act) + } +} diff --git a/ast/visit.go b/ast/visit.go index 02c393cdbc..4c60401e96 100644 --- a/ast/visit.go +++ b/ast/visit.go @@ -432,11 +432,15 @@ func (vis *BeforeAfterVisitor) Walk(x interface{}) { vis.Walk(x.Else) } case *Head: - vis.Walk(x.Name) - vis.Walk(x.Args) - if x.Key != nil { - vis.Walk(x.Key) + if len(x.Reference) > 0 { + vis.Walk(x.Reference) + } else { + vis.Walk(x.Name) + if x.Key != nil { + vis.Walk(x.Key) + } } + vis.Walk(x.Args) if x.Value != nil { vis.Walk(x.Value) } @@ -662,11 +666,16 @@ func (vis *VarVisitor) Walk(x interface{}) { vis.Walk(x.Else) } case *Head: - vis.Walk(x.Name) - vis.Walk(x.Args) - if x.Key != nil { - vis.Walk(x.Key) + if len(x.Reference) > 0 { + vis.Walk(x.Reference) + } else { + vis.Walk(x.Name) + if x.Key != nil { + vis.Walk(x.Key) + } } + vis.Walk(x.Args) + if x.Value != nil { vis.Walk(x.Value) } diff --git a/ast/visit_test.go b/ast/visit_test.go index 1b3694508b..0f2b663275 100644 --- a/ast/visit_test.go +++ b/ast/visit_test.go @@ -393,12 +393,12 @@ fn([x, y]) = z { json.unmarshal(x, z); z > y } }) vis.Walk(rule) - if len(before) != 246 { - t.Errorf("Expected exactly 246 before elements in AST but got %d: %v", len(before), before) + if exp, act := 256, len(before); exp != act { + t.Errorf("Expected exactly %d before elements in AST but got %d: %v", exp, act, before) } - if len(after) != 246 { - t.Errorf("Expected exactly 246 after elements in AST but got %d: %v", len(after), after) + if exp, act := 256, len(before); exp != act { + t.Errorf("Expected exactly %d after elements in AST but got %d: %v", exp, act, after) } } diff --git a/cmd/build_test.go b/cmd/build_test.go index b708417366..cc188cbf5e 100644 --- a/cmd/build_test.go +++ b/cmd/build_test.go @@ -3,6 +3,7 @@ package cmd import ( "archive/tar" "compress/gzip" + "fmt" "io" "os" "path" @@ -162,12 +163,14 @@ func TestBuildErrorDoesNotWriteFile(t *testing.T) { params.outputFile = path.Join(root, "bundle.tar.gz") err := dobuild(params, []string{root}) - if err == nil || !strings.Contains(err.Error(), "rule p is recursive") { - t.Fatal("expected recursion error but got:", err) + exp := fmt.Sprintf("1 error occurred: %s/test.rego:3: rego_recursion_error: rule data.test.p is recursive: data.test.p -> data.test.p", + root) + if err == nil || err.Error() != exp { + t.Fatalf("expected recursion error %q but got: %q", exp, err) } - if _, err := os.Stat(params.outputFile); err == nil { - t.Fatal("expected stat error") + if _, err := os.Stat(params.outputFile); !os.IsNotExist(err) { + t.Fatalf("expected stat \"not found\" error, got %v", err) } }) } diff --git a/compile/compile.go b/compile/compile.go index 8114ada6e3..ea9dfadc32 100644 --- a/compile/compile.go +++ b/compile/compile.go @@ -887,7 +887,7 @@ func (o *optimizer) getSupportForEntrypoint(queries []ast.Body, e *ast.Term, res o.debug.Printf("optimizer: entrypoint: %v: discard due to self-reference", e) return nil } - module.Rules = append(module.Rules, &ast.Rule{ + module.Rules = append(module.Rules, &ast.Rule{ // TODO(sr): use RefHead instead? Head: ast.NewHead(name, nil, resultsym), Body: query, Module: module, @@ -900,6 +900,8 @@ func (o *optimizer) getSupportForEntrypoint(queries []ast.Body, e *ast.Term, res // merge combines two sets of modules and returns the result. The rules from modules // in 'b' override rules from modules in 'a'. If all rules in a module in 'a' are overridden // by rules in modules in 'b' then the module from 'a' is discarded. +// NOTE(sr): This function assumes that `b` is the result of partial eval, and thus does NOT +// contain any rules that genuinely need their ref heads. func (o *optimizer) merge(a, b []bundle.ModuleFile) []bundle.ModuleFile { prefixes := ast.NewSet() @@ -911,12 +913,19 @@ func (o *optimizer) merge(a, b []bundle.ModuleFile) []bundle.ModuleFile { // of rules.) seen := ast.NewVarSet() for _, rule := range b[i].Parsed.Rules { - if _, ok := seen[rule.Head.Name]; !ok { - prefixes.Add(ast.NewTerm(rule.Path())) - seen.Add(rule.Head.Name) + // NOTE(sr): we're relying on the fact that PE never emits ref rules (so far)! + // The rule + // p.a = 1 { ... } + // will be recorded in prefixes as `data.test.p`, and that'll be checked later on against `data.test.p[k]` + if len(rule.Head.Ref()) > 2 { + panic("expected a module without ref rules") + } + name := rule.Head.Name + if !seen.Contains(name) { + prefixes.Add(ast.NewTerm(b[i].Parsed.Package.Path.Append(ast.StringTerm(string(name))))) + seen.Add(name) } } - } for i := range a { @@ -925,30 +934,28 @@ func (o *optimizer) merge(a, b []bundle.ModuleFile) []bundle.ModuleFile { // NOTE(tsandall): same as above--memoize keep/discard decision. If multiple // entrypoints are provided the dst module may contain a large number of rules. - seen := ast.NewVarSet() - discard := ast.NewVarSet() - + seen, discarded := ast.NewSet(), ast.NewSet() for _, rule := range a[i].Parsed.Rules { - - if _, ok := discard[rule.Head.Name]; ok { - continue - } else if _, ok := seen[rule.Head.Name]; ok { + refT := ast.NewTerm(rule.Ref()) + switch { + case seen.Contains(refT): keep = append(keep, rule) continue + case discarded.Contains(refT): + continue } - path := rule.Path() + path := rule.Ref() overlap := prefixes.Until(func(x *ast.Term) bool { - ref := x.Value.(ast.Ref) - return path.HasPrefix(ref) + r := x.Value.(ast.Ref) + return path.HasPrefix(r) }) - if overlap { - discard.Add(rule.Head.Name) - } else { - seen.Add(rule.Head.Name) - keep = append(keep, rule) + discarded.Add(refT) + continue } + seen.Add(refT) + keep = append(keep, rule) } if len(keep) > 0 { @@ -956,7 +963,6 @@ func (o *optimizer) merge(a, b []bundle.ModuleFile) []bundle.ModuleFile { a[i].Raw = nil b = append(b, a[i]) } - } return b diff --git a/compile/compile_test.go b/compile/compile_test.go index 221a4b6997..376313b7bf 100644 --- a/compile/compile_test.go +++ b/compile/compile_test.go @@ -968,7 +968,7 @@ func TestOptimizerErrors(t *testing.T) { p { data.test.p } `, }, - wantErr: fmt.Errorf("1 error occurred: test.rego:3: rego_recursion_error: rule p is recursive: p -> p"), + wantErr: fmt.Errorf("1 error occurred: test.rego:3: rego_recursion_error: rule data.test.p is recursive: data.test.p -> data.test.p"), }, { note: "partial eval error", @@ -1072,6 +1072,36 @@ func TestOptimizerOutput(t *testing.T) { `, }, }, + { + note: "support rules, ref heads", + entrypoints: []string{"data.test.p.q.r"}, + modules: map[string]string{ + "test.rego": ` + package test + + default p.q.r = false + p.q.r { q[input.x] } + + q[1] + q[2]`, + }, + wantModules: map[string]string{ + "optimized/test/p/q.rego": ` + package test.p.q + + default r = false + r = true { 1 = input.x } + r = true { 2 = input.x } + + `, + "test.rego": ` + package test + + q[1] + q[2] + `, + }, + }, { note: "multiple entrypoints", entrypoints: []string{"data.test.p", "data.test.r", "data.test.s"}, @@ -1526,7 +1556,6 @@ func getOptimizer(modules map[string]string, data string, entries []string, root func getModuleFiles(src map[string]string, includeRaw bool) []bundle.ModuleFile { keys := make([]string, 0, len(src)) - for k := range src { keys = append(keys, k) } diff --git a/docs/content/policy-language.md b/docs/content/policy-language.md index eaa0b517e3..87615dd5fd 100644 --- a/docs/content/policy-language.md +++ b/docs/content/policy-language.md @@ -909,6 +909,24 @@ max_memory := 32 { power_users[user] } max_memory := 4 { restricted_users[user] } ``` +### Rule Heads containing References + +As a shorthand for defining nested rule structures, it's valid to use references as rule heads: + +```live:eg/ref_heads:module +fruit.apple.seeds = 12 + +fruit.orange.color = "orange" +``` + +This module defines _two complete rules_, `data.example.fruit.apple.seeds` and `data.example.fruit.orange.color`: + +```live:eg/ref_heads:query:merge_down +data.example +``` +```live:eg/ref_heads:output +``` + ### Functions Rego supports user-defined functions that can be called with the same semantics as [Built-in Functions](#built-in-functions). They have access to both the [the data Document](../philosophy/#the-opa-document-model) and [the input Document](../philosophy/#the-opa-document-model). diff --git a/docs/content/policy-reference.md b/docs/content/policy-reference.md index cc8eb3ba72..2bab586251 100644 --- a/docs/content/policy-reference.md +++ b/docs/content/policy-reference.md @@ -282,6 +282,43 @@ f(x) := "B" if { x >= 80; x < 90 } f(x) := "C" if { x >= 70; x < 80 } ``` +### Reference Heads + +```live:rules/ref_heads:module:read_only +fruit.apple.seeds = 12 if input == "apple" # complete document (single value rule) + +fruit.pineapple.colors contains x if x := "yellow" # multi-value rule + +fruit.banana.phone[x] = "bananular" if x := "cellular" # single value rule +fruit.banana.phone.cellular = "bananular" if true # equivalent single value rule + +fruit.orange.color(x) = true if x == "orange" # function +``` + +For reasons of backwards-compatibility, partial sets need to use `contains` in +their rule hesas, i.e. + +```live:rules/ref_heads/set:module:read_only +fruit.box contains "apples" if true +``` + +whereas + +```live:rules/ref_heads/complete:module:read_only +fruit.box[x] if { x := "apples" } +``` + +defines a _complete document rule_ `fruit.box.apples` with value `true`. +The same is the case of rules with brackets that don't contain dots, like + +```live:rules/ref_heads/simple:module:read_only +box[x] if { x := "apples" } # => {"box": {"apples": true }} +box2[x] { x := "apples" } # => {"box": ["apples"]} +``` + +For backwards-compatibility, rules _without_ if and without _dots_ will be interpreted +as defining partial sets, like `box2`. + ## Tests ```live:tests:module:read_only @@ -1133,7 +1170,7 @@ package = "package" ref import = "import" ref [ "as" var ] policy = { rule } rule = [ "default" ] rule-head { rule-body } -rule-head = var ( rule-head-set | rule-head-obj | rule-head-func | rule-head-comp | "if" ) +rule-head = ( ref | var ) ( rule-head-set | rule-head-obj | rule-head-func | rule-head-comp ) rule-head-comp = [ assign-operator term ] [ "if" ] rule-head-obj = "[" term "]" [ assign-operator term ] [ "if" ] rule-head-func = "(" rule-args ")" [ assign-operator term ] [ "if" ] diff --git a/format/format.go b/format/format.go index f549f4ec5c..745efa05c2 100644 --- a/format/format.go +++ b/format/format.go @@ -406,7 +406,8 @@ func (w *writer) writeElse(rule *ast.Rule, useContainsKW, useIf bool, comments [ w.startLine() } - rule.Else.Head.Name = "else" + rule.Else.Head.Name = "else" // NOTE(sr): whaaat + rule.Else.Head.Reference = ast.Ref{ast.VarTerm("else")} rule.Else.Head.Args = nil comments = w.insertComments(comments, rule.Else.Head.Location) @@ -427,7 +428,12 @@ func (w *writer) writeElse(rule *ast.Rule, useContainsKW, useIf bool, comments [ } func (w *writer) writeHead(head *ast.Head, isDefault, isExpandedConst, useContainsKW, useIf bool, comments []*ast.Comment) []*ast.Comment { - w.write(head.Name.String()) + ref := head.Ref() + if head.Key != nil && head.Value == nil { + ref = ref.GroundPrefix() + } + w.write(ref.String()) + if len(head.Args) > 0 { w.write("(") var args []interface{} @@ -441,7 +447,7 @@ func (w *writer) writeHead(head *ast.Head, isDefault, isExpandedConst, useContai if useContainsKW && head.Value == nil { w.write(" contains ") comments = w.writeTerm(head.Key, comments) - } else { // no `if` for p[x] notation + } else if head.Value == nil { // no `if` for p[x] notation w.write("[") comments = w.writeTerm(head.Key, comments) w.write("]") diff --git a/format/testfiles/test.rego.formatted b/format/testfiles/test.rego.formatted index d78027e32b..9197d0bb45 100644 --- a/format/testfiles/test.rego.formatted +++ b/format/testfiles/test.rego.formatted @@ -24,11 +24,11 @@ globals = { "fizz": "buzz", } -partial_obj["x"] = 1 +partial_obj.x = 1 -partial_obj["y"] = 2 +partial_obj.y = 2 -partial_obj["z"] = 3 +partial_obj.z = 3 partial_set["x"] @@ -198,7 +198,7 @@ nested_infix { expanded_const = true -partial_obj["why"] = true { +partial_obj.why = true { false } diff --git a/format/testfiles/test_ref_heads.rego b/format/testfiles/test_ref_heads.rego new file mode 100644 index 0000000000..a55386ab59 --- /dev/null +++ b/format/testfiles/test_ref_heads.rego @@ -0,0 +1,13 @@ +package test + +import future.keywords + +a.b.c = "d" if true +a.b.e = "f" if true +a.b.g contains x if some x in numbers.range(1, 3) +a.b.h[x] = 1 if x := "one" + +q[1] = y if true +r[x] if x := 10 +p.q.r[x] if x := 10 +p.q.r[2] if true diff --git a/format/testfiles/test_ref_heads.rego.formatted b/format/testfiles/test_ref_heads.rego.formatted new file mode 100644 index 0000000000..5f23e3681b --- /dev/null +++ b/format/testfiles/test_ref_heads.rego.formatted @@ -0,0 +1,19 @@ +package test + +import future.keywords + +a.b.c = "d" + +a.b.e = "f" + +a.b.g contains x if some x in numbers.range(1, 3) + +a.b.h[x] = 1 if x := "one" + +q[1] = y + +r[x] if x := 10 + +p.q.r[x] if x := 10 + +p.q.r[2] = true diff --git a/internal/oracle/oracle.go b/internal/oracle/oracle.go index af2fca93ed..e3eb3afa1a 100644 --- a/internal/oracle/oracle.go +++ b/internal/oracle/oracle.go @@ -52,6 +52,8 @@ func (o *Oracle) FindDefinition(q DefinitionQuery) (*DefinitionQueryResult, erro // TODO(tsandall): how can we cache the results of compilation and parsing so that // multiple queries can be executed without having to re-compute the same values? // Ditto for caching across runs. Avoid repeating the same work. + + // NOTE(sr): "SetRuleTree" because it's needed for compiler.GetRulesExact() below compiler, parsed, err := compileUpto("SetRuleTree", q.Modules, q.Buffer, q.Filename) if err != nil { return nil, err @@ -220,7 +222,8 @@ func getLocMinMax(x ast.Node) (int, int) { // Special case bodies because location text is only for the first expr. if body, ok := x.(ast.Body); ok { - extraLoc := body[len(body)-1].Loc() + last := findLastExpr(body) + extraLoc := last.Loc() if extraLoc == nil { return -1, -1 } @@ -229,3 +232,20 @@ func getLocMinMax(x ast.Node) (int, int) { return min, min + len(loc.Text) } + +// findLastExpr returns the last expression in an ast.Body that has not been generated +// by the compiler. It's used to cope with the fact that a compiler stage before SetRuleTree +// has rewritten the rule bodies slightly. By ignoring appended generated body expressions, +// we can still use the "circling in on the variable" logic based on node locations. +func findLastExpr(body ast.Body) *ast.Expr { + for i := len(body) - 1; i >= 0; i-- { + if !body[i].Generated { + return body[i] + } + } + // NOTE(sr): I believe this shouldn't happen -- we only ever start circling in on a node + // inside a body if there's something in that body. A body that only consists of generated + // expressions should not appear here. Either way, the caller deals with `nil` returned by + // this helper. + return nil +} diff --git a/internal/oracle/oracle_test.go b/internal/oracle/oracle_test.go index 55636734a0..2822c807d2 100644 --- a/internal/oracle/oracle_test.go +++ b/internal/oracle/oracle_test.go @@ -158,6 +158,13 @@ y[1] s = input bar = 7` + // NOTE(sr): Early ref rewriting adds an expression to the rule body for `x.y` + const varInRuleRefModule = `package foo +q[x.y] = 10 { + x := input + some z + z = 1 +}` cases := []struct { note string modules map[string]string @@ -350,6 +357,20 @@ bar = 7` Text: []byte("i"), }, }, + { + note: "intra-rule: ref head", + modules: map[string]string{ + "buffer.rego": varInRuleRefModule, + }, + pos: 47, // "z" in "z = 1" + exp: &ast.Location{ + File: "buffer.rego", + Row: 4, + Col: 7, + Offset: 44, + Text: []byte("z"), + }, + }, } for _, tc := range cases { @@ -362,10 +383,20 @@ bar = 7` t.Fatal(err) } } + buffer := tc.modules["buffer.rego"] + before := tc.pos - 4 + if before < 0 { + before = 0 + } + after := tc.pos + 5 + if after > len(buffer) { + after = len(buffer) + } + t.Logf("pos is %d: \"%s<%s>%s\"", tc.pos, string(buffer[before:tc.pos]), string(buffer[tc.pos]), string(buffer[tc.pos+1:after])) o := New() result, err := o.FindDefinition(DefinitionQuery{ Modules: modules, - Buffer: []byte(tc.modules["buffer.rego"]), + Buffer: []byte(buffer), Filename: "buffer.rego", Pos: tc.pos, }) diff --git a/internal/planner/planner.go b/internal/planner/planner.go index 62e2b4c37b..070b1a6c1e 100644 --- a/internal/planner/planner.go +++ b/internal/planner/planner.go @@ -147,7 +147,7 @@ func (p *Planner) buildFunctrie() error { } for _, rule := range module.Rules { - val := p.rules.LookupOrInsert(rule.Path()) + val := p.rules.LookupOrInsert(rule.Ref()) val.rules = append(val.rules, rule) } } @@ -155,14 +155,29 @@ func (p *Planner) buildFunctrie() error { return nil } -func (p *Planner) planRules(rules []*ast.Rule) (string, error) { - - pathRef := rules[0].Path() +func (p *Planner) planRules(rules []*ast.Rule, cut bool) (string, error) { + pathRef := rules[0].Ref() // NOTE(sr): no longer the same for all those rules, respect `cut`? path := pathRef.String() var pathPieces []string - for i := 1; /* skip `data` */ i < len(pathRef); i++ { - pathPieces = append(pathPieces, string(pathRef[i].Value.(ast.String))) + // TODO(sr): this has to change when allowing `p[v].q.r[w]` ref rules + // including the mapping lookup structure and lookup functions + + // if we're planning both p.q.r and p.q[s], we'll name the function p.q (for the mapping table) + pieces := len(pathRef) + if cut { + pieces-- + } + for i := 1; /* skip `data` */ i < pieces; i++ { + switch q := pathRef[i].Value.(type) { + case ast.String: + pathPieces = append(pathPieces, string(q)) + case ast.Var: + pathPieces = append(pathPieces, fmt.Sprintf("[%s]", q)) + default: + // Needs to be fixed if we allow non-string ref pieces, like `p.q[3][4].r = x` + pathPieces = append(pathPieces, q.String()) + } } if funcName, ok := p.funcs.Get(path); ok { @@ -204,12 +219,27 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) { params := fn.Params[2:] + // control if p.a = 1 is to return 1 directly; or insert 1 under key "a" into an object + buildObject := false + // Initialize return value for partial set/object rules. Complete document // rules assign directly to `fn.Return`. - switch rules[0].Head.DocKind() { - case ast.PartialObjectDoc: - fn.Blocks = append(fn.Blocks, p.blockWithStmt(&ir.MakeObjectStmt{Target: fn.Return})) - case ast.PartialSetDoc: + switch rules[0].Head.RuleKind() { + case ast.SingleValue: + // if any rule has a non-ground last key, create an object, insert into it + any := false + for _, rule := range rules { + ref := rule.Head.Ref() + if last := ref[len(ref)-1]; len(ref) > 1 && !last.IsGround() { + any = true + break + } + } + if any { + buildObject = true + fn.Blocks = append(fn.Blocks, p.blockWithStmt(&ir.MakeObjectStmt{Target: fn.Return})) + } + case ast.MultiValue: fn.Blocks = append(fn.Blocks, p.blockWithStmt(&ir.MakeSetStmt{Target: fn.Return})) } @@ -279,6 +309,7 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) { p.appendStmt(&ir.IsUndefinedStmt{Source: lresult}) } else { // The first rule body resets the local, so it can be reused. + // TODO(sr): I don't think we need this anymore. Double-check? Perhaps multi-value rules need it. p.appendStmt(&ir.ResetLocalStmt{Target: lresult}) } @@ -287,11 +318,27 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) { err := p.planFuncParams(params, rule.Head.Args, 0, func() error { // Run planner on the rule body. - err := p.planQuery(rule.Body, 0, func() error { + return p.planQuery(rule.Body, 0, func() error { // Run planner on the result. - switch rule.Head.DocKind() { - case ast.CompleteDoc: + switch rule.Head.RuleKind() { + case ast.SingleValue: + if buildObject { + ref := rule.Head.Ref() + last := ref[len(ref)-1] + return p.planTerm(last, func() error { + key := p.ltarget + return p.planTerm(rule.Head.Value, func() error { + value := p.ltarget + p.appendStmt(&ir.ObjectInsertOnceStmt{ + Object: fn.Return, + Key: key, + Value: value, + }) + return nil + }) + }) + } return p.planTerm(rule.Head.Value, func() error { p.appendStmt(&ir.AssignVarOnceStmt{ Target: lresult, @@ -299,7 +346,7 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) { }) return nil }) - case ast.PartialSetDoc: + case ast.MultiValue: return p.planTerm(rule.Head.Key, func() error { p.appendStmt(&ir.SetAddStmt{ Set: fn.Return, @@ -307,29 +354,10 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) { }) return nil }) - case ast.PartialObjectDoc: - return p.planTerm(rule.Head.Key, func() error { - key := p.ltarget - return p.planTerm(rule.Head.Value, func() error { - value := p.ltarget - p.appendStmt(&ir.ObjectInsertOnceStmt{ - Object: fn.Return, - Key: key, - Value: value, - }) - return nil - }) - }) default: return fmt.Errorf("illegal rule kind") } }) - - if err != nil { - return err - } - - return nil }) if err != nil { @@ -338,7 +366,7 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) { } // rule[i] and its else-rule(s), if present, are done - if rules[i].Head.DocKind() == ast.CompleteDoc { + if rules[i].Head.RuleKind() == ast.SingleValue && !buildObject { end := &ir.Block{} p.appendStmtToBlock(&ir.IsDefinedStmt{Source: lresult}, end) p.appendStmtToBlock( @@ -841,7 +869,7 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error { // replacement is a function (rule) if node := p.rules.Lookup(r); node != nil { p.mocks.Push() // new scope - name, err = p.planRules(node.Rules()) + name, err = p.planRules(node.Rules(), false) if err != nil { return err } @@ -862,7 +890,7 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error { } if node := p.rules.Lookup(op); node != nil { - name, err = p.planRules(node.Rules()) + name, err = p.planRules(node.Rules(), false) if err != nil { return err } @@ -1604,7 +1632,7 @@ func (p *Planner) planRefData(virtual *ruletrie, base *baseptr, ref ast.Ref, ind // NOTE(sr): we do it on the first index because later on, the recursion // on subtrees of virtual already lost parts of the path we've taken. if index == 1 && virtual != nil { - rulesets, path, index, optimize := p.optimizeLookup(virtual, ref) + rulesets, path, index, optimize := p.optimizeLookup(virtual, ref.GroundPrefix()) if optimize { // If there are no rulesets in a situation that otherwise would // allow for a call_indirect optimization, then there's nothing @@ -1614,7 +1642,7 @@ func (p *Planner) planRefData(virtual *ruletrie, base *baseptr, ref ast.Ref, ind } // plan rules for _, rules := range rulesets { - if _, err := p.planRules(rules); err != nil { + if _, err := p.planRules(rules, false); err != nil { return err } } @@ -1701,17 +1729,36 @@ func (p *Planner) planRefData(virtual *ruletrie, base *baseptr, ref ast.Ref, ind if ref[index].IsGround() { var vchild *ruletrie + var rules []*ast.Rule + + // If there's any non-ground key among the vchild.Children, like + // p[x] and p.a (x being non-ground), we'll collect all 'p' rules, + // plan them. + anyKeyNonGround := false if virtual != nil { + vchild = virtual.Get(ref[index].Value) - } - rules := vchild.Rules() + for _, key := range vchild.Children() { + if !key.IsGround() { + anyKeyNonGround = true + break + } + } + if anyKeyNonGround { + for _, key := range vchild.Children() { + rules = append(rules, vchild.Get(key).Rules()...) + } + } else { + rules = vchild.Rules() // hit or miss + } + } if len(rules) > 0 { p.ltarget = p.newOperand() - funcName, err := p.planRules(rules) + funcName, err := p.planRules(rules, anyKeyNonGround) if err != nil { return err } @@ -1727,7 +1774,6 @@ func (p *Planner) planRefData(virtual *ruletrie, base *baseptr, ref ast.Ref, ind bchild := *base bchild.path = append(bchild.path, ref[index]) - return p.planRefData(vchild, &bchild, ref, index+1, iter) } @@ -1836,38 +1882,21 @@ func (p *Planner) planRefDataExtent(virtual *ruletrie, base *baseptr, iter plani Target: vtarget, }) + anyKeyNonGround := false for _, key := range virtual.Children() { - child := virtual.Get(key) - - // Skip functions. - if child.Arity() > 0 { - continue + if !key.IsGround() { + anyKeyNonGround = true + break } - - lkey := ir.StringIndex(p.getStringConst(string(key.(ast.String)))) - - rules := child.Rules() - - // Build object hierarchy depth-first. - if len(rules) == 0 { - err := p.planRefDataExtent(child, nil, func() error { - p.appendStmt(&ir.ObjectInsertStmt{ - Object: vtarget, - Key: op(lkey), - Value: p.ltarget, - }) - return nil - }) - if err != nil { - return err - } - continue + } + if anyKeyNonGround { + var rules []*ast.Rule + for _, key := range virtual.Children() { + // TODO(sr): skip functions + rules = append(rules, virtual.Get(key).Rules()...) } - // Generate virtual document for leaf. - lvalue := p.newLocal() - - funcName, err := p.planRules(rules) + funcName, err := p.planRules(rules, true) if err != nil { return err } @@ -1877,14 +1906,59 @@ func (p *Planner) planRefDataExtent(virtual *ruletrie, base *baseptr, iter plani p.appendStmtToBlock(&ir.CallStmt{ Func: funcName, Args: p.defaultOperands(), - Result: lvalue, - }, b) - p.appendStmtToBlock(&ir.ObjectInsertStmt{ - Object: vtarget, - Key: op(lkey), - Value: op(lvalue), + Result: vtarget, }, b) p.appendStmt(&ir.BlockStmt{Blocks: []*ir.Block{b}}) + } else { + for _, key := range virtual.Children() { + child := virtual.Get(key) + + // Skip functions. + if child.Arity() > 0 { + continue + } + + lkey := ir.StringIndex(p.getStringConst(string(key.(ast.String)))) + rules := child.Rules() + + // Build object hierarchy depth-first. + if len(rules) == 0 { + err := p.planRefDataExtent(child, nil, func() error { + p.appendStmt(&ir.ObjectInsertStmt{ + Object: vtarget, + Key: op(lkey), + Value: p.ltarget, + }) + return nil + }) + if err != nil { + return err + } + continue + } + + // Generate virtual document for leaf. + lvalue := p.newLocal() + + funcName, err := p.planRules(rules, false) + if err != nil { + return err + } + + // Add leaf to object if defined. + b := &ir.Block{} + p.appendStmtToBlock(&ir.CallStmt{ + Func: funcName, + Args: p.defaultOperands(), + Result: lvalue, + }, b) + p.appendStmtToBlock(&ir.ObjectInsertStmt{ + Object: vtarget, + Key: op(lkey), + Value: op(lvalue), + }, b) + p.appendStmt(&ir.BlockStmt{Blocks: []*ir.Block{b}}) + } } // At this point vtarget refers to the full extent of the virtual diff --git a/internal/planner/planner_test.go b/internal/planner/planner_test.go index bd488972cf..2f589fa697 100644 --- a/internal/planner/planner_test.go +++ b/internal/planner/planner_test.go @@ -86,7 +86,7 @@ func TestPlannerHelloWorld(t *testing.T) { }, { note: "complete rules", - queries: []string{"true"}, + queries: []string{"data.test.p = x"}, modules: []string{` package test p = x { x = 1 } @@ -141,6 +141,15 @@ func TestPlannerHelloWorld(t *testing.T) { p["b"] = 2 `}, }, + { // NOTE(sr): these are handled differently with ref-heads + note: "partial object with var", + queries: []string{`data.test.p = x`}, + modules: []string{` + package test + p["a"] = 1 + p[v] = 2 { v := "b" } + `}, + }, { note: "every", queries: []string{`data.test.p`}, diff --git a/internal/planner/rules.go b/internal/planner/rules.go index 5e94fb1c5e..6ff78f4c69 100644 --- a/internal/planner/rules.go +++ b/internal/planner/rules.go @@ -144,6 +144,9 @@ func (t *ruletrie) LookupOrInsert(key ast.Ref) *ruletrie { } func (t *ruletrie) Children() []ast.Value { + if t == nil { + return nil + } sorted := make([]ast.Value, 0, len(t.children)) for key := range t.children { if t.Get(key) != nil { diff --git a/refactor/refactor_test.go b/refactor/refactor_test.go index ac55d89b05..6a06318142 100644 --- a/refactor/refactor_test.go +++ b/refactor/refactor_test.go @@ -337,7 +337,7 @@ p = 7`) t.Fatal("Expected error but got nil") } - errMsg := "rego_type_error: conflicting rules named p found" + errMsg := "rego_type_error: conflicting rules data.b.p found" if !strings.Contains(err.Error(), errMsg) { t.Fatalf("Expected error message %v but got %v", errMsg, err.Error()) } diff --git a/rego/rego_test.go b/rego/rego_test.go index 3f408409b0..d0c1b1c950 100644 --- a/rego/rego_test.go +++ b/rego/rego_test.go @@ -1867,7 +1867,6 @@ func TestRegoCustomBuiltinPartialPropagate(t *testing.T) { } func TestRegoPartialResultRecursiveRefs(t *testing.T) { - r := New(Query("data"), Module("test.rego", `package foo.bar default p = false diff --git a/repl/repl.go b/repl/repl.go index 48a2f669f5..6adb33a652 100644 --- a/repl/repl.go +++ b/repl/repl.go @@ -811,11 +811,9 @@ func (r *REPL) compileRule(ctx context.Context, rule *ast.Rule) error { switch r.outputFormat { case "json": default: - var msg string + msg := "defined" if unset { msg = "re-defined" - } else { - msg = "defined" } fmt.Fprintf(r.output, "Rule '%v' %v in %v. Type 'show' to see rules.\n", rule.Head.Name, msg, mod.Package) } @@ -1177,10 +1175,16 @@ func (r *REPL) interpretAsRule(ctx context.Context, compiler *ast.Compiler, body if expr.IsAssignment() { rule, err := ast.ParseCompleteDocRuleFromAssignmentExpr(r.getCurrentOrDefaultModule(), expr.Operand(0), expr.Operand(1)) - if err == nil { - if err := r.compileRule(ctx, rule); err != nil { - return false, err - } + if err != nil { + return false, nil + } + // TODO(sr): support interactive ref head rule definitions + if len(rule.Head.Ref()) > 1 { + return false, nil + } + + if err := r.compileRule(ctx, rule); err != nil { + return false, err } return rule != nil, nil } @@ -1194,12 +1198,17 @@ func (r *REPL) interpretAsRule(ctx context.Context, compiler *ast.Compiler, body } rule, err := ast.ParseCompleteDocRuleFromEqExpr(r.getCurrentOrDefaultModule(), expr.Operand(0), expr.Operand(1)) - if err == nil { - if err := r.compileRule(ctx, rule); err != nil { - return false, err - } + if err != nil { + return false, nil + } + // TODO(sr): support interactive ref head rule definitions + if len(rule.Head.Ref()) > 1 { + return false, nil } + if err := r.compileRule(ctx, rule); err != nil { + return false, err + } return rule != nil, nil } diff --git a/server/server.go b/server/server.go index 2800beaafa..213beafa00 100644 --- a/server/server.go +++ b/server/server.go @@ -2638,13 +2638,7 @@ func stringPathToRef(s string) (r ast.Ref) { } func validateQuery(query string) (ast.Body, error) { - - var body ast.Body - body, err := ast.ParseBody(query) - if err != nil { - return nil, err - } - return body, nil + return ast.ParseBody(query) } func getBoolParam(url *url.URL, name string, ifEmpty bool) bool { diff --git a/test/cases/testdata/jsonpatch/json-patch-tests.yaml b/test/cases/testdata/jsonpatch/json-patch-tests.yaml index 7692887d9a..9d58bc2853 100644 --- a/test/cases/testdata/jsonpatch/json-patch-tests.yaml +++ b/test/cases/testdata/jsonpatch/json-patch-tests.yaml @@ -16,7 +16,7 @@ cases: # Grab all cases from the modules inside `data.json_patch_cases` and # construct an object with readable index names. all_cases[k] = t { - t := data.json_patch_cases[p]["cases"][i] + t := data.json_patch_cases[p].cases[i] k := sprintf("%s/%d", [p, i]) } diff --git a/test/cases/testdata/partialobjectdoc/test-wasm-cases.yaml b/test/cases/testdata/partialobjectdoc/test-wasm-cases.yaml new file mode 100644 index 0000000000..d1da298df4 --- /dev/null +++ b/test/cases/testdata/partialobjectdoc/test-wasm-cases.yaml @@ -0,0 +1,106 @@ + +# NOTE(sr): These test cases stem from cases we run against the wasm +# module but since the ref-heads change made them fail -- because of +# changes to the typing of partial object rules (which are now single- +# valued rules), they're added to the topdown tests, too. +cases: + - note: wasm/additive + query: 'data.x.q = x' + modules: + - | + package x + p["a"] = 1 + p["b"] = 2 + + q { + p == {"a": 1, "b": 2} + } + want_result: + - x: true + - note: wasm/additive (negative) + query: 'data.x.p = input' + input: {"a": 1, "b": 2} + modules: + - | + package x + p["a"] = 1 + p["b"] = 2 + p["c"] = 3 + want_result: [] + - note: wasm/input + query: 'data.x.q = x' + modules: + - | + package x + p["a"] = 1 { input.x = 1 } + p["b"] = 2 { input.y = 2 } + + q { + p == {"a": 1, "b": 2} + } + input: {"x": 1, "y": 2} + want_result: + - x: true + - note: wasm/input (negative) + query: 'data.x.q = x' + data: + z: + a: 1 + b: 2 + modules: + - | + package x + p["a"] = 1 { input.x = 1 } + p["b"] = 2 { input.y = 2 } + p["c"] = 3 { input.z = 3 } + + q { + p == data.z + } + want_result: + - x: true + input: {"x": 1, "y": 2} + - note: wasm/composites + query: 'data.x.q = x' + modules: + - | + package x + p[x] = [y] { x = "a"; y = 1 } + p[x] = [y] { x = "b"; y = 2 } + + q { p == {"a": [1], "b": [2]} } + want_result: + - x : true + - note: wasm/conflict error + query: 'data.x.q = x' + data: + z: + a: 1 + modules: + - | + package x + p["x"] = 1 + p["x"] = 2 + q { + p == data.z + } + want_error: "test-0.rego:3: eval_conflict_error: complete rules must not produce multiple outputs" + want_error_code: eval_conflict_error + - note: wasm/object dereference + query: 'data.x.q = x' + modules: + - | + package x + p["a"] = {"b": 1} + q { + p.a.b = 1 + } + want_result: + - x: true + - note: wasm/object dereference (negative) + query: data.x.p.a.b = 1 + modules: + - | + package x + p["a"] = {"b": 2} + want_result: [] diff --git a/test/cases/testdata/refheads/test-refs-as-rule-heads.yaml b/test/cases/testdata/refheads/test-refs-as-rule-heads.yaml new file mode 100644 index 0000000000..f8d00b5a7e --- /dev/null +++ b/test/cases/testdata/refheads/test-refs-as-rule-heads.yaml @@ -0,0 +1,313 @@ +cases: +- modules: + - | + package test + + p.q.r = 1 + p.q.s = 2 + note: 'refheads/single-value' + query: data.test.p = x + want_result: + - x: + q: + r: 1 + s: 2 +- modules: + - | + package test + + p.q.r = 1 + p.q[s] = 2 { s := "s" } + note: 'refheads/single-value, with var' + query: data.test.p = x + want_result: + - x: + q: + r: 1 + s: 2 +- modules: + - | + package test + + a.b.c.p { true } + note: 'refheads/complete: direct query' + query: data.test.a.b.c.p = x + want_result: + - x: true +- modules: + - | + package test + + q = 0 + note: 'refheads/complete: direct query q' + query: data.test.q = x + want_result: + - x: 0 +- modules: + - | + package test + + a.b.c.p { true } + note: 'refheads/complete: full package extent' + query: data.test = x + want_result: + - x: + a: + b: + c: + p: true +- modules: + - | + package test + + a.b.c.p = 1 + q = 0 + a.b.d = 3 + + p { + q == 0 + a.b.c.p == 1 + a.b.d == 3 + } + note: refheads/complete+mixed + query: data.test.p = x + want_result: + - x: true +- modules: + - | + package test + + a.b[x] = y { x := "c"; y := "d" } + note: refheads/single-value rule + query: data.test.a = x + want_result: + - x: + b: + c: d +- modules: + - | + package test + import future.keywords + + a.b contains x if some x in [1,2,3] + note: refheads/multi-value + query: data.test.a = x + want_result: + - x: + b: [1, 2, 3] +# NOTE(sr): This isn't supported yet +# - modules: +# - | +# package test +# import future.keywords + +# a.b[c] contains x if { c := "c"; some x in [1,2,3] } +# note: refheads/multi-value with var in ref +# query: data.test.a = x +# want_result: +# - x: +# b: +# c: [1, 2, 3] +- modules: + - | + package test + import future.keywords + + a.b[x] = i if some i, x in [1, 2, 3] + note: 'refheads/single-value: previously partial object' + query: data.test.a.b = x + want_result: + - x: + 1: 0 + 2: 1 + 3: 2 +- modules: + - | + package test + import future.keywords + + a.b.c.d contains 1 if true + - | + package test.a + import future.keywords + + b.c.d contains 2 if true + note: 'refheads/multi-value: same rule' + query: data.test.a = x + want_result: + - x: + b: + c: + d: [1, 2] +- modules: + - | + package test + + default a.b.c := "d" + note: refheads/single-value default rule + query: data.test.a = x + want_result: + - x: + b: + c: d +- modules: + - | + package test + import future.keywords + + q[7] = 8 if true + a[x] if q[x] + note: refheads/single-value example + query: data.test.a = x + want_result: + - x: + 7: true +- modules: + - | + package test + import future.keywords + + q[7] = 8 if false + a[x] if q[x] + note: refheads/single-value example, false + query: data.test.a = x + want_result: + - x: {} +- modules: + - | + package test + import future.keywords + + a.b.c = "d" if true + a.b.e = "f" if true + a.b.g contains x if some x in numbers.range(1, 3) + a.b.h[x] = 1 if x := "one" + note: refheads/mixed example, multiple rules + query: data.test.a.b = x + want_result: + - x: + c: d + e: f + g: + - 1 + - 2 + - 3 + h: + one: 1 +- modules: + - | + package example + import future.keywords + + apps_by_hostname[hostname] := app if { + some i + server := sites[_].servers[_] + hostname := server.hostname + apps[i].servers[_] == server.name + app := apps[i].name + } + + sites := [ + { + "region": "east", + "name": "prod", + "servers": [ + { + "name": "web-0", + "hostname": "hydrogen" + }, + { + "name": "web-1", + "hostname": "helium" + }, + { + "name": "db-0", + "hostname": "lithium" + } + ] + }, + { + "region": "west", + "name": "smoke", + "servers": [ + { + "name": "web-1000", + "hostname": "beryllium" + }, + { + "name": "web-1001", + "hostname": "boron" + }, + { + "name": "db-1000", + "hostname": "carbon" + } + ] + }, + { + "region": "west", + "name": "dev", + "servers": [ + { + "name": "web-dev", + "hostname": "nitrogen" + }, + { + "name": "db-dev", + "hostname": "oxygen" + } + ] + } + ] + + apps := [ + { + "name": "web", + "servers": ["web-0", "web-1", "web-1000", "web-1001", "web-dev"] + }, + { + "name": "mysql", + "servers": ["db-0", "db-1000"] + }, + { + "name": "mongodb", + "servers": ["db-dev"] + } + ] + + containers := [ + { + "image": "redis", + "ipaddress": "10.0.0.1", + "name": "big_stallman" + }, + { + "image": "nginx", + "ipaddress": "10.0.0.2", + "name": "cranky_euclid" + } + ] + note: refheads/website-example/partial-obj + query: data.example.apps_by_hostname.helium = x + want_result: + - x: web +- modules: + - | + package example + import future.keywords + + public_network contains net.id if { + some net in input.networks + net.public + } + note: refheads/website-example/partial-set + query: data.example.public_network = x + input: + networks: + - id: n1 + public: true + - id: n2 + public: false + want_result: + - x: + - n1 \ No newline at end of file diff --git a/test/cases/testdata/refheads/test-regressions.yaml b/test/cases/testdata/refheads/test-regressions.yaml new file mode 100644 index 0000000000..4bce1e3c7a --- /dev/null +++ b/test/cases/testdata/refheads/test-regressions.yaml @@ -0,0 +1,116 @@ +# NOTE(sr): These tests are not really related to ref heads, but collection +# regressions found when introducing ref heads. +cases: +- note: regression/ref-not-hashable + modules: + - | + package test + + ms[m.z] = m { + m := input.xs[y] + } + input: + xs: + something: + z: a + query: data.test = x + want_result: + - x: + ms: + a: + z: a +- note: regression/function refs and package extent + modules: + - | + package test + import future.keywords.if + + foo.bar(x) = x+1 + x := foo.bar(2) + query: data.test = x + want_result: + - x: + foo: {} + x: 3 + +- note: regression/rule refs and package extent + modules: + - | + package test + import future.keywords.if + + buz.quz = 3 if input == 3 + x := y { + y := buz.quz with input as 3 + } + query: data.test = x + want_result: + - x: + buz: {} + x: 3 +- note: regression/rule refs and package extents, multiple modules + modules: + - | + package test + import future.keywords.if + + x := y { + y := data.test.buz.quz with input as 3 + } + - | + package test.buz + import future.keywords.if + + quz = 3 if input == 3 + query: data.test = x + want_result: + - x: + buz: {} + x: 3 +- note: regression/type checking with ref rules + modules: + - | + package test + all[0] = [2] + level := 1 + + p := y { y := all[level-1][_] } + query: data.test.p = x + want_result: + - x: 2 +- note: regression/type checking with ref rules, number + modules: + - | + package test + p[0] = 1 + query: 'data.test.p[0] = x' + want_result: + - x: 1 +- note: regression/type checking with ref rules, bool + modules: + - | + package test + p[true] = 1 + query: 'data.test.p[true] = x' + want_result: + - x: 1 +- note: regression/full extent with partial object rule with empty indexer lookup result + modules: + - | + package test + p[x] = 2 { + x := input # we'll get 0 rules for data.test.p + false + } + query: data.test = x + want_result: + - x: + p: {} +- note: regression/obj in ref head query + modules: + - | + package test + p[{"a": "b"}] = true + query: data.test.p[{"a":"b"}] = x + want_result: + - x: true \ No newline at end of file diff --git a/test/wasm/assets/012_partialobjects.yaml b/test/wasm/assets/012_partialobjects.yaml index 621ae6be96..39c0f752c7 100644 --- a/test/wasm/assets/012_partialobjects.yaml +++ b/test/wasm/assets/012_partialobjects.yaml @@ -8,7 +8,8 @@ cases: p["b"] = 2 want_defined: true - note: additive (negative) - query: 'data.x.p = {"a": 1, "b": 2}' + query: 'data.x.p = input' + input: {"a": 1, "b": 2} modules: - | package x @@ -18,15 +19,19 @@ cases: want_defined: false - note: input query: 'data.x.p = {"b": 2, "a": 1}' + input: {"x": 1, "y": 2} modules: - | package x p["a"] = 1 { input.x = 1 } p["b"] = 2 { input.y = 2 } want_defined: true - input: {"x": 1, "y": 2} - note: input (negative) - query: 'data.x.p = {"a": 1, "b": 2}' + query: 'data.x.p = data.z' + data: + z: + a: 1 + b: 2 modules: - | package x @@ -44,13 +49,16 @@ cases: p[x] = [y] { x = "b"; y = 2 } want_defined: true - note: conflict error - query: 'data.x.p = {"a": 1}' + data: + z: + a: 1 + query: 'data.x.p = data.z' modules: - | package x p["x"] = 1 p["x"] = 2 - want_error: "module0.rego:3:1: object insert conflict" + want_error: "module0.rego:3:1: var assignment conflict" - note: object dereference query: data.x.p.a.b = 1 modules: diff --git a/tester/reporter_test.go b/tester/reporter_test.go index afaf527e82..56d98b2d57 100644 --- a/tester/reporter_test.go +++ b/tester/reporter_test.go @@ -75,6 +75,14 @@ func TestPrettyReporterVerbose(t *testing.T) { File: "policy3.rego", }, }, + { + Package: "data.foo.baz", + Name: "p.q.r.test_quz", + Trace: getFakeTraceEvents(), + Location: &ast.Location{ + File: "policy4.rego", + }, + }, } r := PrettyReporter{ @@ -109,11 +117,14 @@ data.foo.bar.test_contains_print: PASS (0s) fake print output + +policy4.rego: +data.foo.baz.p.q.r.test_quz: PASS (0s) -------------------------------------------------------------------------------- -PASS: 2/5 -FAIL: 1/5 -SKIPPED: 1/5 -ERROR: 1/5 +PASS: 3/6 +FAIL: 1/6 +SKIPPED: 1/6 +ERROR: 1/6 ` str := buf.String() @@ -181,6 +192,15 @@ func TestPrettyReporter(t *testing.T) { File: "policy2.rego", }, }, + { + Package: "data.foo.baz", + Name: "p.q.r.test_quz", + Fail: true, + Trace: getFakeTraceEvents(), + Location: &ast.Location{ + File: "policy3.rego", + }, + }, } r := PrettyReporter{ @@ -203,11 +223,14 @@ data.foo.bar.test_contains_print_fail: FAIL (0s) fake print output2 + +policy3.rego: +data.foo.baz.p.q.r.test_quz: FAIL (0s) -------------------------------------------------------------------------------- -PASS: 2/6 -FAIL: 2/6 -SKIPPED: 1/6 -ERROR: 1/6 +PASS: 2/7 +FAIL: 3/7 +SKIPPED: 1/7 +ERROR: 1/7 ` if exp != buf.String() { @@ -246,6 +269,10 @@ func TestJSONReporter(t *testing.T) { Name: "test_contains_print", Output: []byte("fake print output\n"), }, + { + Package: "data.foo.baz", + Name: "p.q.r.test_quz", + }, } r := JSONReporter{ @@ -406,7 +433,13 @@ func TestJSONReporter(t *testing.T) { "name": "test_contains_print", "output": "ZmFrZSBwcmludCBvdXRwdXQK", "duration": 0 - } + }, + { + "location": null, + "package": "data.foo.baz", + "name": "p.q.r.test_quz", + "duration": 0 +} ] `)) diff --git a/tester/runner.go b/tester/runner.go index 5b08ee9a6b..9852882f8c 100644 --- a/tester/runner.go +++ b/tester/runner.go @@ -299,7 +299,7 @@ func (r *Runner) runTests(ctx context.Context, txn storage.Transaction, enablePr } // rewrite duplicate test_* rule names as we compile modules - r.compiler.WithStageAfter("ResolveRefs", ast.CompilerStageDefinition{ + r.compiler.WithStageAfter("RewriteRuleHeadRefs", ast.CompilerStageDefinition{ Name: "RewriteDuplicateTestNames", MetricName: "rewrite_duplicate_test_names", Stage: rewriteDuplicateTestNames, @@ -340,7 +340,7 @@ func (r *Runner) runTests(ctx context.Context, txn storage.Transaction, enablePr } } - if r.modules != nil && len(r.modules) > 0 { + if len(r.modules) > 0 { if r.compiler.Compile(r.modules); r.compiler.Failed() { return nil, r.compiler.Errors } @@ -380,7 +380,7 @@ func (r *Runner) runTests(ctx context.Context, txn storage.Transaction, enablePr } func (r *Runner) shouldRun(rule *ast.Rule, testRegex *regexp.Regexp) bool { - ruleName := string(rule.Head.Name) + ruleName := ruleName(rule.Head) // All tests must have the right prefix if !strings.HasPrefix(ruleName, TestPrefix) && !strings.HasPrefix(ruleName, SkipTestPrefix) { @@ -388,7 +388,7 @@ func (r *Runner) shouldRun(rule *ast.Rule, testRegex *regexp.Regexp) bool { } // Even with the prefix it needs to pass the regex (if applicable) - fullName := fmt.Sprintf("%s.%s", rule.Module.Package.Path.String(), ruleName) + fullName := rule.Ref().String() if testRegex != nil && !testRegex.MatchString(fullName) { return false } @@ -403,13 +403,20 @@ func rewriteDuplicateTestNames(compiler *ast.Compiler) *ast.Error { count := map[string]int{} for _, mod := range compiler.Modules { for _, rule := range mod.Rules { - name := rule.Head.Name.String() + name := ruleName(rule.Head) if !strings.HasPrefix(name, TestPrefix) { continue } - key := rule.Path().String() + key := rule.Ref().String() if k, ok := count[key]; ok { - rule.Head.Name = ast.Var(fmt.Sprintf("%s#%02d", name, k)) + ref := rule.Head.Ref() + newName := fmt.Sprintf("%s#%02d", name, k) + if len(ref) == 1 { + ref[0] = ast.VarTerm(newName) + } else { + ref[len(ref)-1] = ast.StringTerm(newName) + } + rule.Head.SetRef(ref) } count[key]++ } @@ -417,6 +424,23 @@ func rewriteDuplicateTestNames(compiler *ast.Compiler) *ast.Error { return nil } +// ruleName is a helper to be used when checking if a function +// (a) is a test, or +// (b) needs to be skipped +// -- it'll resolve `p.q.r` to `r`. For representing results, we'll +// use rule.Head.Ref() +func ruleName(h *ast.Head) string { + ref := h.Ref() + switch last := ref[len(ref)-1].Value.(type) { + case ast.Var: + return string(last) + case ast.String: + return string(last) + default: + panic("unreachable") + } +} + func (r *Runner) runTest(ctx context.Context, txn storage.Transaction, mod *ast.Module, rule *ast.Rule) (*Result, bool) { var bufferTracer *topdown.BufferTracer var bufFailureLineTracer *topdown.BufferTracer @@ -429,10 +453,9 @@ func (r *Runner) runTest(ctx context.Context, txn storage.Transaction, mod *ast. tracer = bufferTracer } - ruleName := string(rule.Head.Name) - - if strings.HasPrefix(ruleName, SkipTestPrefix) { - tr := newResult(rule.Loc(), mod.Package.Path.String(), ruleName, 0*time.Second, nil, nil) + ruleName := ruleName(rule.Head) + if strings.HasPrefix(ruleName, SkipTestPrefix) { // TODO(sr): add test + tr := newResult(rule.Loc(), mod.Package.Path.String(), rule.Head.Ref().String(), 0*time.Second, nil, nil) tr.Skip = true return tr, false } @@ -465,7 +488,7 @@ func (r *Runner) runTest(ctx context.Context, txn storage.Transaction, mod *ast. trace = *bufferTracer } - tr := newResult(rule.Loc(), mod.Package.Path.String(), ruleName, dt, trace, printbuf.Bytes()) + tr := newResult(rule.Loc(), mod.Package.Path.String(), rule.Head.Ref().String(), dt, trace, printbuf.Bytes()) tr.Error = err var stop bool @@ -489,7 +512,7 @@ func (r *Runner) runBenchmark(ctx context.Context, txn storage.Transaction, mod tr := &Result{ Location: rule.Loc(), Package: mod.Package.Path.String(), - Name: string(rule.Head.Name), + Name: rule.Head.Ref().String(), // TODO(sr): test } var stop bool diff --git a/tester/runner_test.go b/tester/runner_test.go index c3679c7cb2..33241fc2be 100644 --- a/tester/runner_test.go +++ b/tester/runner_test.go @@ -75,19 +75,27 @@ func testRun(t *testing.T, conf testRunConfig) map[string]*ast.Module { `, "/b_test.rego": `package bar - test_duplicate { true }`, + test_duplicate { true }`, + "/c_test.rego": `package baz + + a.b.test_duplicate { false } + a.b.test_duplicate { true } + a.b.test_duplicate { true }`, } tests := expectedTestResults{ - {"data.foo", "test_pass"}: {false, false, false}, - {"data.foo", "test_fail"}: {false, true, false}, - {"data.foo", "test_fail_non_bool"}: {false, true, false}, - {"data.foo", "test_duplicate"}: {false, true, false}, - {"data.foo", "test_duplicate#01"}: {false, false, false}, - {"data.foo", "test_duplicate#02"}: {false, false, false}, - {"data.foo", "test_err"}: {true, false, false}, - {"data.foo", "todo_test_skip"}: {false, false, true}, - {"data.bar", "test_duplicate"}: {false, false, false}, + {"data.foo", "test_pass"}: {false, false, false}, + {"data.foo", "test_fail"}: {false, true, false}, + {"data.foo", "test_fail_non_bool"}: {false, true, false}, + {"data.foo", "test_duplicate"}: {false, true, false}, + {"data.foo", "test_duplicate#01"}: {false, false, false}, + {"data.foo", "test_duplicate#02"}: {false, false, false}, + {"data.foo", "test_err"}: {true, false, false}, + {"data.foo", "todo_test_skip"}: {false, false, true}, + {"data.bar", "test_duplicate"}: {false, false, false}, + {"data.baz", "a.b.test_duplicate"}: {false, true, false}, + {"data.baz", "a.b[\"test_duplicate#01\"]"}: {false, false, false}, + {"data.baz", "a.b[\"test_duplicate#02\"]"}: {false, false, false}, } var modules map[string]*ast.Module @@ -186,6 +194,11 @@ func TestRunWithFilterRegex(t *testing.T) { "/b_test.rego": `package bar test_duplicate { true }`, + "/c_test.rego": `package baz + + a.b.test_duplicate { false } + a.b.test_duplicate { true } + a.b.test_duplicate { true }`, } cases := []struct { @@ -197,32 +210,38 @@ func TestRunWithFilterRegex(t *testing.T) { note: "all tests match", regex: ".*", tests: expectedTestResults{ - {"data.foo", "test_pass"}: {false, false, false}, - {"data.foo", "test_fail"}: {false, true, false}, - {"data.foo", "test_fail_non_bool"}: {false, true, false}, - {"data.foo", "test_duplicate"}: {false, true, false}, - {"data.foo", "test_duplicate#01"}: {false, false, false}, - {"data.foo", "test_duplicate#02"}: {false, false, false}, - {"data.foo", "test_err"}: {true, false, false}, - {"data.foo", "todo_test_skip"}: {false, false, true}, - {"data.foo", "todo_test_skip_too"}: {false, false, true}, - {"data.bar", "test_duplicate"}: {false, false, false}, + {"data.foo", "test_pass"}: {false, false, false}, + {"data.foo", "test_fail"}: {false, true, false}, + {"data.foo", "test_fail_non_bool"}: {false, true, false}, + {"data.foo", "test_duplicate"}: {false, true, false}, + {"data.foo", "test_duplicate#01"}: {false, false, false}, + {"data.foo", "test_duplicate#02"}: {false, false, false}, + {"data.foo", "test_err"}: {true, false, false}, + {"data.foo", "todo_test_skip"}: {false, false, true}, + {"data.foo", "todo_test_skip_too"}: {false, false, true}, + {"data.bar", "test_duplicate"}: {false, false, false}, + {"data.baz", "a.b.test_duplicate"}: {false, true, false}, + {"data.baz", "a.b[\"test_duplicate#01\"]"}: {false, false, false}, + {"data.baz", "a.b[\"test_duplicate#02\"]"}: {false, false, false}, }, }, { note: "no filter", regex: "", tests: expectedTestResults{ - {"data.foo", "test_pass"}: {false, false, false}, - {"data.foo", "test_fail"}: {false, true, false}, - {"data.foo", "test_fail_non_bool"}: {false, true, false}, - {"data.foo", "test_duplicate"}: {false, true, false}, - {"data.foo", "test_duplicate#01"}: {false, false, false}, - {"data.foo", "test_duplicate#02"}: {false, false, false}, - {"data.foo", "test_err"}: {true, false, false}, - {"data.foo", "todo_test_skip"}: {false, false, true}, - {"data.foo", "todo_test_skip_too"}: {false, false, true}, - {"data.bar", "test_duplicate"}: {false, false, false}, + {"data.foo", "test_pass"}: {false, false, false}, + {"data.foo", "test_fail"}: {false, true, false}, + {"data.foo", "test_fail_non_bool"}: {false, true, false}, + {"data.foo", "test_duplicate"}: {false, true, false}, + {"data.foo", "test_duplicate#01"}: {false, false, false}, + {"data.foo", "test_duplicate#02"}: {false, false, false}, + {"data.foo", "test_err"}: {true, false, false}, + {"data.foo", "todo_test_skip"}: {false, false, true}, + {"data.foo", "todo_test_skip_too"}: {false, false, true}, + {"data.bar", "test_duplicate"}: {false, false, false}, + {"data.baz", "a.b.test_duplicate"}: {false, true, false}, + {"data.baz", "a.b[\"test_duplicate#01\"]"}: {false, false, false}, + {"data.baz", "a.b[\"test_duplicate#02\"]"}: {false, false, false}, }, }, { @@ -288,6 +307,15 @@ func TestRunWithFilterRegex(t *testing.T) { {"data.bar", "test_duplicate"}: {false, false, false}, }, }, + { + note: "matching ref rule halfways", + regex: "data.baz.a", + tests: expectedTestResults{ + {"data.baz", "a.b.test_duplicate"}: {false, true, false}, + {"data.baz", "a.b[\"test_duplicate#01\"]"}: {false, false, false}, + {"data.baz", "a.b[\"test_duplicate#02\"]"}: {false, false, false}, + }, + }, } test.WithTempFS(files, func(d string) { @@ -436,7 +464,8 @@ func TestRunnerPrintOutput(t *testing.T) { test_a { print("A") } test_b { false; print("B") } - test_c { print("C"); false }`, + test_c { print("C"); false } + p.q.r.test_d { print("D") }`, } ctx := context.Background() @@ -461,9 +490,10 @@ func TestRunnerPrintOutput(t *testing.T) { } exp := map[string]string{ - "test_a": "A\n", - "test_b": "", - "test_c": "C\n", + "test_a": "A\n", + "test_b": "", + "test_c": "C\n", + "p.q.r.test_d": "D\n", } got := map[string]string{} diff --git a/topdown/eval.go b/topdown/eval.go index 55e2c34696..3f824e6e34 100644 --- a/topdown/eval.go +++ b/topdown/eval.go @@ -17,6 +17,7 @@ import ( "github.com/open-policy-agent/opa/topdown/print" "github.com/open-policy-agent/opa/tracing" "github.com/open-policy-agent/opa/types" + "github.com/open-policy-agent/opa/util" ) type evalIterator func(*eval) error @@ -2008,9 +2009,10 @@ func (e evalFunc) partialEvalSupportRule(rule *ast.Rule, path ast.Ref) error { // Type-checking failure means the rule body will never succeed. if e.e.compiler.PassesTypeCheck(plugged) { head := &ast.Head{ - Name: rule.Head.Name, - Value: child.bindings.PlugNamespaced(rule.Head.Value, e.e.caller.bindings), - Args: make([]*ast.Term, len(rule.Head.Args)), + Name: rule.Head.Name, + Reference: rule.Head.Reference, + Value: child.bindings.PlugNamespaced(rule.Head.Value, e.e.caller.bindings), + Args: make([]*ast.Term, len(rule.Head.Args)), } for i, a := range rule.Head.Args { head.Args[i] = child.bindings.PlugNamespaced(a, e.e.caller.bindings) @@ -2089,13 +2091,14 @@ func (e evalTree) next(iter unifyIterator, plugged *ast.Term) error { node = e.node.Child(plugged.Value) if node != nil && len(node.Values) > 0 { r := evalVirtual{ - e: e.e, - ref: e.ref, - plugged: e.plugged, - pos: e.pos, - bindings: e.bindings, - rterm: e.rterm, - rbindings: e.rbindings, + onlyGroundRefs: onlyGroundRefs(node.Values), + e: e.e, + ref: e.ref, + plugged: e.plugged, + pos: e.pos, + bindings: e.bindings, + rterm: e.rterm, + rbindings: e.rbindings, } r.plugged[e.pos] = plugged return r.eval(iter) @@ -2107,6 +2110,16 @@ func (e evalTree) next(iter unifyIterator, plugged *ast.Term) error { return cpy.eval(iter) } +func onlyGroundRefs(values []util.T) bool { + for _, v := range values { + rule := v.(*ast.Rule) + if !rule.Head.Reference.IsGround() { + return false + } + } + return true +} + func (e evalTree) enumerate(iter unifyIterator) error { if e.e.inliningControl.Disabled(e.plugged[:e.pos], true) { @@ -2196,6 +2209,8 @@ func (e evalTree) extent() (*ast.Term, error) { return ast.NewTerm(virtual), nil } +// leaves builds a tree from evaluating the full rule tree extent, by recursing into all +// branches, and building up objects as it goes. func (e evalTree) leaves(plugged ast.Ref, node *ast.TreeNode) (ast.Object, error) { if e.node == nil { @@ -2243,13 +2258,14 @@ func (e evalTree) leaves(plugged ast.Ref, node *ast.TreeNode) (ast.Object, error } type evalVirtual struct { - e *eval - ref ast.Ref - plugged ast.Ref - pos int - bindings *bindings - rterm *ast.Term - rbindings *bindings + onlyGroundRefs bool + e *eval + ref ast.Ref + plugged ast.Ref + pos int + bindings *bindings + rterm *ast.Term + rbindings *bindings } func (e evalVirtual) eval(iter unifyIterator) error { @@ -2266,7 +2282,7 @@ func (e evalVirtual) eval(iter unifyIterator) error { } switch ir.Kind { - case ast.PartialSetDoc: + case ast.MultiValue: eval := evalVirtualPartial{ e: e.e, ref: e.ref, @@ -2279,7 +2295,22 @@ func (e evalVirtual) eval(iter unifyIterator) error { empty: ast.SetTerm(), } return eval.eval(iter) - case ast.PartialObjectDoc: + case ast.SingleValue: + // NOTE(sr): If we allow vars in others than the last position of a ref, we need + // to start reworking things here + if e.onlyGroundRefs { + eval := evalVirtualComplete{ + e: e.e, + ref: e.ref, + plugged: e.plugged, + pos: e.pos, + ir: ir, + bindings: e.bindings, + rterm: e.rterm, + rbindings: e.rbindings, + } + return eval.eval(iter) + } eval := evalVirtualPartial{ e: e.e, ref: e.ref, @@ -2293,17 +2324,7 @@ func (e evalVirtual) eval(iter unifyIterator) error { } return eval.eval(iter) default: - eval := evalVirtualComplete{ - e: e.e, - ref: e.ref, - plugged: e.plugged, - pos: e.pos, - ir: ir, - bindings: e.bindings, - rterm: e.rterm, - rbindings: e.rbindings, - } - return eval.eval(iter) + panic("unreachable") } } @@ -2437,13 +2458,17 @@ func (e evalVirtualPartial) evalOneRulePreUnify(iter unifyIterator, rule *ast.Ru child.traceEnter(rule) var defined bool - err := child.biunify(rule.Head.Key, key, child.bindings, e.bindings, func() error { + headKey := rule.Head.Key + if headKey == nil { + headKey = rule.Head.Reference[len(rule.Head.Reference)-1] + } + err := child.biunify(headKey, key, child.bindings, e.bindings, func() error { defined = true return child.eval(func(child *eval) error { term := rule.Head.Value if term == nil { - term = rule.Head.Key + term = headKey } if hint.key != nil { @@ -2550,7 +2575,8 @@ func (e evalVirtualPartial) partialEvalSupport(iter unifyIterator) error { ok, err := e.partialEvalSupportRule(e.ir.Rules[i], path) if err != nil { return err - } else if ok { + } + if ok { defined = true } } @@ -2669,13 +2695,15 @@ func (e evalVirtualPartial) evalCache(iter unifyIterator) (evalVirtualPartialCac func (e evalVirtualPartial) reduce(head *ast.Head, b *bindings, result *ast.Term) (*ast.Term, bool, error) { var exists bool - key := b.Plug(head.Key) switch v := result.Value.(type) { - case ast.Set: + case ast.Set: // MultiValue + key := b.Plug(head.Key) exists = v.Contains(key) v.Add(key) - case ast.Object: + case ast.Object: // SingleValue + key := head.Reference[len(head.Reference)-1] // NOTE(sr): multiple vars in ref heads need to deal with this better + key = b.Plug(key) value := b.Plug(head.Value) if curr := v.Get(key); curr != nil { if !curr.Equal(value) { @@ -2782,6 +2810,7 @@ func (e evalVirtualComplete) evalValueRule(iter unifyIterator, rule *ast.Rule, p var result *ast.Term err := child.eval(func(child *eval) error { child.traceExit(rule) + result = child.bindings.Plug(rule.Head.Value) if prev != nil { @@ -2794,8 +2823,8 @@ func (e evalVirtualComplete) evalValueRule(iter unifyIterator, rule *ast.Rule, p prev = result e.e.virtualCache.Put(e.plugged[:e.pos+1], result) - term, termbindings := child.bindings.apply(rule.Head.Value) + term, termbindings := child.bindings.apply(rule.Head.Value) err := e.evalTerm(iter, term, termbindings) if err != nil { return err @@ -2840,42 +2869,65 @@ func (e evalVirtualComplete) partialEvalSupport(iter unifyIterator) error { path := e.e.namespaceRef(e.plugged[:e.pos+1]) term := ast.NewTerm(e.e.namespaceRef(e.ref)) - if !e.e.saveSupport.Exists(path) { + var defined bool + if e.e.saveSupport.Exists(path) { + defined = true + } else { for i := range e.ir.Rules { - err := e.partialEvalSupportRule(e.ir.Rules[i], path) + ok, err := e.partialEvalSupportRule(e.ir.Rules[i], path) if err != nil { return err } + if ok { + defined = true + } } if e.ir.Default != nil { - err := e.partialEvalSupportRule(e.ir.Default, path) + ok, err := e.partialEvalSupportRule(e.ir.Default, path) if err != nil { return err } + if ok { + defined = true + } } } + if !defined { + return nil + } + return e.e.saveUnify(term, e.rterm, e.bindings, e.rbindings, iter) } -func (e evalVirtualComplete) partialEvalSupportRule(rule *ast.Rule, path ast.Ref) error { +func (e evalVirtualComplete) partialEvalSupportRule(rule *ast.Rule, path ast.Ref) (bool, error) { child := e.e.child(rule.Body) child.traceEnter(rule) e.e.saveStack.PushQuery(nil) + var defined bool err := child.eval(func(child *eval) error { child.traceExit(rule) + defined = true current := e.e.saveStack.PopQuery() plugged := current.Plug(e.e.caller.bindings) // Skip this rule body if it fails to type-check. // Type-checking failure means the rule body will never succeed. if e.e.compiler.PassesTypeCheck(plugged) { - head := ast.NewHead(rule.Head.Name, nil, child.bindings.PlugNamespaced(rule.Head.Value, e.e.caller.bindings)) + var name ast.Var + switch ref := rule.Head.Ref().GroundPrefix(); len(ref) { + case 1: + name = ref[0].Value.(ast.Var) + default: + s := ref[len(ref)-1].Value.(ast.String) + name = ast.Var(s) + } + head := ast.NewHead(name, nil, child.bindings.PlugNamespaced(rule.Head.Value, e.e.caller.bindings)) if !e.e.inliningControl.shallow { cp := copypropagation.New(head.Vars()). @@ -2895,7 +2947,7 @@ func (e evalVirtualComplete) partialEvalSupportRule(rule *ast.Rule, path ast.Ref return nil }) e.e.saveStack.PopQuery() - return err + return defined, err } func (e evalVirtualComplete) evalTerm(iter unifyIterator, term *ast.Term, termbindings *bindings) error { diff --git a/wasm/src/value.c b/wasm/src/value.c index 07f3c30e9d..d9d03d4035 100644 --- a/wasm/src/value.c +++ b/wasm/src/value.c @@ -1629,8 +1629,13 @@ opa_errc opa_value_remove_path(opa_value *data, opa_value *path) // be found, or of there's no function index leaf when we've run out // of path pieces. int opa_lookup(opa_value *mapping, opa_value *path) { - int path_len = _validate_json_path(path); - if (path_len < 1) + if (path == NULL || opa_value_type(path) != OPA_ARRAY) + { + return 0; + } + + int path_len = opa_value_length(path); + if (path_len == 0) { return 0; }