Skip to content

Commit

Permalink
topdown: Honor default keyword on functions
Browse files Browse the repository at this point in the history
Default functions satisfy the following properties:

* Same arity as other functions with the same name
* Arguments should only be plain variables ie. no composite values. For ex, default f([x]) = 1 is an invalid default function
* Variable names should not be repeated ie. default f(x, x) = 1 is an invalid default function

Fixes: open-policy-agent#2445

Signed-off-by: Ashutosh Narkar <anarkar4387@gmail.com>
  • Loading branch information
ashutosh-narkar committed Jul 19, 2023
1 parent 57c4daa commit a1ca32a
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 36 deletions.
35 changes: 35 additions & 0 deletions ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1844,6 +1844,41 @@ bar.baz contains "quz" if true`,
assertCompilerErrorStrings(t, c, expected)
}

func TestCompilerCheckRuleConflictsDefaultFunction(t *testing.T) {
tests := []struct {
note string
modules []*Module
err string
}{
{
note: "conflicting rules",
modules: modules(
`package pkg
default f(_) = 100
f(x, y) = x {
x == y
}`),
err: "rego_type_error: conflicting rules data.pkg.f found",
},
}
for _, tc := range tests {
t.Run(tc.note, func(t *testing.T) {
mods := make(map[string]*Module, len(tc.modules))
for i, m := range tc.modules {
mods[fmt.Sprint(i)] = m
}
c := NewCompiler()
c.Modules = mods
compileStages(c, c.checkRuleConflicts)
if tc.err != "" {
assertCompilerErrorStrings(t, c, []string{tc.err})
} else {
assertCompilerErrorStrings(t, c, []string{})
}
})
}
}

func TestCompilerCheckRuleConflictsDotsInRuleHeads(t *testing.T) {

tests := []struct {
Expand Down
38 changes: 38 additions & 0 deletions ast/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,12 @@ func (p *Parser) parseRules() []*Rule {
return nil
}

if len(rule.Head.Args) > 0 {
if !p.validateDefaultRuleArgs(&rule) {
return nil
}
}

rule.Body = NewBody(NewExpr(BooleanTerm(true).SetLocation(rule.Location)).SetLocation(rule.Location))
return []*Rule{&rule}
}
Expand Down Expand Up @@ -2176,6 +2182,38 @@ func (p *Parser) validateDefaultRuleValue(rule *Rule) bool {
return valid
}

func (p *Parser) validateDefaultRuleArgs(rule *Rule) bool {

valid := true
vars := NewVarSet()

vis := NewGenericVisitor(func(x interface{}) bool {
switch x := x.(type) {
case Var:
if vars.Contains(x) {
p.error(rule.Loc(), fmt.Sprintf("illegal default rule (arguments cannot be repeated %v)", x))
valid = false
return true
}
vars.Add(x)

case *Term:
switch v := x.Value.(type) {
case Var: // do nothing
default:
p.error(rule.Loc(), fmt.Sprintf("illegal default rule (arguments cannot contain %v)", TypeName(v)))
valid = false
return true
}
}

return false
})

vis.Walk(rule.Head.Args)
return valid
}

// We explicitly use yaml unmarshalling, to accommodate for the '_' in 'related_resources',
// which isn't handled properly by json for some reason.
type rawAnnotation struct {
Expand Down
8 changes: 8 additions & 0 deletions ast/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1611,6 +1611,14 @@ func TestRule(t *testing.T) {
assertParseErrorContains(t, "default invalid rule head builtin call", `default a = upper("foo")`, "illegal default rule (value cannot contain call)")
assertParseErrorContains(t, "default invalid rule head call", `default a = b`, "illegal default rule (value cannot contain var)")

assertParseErrorContains(t, "default invalid function head ref", `default f(x) = b.c.d`, "illegal default rule (value cannot contain ref)")
assertParseErrorContains(t, "default invalid function head call", `default f(x) = g(x)`, "illegal default rule (value cannot contain call)")
assertParseErrorContains(t, "default invalid function head builtin call", `default f(x) = upper("foo")`, "illegal default rule (value cannot contain call)")
assertParseErrorContains(t, "default invalid function head call", `default f(x) = b`, "illegal default rule (value cannot contain var)")
assertParseErrorContains(t, "default invalid function composite argument", `default f([x]) = 1`, "illegal default rule (arguments cannot contain array)")
assertParseErrorContains(t, "default invalid function number argument", `default f(1) = 1`, "illegal default rule (arguments cannot contain number)")
assertParseErrorContains(t, "default invalid function repeated vars", `default f(x, x) = 1`, "illegal default rule (arguments cannot be repeated x)")

assertParseError(t, "extra braces", `{ a := 1 }`)
assertParseError(t, "invalid rule name hyphen", `a-b = x { x := 1 }`)

Expand Down
1 change: 0 additions & 1 deletion internal/wasm/sdk/test/e2e/exceptions.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Exception Format is <test name>: <reason>
"functions/default": "not supported in topdown, https://github.com/open-policy-agent/opa/issues/2445"
"data/toplevel integer": "https://github.com/open-policy-agent/opa/issues/3711"
"data/nested integer": "https://github.com/open-policy-agent/opa/issues/3711"
"withkeyword/function: indirect call, arity 1, replacement is value that needs eval (array comprehension)": "https://github.com/open-policy-agent/opa/issues/5311"
Expand Down
34 changes: 20 additions & 14 deletions test/cases/testdata/defaultkeyword/test-default-functions.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# NOTE(sr): default functions are not supported, but they had sneaked into
# the full extent of a package. These cases assert that this won't happen.
cases:
- note: defaultkeyword/function with var arg
modules:
Expand All @@ -10,24 +8,32 @@ cases:
query: data.test = x
want_result:
- x: {}
- note: defaultkeyword/function with ground arg
- note: defaultkeyword/function with var arg, ref head
modules:
- |
package test
default f(10) := 100
query: data.test = x
want_result:
- x: {}
- note: defaultkeyword/function with ground arg, ref head
modules:
- |
package test
default p.q.r.f(10) := 100
default p.q.r.f(x) := 100
query: data.test = x
want_result:
- x:
p:
q:
r: {}
r: {}
- note: defaultkeyword/function with var arg, ref head query
modules:
- |
package test
default p.q.r.f(x) := 100
p.q.r.f(x) = x {
x == 2
}
foo {
p.q.r.f(3) == 100
}
query: data.test.foo = x
want_result:
- x: true
112 changes: 95 additions & 17 deletions test/cases/testdata/functions/test-functions-default.yaml
Original file line number Diff line number Diff line change
@@ -1,21 +1,99 @@
cases:
- data:
modules:
- |
package p.m
default hello = false
hello() = m {
m = input.message
1 == 2
m = "world"
}
h = m {
m = hello()
}
note: functions/default # not supported but shouldn't panic
query: data.p.m = x
- |
package test
default f(x) = 1
f(x) = x {
x > 0
}
p {
f(-1) == 1
}
note: functions/default
query: data.test.p = x
want_result:
- x: true

- data:
modules:
- |
package test
default f(x) = 1
f(x) = x {
x > 0
}
p {
f(2) == 2
}
note: functions/non default
query: data.test.p = x
want_result:
- x: true

- data:
modules:
- |
package test
default f(x) = 1
p {
f(2) == 1
}
note: functions/only default
query: data.test.p = x
want_result:
- x: true

- data:
modules:
- |
package test
default f(_, _) = 1
f(x, y) = x {
x == y
}
p {
f(2, 2) == 2
}
note: functions/only default
query: data.test.p = x
want_result:
- x: true

- data:
modules:
- |
package test
default f(x) = 1000
f(x) = x {
x > 0
}
p = xs {
xs := [y | x = [1, -2, 3][_]; y := f(x)]
}
note: functions/comprehensions
query: data.test.p = x
want_result:
- x:
hello: false
- x:
- 1
- 1000
- 3
17 changes: 13 additions & 4 deletions topdown/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -1795,13 +1795,16 @@ type evalFunc struct {

func (e evalFunc) eval(iter unifyIterator) error {

// default functions aren't supported:
// https://github.com/open-policy-agent/opa/issues/2445
if len(e.ir.Rules) == 0 {
if e.ir.Empty() {
return nil
}

argCount := len(e.ir.Rules[0].Head.Args)
var argCount int
if len(e.ir.Rules) > 0 {
argCount = len(e.ir.Rules[0].Head.Args)
} else if e.ir.Default != nil {
argCount = len(e.ir.Default.Head.Args)
}

if len(e.ir.Else) > 0 && e.e.unknown(e.e.query[e.e.index], e.e.bindings) {
// Partial evaluation of ordered rules is not supported currently. Save the
Expand All @@ -1820,6 +1823,7 @@ func (e evalFunc) eval(iter unifyIterator) error {
return e.partialEvalSupport(argCount, iter)
}
}

return suppressEarlyExit(e.evalValue(iter, argCount, e.ir.EarlyExit))
}

Expand Down Expand Up @@ -1859,6 +1863,11 @@ func (e evalFunc) evalValue(iter unifyIterator, argCount int, findOne bool) erro
}
}

if e.ir.Default != nil && prev == nil {
_, err := e.evalOneRule(iter, e.ir.Default, cacheKey, prev, findOne)
return err
}

return nil
}

Expand Down

0 comments on commit a1ca32a

Please sign in to comment.