diff --git a/ast/compile_test.go b/ast/compile_test.go index 06890bf7a18..8340a515b75 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -1844,6 +1844,41 @@ bar.baz contains "quz" if true`, assertCompilerErrorStrings(t, c, expected) } +func TestCompilerCheckRuleConflictsDefaultFunction(t *testing.T) { + tests := []struct { + note string + modules []*Module + err string + }{ + { + note: "conflicting rules", + modules: modules( + `package pkg + default f(_) = 100 + f(x, y) = x { + x == y + }`), + err: "rego_type_error: conflicting rules data.pkg.f found", + }, + } + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + mods := make(map[string]*Module, len(tc.modules)) + for i, m := range tc.modules { + mods[fmt.Sprint(i)] = m + } + c := NewCompiler() + c.Modules = mods + compileStages(c, c.checkRuleConflicts) + if tc.err != "" { + assertCompilerErrorStrings(t, c, []string{tc.err}) + } else { + assertCompilerErrorStrings(t, c, []string{}) + } + }) + } +} + func TestCompilerCheckRuleConflictsDotsInRuleHeads(t *testing.T) { tests := []struct { diff --git a/ast/parser.go b/ast/parser.go index 58e9e73c8a6..3337a964e45 100644 --- a/ast/parser.go +++ b/ast/parser.go @@ -614,6 +614,12 @@ func (p *Parser) parseRules() []*Rule { return nil } + if len(rule.Head.Args) > 0 { + if !p.validateDefaultRuleArgs(&rule) { + return nil + } + } + rule.Body = NewBody(NewExpr(BooleanTerm(true).SetLocation(rule.Location)).SetLocation(rule.Location)) return []*Rule{&rule} } @@ -2176,6 +2182,38 @@ func (p *Parser) validateDefaultRuleValue(rule *Rule) bool { return valid } +func (p *Parser) validateDefaultRuleArgs(rule *Rule) bool { + + valid := true + vars := NewVarSet() + + vis := NewGenericVisitor(func(x interface{}) bool { + switch x := x.(type) { + case Var: + if vars.Contains(x) { + p.error(rule.Loc(), fmt.Sprintf("illegal default rule (arguments cannot be repeated %v)", x)) + valid = false + return true + } + vars.Add(x) + + case *Term: + switch v := x.Value.(type) { + case Var: // do nothing + default: + p.error(rule.Loc(), fmt.Sprintf("illegal default rule (arguments cannot contain %v)", TypeName(v))) + valid = false + return true + } + } + + return false + }) + + vis.Walk(rule.Head.Args) + return valid +} + // We explicitly use yaml unmarshalling, to accommodate for the '_' in 'related_resources', // which isn't handled properly by json for some reason. type rawAnnotation struct { diff --git a/ast/parser_test.go b/ast/parser_test.go index 5c8422cc7d1..11d522fd9c4 100644 --- a/ast/parser_test.go +++ b/ast/parser_test.go @@ -1611,6 +1611,14 @@ func TestRule(t *testing.T) { assertParseErrorContains(t, "default invalid rule head builtin call", `default a = upper("foo")`, "illegal default rule (value cannot contain call)") assertParseErrorContains(t, "default invalid rule head call", `default a = b`, "illegal default rule (value cannot contain var)") + assertParseErrorContains(t, "default invalid function head ref", `default f(x) = b.c.d`, "illegal default rule (value cannot contain ref)") + assertParseErrorContains(t, "default invalid function head call", `default f(x) = g(x)`, "illegal default rule (value cannot contain call)") + assertParseErrorContains(t, "default invalid function head builtin call", `default f(x) = upper("foo")`, "illegal default rule (value cannot contain call)") + assertParseErrorContains(t, "default invalid function head call", `default f(x) = b`, "illegal default rule (value cannot contain var)") + assertParseErrorContains(t, "default invalid function composite argument", `default f([x]) = 1`, "illegal default rule (arguments cannot contain array)") + assertParseErrorContains(t, "default invalid function number argument", `default f(1) = 1`, "illegal default rule (arguments cannot contain number)") + assertParseErrorContains(t, "default invalid function repeated vars", `default f(x, x) = 1`, "illegal default rule (arguments cannot be repeated x)") + assertParseError(t, "extra braces", `{ a := 1 }`) assertParseError(t, "invalid rule name hyphen", `a-b = x { x := 1 }`) diff --git a/internal/wasm/sdk/test/e2e/exceptions.yaml b/internal/wasm/sdk/test/e2e/exceptions.yaml index 714af1ba93b..664e29b20f7 100644 --- a/internal/wasm/sdk/test/e2e/exceptions.yaml +++ b/internal/wasm/sdk/test/e2e/exceptions.yaml @@ -1,5 +1,4 @@ # Exception Format is : -"functions/default": "not supported in topdown, https://github.com/open-policy-agent/opa/issues/2445" "data/toplevel integer": "https://github.com/open-policy-agent/opa/issues/3711" "data/nested integer": "https://github.com/open-policy-agent/opa/issues/3711" "withkeyword/function: indirect call, arity 1, replacement is value that needs eval (array comprehension)": "https://github.com/open-policy-agent/opa/issues/5311" diff --git a/test/cases/testdata/defaultkeyword/test-default-functions.yaml b/test/cases/testdata/defaultkeyword/test-default-functions.yaml index 6ece1772036..81350089c9f 100644 --- a/test/cases/testdata/defaultkeyword/test-default-functions.yaml +++ b/test/cases/testdata/defaultkeyword/test-default-functions.yaml @@ -1,5 +1,3 @@ -# NOTE(sr): default functions are not supported, but they had sneaked into -# the full extent of a package. These cases assert that this won't happen. cases: - note: defaultkeyword/function with var arg modules: @@ -10,24 +8,32 @@ cases: query: data.test = x want_result: - x: {} -- note: defaultkeyword/function with ground arg +- note: defaultkeyword/function with var arg, ref head modules: - | package test - default f(10) := 100 - query: data.test = x - want_result: - - x: {} -- note: defaultkeyword/function with ground arg, ref head - modules: - - | - package test - - default p.q.r.f(10) := 100 + default p.q.r.f(x) := 100 query: data.test = x want_result: - x: p: q: - r: {} \ No newline at end of file + r: {} +- note: defaultkeyword/function with var arg, ref head query + modules: + - | + package test + + default p.q.r.f(x) := 100 + + p.q.r.f(x) = x { + x == 2 + } + + foo { + p.q.r.f(3) == 100 + } + query: data.test.foo = x + want_result: + - x: true \ No newline at end of file diff --git a/test/cases/testdata/functions/test-functions-default.yaml b/test/cases/testdata/functions/test-functions-default.yaml index 00c7abcd291..fbeaf84b8ed 100644 --- a/test/cases/testdata/functions/test-functions-default.yaml +++ b/test/cases/testdata/functions/test-functions-default.yaml @@ -1,21 +1,99 @@ cases: - data: modules: - - | - package p.m - - default hello = false - - hello() = m { - m = input.message - 1 == 2 - m = "world" - } - h = m { - m = hello() - } - note: functions/default # not supported but shouldn't panic - query: data.p.m = x + - | + package test + + default f(x) = 1 + + f(x) = x { + x > 0 + } + + p { + f(-1) == 1 + } + + note: functions/default + query: data.test.p = x + want_result: + - x: true + +- data: + modules: + - | + package test + + default f(x) = 1 + + f(x) = x { + x > 0 + } + + p { + f(2) == 2 + } + + note: functions/non default + query: data.test.p = x + want_result: + - x: true + +- data: + modules: + - | + package test + + default f(x) = 1 + + p { + f(2) == 1 + } + + note: functions/only default + query: data.test.p = x + want_result: + - x: true + +- data: + modules: + - | + package test + + default f(_, _) = 1 + + f(x, y) = x { + x == y + } + + p { + f(2, 2) == 2 + } + + note: functions/only default + query: data.test.p = x + want_result: + - x: true + +- data: + modules: + - | + package test + + default f(x) = 1000 + + f(x) = x { + x > 0 + } + + p = xs { + xs := [y | x = [1, -2, 3][_]; y := f(x)] + } + + note: functions/comprehensions + query: data.test.p = x want_result: - - x: - hello: false + - x: + - 1 + - 1000 + - 3 diff --git a/topdown/eval.go b/topdown/eval.go index 545dcff8fe7..49cf0e5dcf6 100644 --- a/topdown/eval.go +++ b/topdown/eval.go @@ -1795,13 +1795,16 @@ type evalFunc struct { func (e evalFunc) eval(iter unifyIterator) error { - // default functions aren't supported: - // https://github.com/open-policy-agent/opa/issues/2445 - if len(e.ir.Rules) == 0 { + if e.ir.Empty() { return nil } - argCount := len(e.ir.Rules[0].Head.Args) + var argCount int + if len(e.ir.Rules) > 0 { + argCount = len(e.ir.Rules[0].Head.Args) + } else if e.ir.Default != nil { + argCount = len(e.ir.Default.Head.Args) + } if len(e.ir.Else) > 0 && e.e.unknown(e.e.query[e.e.index], e.e.bindings) { // Partial evaluation of ordered rules is not supported currently. Save the @@ -1820,6 +1823,7 @@ func (e evalFunc) eval(iter unifyIterator) error { return e.partialEvalSupport(argCount, iter) } } + return suppressEarlyExit(e.evalValue(iter, argCount, e.ir.EarlyExit)) } @@ -1859,6 +1863,11 @@ func (e evalFunc) evalValue(iter unifyIterator, argCount int, findOne bool) erro } } + if e.ir.Default != nil && prev == nil { + _, err := e.evalOneRule(iter, e.ir.Default, cacheKey, prev, findOne) + return err + } + return nil }