From de8875c6d941d5e25fab62703df18280763b7848 Mon Sep 17 00:00:00 2001 From: Johan Fylling Date: Tue, 14 Dec 2021 22:08:20 +0100 Subject: [PATCH] topdown+wasm: Verifying host based on allow_net allowlist in built-in functions Adding host allow-listing based on the allow_net capability in the http.send()- and net.lookup_ip_addr() built-in functions when running the eval command. Fixes: #3665 Signed-off-by: Johan Fylling --- ast/compile.go | 24 +++--- features/wasm/wasm.go | 1 + internal/rego/opa/options.go | 2 + internal/wasm/sdk/internal/wasm/bindings.go | 8 +- internal/wasm/sdk/internal/wasm/pool_test.go | 2 +- internal/wasm/sdk/internal/wasm/vm.go | 12 +-- internal/wasm/sdk/opa/opa.go | 5 +- rego/rego.go | 4 + rego/rego_wasmtarget_test.go | 52 ++++++++++++ topdown/builtins.go | 1 + topdown/eval.go | 6 ++ topdown/http.go | 60 +++++++++++--- topdown/http_test.go | 85 ++++++++++++++++++++ topdown/net.go | 10 ++- topdown/net_test.go | 57 +++++++++++++ topdown/topdown_test.go | 8 ++ 16 files changed, 303 insertions(+), 34 deletions(-) diff --git a/ast/compile.go b/ast/compile.go index f446606d16..d46222eeaf 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -86,6 +86,9 @@ type Compiler struct { // with the key being the generated name and value being the original. RewrittenVars map[Var]Var + // Capabilities is the user-supplied capabilities or features allowed for OPA. + Capabilities *Capabilities + localvargen *localVarGenerator moduleLoader ModuleLoader ruleIndices *util.HashMap @@ -99,10 +102,9 @@ type Compiler struct { pathExists func([]string) (bool, error) after map[string][]CompilerStageDefinition metrics metrics.Metrics - capabilities *Capabilities // user-supplied capabilities builtins map[string]*Builtin // universe of built-in functions - customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities) - unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities) + customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use Capabilities) + unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use Capabilities) enablePrintStatements bool // indicates if print statements should be elided (default) comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index initialized bool // indicates if init() has been called @@ -325,14 +327,14 @@ func (c *Compiler) WithMetrics(metrics metrics.Metrics) *Compiler { return c } -// WithCapabilities sets capabilities to enable during compilation. Capabilities allow the caller -// to specify the set of built-in functions available to the policy. In the future, capabilities +// WithCapabilities sets Capabilities to enable during compilation. Capabilities allow the caller +// to specify the set of built-in functions available to the policy. In the future, Capabilities // may be able to restrict access to other language features. Capabilities allow callers to check // if policies are compatible with a particular version of OPA. If policies are a compiled for a // specific version of OPA, there is no guarantee that _this_ version of OPA can evaluate them // successfully. func (c *Compiler) WithCapabilities(capabilities *Capabilities) *Compiler { - c.capabilities = capabilities + c.Capabilities = capabilities return c } @@ -1203,13 +1205,13 @@ func (c *Compiler) init() { return } - if c.capabilities == nil { - c.capabilities = CapabilitiesForThisVersion() + if c.Capabilities == nil { + c.Capabilities = CapabilitiesForThisVersion() } - c.builtins = make(map[string]*Builtin, len(c.capabilities.Builtins)+len(c.customBuiltins)) + c.builtins = make(map[string]*Builtin, len(c.Capabilities.Builtins)+len(c.customBuiltins)) - for _, bi := range c.capabilities.Builtins { + for _, bi := range c.Capabilities.Builtins { c.builtins[bi.Name] = bi } @@ -1220,7 +1222,7 @@ func (c *Compiler) init() { // Load the global input schema if one was provided. if c.schemaSet != nil { if schema := c.schemaSet.Get(SchemaRootRef); schema != nil { - tpe, err := loadSchema(schema, c.capabilities.AllowNet) + tpe, err := loadSchema(schema, c.Capabilities.AllowNet) if err != nil { c.err(NewError(TypeErr, nil, err.Error())) } else { diff --git a/features/wasm/wasm.go b/features/wasm/wasm.go index 8e8e52ea2d..8eee1e3372 100644 --- a/features/wasm/wasm.go +++ b/features/wasm/wasm.go @@ -65,6 +65,7 @@ func (o *OPA) Eval(ctx context.Context, opts opa.EvalOpts) (*opa.Result, error) Seed: opts.Seed, InterQueryBuiltinCache: opts.InterQueryBuiltinCache, PrintHook: opts.PrintHook, + Capabilities: opts.Capabilities, } res, err := o.opa.Eval(ctx, evalOptions) diff --git a/internal/rego/opa/options.go b/internal/rego/opa/options.go index ec6b2f0b5a..1e0477a0f6 100644 --- a/internal/rego/opa/options.go +++ b/internal/rego/opa/options.go @@ -4,6 +4,7 @@ import ( "io" "time" + "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/metrics" "github.com/open-policy-agent/opa/topdown/cache" "github.com/open-policy-agent/opa/topdown/print" @@ -23,4 +24,5 @@ type EvalOpts struct { Seed io.Reader InterQueryBuiltinCache cache.InterQueryCache PrintHook print.Hook + Capabilities *ast.Capabilities } diff --git a/internal/wasm/sdk/internal/wasm/bindings.go b/internal/wasm/sdk/internal/wasm/bindings.go index 6c27e330a4..eadd9278d3 100644 --- a/internal/wasm/sdk/internal/wasm/bindings.go +++ b/internal/wasm/sdk/internal/wasm/bindings.go @@ -81,7 +81,12 @@ func (d *builtinDispatcher) SetMap(m map[int32]topdown.BuiltinFunc) { } // Reset is called in Eval before using the builtinDispatcher. -func (d *builtinDispatcher) Reset(ctx context.Context, seed io.Reader, ns time.Time, iqbCache cache.InterQueryCache, ph print.Hook) { +func (d *builtinDispatcher) Reset(ctx context.Context, + seed io.Reader, + ns time.Time, + iqbCache cache.InterQueryCache, + ph print.Hook, + capabilities *ast.Capabilities) { if ns.IsZero() { ns = time.Now() } @@ -103,6 +108,7 @@ func (d *builtinDispatcher) Reset(ctx context.Context, seed io.Reader, ns time.T ParentID: 0, InterQueryBuiltinCache: iqbCache, PrintHook: ph, + Capabilities: capabilities, } } diff --git a/internal/wasm/sdk/internal/wasm/pool_test.go b/internal/wasm/sdk/internal/wasm/pool_test.go index 64089379e7..57d29a6f29 100644 --- a/internal/wasm/sdk/internal/wasm/pool_test.go +++ b/internal/wasm/sdk/internal/wasm/pool_test.go @@ -176,7 +176,7 @@ func ensurePoolResults(t *testing.T, ctx context.Context, testPool *wasm.Pool, p toRelease = append(toRelease, vm) cfg, _ := cache.ParseCachingConfig(nil) - result, err := vm.Eval(ctx, 0, input, metrics.New(), rand.New(rand.NewSource(0)), time.Now(), cache.NewInterQueryCache(cfg), nil) + result, err := vm.Eval(ctx, 0, input, metrics.New(), rand.New(rand.NewSource(0)), time.Now(), cache.NewInterQueryCache(cfg), nil, nil) if err != nil { t.Fatalf("Unexpected error: %s", err) } diff --git a/internal/wasm/sdk/internal/wasm/vm.go b/internal/wasm/sdk/internal/wasm/vm.go index e768dab448..d8a3923049 100644 --- a/internal/wasm/sdk/internal/wasm/vm.go +++ b/internal/wasm/sdk/internal/wasm/vm.go @@ -275,9 +275,10 @@ func (i *VM) Eval(ctx context.Context, seed io.Reader, ns time.Time, iqbCache cache.InterQueryCache, - ph print.Hook) ([]byte, error) { + ph print.Hook, + capabilities *ast.Capabilities) ([]byte, error) { if i.abiMinorVersion < int32(2) { - return i.evalCompat(ctx, entrypoint, input, metrics, seed, ns, iqbCache, ph) + return i.evalCompat(ctx, entrypoint, input, metrics, seed, ns, iqbCache, ph, capabilities) } metrics.Timer("wasm_vm_eval").Start() @@ -328,7 +329,7 @@ func (i *VM) Eval(ctx context.Context, // make use of it (e.g. `http.send`); and it will spawn a go routine // cancelling the builtins that use topdown.Cancel, when the context is // cancelled. - i.dispatcher.Reset(ctx, seed, ns, iqbCache, ph) + i.dispatcher.Reset(ctx, seed, ns, iqbCache, ph, capabilities) metrics.Timer("wasm_vm_eval_call").Start() resultAddr, err := i.evalOneOff(ctx, int32(entrypoint), i.dataAddr, inputAddr, inputLen, heapPtr) @@ -357,7 +358,8 @@ func (i *VM) evalCompat(ctx context.Context, seed io.Reader, ns time.Time, iqbCache cache.InterQueryCache, - ph print.Hook) ([]byte, error) { + ph print.Hook, + capabilities *ast.Capabilities) ([]byte, error) { metrics.Timer("wasm_vm_eval").Start() defer metrics.Timer("wasm_vm_eval").Stop() @@ -367,7 +369,7 @@ func (i *VM) evalCompat(ctx context.Context, // make use of it (e.g. `http.send`); and it will spawn a go routine // cancelling the builtins that use topdown.Cancel, when the context is // cancelled. - i.dispatcher.Reset(ctx, seed, ns, iqbCache, ph) + i.dispatcher.Reset(ctx, seed, ns, iqbCache, ph, capabilities) err := i.setHeapState(ctx, i.evalHeapPtr) if err != nil { diff --git a/internal/wasm/sdk/opa/opa.go b/internal/wasm/sdk/opa/opa.go index 95130cc715..159771060a 100644 --- a/internal/wasm/sdk/opa/opa.go +++ b/internal/wasm/sdk/opa/opa.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/internal/wasm/sdk/internal/wasm" "github.com/open-policy-agent/opa/internal/wasm/sdk/opa/errors" sdk_errors "github.com/open-policy-agent/opa/internal/wasm/sdk/opa/errors" @@ -165,6 +166,7 @@ type EvalOpts struct { Seed io.Reader InterQueryBuiltinCache cache.InterQueryCache PrintHook print.Hook + Capabilities *ast.Capabilities } // Eval evaluates the policy with the given input, returning the @@ -188,7 +190,8 @@ func (o *OPA) Eval(ctx context.Context, opts EvalOpts) (*Result, error) { defer o.pool.Release(instance, m) - result, err := instance.Eval(ctx, opts.Entrypoint, opts.Input, m, opts.Seed, opts.Time, opts.InterQueryBuiltinCache, opts.PrintHook) + result, err := instance.Eval(ctx, opts.Entrypoint, opts.Input, m, opts.Seed, opts.Time, opts.InterQueryBuiltinCache, + opts.PrintHook, opts.Capabilities) if err != nil { return nil, err } diff --git a/rego/rego.go b/rego/rego.go index d7048a0e81..4d82bab3c8 100644 --- a/rego/rego.go +++ b/rego/rego.go @@ -118,6 +118,7 @@ type EvalContext struct { resolvers []refResolver sortSets bool printHook print.Hook + capabilities *ast.Capabilities } // EvalOption defines a function to set an option on an EvalConfig @@ -311,6 +312,7 @@ func (pq preparedQuery) newEvalContext(ctx context.Context, options []EvalOption earlyExit: true, resolvers: pq.r.resolvers, printHook: pq.r.printHook, + capabilities: pq.r.capabilities, } for _, o := range options { @@ -1969,6 +1971,7 @@ func (r *Rego) evalWasm(ctx context.Context, ectx *EvalContext) (ResultSet, erro Seed: ectx.seed, InterQueryBuiltinCache: ectx.interQueryBuiltinCache, PrintHook: ectx.printHook, + Capabilities: ectx.capabilities, }) if err != nil { return nil, err @@ -2075,6 +2078,7 @@ func (r *Rego) partialResult(ctx context.Context, pCfg *PrepareConfig) (PartialR instrumentation: r.instrumentation, indexing: true, resolvers: r.resolvers, + capabilities: r.capabilities, } disableInlining := r.disableInlining diff --git a/rego/rego_wasmtarget_test.go b/rego/rego_wasmtarget_test.go index 787fe4fbc4..9ad66f0198 100644 --- a/rego/rego_wasmtarget_test.go +++ b/rego/rego_wasmtarget_test.go @@ -12,6 +12,8 @@ import ( "math/rand" "net/http" "net/http/httptest" + "net/url" + "strings" "testing" "time" @@ -338,3 +340,53 @@ func TestEvalWasmWithInterQueryCache(t *testing.T) { t.Fatal("Expected server to be called only once") } } + +func TestEvalWasmWithHTTPAllowNet(t *testing.T) { + var requests []*http.Request + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests = append(requests, r) + + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"x": 1}`)) + })) + defer ts.Close() + + serverUrl, err := url.Parse(ts.URL) + if err != nil { + panic(err) + } + serverHost := strings.Split(serverUrl.Host, ":")[0] + + query := fmt.Sprintf(`http.send({"method": "get", "url": "%s", "force_json_decode": true, "cache": true})`, ts.URL) + capabilities := ast.CapabilitiesForThisVersion() + capabilities.AllowNet = []string{"example.com"} + + // add an inter-query cache + config, _ := cache.ParseCachingConfig(nil) + interQueryCache := cache.NewInterQueryCache(config) + + ctx := context.Background() + // StrictBuiltinErrors(true) has no effect when target is 'wasm' + // this request should be rejected by the allow_net allowlist + _, err = New(Target("wasm"), Query(query), InterQueryBuiltinCache(interQueryCache), Capabilities(capabilities)).Eval(ctx) + if err != nil { + t.Fatal(err) + } + + if len(requests) != 0 { + t.Fatal("Expected server to not be called") + } + + capabilities.AllowNet = []string{serverHost} + + // eval again with same query + // this request should not be rejected by the allow_net allowlist + _, err = New(Target("wasm"), Query(query), InterQueryBuiltinCache(interQueryCache), Capabilities(capabilities)).Eval(ctx) + if err != nil { + t.Fatal(err) + } + + if len(requests) != 1 { + t.Fatal("Expected server to never be called") + } +} diff --git a/topdown/builtins.go b/topdown/builtins.go index 4081878673..0df8bb8443 100644 --- a/topdown/builtins.go +++ b/topdown/builtins.go @@ -52,6 +52,7 @@ type ( PrintHook print.Hook // provides callback function to use for printing DistributedTracingOpts tracing.Options // options to be used by distributed tracing. rand *rand.Rand // randomization source for non-security-sensitive operations + Capabilities *ast.Capabilities } // BuiltinFunc defines an interface for implementing built-in functions. diff --git a/topdown/eval.go b/topdown/eval.go index d2fdce5411..e56b805079 100644 --- a/topdown/eval.go +++ b/topdown/eval.go @@ -701,6 +701,11 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error { parentID = e.parent.queryID } + var capabilities *ast.Capabilities + if e.compiler != nil { + capabilities = e.compiler.Capabilities + } + bctx := BuiltinContext{ Context: e.ctx, Metrics: e.metrics, @@ -717,6 +722,7 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error { ParentID: parentID, PrintHook: e.printHook, DistributedTracingOpts: e.tracingOpts, + Capabilities: capabilities, } eval := evalBuiltin{ diff --git a/topdown/http.go b/topdown/http.go index e855794cff..c8205f18d1 100644 --- a/topdown/http.go +++ b/topdown/http.go @@ -264,6 +264,36 @@ func useSocket(rawURL string, tlsConfig *tls.Config) (bool, string, *http.Transp return true, rawURL, tr } +func verifyHost(bctx BuiltinContext, host string) error { + if bctx.Capabilities == nil || bctx.Capabilities.AllowNet == nil { + return nil + } + + for _, allowed := range bctx.Capabilities.AllowNet { + if allowed == host { + return nil + } + } + + return fmt.Errorf("unallowed host: %s", host) +} + +func verifyURLHost(bctx BuiltinContext, unverifiedURL string) error { + // Eager return to avoid unnecessary URL parsing + if bctx.Capabilities == nil || bctx.Capabilities.AllowNet == nil { + return nil + } + + parsedURL, err := url.Parse(unverifiedURL) + if err != nil { + return err + } + + host := strings.Split(parsedURL.Host, ":")[0] + + return verifyHost(bctx, host) +} + func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *http.Client, error) { var url string var method string @@ -305,7 +335,7 @@ func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *htt var strVal string if s, ok := obj.Get(val).Value.(ast.String); ok { - strVal = string(s) + strVal = strings.Trim(string(s), "\"") } else { // Most parameters are strings, so consolidate the type checking. switch key { @@ -328,9 +358,13 @@ func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *htt switch key { case "method": - method = strings.ToUpper(strings.Trim(strVal, "\"")) + method = strings.ToUpper(strVal) case "url": - url = strings.Trim(strVal, "\"") + err := verifyURLHost(bctx, strVal) + if err != nil { + return nil, nil, err + } + url = strVal case "enable_redirect": enableRedirect, err = strconv.ParseBool(obj.Get(val).String()) if err != nil { @@ -357,25 +391,25 @@ func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *htt } tlsUseSystemCerts = &tempTLSUseSystemCerts case "tls_ca_cert": - tlsCaCert = bytes.Trim([]byte(strVal), "\"") + tlsCaCert = []byte(strVal) case "tls_ca_cert_file": - tlsCaCertFile = strings.Trim(strVal, "\"") + tlsCaCertFile = strVal case "tls_ca_cert_env_variable": - tlsCaCertEnvVar = strings.Trim(strVal, "\"") + tlsCaCertEnvVar = strVal case "tls_client_cert": - tlsClientCert = bytes.Trim([]byte(strVal), "\"") + tlsClientCert = []byte(strVal) case "tls_client_cert_file": - tlsClientCertFile = strings.Trim(strVal, "\"") + tlsClientCertFile = strVal case "tls_client_cert_env_variable": - tlsClientCertEnvVar = strings.Trim(strVal, "\"") + tlsClientCertEnvVar = strVal case "tls_client_key": - tlsClientKey = bytes.Trim([]byte(strVal), "\"") + tlsClientKey = []byte(strVal) case "tls_client_key_file": - tlsClientKeyFile = strings.Trim(strVal, "\"") + tlsClientKeyFile = strVal case "tls_client_key_env_variable": - tlsClientKeyEnvVar = strings.Trim(strVal, "\"") + tlsClientKeyEnvVar = strVal case "tls_server_name": - tlsServerName = strings.Trim(strVal, "\"") + tlsServerName = strVal case "headers": headersVal := obj.Get(val).Value headersValInterface, err := ast.JSON(headersVal) diff --git a/topdown/http_test.go b/topdown/http_test.go index 66ea15fc70..4578c26103 100644 --- a/topdown/http_test.go +++ b/topdown/http_test.go @@ -2681,3 +2681,88 @@ func TestDistributedTracingDisabled(t *testing.T) { t.Errorf("calls to NewTransported: expected %d, got %d", exp, act) } } + +func TestHTTPGetRequestAllowNet(t *testing.T) { + + // test data + body := map[string]bool{"ok": true} + + // test server + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(body) + })) + + defer ts.Close() + + // host + serverURL, err := url.Parse(ts.URL) + if err != nil { + panic(err) + } + serverHost := strings.Split(serverURL.Host, ":")[0] + + // expected result + expectedResult := make(map[string]interface{}) + expectedResult["status"] = "200 OK" + expectedResult["status_code"] = http.StatusOK + + expectedResult["body"] = body + expectedResult["raw_body"] = "{\"ok\":true}\n" + + resultObj, err := ast.InterfaceToValue(expectedResult) + if err != nil { + panic(err) + } + + expectedError := &Error{Code: "eval_builtin_error", Message: fmt.Sprintf("http.send: unallowed host: %s", serverHost)} + + rules := []string{fmt.Sprintf( + `p = x { http.send({"method": "get", "url": "%s", "force_json_decode": true}, resp); x := remove_headers(resp) }`, ts.URL)} + + // run the test + tests := []struct { + note string + rules []string + options func(*Query) *Query + expected interface{} + }{ + { + "http.send allow_net nil", + rules, + + setAllowNet(nil), + resultObj.String(), + }, + { + "http.send allow_net match", + rules, + setAllowNet([]string{serverHost}), + resultObj.String(), + }, + { + "http.send allow_net match + additional host", + rules, + setAllowNet([]string{serverHost, "example.com"}), + resultObj.String(), + }, + { + "http.send allow_net empty", + rules, + setAllowNet([]string{}), + expectedError, + }, + { + "http.send allow_net no match", + rules, + setAllowNet([]string{"example.com"}), + expectedError, + }, + } + + data := loadSmallTestData() + + for _, tc := range tests { + runTopDownTestCase(t, data, tc.note, append(tc.rules, httpSendHelperRules...), tc.expected, tc.options) + } +} diff --git a/topdown/net.go b/topdown/net.go index 24b2d0b032..a167cf13d1 100644 --- a/topdown/net.go +++ b/topdown/net.go @@ -17,7 +17,13 @@ type lookupIPAddrCacheKey string var resolv = &net.Resolver{} func builtinLookupIPAddr(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { - name, err := builtins.StringOperand(operands[0].Value, 1) + a, err := builtins.StringOperand(operands[0].Value, 1) + if err != nil { + return err + } + name := string(a) + + err = verifyHost(bctx, name) if err != nil { return err } @@ -27,7 +33,7 @@ func builtinLookupIPAddr(bctx BuiltinContext, operands []*ast.Term, iter func(*a return iter(val.(*ast.Term)) } - addrs, err := resolv.LookupIPAddr(bctx.Context, string(name)) + addrs, err := resolv.LookupIPAddr(bctx.Context, name) if err != nil { // NOTE(sr): We can't do better than this right now, see https://github.com/golang/go/issues/36208 if err.Error() == "operation was canceled" || err.Error() == "i/o timeout" { diff --git a/topdown/net_test.go b/topdown/net_test.go index 9db8b72087..41a18a433e 100644 --- a/topdown/net_test.go +++ b/topdown/net_test.go @@ -155,6 +155,63 @@ func TestNetLookupIPAddr(t *testing.T) { } }) } + + addr := "v4.org" + exp := ast.NewSet(ast.StringTerm("1.2.3.4")) + for name, allowNet := range map[string][]string{ + "allow_net nil": nil, + "allow_net match": {addr}, + "allow_net match + additional host": {addr, "example.com"}, + } { + t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + capabilities := ast.CapabilitiesForThisVersion() + capabilities.AllowNet = allowNet + bctx := BuiltinContext{ + Context: ctx, + Cache: make(builtins.Cache), + Capabilities: capabilities, + } + srv.PatchNet(resolv) + err := builtinLookupIPAddr(bctx, []*ast.Term{ast.StringTerm(addr)}, func(act *ast.Term) error { + if exp.Compare(act.Value) != 0 { + t.Errorf("expected %v, got %v", exp, act) + } + return nil + }) + if err != nil { + t.Error(err) + } + }) + } + + expError := fmt.Errorf("unallowed host: %s", addr) + for name, allowNet := range map[string][]string{ + "allow_net empty": {}, + "allow_net no match": {"example.com"}, + } { + t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + capabilities := ast.CapabilitiesForThisVersion() + capabilities.AllowNet = allowNet + bctx := BuiltinContext{ + Context: ctx, + Cache: make(builtins.Cache), + Capabilities: capabilities, + } + srv.PatchNet(resolv) + err := builtinLookupIPAddr(bctx, []*ast.Term{ast.StringTerm(addr)}, func(act *ast.Term) error { + t.Fatal("expected not to be called") + return nil + }) + if err == nil { + t.Error("expected error") + } + assertError(t, expError, err) + }) + } } type sink struct{} diff --git a/topdown/topdown_test.go b/topdown/topdown_test.go index 4617f09164..5b01c35340 100644 --- a/topdown/topdown_test.go +++ b/topdown/topdown_test.go @@ -740,6 +740,14 @@ func setTime(t time.Time) func(*Query) *Query { } } +func setAllowNet(a []string) func(*Query) *Query { + return func(q *Query) *Query { + c := q.compiler.Capabilities + c.AllowNet = a + return q.WithCompiler(q.compiler.WithCapabilities(c)) + } +} + func runTopDownTestCase(t *testing.T, data map[string]interface{}, note string, rules []string, expected interface{}, options ...func(*Query) *Query) { t.Helper()