diff --git a/features/wasm/wasm.go b/features/wasm/wasm.go index 009768ed28..8e8e52ea2d 100644 --- a/features/wasm/wasm.go +++ b/features/wasm/wasm.go @@ -64,6 +64,7 @@ func (o *OPA) Eval(ctx context.Context, opts opa.EvalOpts) (*opa.Result, error) Time: opts.Time, Seed: opts.Seed, InterQueryBuiltinCache: opts.InterQueryBuiltinCache, + PrintHook: opts.PrintHook, } res, err := o.opa.Eval(ctx, evalOptions) diff --git a/internal/compiler/wasm/wasm.go b/internal/compiler/wasm/wasm.go index 2e5c064c55..e415f5f3fd 100644 --- a/internal/compiler/wasm/wasm.go +++ b/internal/compiler/wasm/wasm.go @@ -21,6 +21,7 @@ import ( "github.com/open-policy-agent/opa/internal/wasm/instruction" "github.com/open-policy-agent/opa/internal/wasm/module" "github.com/open-policy-agent/opa/internal/wasm/types" + opatypes "github.com/open-policy-agent/opa/types" ) // Record Wasm ABI version in exported global variable @@ -179,6 +180,11 @@ var builtinsUsingRE2 = [...]string{ builtinsFunctions[ast.GlobMatch.Name], } +type externalFunc struct { + ID int32 + Decl *opatypes.Function +} + var builtinDispatchers = [...]string{ "opa_builtin0", "opa_builtin1", @@ -198,17 +204,17 @@ type Compiler struct { funcsCode []funcCode // compile functions' code - builtinStringAddrs map[int]uint32 // addresses of built-in string constants - externalFuncNameAddrs map[string]int32 // addresses of required built-in function names for listing - externalFuncs map[string]int32 // required built-in function ids - entrypointNameAddrs map[string]int32 // addresses of available entrypoint names for listing - entrypoints map[string]int32 // available entrypoint ids - stringOffset int32 // null-terminated string data base offset - stringAddrs []uint32 // null-terminated string constant addresses - opaStringAddrs []uint32 // addresses of interned opa_string_t - opaBoolAddrs map[ir.Bool]uint32 // addresses of interned opa_boolean_t - fileAddrs []uint32 // null-terminated string constant addresses, used for file names - funcs map[string]uint32 // maps imported and exported function names to function indices + builtinStringAddrs map[int]uint32 // addresses of built-in string constants + externalFuncNameAddrs map[string]int32 // addresses of required built-in function names for listing + externalFuncs map[string]externalFunc // required built-in function ids and types + entrypointNameAddrs map[string]int32 // addresses of available entrypoint names for listing + entrypoints map[string]int32 // available entrypoint ids + stringOffset int32 // null-terminated string data base offset + stringAddrs []uint32 // null-terminated string constant addresses + opaStringAddrs []uint32 // addresses of interned opa_string_t + opaBoolAddrs map[ir.Bool]uint32 // addresses of interned opa_boolean_t + fileAddrs []uint32 // null-terminated string constant addresses, used for file names + funcs map[string]uint32 // maps imported and exported function names to function indices nextLocal uint32 locals map[ir.Local]uint32 @@ -567,7 +573,7 @@ func (c *Compiler) compileExternalFuncDecls() error { c.appendInstr(instruction.Call{Index: c.function(opaObject)}) c.appendInstr(instruction.SetLocal{Index: lobj}) - c.externalFuncs = make(map[string]int32) + c.externalFuncs = make(map[string]externalFunc) for index, decl := range c.policy.Static.BuiltinFuncs { if _, ok := builtinsFunctions[decl.Name]; !ok { @@ -577,7 +583,7 @@ func (c *Compiler) compileExternalFuncDecls() error { c.appendInstr(instruction.I64Const{Value: int64(index)}) c.appendInstr(instruction.Call{Index: c.function(opaNumberInt)}) c.appendInstr(instruction.Call{Index: c.function(opaObjectInsert)}) - c.externalFuncs[decl.Name] = int32(index) + c.externalFuncs[decl.Name] = externalFunc{ID: int32(index), Decl: decl.Decl} } } @@ -1428,8 +1434,8 @@ func (c *Compiler) compileCallStmt(stmt *ir.CallStmt, result *[]instruction.Inst return c.compileInternalCall(stmt, index, result) } - if id, ok := c.externalFuncs[fn]; ok { - return c.compileExternalCall(stmt, id, result) + if ef, ok := c.externalFuncs[fn]; ok { + return c.compileExternalCall(stmt, ef, result) } c.errors = append(c.errors, fmt.Errorf("undefined function: %q", fn)) @@ -1457,7 +1463,7 @@ func (c *Compiler) compileInternalCall(stmt *ir.CallStmt, index uint32, result * return nil } -func (c *Compiler) compileExternalCall(stmt *ir.CallStmt, id int32, result *[]instruction.Instruction) error { +func (c *Compiler) compileExternalCall(stmt *ir.CallStmt, ef externalFunc, result *[]instruction.Instruction) error { if len(stmt.Args) >= len(builtinDispatchers) { c.errors = append(c.errors, fmt.Errorf("too many built-in call arguments: %q", stmt.Func)) @@ -1465,7 +1471,7 @@ func (c *Compiler) compileExternalCall(stmt *ir.CallStmt, id int32, result *[]in } instrs := *result - instrs = append(instrs, instruction.I32Const{Value: id}) + instrs = append(instrs, instruction.I32Const{Value: ef.ID}) instrs = append(instrs, instruction.I32Const{Value: 0}) // unused context parameter for _, arg := range stmt.Args { @@ -1473,9 +1479,15 @@ func (c *Compiler) compileExternalCall(stmt *ir.CallStmt, id int32, result *[]in } instrs = append(instrs, instruction.Call{Index: c.function(builtinDispatchers[len(stmt.Args)])}) - instrs = append(instrs, instruction.TeeLocal{Index: c.local(stmt.Result)}) - instrs = append(instrs, instruction.I32Eqz{}) - instrs = append(instrs, instruction.BrIf{Index: 0}) + + if ef.Decl.Result() != nil { + instrs = append(instrs, instruction.TeeLocal{Index: c.local(stmt.Result)}) + instrs = append(instrs, instruction.I32Eqz{}) + instrs = append(instrs, instruction.BrIf{Index: 0}) + } else { + instrs = append(instrs, instruction.Drop{}) + } + *result = instrs return nil } diff --git a/internal/ir/ir.go b/internal/ir/ir.go index 6a26af9fb7..decd8f2033 100644 --- a/internal/ir/ir.go +++ b/internal/ir/ir.go @@ -10,6 +10,8 @@ package ir import ( "fmt" + + "github.com/open-policy-agent/opa/types" ) type ( @@ -31,6 +33,7 @@ type ( // policy. BuiltinFunc struct { Name string + Decl *types.Function } // Plans represents a collection of named query plans to expose in the policy. diff --git a/internal/planner/planner.go b/internal/planner/planner.go index 47a54bd063..60810698eb 100644 --- a/internal/planner/planner.go +++ b/internal/planner/planner.go @@ -33,7 +33,7 @@ type Planner struct { modules []*ast.Module // input modules to support queries strings map[string]int // global string constant indices files map[string]int // global file constant indices - externs map[string]struct{} // built-in functions that are required in execution environment + externs map[string]*ast.Builtin // built-in functions that are required in execution environment decls map[string]*ast.Builtin // built-in functions that may be provided in execution environment rules *ruletrie // rules that may be planned funcs *funcstack // functions that have been planned @@ -68,7 +68,7 @@ func New() *Planner { }, strings: map[string]int{}, files: map[string]int{}, - externs: map[string]struct{}{}, + externs: map[string]*ast.Builtin{}, lnext: ir.Unused, vars: newVarstack(map[ast.Var]ir.Local{ ast.InputRootDocument.Value.(ast.Var): ir.Input, @@ -725,6 +725,7 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error { var relation bool var name string var arity int + var void bool var args []ir.LocalOrConst node := p.rules.Lookup(e.Operator()) @@ -743,8 +744,9 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error { } else if decl, ok := p.decls[operator]; ok { relation = decl.Relation arity = len(decl.Decl.Args()) + void = decl.Decl.Result() == nil name = operator - p.externs[operator] = struct{}{} + p.externs[operator] = decl } else { return fmt.Errorf("illegal call: unknown operator %q", operator) } @@ -759,7 +761,7 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error { return p.planExprCallRelation(name, arity, operands, args, iter) } - return p.planExprCallFunc(name, arity, operands, args, iter) + return p.planExprCallFunc(name, arity, void, operands, args, iter) } } @@ -814,7 +816,7 @@ func (p *Planner) planExprCallRelation(name string, arity int, operands []*ast.T }) } -func (p *Planner) planExprCallFunc(name string, arity int, operands []*ast.Term, args []ir.LocalOrConst, iter planiter) error { +func (p *Planner) planExprCallFunc(name string, arity int, void bool, operands []*ast.Term, args []ir.LocalOrConst, iter planiter) error { if len(operands) == arity { // definition: f(x) = y { ... } @@ -828,10 +830,12 @@ func (p *Planner) planExprCallFunc(name string, arity int, operands []*ast.Term, Result: ltarget, }) - p.appendStmt(&ir.NotEqualStmt{ - A: ltarget, - B: ir.Bool(false), - }) + if !void { + p.appendStmt(&ir.NotEqualStmt{ + A: ltarget, + B: ir.Bool(false), + }) + } return iter() }) @@ -1890,8 +1894,8 @@ func (p *Planner) planExterns() error { p.policy.Static.BuiltinFuncs = make([]*ir.BuiltinFunc, 0, len(p.externs)) - for name := range p.externs { - p.policy.Static.BuiltinFuncs = append(p.policy.Static.BuiltinFuncs, &ir.BuiltinFunc{Name: name}) + for name, decl := range p.externs { + p.policy.Static.BuiltinFuncs = append(p.policy.Static.BuiltinFuncs, &ir.BuiltinFunc{Name: name, Decl: decl.Decl}) } sort.Slice(p.policy.Static.BuiltinFuncs, func(i, j int) bool { diff --git a/internal/rego/opa/options.go b/internal/rego/opa/options.go index e74a45fec3..ec6b2f0b5a 100644 --- a/internal/rego/opa/options.go +++ b/internal/rego/opa/options.go @@ -6,6 +6,7 @@ import ( "github.com/open-policy-agent/opa/metrics" "github.com/open-policy-agent/opa/topdown/cache" + "github.com/open-policy-agent/opa/topdown/print" ) // Result holds the evaluation result. @@ -21,4 +22,5 @@ type EvalOpts struct { Time time.Time Seed io.Reader InterQueryBuiltinCache cache.InterQueryCache + PrintHook print.Hook } diff --git a/internal/wasm/sdk/internal/wasm/bindings.go b/internal/wasm/sdk/internal/wasm/bindings.go index 8ed98917a1..6c27e330a4 100644 --- a/internal/wasm/sdk/internal/wasm/bindings.go +++ b/internal/wasm/sdk/internal/wasm/bindings.go @@ -23,6 +23,7 @@ import ( "github.com/open-policy-agent/opa/topdown" "github.com/open-policy-agent/opa/topdown/builtins" "github.com/open-policy-agent/opa/topdown/cache" + "github.com/open-policy-agent/opa/topdown/print" ) func opaFunctions(dispatcher *builtinDispatcher, store *wasmtime.Store) map[string]wasmtime.AsExtern { @@ -80,7 +81,7 @@ func (d *builtinDispatcher) SetMap(m map[int32]topdown.BuiltinFunc) { } // Reset is called in Eval before using the builtinDispatcher. -func (d *builtinDispatcher) Reset(ctx context.Context, seed io.Reader, ns time.Time, iqbCache cache.InterQueryCache) { +func (d *builtinDispatcher) Reset(ctx context.Context, seed io.Reader, ns time.Time, iqbCache cache.InterQueryCache, ph print.Hook) { if ns.IsZero() { ns = time.Now() } @@ -101,6 +102,7 @@ func (d *builtinDispatcher) Reset(ctx context.Context, seed io.Reader, ns time.T QueryID: 0, ParentID: 0, InterQueryBuiltinCache: iqbCache, + PrintHook: ph, } } diff --git a/internal/wasm/sdk/internal/wasm/pool_test.go b/internal/wasm/sdk/internal/wasm/pool_test.go index 39a1783d85..24cb09265d 100644 --- a/internal/wasm/sdk/internal/wasm/pool_test.go +++ b/internal/wasm/sdk/internal/wasm/pool_test.go @@ -158,7 +158,7 @@ func ensurePoolResults(t *testing.T, ctx context.Context, testPool *wasm.Pool, p toRelease = append(toRelease, vm) cfg, _ := cache.ParseCachingConfig(nil) - result, err := vm.Eval(ctx, 0, nil, metrics.New(), rand.New(rand.NewSource(0)), time.Now(), cache.NewInterQueryCache(cfg)) + result, err := vm.Eval(ctx, 0, nil, metrics.New(), rand.New(rand.NewSource(0)), time.Now(), cache.NewInterQueryCache(cfg), nil) if err != nil { t.Fatalf("Unexpected error: %s", err) } diff --git a/internal/wasm/sdk/internal/wasm/vm.go b/internal/wasm/sdk/internal/wasm/vm.go index 0b2ad33c00..ae7ee6cd3a 100644 --- a/internal/wasm/sdk/internal/wasm/vm.go +++ b/internal/wasm/sdk/internal/wasm/vm.go @@ -21,6 +21,7 @@ import ( "github.com/open-policy-agent/opa/metrics" "github.com/open-policy-agent/opa/topdown" "github.com/open-policy-agent/opa/topdown/cache" + "github.com/open-policy-agent/opa/topdown/print" ) // VM is a wrapper around a Wasm VM instance @@ -269,9 +270,10 @@ func (i *VM) Eval(ctx context.Context, metrics metrics.Metrics, seed io.Reader, ns time.Time, - iqbCache cache.InterQueryCache) ([]byte, error) { + iqbCache cache.InterQueryCache, + ph print.Hook) ([]byte, error) { if i.abiMinorVersion < int32(2) { - return i.evalCompat(ctx, entrypoint, input, metrics, seed, ns, iqbCache) + return i.evalCompat(ctx, entrypoint, input, metrics, seed, ns, iqbCache, ph) } metrics.Timer("wasm_vm_eval").Start() @@ -313,7 +315,7 @@ func (i *VM) Eval(ctx context.Context, // make use of it (e.g. `http.send`); and it will spawn a go routine // cancelling the builtins that use topdown.Cancel, when the context is // cancelled. - i.dispatcher.Reset(ctx, seed, ns, iqbCache) + i.dispatcher.Reset(ctx, seed, ns, iqbCache, ph) metrics.Timer("wasm_vm_eval_call").Start() resultAddr, err := i.evalOneOff(ctx, int32(entrypoint), i.dataAddr, inputAddr, inputLen, heapPtr) @@ -341,7 +343,8 @@ func (i *VM) evalCompat(ctx context.Context, metrics metrics.Metrics, seed io.Reader, ns time.Time, - iqbCache cache.InterQueryCache) ([]byte, error) { + iqbCache cache.InterQueryCache, + ph print.Hook) ([]byte, error) { metrics.Timer("wasm_vm_eval").Start() defer metrics.Timer("wasm_vm_eval").Stop() @@ -351,7 +354,7 @@ func (i *VM) evalCompat(ctx context.Context, // make use of it (e.g. `http.send`); and it will spawn a go routine // cancelling the builtins that use topdown.Cancel, when the context is // cancelled. - i.dispatcher.Reset(ctx, seed, ns, iqbCache) + i.dispatcher.Reset(ctx, seed, ns, iqbCache, ph) err := i.setHeapState(ctx, i.evalHeapPtr) if err != nil { diff --git a/internal/wasm/sdk/opa/opa.go b/internal/wasm/sdk/opa/opa.go index e1a2d97383..95130cc715 100644 --- a/internal/wasm/sdk/opa/opa.go +++ b/internal/wasm/sdk/opa/opa.go @@ -17,6 +17,7 @@ import ( sdk_errors "github.com/open-policy-agent/opa/internal/wasm/sdk/opa/errors" "github.com/open-policy-agent/opa/metrics" "github.com/open-policy-agent/opa/topdown/cache" + "github.com/open-policy-agent/opa/topdown/print" ) var errNotReady = errors.New(errors.NotReadyErr, "") @@ -163,6 +164,7 @@ type EvalOpts struct { Time time.Time Seed io.Reader InterQueryBuiltinCache cache.InterQueryCache + PrintHook print.Hook } // Eval evaluates the policy with the given input, returning the @@ -186,7 +188,7 @@ func (o *OPA) Eval(ctx context.Context, opts EvalOpts) (*Result, error) { defer o.pool.Release(instance, m) - result, err := instance.Eval(ctx, opts.Entrypoint, opts.Input, m, opts.Seed, opts.Time, opts.InterQueryBuiltinCache) + result, err := instance.Eval(ctx, opts.Entrypoint, opts.Input, m, opts.Seed, opts.Time, opts.InterQueryBuiltinCache, opts.PrintHook) if err != nil { return nil, err } diff --git a/rego/rego.go b/rego/rego.go index 1e5ac8641f..a676fc73d7 100644 --- a/rego/rego.go +++ b/rego/rego.go @@ -1920,6 +1920,7 @@ func (r *Rego) evalWasm(ctx context.Context, ectx *EvalContext) (ResultSet, erro Time: ectx.time, Seed: ectx.seed, InterQueryBuiltinCache: ectx.interQueryBuiltinCache, + PrintHook: ectx.printHook, }) if err != nil { return nil, err