From cb43b1e0547853cdf06fb89463c42a895dad7008 Mon Sep 17 00:00:00 2001 From: Rudrakh Panigrahi Date: Tue, 7 Nov 2023 20:34:36 +0530 Subject: [PATCH] clean expired cache entries periodically Signed-off-by: Rudrakh Panigrahi --- internal/wasm/sdk/internal/wasm/pool_test.go | 3 +- rego/rego_test.go | 3 +- rego/rego_wasmtarget_test.go | 5 +- runtime/runtime.go | 10 +- sdk/opa.go | 2 +- server/authorizer/authorizer_test.go | 4 +- server/server.go | 8 +- server/server_test.go | 14 +- topdown/cache/cache.go | 110 ++++++++++-- topdown/cache/cache_test.go | 174 ++++++++++++++++--- topdown/http.go | 19 +- topdown/http_test.go | 17 +- topdown/topdown_test.go | 5 +- 13 files changed, 303 insertions(+), 71 deletions(-) diff --git a/internal/wasm/sdk/internal/wasm/pool_test.go b/internal/wasm/sdk/internal/wasm/pool_test.go index 3254b624395..ce4d9c4fd92 100644 --- a/internal/wasm/sdk/internal/wasm/pool_test.go +++ b/internal/wasm/sdk/internal/wasm/pool_test.go @@ -9,6 +9,7 @@ package wasm_test import ( "context" + "github.com/open-policy-agent/opa/logging" "math/rand" "strings" "testing" @@ -177,7 +178,7 @@ func ensurePoolResults(t *testing.T, ctx context.Context, testPool *wasm.Pool, p toRelease = append(toRelease, vm) cfg, _ := cache.ParseCachingConfig(nil) - result, err := vm.Eval(ctx, 0, input, metrics.New(), rand.New(rand.NewSource(0)), time.Now(), cache.NewInterQueryCache(cfg), builtins.NDBCache{}, nil, nil) + result, err := vm.Eval(ctx, 0, input, metrics.New(), rand.New(rand.NewSource(0)), time.Now(), cache.NewInterQueryCache(ctx, logging.NewNoOpLogger(), cfg), builtins.NDBCache{}, nil, nil) if err != nil { t.Fatalf("Unexpected error: %s", err) } diff --git a/rego/rego_test.go b/rego/rego_test.go index 8e857c7ddd1..cfe56d3860e 100644 --- a/rego/rego_test.go +++ b/rego/rego_test.go @@ -26,6 +26,7 @@ import ( "github.com/open-policy-agent/opa/ast/location" "github.com/open-policy-agent/opa/bundle" "github.com/open-policy-agent/opa/internal/storage/mock" + "github.com/open-policy-agent/opa/logging" "github.com/open-policy-agent/opa/metrics" "github.com/open-policy-agent/opa/storage" "github.com/open-policy-agent/opa/storage/inmem" @@ -2155,7 +2156,7 @@ func TestEvalWithInterQueryCache(t *testing.T) { // add an inter-query cache config, _ := cache.ParseCachingConfig(nil) - interQueryCache := cache.NewInterQueryCache(config) + interQueryCache := cache.NewInterQueryCache(context.Background(), logging.NewNoOpLogger(), config) ctx := context.Background() _, err := New(Query(query), InterQueryBuiltinCache(interQueryCache)).Eval(ctx) diff --git a/rego/rego_wasmtarget_test.go b/rego/rego_wasmtarget_test.go index 3d874ae44af..5331fc64737 100644 --- a/rego/rego_wasmtarget_test.go +++ b/rego/rego_wasmtarget_test.go @@ -22,6 +22,7 @@ import ( "github.com/open-policy-agent/opa/ast" sdk_errors "github.com/open-policy-agent/opa/internal/wasm/sdk/opa/errors" + "github.com/open-policy-agent/opa/logging" "github.com/open-policy-agent/opa/storage/inmem" "github.com/open-policy-agent/opa/topdown" "github.com/open-policy-agent/opa/topdown/cache" @@ -325,7 +326,7 @@ func TestEvalWasmWithInterQueryCache(t *testing.T) { // add an inter-query cache config, _ := cache.ParseCachingConfig(nil) - interQueryCache := cache.NewInterQueryCache(config) + interQueryCache := cache.NewInterQueryCache(context.Background(), logging.NewNoOpLogger(), config) ctx := context.Background() _, err := New(Target("wasm"), Query(query), InterQueryBuiltinCache(interQueryCache)).Eval(ctx) @@ -367,7 +368,7 @@ func TestEvalWasmWithHTTPAllowNet(t *testing.T) { // add an inter-query cache config, _ := cache.ParseCachingConfig(nil) - interQueryCache := cache.NewInterQueryCache(config) + interQueryCache := cache.NewInterQueryCache(context.Background(), logging.NewNoOpLogger(), config) ctx := context.Background() // StrictBuiltinErrors(true) has no effect when target is 'wasm' diff --git a/runtime/runtime.go b/runtime/runtime.go index 09d1ab1d8c4..c90955e87a0 100644 --- a/runtime/runtime.go +++ b/runtime/runtime.go @@ -562,15 +562,18 @@ func (rt *Runtime) Serve(ctx context.Context) error { rt.server = rt.server.WithUnixSocketPermission(rt.Params.UnixSocketPerm) } - rt.server, err = rt.server.Init(ctx) + ctx, cancel := context.WithCancel(ctx) + rt.server, err = rt.server.Init(ctx, rt.logger) if err != nil { rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Unable to initialize server.") + cancel() return err } if rt.Params.Watch { if err := rt.startWatcher(ctx, rt.Params.Paths, rt.onReloadLogger); err != nil { rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Unable to open watch.") + cancel() return err } } @@ -594,12 +597,14 @@ func (rt *Runtime) Serve(ctx context.Context) error { 100*time.Millisecond, time.Second*time.Duration(rt.Params.ReadyTimeout)); err != nil { rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Failed to wait for plugins activation.") + cancel() return err } loops, err := rt.server.Listeners() if err != nil { rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Unable to create listeners.") + cancel() return err } @@ -630,11 +635,14 @@ func (rt *Runtime) Serve(ctx context.Context) error { for { select { case <-ctx.Done(): + cancel() return rt.gracefulServerShutdown(rt.server) case <-signalc: + cancel() return rt.gracefulServerShutdown(rt.server) case err := <-errc: rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Listener failed.") + cancel() os.Exit(1) } } diff --git a/sdk/opa.go b/sdk/opa.go index f7e2cccaa1f..add1f29ba0b 100644 --- a/sdk/opa.go +++ b/sdk/opa.go @@ -212,7 +212,7 @@ func (opa *OPA) configure(ctx context.Context, bs []byte, ready chan struct{}, b opa.state.manager = manager opa.state.queryCache.Clear() - opa.state.interQueryBuiltinCache = cache.NewInterQueryCache(manager.InterQueryBuiltinCacheConfig()) + opa.state.interQueryBuiltinCache = cache.NewInterQueryCache(ctx, opa.logger, manager.InterQueryBuiltinCacheConfig()) opa.config = bs return nil diff --git a/server/authorizer/authorizer_test.go b/server/authorizer/authorizer_test.go index e30833fe52d..19fd08c3afd 100644 --- a/server/authorizer/authorizer_test.go +++ b/server/authorizer/authorizer_test.go @@ -6,6 +6,7 @@ package authorizer import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -15,6 +16,7 @@ import ( "testing" "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/logging" "github.com/open-policy-agent/opa/server/identifier" "github.com/open-policy-agent/opa/server/types" "github.com/open-policy-agent/opa/storage/inmem" @@ -499,7 +501,7 @@ func TestInterQueryCache(t *testing.T) { } config, _ := cache.ParseCachingConfig(nil) - interQueryCache := cache.NewInterQueryCache(config) + interQueryCache := cache.NewInterQueryCache(context.Background(), logging.NewNoOpLogger(), config) basic := NewBasic(&mockHandler{}, compiler, inmem.New(), InterQueryCache(interQueryCache), Decision(func() ast.Ref { return ast.MustParseRef("data.system.authz.allow") diff --git a/server/server.go b/server/server.go index a2b652dd2ff..6c95925e5ab 100644 --- a/server/server.go +++ b/server/server.go @@ -164,8 +164,8 @@ func New() *Server { // Init initializes the server. This function MUST be called before starting any loops // from s.Listeners(). -func (s *Server) Init(ctx context.Context) (*Server, error) { - s.initRouters() +func (s *Server) Init(ctx context.Context, logger logging.Logger) (*Server, error) { + s.initRouters(ctx, logger) txn, err := s.store.NewTransaction(ctx, storage.WriteParams) if err != nil { @@ -706,7 +706,7 @@ func (s *Server) initHandlerCompression(handler http.Handler) (http.Handler, err return compressHandler, nil } -func (s *Server) initRouters() { +func (s *Server) initRouters(ctx context.Context, logger logging.Logger) { mainRouter := s.router if mainRouter == nil { mainRouter = mux.NewRouter() @@ -715,7 +715,7 @@ func (s *Server) initRouters() { diagRouter := mux.NewRouter() // authorizer, if configured, needs the iCache to be set up already - s.interQueryBuiltinCache = iCache.NewInterQueryCache(s.manager.InterQueryBuiltinCacheConfig()) + s.interQueryBuiltinCache = iCache.NewInterQueryCache(ctx, logger, s.manager.InterQueryBuiltinCacheConfig()) s.manager.RegisterCacheTrigger(s.updateCacheConfig) // Add authorization handler. This must come BEFORE authentication handler diff --git a/server/server_test.go b/server/server_test.go index c2a4675ee9f..155b2c7d028 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -3395,7 +3395,7 @@ func TestQueryPostBasic(t *testing.T) { WithAddresses([]string{"localhost:8182"}). WithStore(f.server.store). WithManager(f.server.manager). - Init(context.Background()) + Init(context.Background(), logging.NewNoOpLogger()) setup := []tr{ {http.MethodPost, "/query", `{"query": "a=data.k.x with data.k as {\"x\" : 7}"}`, 200, `{"result":[{"a":7}]}`}, @@ -3909,7 +3909,7 @@ func TestAuthorization(t *testing.T) { WithStore(store). WithManager(m). WithAuthorization(AuthorizationBasic). - Init(ctx) + Init(ctx, logging.NewNoOpLogger()) if err != nil { panic(err) @@ -4040,7 +4040,7 @@ allow { WithStore(store). WithManager(m). WithAuthorization(AuthorizationBasic). - Init(ctx) + Init(ctx, logging.NewNoOpLogger()) if err != nil { t.Fatal(err) @@ -4225,7 +4225,7 @@ func TestQueryBindingIterationError(t *testing.T) { panic(err) } - server, err := New().WithStore(mock).WithManager(m).WithAddresses([]string{":8182"}).Init(ctx) + server, err := New().WithStore(mock).WithManager(m).WithAddresses([]string{":8182"}).Init(ctx, logging.NewNoOpLogger()) if err != nil { panic(err) } @@ -4288,7 +4288,7 @@ func newFixture(t *testing.T, opts ...func(*Server)) *fixture { if err := m.Start(ctx); err != nil { t.Fatal(err) } - server, err = server.Init(ctx) + server, err = server.Init(ctx, logging.NewNoOpLogger()) if err != nil { t.Fatal(err) } @@ -4318,7 +4318,7 @@ func newFixtureWithConfig(t *testing.T, config string, opts ...func(*Server)) *f if err := m.Start(ctx); err != nil { t.Fatal(err) } - server, err = server.Init(ctx) + server, err = server.Init(ctx, logging.NewNoOpLogger()) if err != nil { t.Fatal(err) } @@ -4349,7 +4349,7 @@ func newFixtureWithStore(t *testing.T, store storage.Store, opts ...func(*Server for _, opt := range opts { opt(server) } - server, err = server.Init(ctx) + server, err = server.Init(ctx, logging.NewNoOpLogger()) if err != nil { panic(err) } diff --git a/topdown/cache/cache.go b/topdown/cache/cache.go index f9d2bcff752..3be32349ecd 100644 --- a/topdown/cache/cache.go +++ b/topdown/cache/cache.go @@ -7,15 +7,21 @@ package cache import ( "container/list" + "context" + "fmt" + "math" "sync" + "time" "github.com/open-policy-agent/opa/ast" - + "github.com/open-policy-agent/opa/logging" "github.com/open-policy-agent/opa/util" ) const ( - defaultMaxSizeBytes = int64(0) // unlimited + defaultMaxSizeBytes = int64(0) // unlimited + defaultForcedEvictionThresholdPercentage = int64(100) // trigger at max_size_bytes + defaultStaleEntryEvictionPeriodSeconds = int64(0) // never ) // Config represents the configuration of the inter-query cache. @@ -25,7 +31,9 @@ type Config struct { // InterQueryBuiltinCacheConfig represents the configuration of the inter-query cache that built-in functions can utilize. type InterQueryBuiltinCacheConfig struct { - MaxSizeBytes *int64 `json:"max_size_bytes,omitempty"` + MaxSizeBytes *int64 `json:"max_size_bytes,omitempty"` + ForcedEvictionThresholdPercentage *int64 `json:"forced_eviction_threshold_percentage,omitempty"` + StaleEntryEvictionPeriodSeconds *int64 `json:"stale_entry_eviction_period_seconds,omitempty"` } // ParseCachingConfig returns the config for the inter-query cache. @@ -33,7 +41,11 @@ func ParseCachingConfig(raw []byte) (*Config, error) { if raw == nil { maxSize := new(int64) *maxSize = defaultMaxSizeBytes - return &Config{InterQueryBuiltinCache: InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize}}, nil + threshold := new(int64) + *threshold = defaultForcedEvictionThresholdPercentage + period := new(int64) + *period = defaultStaleEntryEvictionPeriodSeconds + return &Config{InterQueryBuiltinCache: InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize, ForcedEvictionThresholdPercentage: threshold, StaleEntryEvictionPeriodSeconds: period}}, nil } var config Config @@ -55,6 +67,26 @@ func (c *Config) validateAndInjectDefaults() error { *maxSize = defaultMaxSizeBytes c.InterQueryBuiltinCache.MaxSizeBytes = maxSize } + if c.InterQueryBuiltinCache.ForcedEvictionThresholdPercentage == nil { + threshold := new(int64) + *threshold = defaultForcedEvictionThresholdPercentage + c.InterQueryBuiltinCache.ForcedEvictionThresholdPercentage = threshold + } else { + threshold := *c.InterQueryBuiltinCache.ForcedEvictionThresholdPercentage + if threshold < 0 || threshold > 100 { + return fmt.Errorf("invalid forced_eviction_threshold_percentage %v", threshold) + } + } + if c.InterQueryBuiltinCache.StaleEntryEvictionPeriodSeconds == nil { + period := new(int64) + *period = defaultStaleEntryEvictionPeriodSeconds + c.InterQueryBuiltinCache.StaleEntryEvictionPeriodSeconds = period + } else { + period := *c.InterQueryBuiltinCache.StaleEntryEvictionPeriodSeconds + if period < 0 { + return fmt.Errorf("invalid stale_entry_eviction_period_seconds %v", period) + } + } return nil } @@ -67,24 +99,48 @@ type InterQueryCacheValue interface { // InterQueryCache defines the interface for the inter-query cache. type InterQueryCache interface { Get(key ast.Value) (value InterQueryCacheValue, found bool) - Insert(key ast.Value, value InterQueryCacheValue) int + Insert(key ast.Value, value InterQueryCacheValue, expiresAt time.Time) int Delete(key ast.Value) UpdateConfig(config *Config) Clone(value InterQueryCacheValue) (InterQueryCacheValue, error) } // NewInterQueryCache returns a new inter-query cache. -func NewInterQueryCache(config *Config) InterQueryCache { - return &cache{ +func NewInterQueryCache(ctx context.Context, logger logging.Logger, config *Config) InterQueryCache { + iqCache := &cache{ items: map[string]cacheItem{}, usage: 0, config: config, l: list.New(), + logger: logger, } + + // Start routine to clean up stale values once every StaleEntryEvictionPeriodSeconds + cleanupPeriod := iqCache.staleEntryEvictionTimePeriodSeconds() + if cleanupPeriod > 0 { + ticker := time.NewTicker(time.Duration(cleanupPeriod) * time.Second) + go func() { + defer func() { + iqCache.logger.Info("Stopping ticker for cache cleanup") + ticker.Stop() + }() + for { + select { + case <-ticker.C: + iqCache.logger.Debug("Dropped %v stale values", iqCache.cleanStaleValues()) + case <-ctx.Done(): + return + } + } + }() + } + + return iqCache } type cacheItem struct { value InterQueryCacheValue + expiresAt time.Time keyElement *list.Element } @@ -94,13 +150,14 @@ type cache struct { config *Config l *list.List mtx sync.Mutex + logger logging.Logger } // Insert inserts a key k into the cache with value v. -func (c *cache) Insert(k ast.Value, v InterQueryCacheValue) (dropped int) { +func (c *cache) Insert(k ast.Value, v InterQueryCacheValue, expiresAt time.Time) (dropped int) { c.mtx.Lock() defer c.mtx.Unlock() - return c.unsafeInsert(k, v) + return c.unsafeInsert(k, v, expiresAt) } // Get returns the value in the cache for k. @@ -137,10 +194,9 @@ func (c *cache) Clone(value InterQueryCacheValue) (InterQueryCacheValue, error) return c.unsafeClone(value) } -func (c *cache) unsafeInsert(k ast.Value, v InterQueryCacheValue) (dropped int) { +func (c *cache) unsafeInsert(k ast.Value, v InterQueryCacheValue, expiresAt time.Time) (dropped int) { size := v.SizeInBytes() - limit := c.maxSizeBytes() - + limit := int64(math.Ceil(float64(c.forcedEvictionThresholdPercentage())/100.0) * (float64(c.maxSizeBytes()))) if limit > 0 { if size > limit { dropped++ @@ -159,6 +215,7 @@ func (c *cache) unsafeInsert(k ast.Value, v InterQueryCacheValue) (dropped int) c.items[k.String()] = cacheItem{ value: v, + expiresAt: expiresAt, keyElement: c.l.PushBack(k), } c.usage += size @@ -191,3 +248,32 @@ func (c *cache) maxSizeBytes() int64 { } return *c.config.InterQueryBuiltinCache.MaxSizeBytes } + +func (c *cache) forcedEvictionThresholdPercentage() int64 { + if c.config == nil { + return defaultForcedEvictionThresholdPercentage + } + return *c.config.InterQueryBuiltinCache.ForcedEvictionThresholdPercentage +} + +func (c *cache) staleEntryEvictionTimePeriodSeconds() int64 { + if c.config == nil { + return defaultStaleEntryEvictionPeriodSeconds + } + return *c.config.InterQueryBuiltinCache.StaleEntryEvictionPeriodSeconds +} + +func (c *cache) cleanStaleValues() (dropped int) { + c.mtx.Lock() + defer c.mtx.Unlock() + for key := c.l.Front(); key != nil; { + nextKey := key.Next() + // if expiresAt is zero, the item doesn't have an expiry + if ea := c.items[(key.Value.(ast.Value)).String()].expiresAt; !ea.IsZero() && ea.Before(time.Now()) { + c.unsafeDelete(key.Value.(ast.Value)) + dropped++ + } + key = nextKey + } + return dropped +} diff --git a/topdown/cache/cache_test.go b/topdown/cache/cache_test.go index 85375e93f0c..f03cbd3747f 100644 --- a/topdown/cache/cache_test.go +++ b/topdown/cache/cache_test.go @@ -5,17 +5,24 @@ package cache import ( + "context" "reflect" "sync" "testing" + "time" "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/logging" ) func TestParseCachingConfig(t *testing.T) { maxSize := new(int64) *maxSize = defaultMaxSizeBytes - expected := &Config{InterQueryBuiltinCache: InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize}} + period := new(int64) + *period = defaultStaleEntryEvictionPeriodSeconds + threshold := new(int64) + *threshold = defaultForcedEvictionThresholdPercentage + expected := &Config{InterQueryBuiltinCache: InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize, StaleEntryEvictionPeriodSeconds: period, ForcedEvictionThresholdPercentage: threshold}} tests := map[string]struct { input []byte @@ -80,11 +87,11 @@ func TestInsert(t *testing.T) { t.Fatalf("Unexpected error %v", err) } - cache := NewInterQueryCache(config) + cache := NewInterQueryCache(context.Background(), logging.NewNoOpLogger(), config) // large cache value that exceeds limit cacheValueLarge := newInterQueryCacheValue(ast.StringTerm("bar").Value, 40) - dropped := cache.Insert(ast.StringTerm("foo").Value, cacheValueLarge) + dropped := cache.Insert(ast.StringTerm("foo").Value, cacheValueLarge, time.Time{}) if dropped != 1 { t.Fatal("Expected dropped to be one") @@ -96,7 +103,7 @@ func TestInsert(t *testing.T) { } cacheValue := newInterQueryCacheValue(ast.StringTerm("bar").Value, 20) - dropped = cache.Insert(ast.StringTerm("foo").Value, cacheValue) + dropped = cache.Insert(ast.StringTerm("foo").Value, cacheValue, time.Time{}) if dropped != 0 { t.Fatal("Expected dropped to be zero") @@ -104,7 +111,7 @@ func TestInsert(t *testing.T) { // exceed cache limit cacheValue2 := newInterQueryCacheValue(ast.StringTerm("bar2").Value, 20) - dropped = cache.Insert(ast.StringTerm("foo2").Value, cacheValue2) + dropped = cache.Insert(ast.StringTerm("foo2").Value, cacheValue2, time.Time{}) if dropped != 1 { t.Fatal("Expected dropped to be one") @@ -120,11 +127,11 @@ func TestInsert(t *testing.T) { t.Fatal("Unexpected key \"foo\" in cache") } cacheValue3 := newInterQueryCacheValue(ast.StringTerm("bar3").Value, 10) - cache.Insert(ast.StringTerm("foo3").Value, cacheValue3) + cache.Insert(ast.StringTerm("foo3").Value, cacheValue3, time.Time{}) cacheValue4 := newInterQueryCacheValue(ast.StringTerm("bar4").Value, 10) - cache.Insert(ast.StringTerm("foo4").Value, cacheValue4) + cache.Insert(ast.StringTerm("foo4").Value, cacheValue4, time.Time{}) cacheValue5 := newInterQueryCacheValue(ast.StringTerm("bar5").Value, 20) - dropped = cache.Insert(ast.StringTerm("foo5").Value, cacheValue5) + dropped = cache.Insert(ast.StringTerm("foo5").Value, cacheValue5, time.Time{}) if dropped != 2 { t.Fatal("Expected dropped to be two") } @@ -143,15 +150,15 @@ func TestInsert(t *testing.T) { verifyCacheList(t, cache) // replacing an existing key should not affect cache size - cache = NewInterQueryCache(config) + cache = NewInterQueryCache(context.Background(), logging.NewNoOpLogger(), config) cacheValue6 := newInterQueryCacheValue(ast.String("bar6"), 10) - cache.Insert(ast.String("foo6"), cacheValue6) - cache.Insert(ast.String("foo6"), cacheValue6) + cache.Insert(ast.String("foo6"), cacheValue6, time.Time{}) + cache.Insert(ast.String("foo6"), cacheValue6, time.Time{}) verifyCacheList(t, cache) cacheValue7 := newInterQueryCacheValue(ast.String("bar7"), 10) - dropped = cache.Insert(ast.StringTerm("foo7").Value, cacheValue7) + dropped = cache.Insert(ast.StringTerm("foo7").Value, cacheValue7, time.Time{}) verifyCacheList(t, cache) if dropped != 0 { @@ -167,10 +174,10 @@ func TestConcurrentInsert(t *testing.T) { t.Fatalf("Unexpected error %v", err) } - cache := NewInterQueryCache(config) + cache := NewInterQueryCache(context.Background(), logging.NewNoOpLogger(), config) cacheValue := newInterQueryCacheValue(ast.String("bar"), 10) - cache.Insert(ast.String("foo"), cacheValue) + cache.Insert(ast.String("foo"), cacheValue, time.Time{}) wg := sync.WaitGroup{} @@ -181,14 +188,14 @@ func TestConcurrentInsert(t *testing.T) { defer wg.Done() cacheValue2 := newInterQueryCacheValue(ast.String("bar2"), 5) - cache.Insert(ast.String("foo2"), cacheValue2) + cache.Insert(ast.String("foo2"), cacheValue2, time.Time{}) }() } wg.Wait() cacheValue3 := newInterQueryCacheValue(ast.String("bar3"), 5) - dropped := cache.Insert(ast.String("foo3"), cacheValue3) + dropped := cache.Insert(ast.String("foo3"), cacheValue3, time.Time{}) verifyCacheList(t, cache) if dropped != 0 { @@ -219,10 +226,10 @@ func TestClone(t *testing.T) { t.Fatalf("Unexpected error %v", err) } - cache := NewInterQueryCache(config) + cache := NewInterQueryCache(context.Background(), logging.NewNoOpLogger(), config) cacheValue := newInterQueryCacheValue(ast.StringTerm("bar").Value, 20) - dropped := cache.Insert(ast.StringTerm("foo").Value, cacheValue) + dropped := cache.Insert(ast.StringTerm("foo").Value, cacheValue, time.Time{}) if dropped != 0 { t.Fatal("Expected dropped to be zero") } @@ -258,10 +265,10 @@ func TestDelete(t *testing.T) { t.Fatalf("Unexpected error %v", err) } - cache := NewInterQueryCache(config) + cache := NewInterQueryCache(context.Background(), logging.NewNoOpLogger(), config) cacheValue := newInterQueryCacheValue(ast.StringTerm("bar").Value, 20) - dropped := cache.Insert(ast.StringTerm("foo").Value, cacheValue) + dropped := cache.Insert(ast.StringTerm("foo").Value, cacheValue, time.Time{}) if dropped != 0 { t.Fatal("Expected dropped to be zero") @@ -277,12 +284,135 @@ func TestDelete(t *testing.T) { verifyCacheList(t, cache) } +func TestInsertWithExpiryAndEviction(t *testing.T) { + // 50 byte max size + // 1s stale cleanup period + // 80% threshold to for FIFO eviction (eviction after 40 bytes) + in := `{"inter_query_builtin_cache": {"max_size_bytes": 50, "stale_entry_eviction_period_seconds": 1, "forced_eviction_threshold_percentage": 80},}` + + config, err := ParseCachingConfig([]byte(in)) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + + cache := NewInterQueryCache(context.Background(), logging.NewNoOpLogger(), config) + + cacheValue := newInterQueryCacheValue(ast.StringTerm("bar").Value, 20) + cache.Insert(ast.StringTerm("force_evicted_foo").Value, cacheValue, time.Now().Add(100*time.Second)) + if fetchedCacheValue, found := cache.Get(ast.StringTerm("force_evicted_foo").Value); !found { + t.Fatalf("Expected cache entry with value %v, found %v", cacheValue, fetchedCacheValue) + } + cache.Insert(ast.StringTerm("expired_foo").Value, cacheValue, time.Now().Add(1*time.Second)) + if fetchedCacheValue, found := cache.Get(ast.StringTerm("expired_foo").Value); !found { + t.Fatalf("Expected cache entry with value %v, found %v", cacheValue, fetchedCacheValue) + } + cache.Insert(ast.StringTerm("foo").Value, cacheValue, time.Now().Add(10*time.Second)) + if fetchedCacheValue, found := cache.Get(ast.StringTerm("foo").Value); !found { + t.Fatalf("Expected cache entry with value %v, found %v", cacheValue, fetchedCacheValue) + } + + // Ensure stale entries clean up routine runs at least once + time.Sleep(2 * time.Second) + + // Entry deleted even though not expired because force evicted when foo is inserted + if fetchedCacheValue, found := cache.Get(ast.StringTerm("force_evicted_foo").Value); found { + t.Fatalf("Didn't expect cache entry for force_evicted_foo, found entry with value %v", fetchedCacheValue) + } + // Stale clean up routine runs and deletes expired entry + if fetchedCacheValue, found := cache.Get(ast.StringTerm("expired_foo").Value); found { + t.Fatalf("Didn't expect cache entry for expired_foo, found entry with value %v", fetchedCacheValue) + } + // Stale clean up routine runs but doesn't delete the entry + if fetchedCacheValue, found := cache.Get(ast.StringTerm("foo").Value); !found { + t.Fatalf("Expected cache entry with value %v for foo, found %v", cacheValue, fetchedCacheValue) + } +} + +func TestInsertHighTTLWithStaleEntryCleanup(t *testing.T) { + // 40 byte max size + // 1s stale cleanup period + // 100% threshold to for FIFO eviction (eviction after 40 bytes) + in := `{"inter_query_builtin_cache": {"max_size_bytes": 40, "stale_entry_eviction_period_seconds": 1, "forced_eviction_threshold_percentage": 100},}` + + config, err := ParseCachingConfig([]byte(in)) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + + cache := NewInterQueryCache(context.Background(), logging.NewNoOpLogger(), config) + + cacheValue := newInterQueryCacheValue(ast.StringTerm("bar").Value, 20) + cache.Insert(ast.StringTerm("high_ttl_foo").Value, cacheValue, time.Now().Add(100*time.Second)) + if fetchedCacheValue, found := cache.Get(ast.StringTerm("high_ttl_foo").Value); !found { + t.Fatalf("Expected cache entry with value %v, found %v", cacheValue, fetchedCacheValue) + } + cache.Insert(ast.StringTerm("expired_foo").Value, cacheValue, time.Now().Add(1*time.Second)) + if fetchedCacheValue, found := cache.Get(ast.StringTerm("expired_foo").Value); !found { + t.Fatalf("Expected cache entry with value %v, found no entry", fetchedCacheValue) + } + + // Ensure stale entries clean up routine runs at least once + time.Sleep(2 * time.Second) + + cache.Insert(ast.StringTerm("foo").Value, cacheValue, time.Now().Add(10*time.Second)) + if fetchedCacheValue, found := cache.Get(ast.StringTerm("foo").Value); !found { + t.Fatalf("Expected cache entry with value %v, found %v", cacheValue, fetchedCacheValue) + } + + // Since expired_foo is deleted by stale cleanup routine, high_ttl_foo is not evicted when foo is inserted + if fetchedCacheValue, found := cache.Get(ast.StringTerm("high_ttl_foo").Value); !found { + t.Fatalf("Expected cache entry with value %v for high_ttl_foo, found %v", cacheValue, fetchedCacheValue) + } + // Stale clean up routine runs and deletes expired entry + if fetchedCacheValue, found := cache.Get(ast.StringTerm("expired_foo").Value); found { + t.Fatalf("Didn't expect cache entry for expired_foo, found entry with value %v", fetchedCacheValue) + } +} + +func TestInsertHighTTLWithoutStaleEntryCleanup(t *testing.T) { + // 40 byte max size + // 0s stale cleanup period -> no cleanup + // 100% threshold to for FIFO eviction (eviction after 40 bytes) + in := `{"inter_query_builtin_cache": {"max_size_bytes": 40, "stale_entry_eviction_period_seconds": 0, "forced_eviction_threshold_percentage": 100},}` + + config, err := ParseCachingConfig([]byte(in)) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + + cache := NewInterQueryCache(context.Background(), logging.NewNoOpLogger(), config) + + cacheValue := newInterQueryCacheValue(ast.StringTerm("bar").Value, 20) + cache.Insert(ast.StringTerm("high_ttl_foo").Value, cacheValue, time.Now().Add(100*time.Second)) + if fetchedCacheValue, found := cache.Get(ast.StringTerm("high_ttl_foo").Value); !found { + t.Fatalf("Expected cache entry with value %v for high_ttl_foo, found no entry", fetchedCacheValue) + } + cache.Insert(ast.StringTerm("expired_foo").Value, cacheValue, time.Now().Add(1*time.Second)) + if fetchedCacheValue, found := cache.Get(ast.StringTerm("expired_foo").Value); !found { + t.Fatalf("Expected cache entry with value %v for expired_foo, found no entry", fetchedCacheValue) + } + + cache.Insert(ast.StringTerm("foo").Value, cacheValue, time.Now().Add(10*time.Second)) + if fetchedCacheValue, found := cache.Get(ast.StringTerm("foo").Value); !found { + t.Fatalf("Expected cache entry with value %v for foo, found no entry", fetchedCacheValue) + } + + // Since stale cleanup routine is disabled, high_ttl_foo is evicted when foo is inserted + if fetchedCacheValue, found := cache.Get(ast.StringTerm("high_ttl_foo").Value); found { + t.Fatalf("Didn't expect cache entry for high_ttl_foo, found entry with value %v", fetchedCacheValue) + } + // Stale clean up disabled so expired entry exists + if fetchedCacheValue, found := cache.Get(ast.StringTerm("expired_foo").Value); !found { + t.Fatalf("Expected cache entry with value %v for expired_foo, found no entry", fetchedCacheValue) + } +} + func TestUpdateConfig(t *testing.T) { config, err := ParseCachingConfig(nil) if err != nil { t.Fatalf("Unexpected error %v", err) } - c := NewInterQueryCache(config) + c := NewInterQueryCache(context.Background(), logging.NewNoOpLogger(), config) actualC, ok := c.(*cache) if !ok { t.Fatal("Unexpected error converting InterQueryCache to cache struct") @@ -305,7 +435,7 @@ func TestUpdateConfig(t *testing.T) { } func TestDefaultMaxSizeBytes(t *testing.T) { - c := NewInterQueryCache(nil) + c := NewInterQueryCache(context.Background(), logging.NewNoOpLogger(), nil) actualC, ok := c.(*cache) if !ok { t.Fatal("Unexpected error converting InterQueryCache to cache struct") diff --git a/topdown/http.go b/topdown/http.go index bf5dbb55d30..3f97b47e342 100644 --- a/topdown/http.go +++ b/topdown/http.go @@ -888,7 +888,7 @@ func (c *interQueryCache) checkHTTPSendInterQueryCache() (ast.Value, error) { pcv = cachedRespData } - c.bctx.InterQueryBuiltinCache.Insert(c.key, pcv) + c.bctx.InterQueryBuiltinCache.Insert(c.key, pcv, cachedRespData.ExpiresAt) return cachedRespData.formatToAST(c.forceJSONDecode, c.forceYAMLDecode) } @@ -924,18 +924,19 @@ func insertIntoHTTPSendInterQueryCache(bctx BuiltinContext, key ast.Value, resp } var pcv cache.InterQueryCacheValue - + var pcvData *interQueryCacheData if cachingMode == defaultCachingMode { - pcv, err = newInterQueryCacheValue(bctx, resp, respBody, cacheParams) + pcv, pcvData, err = newInterQueryCacheValue(bctx, resp, respBody, cacheParams) } else { - pcv, err = newInterQueryCacheData(bctx, resp, respBody, cacheParams) + pcvData, err = newInterQueryCacheData(bctx, resp, respBody, cacheParams) + pcv = pcvData } if err != nil { return err } - requestCache.Insert(key, pcv) + requestCache.Insert(key, pcv, pcvData.ExpiresAt) return nil } @@ -1030,17 +1031,17 @@ type interQueryCacheValue struct { Data []byte } -func newInterQueryCacheValue(bctx BuiltinContext, resp *http.Response, respBody []byte, cacheParams *forceCacheParams) (*interQueryCacheValue, error) { +func newInterQueryCacheValue(bctx BuiltinContext, resp *http.Response, respBody []byte, cacheParams *forceCacheParams) (*interQueryCacheValue, *interQueryCacheData, error) { data, err := newInterQueryCacheData(bctx, resp, respBody, cacheParams) if err != nil { - return nil, err + return nil, nil, err } b, err := json.Marshal(data) if err != nil { - return nil, err + return nil, nil, err } - return &interQueryCacheValue{Data: b}, nil + return &interQueryCacheValue{Data: b}, data, nil } func (cb interQueryCacheValue) Clone() (cache.InterQueryCacheValue, error) { diff --git a/topdown/http_test.go b/topdown/http_test.go index 8f9c71b9e37..f78c72664a7 100644 --- a/topdown/http_test.go +++ b/topdown/http_test.go @@ -27,6 +27,7 @@ import ( "time" "github.com/open-policy-agent/opa/internal/version" + "github.com/open-policy-agent/opa/logging" "github.com/open-policy-agent/opa/metrics" "github.com/open-policy-agent/opa/storage" "github.com/open-policy-agent/opa/topdown/builtins" @@ -1130,7 +1131,7 @@ func TestHTTPSendIntraQueryCaching(t *testing.T) { defer ts.Close() config, _ := iCache.ParseCachingConfig(nil) - interQueryCache := iCache.NewInterQueryCache(config) + interQueryCache := iCache.NewInterQueryCache(context.Background(), logging.NewNoOpLogger(), config) opts := []func(*Query) *Query{ setTime(t0), @@ -1539,7 +1540,7 @@ func TestHTTPSendInterQueryForceCachingRefresh(t *testing.T) { request = strings.ReplaceAll(request, "%CACHE%", strconv.Itoa(cacheTime)) full := fmt.Sprintf("http.send(%s, x)", request) config, _ := iCache.ParseCachingConfig(nil) - interQueryCache := iCache.NewInterQueryCache(config) + interQueryCache := iCache.NewInterQueryCache(context.Background(), logging.NewNoOpLogger(), config) q := NewQuery(ast.MustParseBody(full)). WithInterQueryBuiltinCache(interQueryCache). WithTime(t0) @@ -1598,7 +1599,7 @@ func TestHTTPSendInterQueryForceCachingRefresh(t *testing.T) { t.Fatal(err) } - interQueryCache.Insert(cacheKey, v) + interQueryCache.Insert(cacheKey, v, m.ExpiresAt) } actualCount := len(requests) @@ -1770,7 +1771,7 @@ func TestHTTPSendInterQueryCachingNewResp(t *testing.T) { func newQuery(qStr string, t0 time.Time) *Query { config, _ := iCache.ParseCachingConfig(nil) - interQueryCache := iCache.NewInterQueryCache(config) + interQueryCache := iCache.NewInterQueryCache(context.Background(), logging.NewNoOpLogger(), config) ctx := context.Background() store := inmem.New() txn := storage.NewTransactionOrDie(ctx, store) @@ -2159,7 +2160,7 @@ func TestNewInterQueryCacheValue(t *testing.T) { Body: io.NopCloser(bytes.NewBuffer(b)), } - result, err := newInterQueryCacheValue(BuiltinContext{}, response, b, &forceCacheParams{}) + result, _, err := newInterQueryCacheValue(BuiltinContext{}, response, b, &forceCacheParams{}) if err != nil { t.Fatalf("Unexpected error %v", err) } @@ -2935,7 +2936,7 @@ func TestHTTPSendCacheDefaultStatusCodesInterQueryCache(t *testing.T) { // add an inter-query cache config, _ := iCache.ParseCachingConfig(nil) - interQueryCache := iCache.NewInterQueryCache(config) + interQueryCache := iCache.NewInterQueryCache(context.Background(), logging.NewNoOpLogger(), config) m := metrics.New() @@ -2988,7 +2989,7 @@ func (c *onlyOnceInterQueryCache) Get(_ ast.Value) (value iCache.InterQueryCache return nil, false } -func (c *onlyOnceInterQueryCache) Insert(_ ast.Value, _ iCache.InterQueryCacheValue) int { +func (c *onlyOnceInterQueryCache) Insert(_ ast.Value, _ iCache.InterQueryCacheValue, _ time.Time) int { return 0 } @@ -3275,7 +3276,7 @@ func TestHTTPSendMetrics(t *testing.T) { t.Run("cache hits", func(t *testing.T) { // add an inter-query cache config, _ := iCache.ParseCachingConfig(nil) - interQueryCache := iCache.NewInterQueryCache(config) + interQueryCache := iCache.NewInterQueryCache(context.Background(), logging.NewNoOpLogger(), config) // Execute query twice and verify http.send inter-query cache hit metric is incremented. m := metrics.New() diff --git a/topdown/topdown_test.go b/topdown/topdown_test.go index b2d3838e810..5ddb98f0f32 100644 --- a/topdown/topdown_test.go +++ b/topdown/topdown_test.go @@ -26,6 +26,7 @@ import ( iCache "github.com/open-policy-agent/opa/topdown/cache" "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/logging" "github.com/open-policy-agent/opa/storage" inmem "github.com/open-policy-agent/opa/storage/inmem/test" "github.com/open-policy-agent/opa/types" @@ -1152,7 +1153,7 @@ func assertTopDownWithPathAndContext(ctx context.Context, t *testing.T, compiler // add an inter-query cache config, _ := iCache.ParseCachingConfig(nil) - interQueryCache := iCache.NewInterQueryCache(config) + interQueryCache := iCache.NewInterQueryCache(ctx, logging.NewNoOpLogger(), config) var strictBuiltinErrors bool @@ -1245,7 +1246,7 @@ func runTopDownPartialTestCase(ctx context.Context, t *testing.T, compiler *ast. // add an inter-query cache config, _ := iCache.ParseCachingConfig(nil) - interQueryCache := iCache.NewInterQueryCache(config) + interQueryCache := iCache.NewInterQueryCache(ctx, logging.NewNoOpLogger(), config) partialQuery := NewQuery(body). WithCompiler(compiler).