Skip to content

Commit

Permalink
topdown+wasm: Verifying host based on allow_net allowlist in built-in…
Browse files Browse the repository at this point in the history
… functions (#4152)

Adding host allow-listing based on the allow_net capability in the http.send()- and
net.lookup_ip_addr() built-in functions when running the eval command.

Fixes: #3665

Signed-off-by: Johan Fylling <johan.dev@fylling.se>
  • Loading branch information
johanfylling authored Dec 17, 2021
1 parent 9c9051b commit 7b64270
Show file tree
Hide file tree
Showing 17 changed files with 304 additions and 63 deletions.
5 changes: 5 additions & 0 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,11 @@ func (c *Compiler) WithCapabilities(capabilities *Capabilities) *Compiler {
return c
}

// Capabilities returns the capabilities enabled during compilation.
func (c *Compiler) Capabilities() *Capabilities {
return c.capabilities
}

// WithDebug sets where debug messages are written to. Passing `nil` has no
// effect.
func (c *Compiler) WithDebug(sink io.Writer) *Compiler {
Expand Down
2 changes: 2 additions & 0 deletions docs/content/deployments.md
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,8 @@ Not providing a capabilities file, or providing a file without an `allow_net` ke

Note that the metaschemas http://json-schema.org/draft-04/schema, http://json-schema.org/draft-06/schema, and http://json-schema.org/draft-07/schema, are always available, even without network access.

Similarly, the `allow_net` capability restricts what hosts the `http.send` built-in function may send requests to, and what hosts the `net.lookup_ip_addr` built-in function may resolve IP addresses for.

### Future keywords

The availability of future keywords in an OPA version can also be controlled using the capabilities file:
Expand Down
1 change: 1 addition & 0 deletions features/wasm/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func (o *OPA) Eval(ctx context.Context, opts opa.EvalOpts) (*opa.Result, error)
Seed: opts.Seed,
InterQueryBuiltinCache: opts.InterQueryBuiltinCache,
PrintHook: opts.PrintHook,
Capabilities: opts.Capabilities,
}

res, err := o.opa.Eval(ctx, evalOptions)
Expand Down
2 changes: 2 additions & 0 deletions internal/rego/opa/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"io"
"time"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/metrics"
"github.com/open-policy-agent/opa/topdown/cache"
"github.com/open-policy-agent/opa/topdown/print"
Expand All @@ -23,4 +24,5 @@ type EvalOpts struct {
Seed io.Reader
InterQueryBuiltinCache cache.InterQueryCache
PrintHook print.Hook
Capabilities *ast.Capabilities
}
8 changes: 7 additions & 1 deletion internal/wasm/sdk/internal/wasm/bindings.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@ func (d *builtinDispatcher) SetMap(m map[int32]topdown.BuiltinFunc) {
}

// Reset is called in Eval before using the builtinDispatcher.
func (d *builtinDispatcher) Reset(ctx context.Context, seed io.Reader, ns time.Time, iqbCache cache.InterQueryCache, ph print.Hook) {
func (d *builtinDispatcher) Reset(ctx context.Context,
seed io.Reader,
ns time.Time,
iqbCache cache.InterQueryCache,
ph print.Hook,
capabilities *ast.Capabilities) {
if ns.IsZero() {
ns = time.Now()
}
Expand All @@ -103,6 +108,7 @@ func (d *builtinDispatcher) Reset(ctx context.Context, seed io.Reader, ns time.T
ParentID: 0,
InterQueryBuiltinCache: iqbCache,
PrintHook: ph,
Capabilities: capabilities,
}

}
Expand Down
2 changes: 1 addition & 1 deletion internal/wasm/sdk/internal/wasm/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func ensurePoolResults(t *testing.T, ctx context.Context, testPool *wasm.Pool, p
toRelease = append(toRelease, vm)

cfg, _ := cache.ParseCachingConfig(nil)
result, err := vm.Eval(ctx, 0, input, metrics.New(), rand.New(rand.NewSource(0)), time.Now(), cache.NewInterQueryCache(cfg), nil)
result, err := vm.Eval(ctx, 0, input, metrics.New(), rand.New(rand.NewSource(0)), time.Now(), cache.NewInterQueryCache(cfg), nil, nil)
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
Expand Down
12 changes: 7 additions & 5 deletions internal/wasm/sdk/internal/wasm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,10 @@ func (i *VM) Eval(ctx context.Context,
seed io.Reader,
ns time.Time,
iqbCache cache.InterQueryCache,
ph print.Hook) ([]byte, error) {
ph print.Hook,
capabilities *ast.Capabilities) ([]byte, error) {
if i.abiMinorVersion < int32(2) {
return i.evalCompat(ctx, entrypoint, input, metrics, seed, ns, iqbCache, ph)
return i.evalCompat(ctx, entrypoint, input, metrics, seed, ns, iqbCache, ph, capabilities)
}

metrics.Timer("wasm_vm_eval").Start()
Expand Down Expand Up @@ -328,7 +329,7 @@ func (i *VM) Eval(ctx context.Context,
// make use of it (e.g. `http.send`); and it will spawn a go routine
// cancelling the builtins that use topdown.Cancel, when the context is
// cancelled.
i.dispatcher.Reset(ctx, seed, ns, iqbCache, ph)
i.dispatcher.Reset(ctx, seed, ns, iqbCache, ph, capabilities)

metrics.Timer("wasm_vm_eval_call").Start()
resultAddr, err := i.evalOneOff(ctx, int32(entrypoint), i.dataAddr, inputAddr, inputLen, heapPtr)
Expand Down Expand Up @@ -357,7 +358,8 @@ func (i *VM) evalCompat(ctx context.Context,
seed io.Reader,
ns time.Time,
iqbCache cache.InterQueryCache,
ph print.Hook) ([]byte, error) {
ph print.Hook,
capabilities *ast.Capabilities) ([]byte, error) {
metrics.Timer("wasm_vm_eval").Start()
defer metrics.Timer("wasm_vm_eval").Stop()

Expand All @@ -367,7 +369,7 @@ func (i *VM) evalCompat(ctx context.Context,
// make use of it (e.g. `http.send`); and it will spawn a go routine
// cancelling the builtins that use topdown.Cancel, when the context is
// cancelled.
i.dispatcher.Reset(ctx, seed, ns, iqbCache, ph)
i.dispatcher.Reset(ctx, seed, ns, iqbCache, ph, capabilities)

err := i.setHeapState(ctx, i.evalHeapPtr)
if err != nil {
Expand Down
5 changes: 4 additions & 1 deletion internal/wasm/sdk/opa/opa.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"sync"
"time"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/internal/wasm/sdk/internal/wasm"
"github.com/open-policy-agent/opa/internal/wasm/sdk/opa/errors"
sdk_errors "github.com/open-policy-agent/opa/internal/wasm/sdk/opa/errors"
Expand Down Expand Up @@ -165,6 +166,7 @@ type EvalOpts struct {
Seed io.Reader
InterQueryBuiltinCache cache.InterQueryCache
PrintHook print.Hook
Capabilities *ast.Capabilities
}

// Eval evaluates the policy with the given input, returning the
Expand All @@ -188,7 +190,8 @@ func (o *OPA) Eval(ctx context.Context, opts EvalOpts) (*Result, error) {

defer o.pool.Release(instance, m)

result, err := instance.Eval(ctx, opts.Entrypoint, opts.Input, m, opts.Seed, opts.Time, opts.InterQueryBuiltinCache, opts.PrintHook)
result, err := instance.Eval(ctx, opts.Entrypoint, opts.Input, m, opts.Seed, opts.Time, opts.InterQueryBuiltinCache,
opts.PrintHook, opts.Capabilities)
if err != nil {
return nil, err
}
Expand Down
4 changes: 4 additions & 0 deletions rego/rego.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ type EvalContext struct {
resolvers []refResolver
sortSets bool
printHook print.Hook
capabilities *ast.Capabilities
}

// EvalOption defines a function to set an option on an EvalConfig
Expand Down Expand Up @@ -311,6 +312,7 @@ func (pq preparedQuery) newEvalContext(ctx context.Context, options []EvalOption
earlyExit: true,
resolvers: pq.r.resolvers,
printHook: pq.r.printHook,
capabilities: pq.r.capabilities,
}

for _, o := range options {
Expand Down Expand Up @@ -1969,6 +1971,7 @@ func (r *Rego) evalWasm(ctx context.Context, ectx *EvalContext) (ResultSet, erro
Seed: ectx.seed,
InterQueryBuiltinCache: ectx.interQueryBuiltinCache,
PrintHook: ectx.printHook,
Capabilities: ectx.capabilities,
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -2075,6 +2078,7 @@ func (r *Rego) partialResult(ctx context.Context, pCfg *PrepareConfig) (PartialR
instrumentation: r.instrumentation,
indexing: true,
resolvers: r.resolvers,
capabilities: r.capabilities,
}

disableInlining := r.disableInlining
Expand Down
52 changes: 52 additions & 0 deletions rego/rego_wasmtarget_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"math/rand"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -338,3 +340,53 @@ func TestEvalWasmWithInterQueryCache(t *testing.T) {
t.Fatal("Expected server to be called only once")
}
}

func TestEvalWasmWithHTTPAllowNet(t *testing.T) {
var requests []*http.Request
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requests = append(requests, r)

w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"x": 1}`))
}))
defer ts.Close()

serverUrl, err := url.Parse(ts.URL)
if err != nil {
t.Fatal(err)
}
serverHost := strings.Split(serverUrl.Host, ":")[0]

query := fmt.Sprintf(`http.send({"method": "get", "url": "%s", "force_json_decode": true, "cache": true})`, ts.URL)
capabilities := ast.CapabilitiesForThisVersion()
capabilities.AllowNet = []string{"example.com"}

// add an inter-query cache
config, _ := cache.ParseCachingConfig(nil)
interQueryCache := cache.NewInterQueryCache(config)

ctx := context.Background()
// StrictBuiltinErrors(true) has no effect when target is 'wasm'
// this request should be rejected by the allow_net allowlist
_, err = New(Target("wasm"), Query(query), InterQueryBuiltinCache(interQueryCache), Capabilities(capabilities)).Eval(ctx)
if err != nil {
t.Fatal(err)
}

if len(requests) != 0 {
t.Fatal("Expected server to not be called")
}

capabilities.AllowNet = []string{serverHost}

// eval again with same query
// this request should not be rejected by the allow_net allowlist
_, err = New(Target("wasm"), Query(query), InterQueryBuiltinCache(interQueryCache), Capabilities(capabilities)).Eval(ctx)
if err != nil {
t.Fatal(err)
}

if len(requests) != 1 {
t.Fatal("Expected server to never be called")
}
}
1 change: 1 addition & 0 deletions topdown/builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ type (
PrintHook print.Hook // provides callback function to use for printing
DistributedTracingOpts tracing.Options // options to be used by distributed tracing.
rand *rand.Rand // randomization source for non-security-sensitive operations
Capabilities *ast.Capabilities
}

// BuiltinFunc defines an interface for implementing built-in functions.
Expand Down
6 changes: 6 additions & 0 deletions topdown/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,11 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error {
parentID = e.parent.queryID
}

var capabilities *ast.Capabilities
if e.compiler != nil {
capabilities = e.compiler.Capabilities()
}

bctx := BuiltinContext{
Context: e.ctx,
Metrics: e.metrics,
Expand All @@ -717,6 +722,7 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error {
ParentID: parentID,
PrintHook: e.printHook,
DistributedTracingOpts: e.tracingOpts,
Capabilities: capabilities,
}

eval := evalBuiltin{
Expand Down
60 changes: 47 additions & 13 deletions topdown/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,36 @@ func useSocket(rawURL string, tlsConfig *tls.Config) (bool, string, *http.Transp
return true, rawURL, tr
}

func verifyHost(bctx BuiltinContext, host string) error {
if bctx.Capabilities == nil || bctx.Capabilities.AllowNet == nil {
return nil
}

for _, allowed := range bctx.Capabilities.AllowNet {
if allowed == host {
return nil
}
}

return fmt.Errorf("unallowed host: %s", host)
}

func verifyURLHost(bctx BuiltinContext, unverifiedURL string) error {
// Eager return to avoid unnecessary URL parsing
if bctx.Capabilities == nil || bctx.Capabilities.AllowNet == nil {
return nil
}

parsedURL, err := url.Parse(unverifiedURL)
if err != nil {
return err
}

host := strings.Split(parsedURL.Host, ":")[0]

return verifyHost(bctx, host)
}

func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *http.Client, error) {
var url string
var method string
Expand Down Expand Up @@ -305,7 +335,7 @@ func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *htt
var strVal string

if s, ok := obj.Get(val).Value.(ast.String); ok {
strVal = string(s)
strVal = strings.Trim(string(s), "\"")
} else {
// Most parameters are strings, so consolidate the type checking.
switch key {
Expand All @@ -328,9 +358,13 @@ func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *htt

switch key {
case "method":
method = strings.ToUpper(strings.Trim(strVal, "\""))
method = strings.ToUpper(strVal)
case "url":
url = strings.Trim(strVal, "\"")
err := verifyURLHost(bctx, strVal)
if err != nil {
return nil, nil, err
}
url = strVal
case "enable_redirect":
enableRedirect, err = strconv.ParseBool(obj.Get(val).String())
if err != nil {
Expand All @@ -357,25 +391,25 @@ func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *htt
}
tlsUseSystemCerts = &tempTLSUseSystemCerts
case "tls_ca_cert":
tlsCaCert = bytes.Trim([]byte(strVal), "\"")
tlsCaCert = []byte(strVal)
case "tls_ca_cert_file":
tlsCaCertFile = strings.Trim(strVal, "\"")
tlsCaCertFile = strVal
case "tls_ca_cert_env_variable":
tlsCaCertEnvVar = strings.Trim(strVal, "\"")
tlsCaCertEnvVar = strVal
case "tls_client_cert":
tlsClientCert = bytes.Trim([]byte(strVal), "\"")
tlsClientCert = []byte(strVal)
case "tls_client_cert_file":
tlsClientCertFile = strings.Trim(strVal, "\"")
tlsClientCertFile = strVal
case "tls_client_cert_env_variable":
tlsClientCertEnvVar = strings.Trim(strVal, "\"")
tlsClientCertEnvVar = strVal
case "tls_client_key":
tlsClientKey = bytes.Trim([]byte(strVal), "\"")
tlsClientKey = []byte(strVal)
case "tls_client_key_file":
tlsClientKeyFile = strings.Trim(strVal, "\"")
tlsClientKeyFile = strVal
case "tls_client_key_env_variable":
tlsClientKeyEnvVar = strings.Trim(strVal, "\"")
tlsClientKeyEnvVar = strVal
case "tls_server_name":
tlsServerName = strings.Trim(strVal, "\"")
tlsServerName = strVal
case "headers":
headersVal := obj.Get(val).Value
headersValInterface, err := ast.JSON(headersVal)
Expand Down
Loading

0 comments on commit 7b64270

Please sign in to comment.