diff --git a/test/cases/testdata/withkeyword/test-with-and-ndbcache-issue.yaml b/test/cases/testdata/withkeyword/test-with-and-ndbcache-issue.yaml new file mode 100644 index 0000000000..aaa125e0e9 --- /dev/null +++ b/test/cases/testdata/withkeyword/test-with-and-ndbcache-issue.yaml @@ -0,0 +1,16 @@ +cases: +- modules: + - | + package rules + + p { + time.now_ns(now) + } + + q { p with data.x as 7 } + note: "with: ndb_cache-issue" + query: data.rules = x + want_result: + - x: + p: true + q: true \ No newline at end of file diff --git a/topdown/eval.go b/topdown/eval.go index fc501899ee..56bbb2c7d4 100644 --- a/topdown/eval.go +++ b/topdown/eval.go @@ -1696,7 +1696,7 @@ func (e evalBuiltin) eval(iter unifyIterator) error { operands := make([]*ast.Term, len(e.terms)) - for i := 0; i < len(e.terms); i++ { + for i := range e.terms { operands[i] = e.e.bindings.Plug(e.terms[i]) } @@ -1723,25 +1723,19 @@ func (e evalBuiltin) eval(iter unifyIterator) error { if v, ok := e.bctx.NDBuiltinCache.Get(e.bi.Name, ast.NewArray(operands[:endIndex]...)); ok { switch { case e.bi.Decl.Result() == nil: - err = iter() + return iter() case len(operands) == numDeclArgs: - if v.Compare(ast.Boolean(false)) != 0 { - err = iter() - } // else: nothing to do, don't iter() + if v.Compare(ast.Boolean(false)) == 0 { + return nil // nothing to do + } + return iter() default: - err = e.e.unify(e.terms[endIndex], ast.NewTerm(v), iter) - } - - if err != nil { - return Halt{Err: err} + return e.e.unify(e.terms[endIndex], ast.NewTerm(v), iter) } - - return nil } - e.e.instr.startTimer(evalOpBuiltinCall) - // Otherwise, we'll need to go through the normal unify flow. + e.e.instr.startTimer(evalOpBuiltinCall) } // Normal unification flow for builtins: @@ -1770,6 +1764,9 @@ func (e evalBuiltin) eval(iter unifyIterator) error { } if err != nil { + // NOTE(sr): We wrap the errors here into Halt{} because we don't want to + // record them into builtinErrors below. The errors set here are coming from + // the call to iter(), not from the builtin implementation. err = Halt{Err: err} } diff --git a/topdown/exported_test.go b/topdown/exported_test.go index eaf3c66a77..675132d423 100644 --- a/topdown/exported_test.go +++ b/topdown/exported_test.go @@ -16,6 +16,7 @@ import ( "github.com/open-policy-agent/opa/storage" inmem "github.com/open-policy-agent/opa/storage/inmem/test" "github.com/open-policy-agent/opa/test/cases" + "github.com/open-policy-agent/opa/topdown/builtins" ) func TestRego(t *testing.T) { @@ -34,7 +35,19 @@ func TestOPARego(t *testing.T) { } } -func testRun(t *testing.T, tc cases.TestCase) { +func TestRegoWithNDBCache(t *testing.T) { + for _, tc := range cases.MustLoad("../test/cases/testdata").Sorted().Cases { + t.Run(tc.Note, func(t *testing.T) { + testRun(t, tc, func(q *Query) *Query { + return q.WithNDBuiltinCache(builtins.NDBCache{}) + }) + }) + } +} + +type opt func(*Query) *Query + +func testRun(t *testing.T, tc cases.TestCase, opts ...opt) { ctx := context.Background() @@ -69,14 +82,19 @@ func testRun(t *testing.T, tc cases.TestCase) { } buf := NewBufferTracer() - rs, err := NewQuery(query). + q := NewQuery(query). WithCompiler(compiler). WithStore(store). WithTransaction(txn). WithInput(input). WithStrictBuiltinErrors(tc.StrictError). - WithTracer(buf). - Run(ctx) + WithTracer(buf) + + for _, o := range opts { + q = o(q) + } + + rs, err := q.Run(ctx) if tc.WantError != nil { testAssertErrorText(t, *tc.WantError, err)