Skip to content

Commit

Permalink
topdown+wasm: Verifying host based on allow_net allowlist in built-in…
Browse files Browse the repository at this point in the history
… functions

Adding host allow-listing based on the allow_net capability in the http.send()- and
net.lookup_ip_addr() built-in functions when running the eval command.

Fixes: open-policy-agent#3665

Signed-off-by: Johan Fylling <johan.dev@fylling.se>
  • Loading branch information
johanfylling committed Dec 16, 2021
1 parent a829c21 commit de8875c
Show file tree
Hide file tree
Showing 16 changed files with 303 additions and 34 deletions.
24 changes: 13 additions & 11 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ type Compiler struct {
// with the key being the generated name and value being the original.
RewrittenVars map[Var]Var

// Capabilities is the user-supplied capabilities or features allowed for OPA.
Capabilities *Capabilities

localvargen *localVarGenerator
moduleLoader ModuleLoader
ruleIndices *util.HashMap
Expand All @@ -99,10 +102,9 @@ type Compiler struct {
pathExists func([]string) (bool, error)
after map[string][]CompilerStageDefinition
metrics metrics.Metrics
capabilities *Capabilities // user-supplied capabilities
builtins map[string]*Builtin // universe of built-in functions
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use Capabilities)
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use Capabilities)
enablePrintStatements bool // indicates if print statements should be elided (default)
comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index
initialized bool // indicates if init() has been called
Expand Down Expand Up @@ -325,14 +327,14 @@ func (c *Compiler) WithMetrics(metrics metrics.Metrics) *Compiler {
return c
}

// WithCapabilities sets capabilities to enable during compilation. Capabilities allow the caller
// to specify the set of built-in functions available to the policy. In the future, capabilities
// WithCapabilities sets Capabilities to enable during compilation. Capabilities allow the caller
// to specify the set of built-in functions available to the policy. In the future, Capabilities
// may be able to restrict access to other language features. Capabilities allow callers to check
// if policies are compatible with a particular version of OPA. If policies are a compiled for a
// specific version of OPA, there is no guarantee that _this_ version of OPA can evaluate them
// successfully.
func (c *Compiler) WithCapabilities(capabilities *Capabilities) *Compiler {
c.capabilities = capabilities
c.Capabilities = capabilities
return c
}

Expand Down Expand Up @@ -1203,13 +1205,13 @@ func (c *Compiler) init() {
return
}

