diff --git a/server/authorizer/authorizer.go b/server/authorizer/authorizer.go index dd7344d435..4ff3315687 100644 --- a/server/authorizer/authorizer.go +++ b/server/authorizer/authorizer.go @@ -18,6 +18,7 @@ import ( "github.com/open-policy-agent/opa/server/types" "github.com/open-policy-agent/opa/server/writer" "github.com/open-policy-agent/opa/storage" + "github.com/open-policy-agent/opa/topdown/cache" "github.com/open-policy-agent/opa/topdown/print" "github.com/open-policy-agent/opa/util" ) @@ -31,6 +32,7 @@ type Basic struct { decision func() ast.Ref printHook print.Hook enablePrintStatements bool + interQueryCache cache.InterQueryCache } // Runtime returns an argument that sets the runtime on the authorizer. @@ -65,6 +67,13 @@ func EnablePrintStatements(yes bool) func(r *Basic) { } } +// InterQueryCache enables the inter-query cache on the authorizer +func InterQueryCache(interQueryCache cache.InterQueryCache) func(*Basic) { + return func(b *Basic) { + b.interQueryCache = interQueryCache + } +} + // NewBasic returns a new Basic object. func NewBasic(inner http.Handler, compiler func() *ast.Compiler, store storage.Store, opts ...func(*Basic)) http.Handler { b := &Basic{ @@ -98,6 +107,7 @@ func (h *Basic) ServeHTTP(w http.ResponseWriter, r *http.Request) { rego.Runtime(h.runtime), rego.EnablePrintStatements(h.enablePrintStatements), rego.PrintHook(h.printHook), + rego.InterQueryBuiltinCache(h.interQueryCache), ) rs, err := rego.Eval(r.Context()) diff --git a/server/authorizer/authorizer_test.go b/server/authorizer/authorizer_test.go index 6803a3c996..e30833fe52 100644 --- a/server/authorizer/authorizer_test.go +++ b/server/authorizer/authorizer_test.go @@ -7,6 +7,7 @@ package authorizer import ( "bytes" "encoding/json" + "fmt" "net/http" "net/http/httptest" "reflect" @@ -17,6 +18,7 @@ import ( "github.com/open-policy-agent/opa/server/identifier" "github.com/open-policy-agent/opa/server/types" "github.com/open-policy-agent/opa/storage/inmem" + "github.com/open-policy-agent/opa/topdown/cache" "github.com/open-policy-agent/opa/topdown/print" "github.com/open-policy-agent/opa/util" ) @@ -260,7 +262,7 @@ func TestBasicEscapeError(t *testing.T) { recorder := httptest.NewRecorder() req, err := http.NewRequest(http.MethodGet, "http://localhost:8181", nil) if err != nil { - panic(err) + t.Fatal(err) } req.URL.Path = `/invalid/path/foo%LALALA` @@ -293,7 +295,7 @@ func TestMakeInput(t *testing.T) { path := "/foo/bar?pretty=true&explain=\"full\"" req, err := http.NewRequest(http.MethodGet, "http://localhost:8181"+path, nil) if err != nil { - panic(err) + t.Fatal(err) } req.Header.Add("x-custom", "foo") @@ -312,7 +314,7 @@ func TestMakeInput(t *testing.T) { _, result, err := makeInput(req) if err != nil { - panic(err) + t.Fatal(err) } expectedResult := util.MustUnmarshalJSON([]byte(` @@ -455,6 +457,64 @@ func TestMakeInputWithBody(t *testing.T) { } +func TestInterQueryCache(t *testing.T) { + + count := 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + count++ + })) + + t.Cleanup(func() { + ts.Close() + }) + + compiler := func() *ast.Compiler { + module := fmt.Sprintf(` + package system.authz + + allow { + http.send({ + "method": "GET", + "url": "%v", + "force_cache": true, + "force_cache_duration_seconds": 60 + }).status_code == 200 + } + `, ts.URL) + c := ast.NewCompiler() + c.Compile(map[string]*ast.Module{ + "test.rego": ast.MustParseModule(module), + }) + if c.Failed() { + t.Fatalf("Unexpected error compiling test module: %v", c.Errors) + } + return c + } + + recorder := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "http://localhost:8181/v1/data", nil) + if err != nil { + t.Fatal(err) + } + + config, _ := cache.ParseCachingConfig(nil) + interQueryCache := cache.NewInterQueryCache(config) + + basic := NewBasic(&mockHandler{}, compiler, inmem.New(), InterQueryCache(interQueryCache), Decision(func() ast.Ref { + return ast.MustParseRef("data.system.authz.allow") + })) + + // Execute the policy twice + basic.ServeHTTP(recorder, req) + basic.ServeHTTP(recorder, req) + + // And make sure the test server was only hit once + if count != 1 { + t.Error("Expected http.send response to be cached") + } +} + func Equal(a, b []string) bool { if len(a) != len(b) { return false diff --git a/server/server.go b/server/server.go index 846ab30a55..5065a416c4 100644 --- a/server/server.go +++ b/server/server.go @@ -632,7 +632,8 @@ func (s *Server) initHandlerAuth(handler http.Handler) http.Handler { authorizer.Runtime(s.runtime), authorizer.Decision(s.manager.Config.DefaultAuthorizationDecisionRef), authorizer.PrintHook(s.manager.PrintHook()), - authorizer.EnablePrintStatements(s.manager.EnablePrintStatements())) + authorizer.EnablePrintStatements(s.manager.EnablePrintStatements()), + authorizer.InterQueryCache(s.interQueryBuiltinCache)) } switch s.authentication {