Skip to content

Commit

Permalink
internal/wasm: Enable print calls
Browse files Browse the repository at this point in the history
This commit plumbs the print hook through to the wasm runtime so that
print calls are enabled when the wasm target is on. This commit also
updates the planner and compiler to support calls to void
functions--previously, the planner and wasm backend assumed that
functions returned values so they would perform checks for defined
values, however, with void functions, those checks must be suppressed.

Signed-off-by: Torin Sandall <torinsandall@gmail.com>
  • Loading branch information
tsandall committed Oct 14, 2021
1 parent 64522f9 commit a1f7e30
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 39 deletions.
1 change: 1 addition & 0 deletions features/wasm/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ func (o *OPA) Eval(ctx context.Context, opts opa.EvalOpts) (*opa.Result, error)
Time: opts.Time,
Seed: opts.Seed,
InterQueryBuiltinCache: opts.InterQueryBuiltinCache,
PrintHook: opts.PrintHook,
}

res, err := o.opa.Eval(ctx, evalOptions)
Expand Down
52 changes: 32 additions & 20 deletions internal/compiler/wasm/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/open-policy-agent/opa/internal/wasm/instruction"
"github.com/open-policy-agent/opa/internal/wasm/module"
"github.com/open-policy-agent/opa/internal/wasm/types"
opatypes "github.com/open-policy-agent/opa/types"
)

// Record Wasm ABI version in exported global variable
Expand Down Expand Up @@ -179,6 +180,11 @@ var builtinsUsingRE2 = [...]string{
builtinsFunctions[ast.GlobMatch.Name],
}

type externalFunc struct {
ID int32
Decl *opatypes.Function
}