if c.capabilities == nil {
c.capabilities = CapabilitiesForThisVersion()
if c.Capabilities == nil {
c.Capabilities = CapabilitiesForThisVersion()
}

c.builtins = make(map[string]*Builtin, len(c.capabilities.Builtins)+len(c.customBuiltins))
c.builtins = make(map[string]*Builtin, len(c.Capabilities.Builtins)+len(c.customBuiltins))

for _, bi := range c.capabilities.Builtins {
for _, bi := range c.Capabilities.Builtins {
c.builtins[bi.Name] = bi
}

Expand All @@ -1220,7 +1222,7 @@ func (c *Compiler) init() {
// Load the global input schema if one was provided.
if c.schemaSet != nil {
if schema := c.schemaSet.Get(SchemaRootRef); schema != nil {
tpe, err := loadSchema(schema, c.capabilities.AllowNet)
tpe, err := loadSchema(schema, c.Capabilities.AllowNet)
if err != nil {
c.err(NewError(TypeErr, nil, err.Error()))
} else {
Expand Down
1 change: 1 addition & 0 deletions features/wasm/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func (o *OPA) Eval(ctx context.Context, opts opa.EvalOpts) (*opa.Result, error)
Seed: opts.Seed,
InterQueryBuiltinCache: opts.InterQueryBuiltinCache,
PrintHook: opts.PrintHook,
Capabilities: opts.Capabilities,
}

res, err := o.opa.Eval(ctx, evalOptions)
Expand Down
2 changes: 2 additions & 0 deletions internal/rego/opa/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"io"
"time"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/metrics"
"github.com/open-policy-agent/opa/topdown/cache"
"github.com/open-policy-agent/opa/topdown/print"
Expand All @@ -23,4 +24,5 @@ type EvalOpts struct {
Seed io.Reader
InterQueryBuiltinCache cache.InterQueryCache
PrintHook print.Hook
Capabilities *ast.Capabilities
}
8 changes: 7 additions & 1 deletion internal/wasm/sdk/internal/wasm/bindings.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@ func (d *builtinDispatcher) SetMap(m map[int32]topdown.BuiltinFunc) {
}

// Reset is called in Eval before using the builtinDispatcher.
func (d *builtinDispatcher) Reset(ctx context.Context, seed io.Reader, ns time.Time, iqbCache cache.InterQueryCache, ph print.Hook) {
func (d *builtinDispatcher) Reset(ctx context.Context,
seed io.Reader,
ns time.Time,
iqbCache cache.InterQueryCache,
ph print.Hook,
capabilities *ast.Capabilities) {
if ns.IsZero() {
ns = time.Now()
}
Expand All @@ -103,6 +108,7 @@ func (d *builtinDispatcher) Reset(ctx context.Context, seed io.Reader, ns time.T
ParentID: 0,
InterQueryBuiltinCache: iqbCache,
PrintHook: ph,
Capabilities: capabilities,
}

}
Expand Down
2 changes: 1 addition & 1 deletion internal/wasm/sdk/internal/wasm/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func ensurePoolResults(t *testing.T, ctx context.Context, testPool *wasm.Pool, p
toRelease = append(toRelease, vm)

cfg, _ := cache.ParseCachingConfig(nil)
result, err := vm.Eval(ctx, 0, input, metrics.New(), rand.New(rand.NewSource(0)), time.Now(), cache.NewInterQueryCache(cfg), nil)
result, err := vm.Eval(ctx, 0, input, metrics.New(), rand.New(rand.NewSource(0)), time.Now(), cache.NewInterQueryCache(cfg), nil, nil)
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
Expand Down
12 changes: 7 additions & 5 deletions internal/wasm/sdk/internal/wasm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,10 @@ func (i *VM) Eval(ctx context.Context,
seed io.Reader,
ns time.Time,
iqbCache cache.InterQueryCache,
ph print.Hook) ([]byte, error) {
ph print.Hook,
capabilities *ast.Capabilities) ([]byte, error) {
if i.abiMinorVersion < int32(2) {
return i.evalCompat(ctx, entrypoint, input, metrics, seed, ns, iqbCache, ph)
return i.evalCompat(ctx, entrypoint, input, metrics, seed, ns, iqbCache, ph, capabilities)
}

metrics.Timer("wasm_vm_eval").Start()
Expand Down Expand Up @@ -328,7 +329,7 @@ func (i *VM) Eval(ctx context.Context,
// make use of it (e.g. `http.send`); and it will spawn a go routine
// cancelling the builtins that use topdown.Cancel, when the context is
// cancelled.
i.dispatcher.Reset(ctx, seed, ns, iqbCache, ph)
i.dispatcher.Reset(ctx, seed, ns, iqbCache, ph, capabilities)

metrics.Timer("wasm_vm_eval_call").Start()
resultAddr, err := i.evalOneOff(ctx, int32(entrypoint), i.dataAddr, inputAddr, inputLen, heapPtr)
Expand Down Expand Up @@ -357,7 +358,8 @@ func (i *VM) evalCompat(ctx context.Context,
seed io.Reader,
ns time.Time,
iqbCache cache.InterQueryCache,
ph print.Hook) ([]byte, error) {
ph print.Hook,
capabilities *ast.Capabilities) ([]byte, error) {
metrics.Timer("wasm_vm_eval").Start()
defer metrics.Timer("wasm_vm_eval").Stop()

Expand All @@ -367,7 +369,7 @@ func (i *VM) evalCompat(ctx context.Context,
// make use of it (e.g. `http.send`); and it will spawn a go routine
// cancelling the builtins that use topdown.Cancel, when the context is
// cancelled.
i.dispatcher.Reset(ctx, seed, ns, iqbCache, ph)
i.dispatcher.Reset(ctx, seed, ns, iqbCache, ph, capabilities)

err := i.setHeapState(ctx, i.evalHeapPtr)
if err != nil {
Expand Down
5 changes: 4 additions & 1 deletion internal/wasm/sdk/opa/opa.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"sync"
"time"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/internal/wasm/sdk/internal/wasm"
"github.com/open-policy-agent/opa/internal/wasm/sdk/opa/errors"
sdk_errors "github.com/open-policy-agent/opa/internal/wasm/sdk/opa/errors"
Expand Down Expand Up @@ -165,6 +166,7 @@ type EvalOpts struct {
Seed io.Reader
InterQueryBuiltinCache cache.InterQueryCache
PrintHook print.Hook
Capabilities *ast.Capabilities
}

// Eval evaluates the policy with the given input, returning the
Expand All @@ -188,7 +190,8 @@ func (o *OPA) Eval(ctx context.Context, opts EvalOpts) (*Result, error) {

defer o.pool.Release(instance, m)

result, err := instance.Eval(ctx, opts.Entrypoint, opts.Input, m, opts.Seed, opts.Time, opts.InterQueryBuiltinCache, opts.PrintHook)
result, err := instance.Eval(ctx, opts.Entrypoint, opts.Input, m, opts.Seed, opts.Time, opts.InterQueryBuiltinCache,
opts.PrintHook, opts.Capabilities)
if err != nil {
return nil, err
}
Expand Down
4 changes: 4 additions & 0 deletions rego/rego.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ type EvalContext struct {
resolvers []refResolver
sortSets bool
printHook print.Hook
capabilities *ast.Capabilities
}

// EvalOption defines a function to set an option on an EvalConfig
Expand Down Expand Up @@ -311,6 +312,7 @@ func (pq preparedQuery) newEvalContext(ctx context.Context, options []EvalOption
earlyExit: true,
resolvers: pq.r.resolvers,
printHook: pq.r.printHook,
capabilities: pq.r.capabilities,
}

for _, o := range options {
Expand Down Expand Up @@ -1969,6 +1971,7 @@ func (r *Rego) evalWasm(ctx context.Context, ectx *EvalContext) (ResultSet, erro
Seed: ectx.seed,
InterQueryBuiltinCache: ectx.interQueryBuiltinCache,
PrintHook: ectx.printHook,
Capabilities: ectx.capabilities,
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -2075,6 +2078,7 @@ func (r *Rego) partialResult(ctx context.Context, pCfg *PrepareConfig) (PartialR
instrumentation: r.instrumentation,
indexing: true,
resolvers: r.resolvers,
capabilities: r.capabilities,
}

disableInlining := r.disableInlining
Expand Down
52 changes: 52 additions & 0 deletions rego/rego_wasmtarget_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"math/rand"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -338,3 +340,53 @@ func TestEvalWasmWithInterQueryCache(t *testing.T) {
t.Fatal("Expected server to be called only once")
}
}

func TestEvalWasmWithHTTPAllowNet(t *testing.T) {
var requests []*http.Request
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requests = append(requests, r)

w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"x": 1}`))
}))
defer ts.Close()

serverUrl, err := url.Parse(ts.URL)
if err != nil {
panic(err)
}
serverHost := strings.Split(serverUrl.Host, ":")[0]

query := fmt.Sprintf(`http.send({"method": "get", "url": "%s", "force_json_decode": true, "cache": true})`, ts.URL)
capabilities := ast.CapabilitiesForThisVersion()
capabilities.AllowNet = []string{"example.com"}

// add an inter-query cache
config, _ := cache.ParseCachingConfig(nil)
interQueryCache := cache.NewInterQueryCache(config)

ctx := context.Background()
// StrictBuiltinErrors(true) has no effect when target is 'wasm'
// this request should be rejected by the allow_net allowlist
_, err = New(Target("wasm"), Query(query), InterQueryBuiltinCache(interQueryCache), Capabilities(capabilities)).Eval(ctx)
if err != nil {
t.Fatal(err)
}

if len(requests) != 0 {
t.Fatal("Expected server to not be called")
}

capabilities.AllowNet = []string{serverHost}

// eval again with same query
// this request should not be rejected by the allow_net allowlist
_, err = New(Target("wasm"), Query(query), InterQueryBuiltinCache(interQueryCache), Capabilities(capabilities)).Eval(ctx)
if err != nil {
t.Fatal(err)
}

if len(requests) != 1 {
t.Fatal("Expected server to never be called")
}
}
1 change: 1 addition & 0 deletions topdown/builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ type (
PrintHook print.Hook // provides callback function to use for printing
DistributedTracingOpts tracing.Options // options to be used by distributed tracing.
rand *rand.Rand // randomization source for non-security-sensitive operations
Capabilities *ast.Capabilities
}

// BuiltinFunc defines an interface for implementing built-in functions.
Expand Down
6 changes: 6 additions & 0 deletions topdown/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,11 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error {
parentID = e.parent.queryID
}

var capabilities *ast.Capabilities
if e.compiler != nil {
capabilities = e.compiler.Capabilities
}

bctx := BuiltinContext{
Context: e.ctx,
Metrics: e.metrics,
Expand All @@ -717,6 +722,7 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error {
ParentID: parentID,
PrintHook: e.printHook,
DistributedTracingOpts: e.tracingOpts,
Capabilities: capabilities,
}

eval := evalBuiltin{
Expand Down
Loading

0 comments on commit de8875c

Please sign in to comment.