var builtinDispatchers = [...]string{
"opa_builtin0",
"opa_builtin1",
Expand All @@ -198,17 +204,17 @@ type Compiler struct {

funcsCode []funcCode // compile functions' code

builtinStringAddrs map[int]uint32 // addresses of built-in string constants
externalFuncNameAddrs map[string]int32 // addresses of required built-in function names for listing
externalFuncs map[string]int32 // required built-in function ids
entrypointNameAddrs map[string]int32 // addresses of available entrypoint names for listing
entrypoints map[string]int32 // available entrypoint ids
stringOffset int32 // null-terminated string data base offset
stringAddrs []uint32 // null-terminated string constant addresses
opaStringAddrs []uint32 // addresses of interned opa_string_t
opaBoolAddrs map[ir.Bool]uint32 // addresses of interned opa_boolean_t
fileAddrs []uint32 // null-terminated string constant addresses, used for file names
funcs map[string]uint32 // maps imported and exported function names to function indices
builtinStringAddrs map[int]uint32 // addresses of built-in string constants
externalFuncNameAddrs map[string]int32 // addresses of required built-in function names for listing
externalFuncs map[string]externalFunc // required built-in function ids and types
entrypointNameAddrs map[string]int32 // addresses of available entrypoint names for listing
entrypoints map[string]int32 // available entrypoint ids
stringOffset int32 // null-terminated string data base offset
stringAddrs []uint32 // null-terminated string constant addresses
opaStringAddrs []uint32 // addresses of interned opa_string_t
opaBoolAddrs map[ir.Bool]uint32 // addresses of interned opa_boolean_t
fileAddrs []uint32 // null-terminated string constant addresses, used for file names
funcs map[string]uint32 // maps imported and exported function names to function indices

nextLocal uint32
locals map[ir.Local]uint32
Expand Down Expand Up @@ -567,7 +573,7 @@ func (c *Compiler) compileExternalFuncDecls() error {

c.appendInstr(instruction.Call{Index: c.function(opaObject)})
c.appendInstr(instruction.SetLocal{Index: lobj})
c.externalFuncs = make(map[string]int32)
c.externalFuncs = make(map[string]externalFunc)

for index, decl := range c.policy.Static.BuiltinFuncs {
if _, ok := builtinsFunctions[decl.Name]; !ok {
Expand All @@ -577,7 +583,7 @@ func (c *Compiler) compileExternalFuncDecls() error {
c.appendInstr(instruction.I64Const{Value: int64(index)})
c.appendInstr(instruction.Call{Index: c.function(opaNumberInt)})
c.appendInstr(instruction.Call{Index: c.function(opaObjectInsert)})
c.externalFuncs[decl.Name] = int32(index)
c.externalFuncs[decl.Name] = externalFunc{ID: int32(index), Decl: decl.Decl}
}
}

Expand Down Expand Up @@ -1428,8 +1434,8 @@ func (c *Compiler) compileCallStmt(stmt *ir.CallStmt, result *[]instruction.Inst
return c.compileInternalCall(stmt, index, result)
}

if id, ok := c.externalFuncs[fn]; ok {
return c.compileExternalCall(stmt, id, result)
if ef, ok := c.externalFuncs[fn]; ok {
return c.compileExternalCall(stmt, ef, result)
}

c.errors = append(c.errors, fmt.Errorf("undefined function: %q", fn))
Expand Down Expand Up @@ -1457,25 +1463,31 @@ func (c *Compiler) compileInternalCall(stmt *ir.CallStmt, index uint32, result *
return nil
}

func (c *Compiler) compileExternalCall(stmt *ir.CallStmt, id int32, result *[]instruction.Instruction) error {
func (c *Compiler) compileExternalCall(stmt *ir.CallStmt, ef externalFunc, result *[]instruction.Instruction) error {

if len(stmt.Args) >= len(builtinDispatchers) {
c.errors = append(c.errors, fmt.Errorf("too many built-in call arguments: %q", stmt.Func))
return nil
}

instrs := *result
instrs = append(instrs, instruction.I32Const{Value: id})
instrs = append(instrs, instruction.I32Const{Value: ef.ID})
instrs = append(instrs, instruction.I32Const{Value: 0}) // unused context parameter

for _, arg := range stmt.Args {
instrs = append(instrs, c.instrRead(arg))
}

instrs = append(instrs, instruction.Call{Index: c.function(builtinDispatchers[len(stmt.Args)])})
instrs = append(instrs, instruction.TeeLocal{Index: c.local(stmt.Result)})
instrs = append(instrs, instruction.I32Eqz{})
instrs = append(instrs, instruction.BrIf{Index: 0})

if ef.Decl.Result() != nil {
instrs = append(instrs, instruction.TeeLocal{Index: c.local(stmt.Result)})
instrs = append(instrs, instruction.I32Eqz{})
instrs = append(instrs, instruction.BrIf{Index: 0})
} else {
instrs = append(instrs, instruction.Drop{})
}

*result = instrs
return nil
}
Expand Down
3 changes: 3 additions & 0 deletions internal/ir/ir.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ package ir

import (
"fmt"

"github.com/open-policy-agent/opa/types"
)

type (
Expand All @@ -31,6 +33,7 @@ type (
// policy.
BuiltinFunc struct {
Name string
Decl *types.Function
}

// Plans represents a collection of named query plans to expose in the policy.
Expand Down
26 changes: 15 additions & 11 deletions internal/planner/planner.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type Planner struct {
modules []*ast.Module // input modules to support queries
strings map[string]int // global string constant indices
files map[string]int // global file constant indices
externs map[string]struct{} // built-in functions that are required in execution environment
externs map[string]*ast.Builtin // built-in functions that are required in execution environment
decls map[string]*ast.Builtin // built-in functions that may be provided in execution environment
rules *ruletrie // rules that may be planned
funcs *funcstack // functions that have been planned
Expand Down Expand Up @@ -68,7 +68,7 @@ func New() *Planner {
},
strings: map[string]int{},
files: map[string]int{},
externs: map[string]struct{}{},
externs: map[string]*ast.Builtin{},
lnext: ir.Unused,
vars: newVarstack(map[ast.Var]ir.Local{
ast.InputRootDocument.Value.(ast.Var): ir.Input,
Expand Down Expand Up @@ -725,6 +725,7 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error {
var relation bool
var name string
var arity int
var void bool
var args []ir.LocalOrConst

node := p.rules.Lookup(e.Operator())
Expand All @@ -743,8 +744,9 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error {
} else if decl, ok := p.decls[operator]; ok {
relation = decl.Relation
arity = len(decl.Decl.Args())
void = decl.Decl.Result() == nil
name = operator
p.externs[operator] = struct{}{}
p.externs[operator] = decl
} else {
return fmt.Errorf("illegal call: unknown operator %q", operator)
}
Expand All @@ -759,7 +761,7 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error {
return p.planExprCallRelation(name, arity, operands, args, iter)
}

return p.planExprCallFunc(name, arity, operands, args, iter)
return p.planExprCallFunc(name, arity, void, operands, args, iter)
}
}

Expand Down Expand Up @@ -814,7 +816,7 @@ func (p *Planner) planExprCallRelation(name string, arity int, operands []*ast.T
})
}

func (p *Planner) planExprCallFunc(name string, arity int, operands []*ast.Term, args []ir.LocalOrConst, iter planiter) error {
func (p *Planner) planExprCallFunc(name string, arity int, void bool, operands []*ast.Term, args []ir.LocalOrConst, iter planiter) error {

if len(operands) == arity {
// definition: f(x) = y { ... }
Expand All @@ -828,10 +830,12 @@ func (p *Planner) planExprCallFunc(name string, arity int, operands []*ast.Term,
Result: ltarget,
})

p.appendStmt(&ir.NotEqualStmt{
A: ltarget,
B: ir.Bool(false),
})
if !void {
p.appendStmt(&ir.NotEqualStmt{
A: ltarget,
B: ir.Bool(false),
})
}

return iter()
})
Expand Down Expand Up @@ -1890,8 +1894,8 @@ func (p *Planner) planExterns() error {

p.policy.Static.BuiltinFuncs = make([]*ir.BuiltinFunc, 0, len(p.externs))

for name := range p.externs {
p.policy.Static.BuiltinFuncs = append(p.policy.Static.BuiltinFuncs, &ir.BuiltinFunc{Name: name})
for name, decl := range p.externs {
p.policy.Static.BuiltinFuncs = append(p.policy.Static.BuiltinFuncs, &ir.BuiltinFunc{Name: name, Decl: decl.Decl})
}

sort.Slice(p.policy.Static.BuiltinFuncs, func(i, j int) bool {
Expand Down
2 changes: 2 additions & 0 deletions internal/rego/opa/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/open-policy-agent/opa/metrics"
"github.com/open-policy-agent/opa/topdown/cache"
"github.com/open-policy-agent/opa/topdown/print"
)

// Result holds the evaluation result.
Expand All @@ -21,4 +22,5 @@ type EvalOpts struct {
Time time.Time
Seed io.Reader
InterQueryBuiltinCache cache.InterQueryCache
PrintHook print.Hook
}
4 changes: 3 additions & 1 deletion internal/wasm/sdk/internal/wasm/bindings.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/open-policy-agent/opa/topdown"
"github.com/open-policy-agent/opa/topdown/builtins"
"github.com/open-policy-agent/opa/topdown/cache"
"github.com/open-policy-agent/opa/topdown/print"
)

func opaFunctions(dispatcher *builtinDispatcher, store *wasmtime.Store) map[string]wasmtime.AsExtern {
Expand Down Expand Up @@ -80,7 +81,7 @@ func (d *builtinDispatcher) SetMap(m map[int32]topdown.BuiltinFunc) {
}

// Reset is called in Eval before using the builtinDispatcher.
func (d *builtinDispatcher) Reset(ctx context.Context, seed io.Reader, ns time.Time, iqbCache cache.InterQueryCache) {
func (d *builtinDispatcher) Reset(ctx context.Context, seed io.Reader, ns time.Time, iqbCache cache.InterQueryCache, ph print.Hook) {
if ns.IsZero() {
ns = time.Now()
}
Expand All @@ -101,6 +102,7 @@ func (d *builtinDispatcher) Reset(ctx context.Context, seed io.Reader, ns time.T
QueryID: 0,
ParentID: 0,
InterQueryBuiltinCache: iqbCache,
PrintHook: ph,
}

}
Expand Down
2 changes: 1 addition & 1 deletion internal/wasm/sdk/internal/wasm/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func ensurePoolResults(t *testing.T, ctx context.Context, testPool *wasm.Pool, p
toRelease = append(toRelease, vm)

cfg, _ := cache.ParseCachingConfig(nil)
result, err := vm.Eval(ctx, 0, nil, metrics.New(), rand.New(rand.NewSource(0)), time.Now(), cache.NewInterQueryCache(cfg))
result, err := vm.Eval(ctx, 0, nil, metrics.New(), rand.New(rand.NewSource(0)), time.Now(), cache.NewInterQueryCache(cfg), nil)
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
Expand Down
13 changes: 8 additions & 5 deletions internal/wasm/sdk/internal/wasm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/open-policy-agent/opa/metrics"
"github.com/open-policy-agent/opa/topdown"
"github.com/open-policy-agent/opa/topdown/cache"
"github.com/open-policy-agent/opa/topdown/print"
)

// VM is a wrapper around a Wasm VM instance
Expand Down Expand Up @@ -269,9 +270,10 @@ func (i *VM) Eval(ctx context.Context,
metrics metrics.Metrics,
seed io.Reader,
ns time.Time,
iqbCache cache.InterQueryCache) ([]byte, error) {
iqbCache cache.InterQueryCache,
ph print.Hook) ([]byte, error) {
if i.abiMinorVersion < int32(2) {
return i.evalCompat(ctx, entrypoint, input, metrics, seed, ns, iqbCache)
return i.evalCompat(ctx, entrypoint, input, metrics, seed, ns, iqbCache, ph)
}

metrics.Timer("wasm_vm_eval").Start()
Expand Down Expand Up @@ -313,7 +315,7 @@ func (i *VM) Eval(ctx context.Context,
// make use of it (e.g. `http.send`); and it will spawn a go routine
// cancelling the builtins that use topdown.Cancel, when the context is
// cancelled.
i.dispatcher.Reset(ctx, seed, ns, iqbCache)
i.dispatcher.Reset(ctx, seed, ns, iqbCache, ph)

metrics.Timer("wasm_vm_eval_call").Start()
resultAddr, err := i.evalOneOff(ctx, int32(entrypoint), i.dataAddr, inputAddr, inputLen, heapPtr)
Expand Down Expand Up @@ -341,7 +343,8 @@ func (i *VM) evalCompat(ctx context.Context,
metrics metrics.Metrics,
seed io.Reader,
ns time.Time,
iqbCache cache.InterQueryCache) ([]byte, error) {
iqbCache cache.InterQueryCache,
ph print.Hook) ([]byte, error) {
metrics.Timer("wasm_vm_eval").Start()
defer metrics.Timer("wasm_vm_eval").Stop()

Expand All @@ -351,7 +354,7 @@ func (i *VM) evalCompat(ctx context.Context,
// make use of it (e.g. `http.send`); and it will spawn a go routine
// cancelling the builtins that use topdown.Cancel, when the context is
// cancelled.
i.dispatcher.Reset(ctx, seed, ns, iqbCache)
i.dispatcher.Reset(ctx, seed, ns, iqbCache, ph)

err := i.setHeapState(ctx, i.evalHeapPtr)
if err != nil {
Expand Down
4 changes: 3 additions & 1 deletion internal/wasm/sdk/opa/opa.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
sdk_errors "github.com/open-policy-agent/opa/internal/wasm/sdk/opa/errors"
"github.com/open-policy-agent/opa/metrics"
"github.com/open-policy-agent/opa/topdown/cache"
"github.com/open-policy-agent/opa/topdown/print"
)

var errNotReady = errors.New(errors.NotReadyErr, "")
Expand Down Expand Up @@ -163,6 +164,7 @@ type EvalOpts struct {
Time time.Time
Seed io.Reader
InterQueryBuiltinCache cache.InterQueryCache
PrintHook print.Hook
}

// Eval evaluates the policy with the given input, returning the
Expand All @@ -186,7 +188,7 @@ func (o *OPA) Eval(ctx context.Context, opts EvalOpts) (*Result, error) {

defer o.pool.Release(instance, m)

result, err := instance.Eval(ctx, opts.Entrypoint, opts.Input, m, opts.Seed, opts.Time, opts.InterQueryBuiltinCache)
result, err := instance.Eval(ctx, opts.Entrypoint, opts.Input, m, opts.Seed, opts.Time, opts.InterQueryBuiltinCache, opts.PrintHook)
if err != nil {
return nil, err
}
Expand Down
1 change: 1 addition & 0 deletions rego/rego.go
Original file line number Diff line number Diff line change
Expand Up @@ -1920,6 +1920,7 @@ func (r *Rego) evalWasm(ctx context.Context, ectx *EvalContext) (ResultSet, erro
Time: ectx.time,
Seed: ectx.seed,
InterQueryBuiltinCache: ectx.interQueryBuiltinCache,
PrintHook: ectx.printHook,
})
if err != nil {
return nil, err
Expand Down

0 comments on commit a1f7e30

Please sign in to comment.