diff --git a/pkg/server/authorizedentries/cache.go b/pkg/server/authorizedentries/cache.go index 293789577e..9814795816 100644 --- a/pkg/server/authorizedentries/cache.go +++ b/pkg/server/authorizedentries/cache.go @@ -5,6 +5,7 @@ import ( "sync" "time" + "github.com/andres-erbsen/clock" "github.com/google/btree" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire-api-sdk/proto/spire/api/types" @@ -33,7 +34,8 @@ func (s Selector) String() string { } type Cache struct { - mu sync.RWMutex + mu sync.RWMutex + clk clock.Clock agentsByID *btree.BTreeG[agentRecord] agentsByExpiresAt *btree.BTreeG[agentRecord] @@ -45,8 +47,9 @@ type Cache struct { entriesByParentID *btree.BTreeG[entryRecord] } -func NewCache() *Cache { +func NewCache(clk clock.Clock) *Cache { return &Cache{ + clk: clk, agentsByID: btree.NewG(agentRecordDegree, agentRecordByID), agentsByExpiresAt: btree.NewG(agentRecordDegree, agentRecordByExpiresAt), aliasesByEntryID: btree.NewG(aliasRecordDegree, aliasRecordByEntryID), @@ -130,7 +133,7 @@ func (c *Cache) RemoveAgent(agentID string) { } func (c *Cache) PruneExpiredAgents() int { - now := time.Now().Unix() + now := c.clk.Now().Unix() pruned := 0 c.mu.Lock() diff --git a/pkg/server/authorizedentries/cache_test.go b/pkg/server/authorizedentries/cache_test.go index a80abf2051..be4c2ef929 100644 --- a/pkg/server/authorizedentries/cache_test.go +++ b/pkg/server/authorizedentries/cache_test.go @@ -10,6 +10,7 @@ import ( "github.com/spiffe/spire-api-sdk/proto/spire/api/types" "github.com/spiffe/spire/pkg/common/idutil" "github.com/spiffe/spire/pkg/server/api" + "github.com/spiffe/spire/test/clock" "github.com/spiffe/spire/test/spiretest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -58,7 +59,7 @@ func TestGetAuthorizedEntries(t *testing.T) { workload := makeWorkload(agent1) cache := testCache(). withAgent(agent1, sel1). - withEntries(workload).hydrate() + withEntries(workload).hydrate(t) cache.RemoveEntry(workload.Id) assertAuthorizedEntries(t, cache, agent1) }) @@ -116,7 +117,7 @@ func TestGetAuthorizedEntries(t *testing.T) { cache := testCache(). withEntries(workloadEntry, aliasEntry). withAgent(agent1, sel1, sel2). - hydrate() + hydrate(t) cache.RemoveEntry(aliasEntry.Id) assertAuthorizedEntries(t, cache, agent1) @@ -131,7 +132,7 @@ func TestGetAuthorizedEntries(t *testing.T) { cache := testCache(). withEntries(workloadEntry, aliasEntry). withAgent(agent1, sel1, sel2). - hydrate() + hydrate(t) cache.RemoveAgent(agent1.String()) assertAuthorizedEntries(t, cache, agent1) @@ -149,7 +150,7 @@ func TestGetAuthorizedEntries(t *testing.T) { withExpiredAgent(agent2, time.Hour, sel1, sel2). withExpiredAgent(agent3, time.Hour*2, sel1, sel2). withAgent(agent4, sel1, sel2). - hydrate() + hydrate(t) assertAuthorizedEntries(t, cache, agent1, workloadEntry) assertAuthorizedEntries(t, cache, agent2, workloadEntry) assertAuthorizedEntries(t, cache, agent3, workloadEntry) @@ -169,8 +170,9 @@ func TestCacheInternalStats(t *testing.T) { // across various operations. The motivation is to ensure that as the cache // is updated that we are appropriately inserting and removing records from // the indexees. + clk := clock.NewMock(t) t.Run("pristine", func(t *testing.T) { - cache := NewCache() + cache := NewCache(clk) require.Zero(t, cache.stats()) }) @@ -182,7 +184,7 @@ func TestCacheInternalStats(t *testing.T) { entry2b := makeAlias(alias1, sel1, sel2) entry2b.Id = entry2a.Id - cache := NewCache() + cache := NewCache(clk) cache.UpdateEntry(entry1) require.Equal(t, cacheStats{ EntriesByEntryID: 1, @@ -218,7 +220,7 @@ func TestCacheInternalStats(t *testing.T) { }) t.Run("agents", func(t *testing.T) { - cache := NewCache() + cache := NewCache(clk) cache.UpdateAgent(agent1.String(), now.Add(time.Hour), []*types.Selector{sel1}) require.Equal(t, cacheStats{ AgentsByID: 1, @@ -297,8 +299,9 @@ func (a *cacheTest) withExpiredAgent(node spiffeid.ID, expiredBy time.Duration, return a } -func (a *cacheTest) hydrate() *Cache { - cache := NewCache() +func (a *cacheTest) hydrate(tb testing.TB) *Cache { + clk := clock.NewMock(tb) + cache := NewCache(clk) for _, entry := range a.entries { cache.UpdateEntry(entry) } @@ -310,7 +313,7 @@ func (a *cacheTest) hydrate() *Cache { func (a *cacheTest) assertAuthorizedEntries(t *testing.T, agent spiffeid.ID, expectEntries ...*types.Entry) { t.Helper() - assertAuthorizedEntries(t, a.hydrate(), agent, expectEntries...) + assertAuthorizedEntries(t, a.hydrate(t), agent, expectEntries...) } func makeAlias(alias spiffeid.ID, selectors ...*types.Selector) *types.Entry { @@ -407,7 +410,7 @@ func BenchmarkGetAuthorizedEntriesInMemory(b *testing.B) { }) } - cache := test.hydrate() + cache := test.hydrate(b) b.ResetTimer() for i := 0; i < b.N; i++ { cache.GetAuthorizedEntries(test.pickAgent()) diff --git a/pkg/server/endpoints/authorized_entryfetcher.go b/pkg/server/endpoints/authorized_entryfetcher.go index 09d8f7d1b2..76d4691628 100644 --- a/pkg/server/endpoints/authorized_entryfetcher.go +++ b/pkg/server/endpoints/authorized_entryfetcher.go @@ -62,12 +62,14 @@ func (a *AuthorizedEntryFetcherWithEventsBasedCache) RunUpdateCacheTask(ctx cont select { case <-ctx.Done(): a.log.Debug("Stopping in-memory entry cache hydrator") - return nil + return ctx.Err() case <-a.clk.After(a.cacheReloadInterval): - err := a.updateCache(ctx) - if err != nil { + if err := a.updateCache(ctx); err != nil { a.log.WithError(err).Error("Failed to update entry cache") } + if pruned := a.cache.PruneExpiredAgents(); pruned > 0 { + a.log.Debugf("Pruned %d expired agents from entry cache", pruned) + } } } } @@ -78,7 +80,7 @@ func (a *AuthorizedEntryFetcherWithEventsBasedCache) PruneEventsTask(ctx context select { case <-ctx.Done(): a.log.Debug("Stopping event pruner") - return nil + return ctx.Err() case <-a.clk.After(a.pruneEventsOlderThan / 2): a.log.Debug("Pruning events") if err := a.pruneEvents(ctx, a.pruneEventsOlderThan); err != nil { @@ -162,6 +164,7 @@ func (a *AuthorizedEntryFetcherWithEventsBasedCache) updateAttestedNodesCache(ct } a.lastAttestedNodeEventID = event.EventID + // Node was deleted if node == nil { a.cache.RemoveAgent(event.SpiffeID) continue @@ -174,11 +177,6 @@ func (a *AuthorizedEntryFetcherWithEventsBasedCache) updateAttestedNodesCache(ct node.Selectors = selectors agentExpiresAt := time.Unix(node.CertNotAfter, 0) - if agentExpiresAt.Before(a.clk.Now()) { - a.cache.RemoveAgent(event.SpiffeID) - continue - } - a.cache.UpdateAgent(node.SpiffeId, agentExpiresAt, api.ProtoFromSelectors(node.Selectors)) } @@ -186,7 +184,7 @@ func (a *AuthorizedEntryFetcherWithEventsBasedCache) updateAttestedNodesCache(ct } func buildCache(ctx context.Context, ds datastore.DataStore, clk clock.Clock) (*authorizedentries.Cache, uint, uint, error) { - cache := authorizedentries.NewCache() + cache := authorizedentries.NewCache(clk) lastRegistrationEntryEventID, err := buildRegistrationEntriesCache(ctx, ds, cache, buildCachePageSize) if err != nil { diff --git a/pkg/server/endpoints/authorized_entryfetcher_test.go b/pkg/server/endpoints/authorized_entryfetcher_test.go index 0fa509a2b3..725a5b7861 100644 --- a/pkg/server/endpoints/authorized_entryfetcher_test.go +++ b/pkg/server/endpoints/authorized_entryfetcher_test.go @@ -3,11 +3,13 @@ package endpoints import ( "context" "errors" + "regexp" "sort" "strconv" "testing" "time" + "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/common/idutil" @@ -114,6 +116,7 @@ func TestNewAuthorizedEntryFetcherWithEventsBasedCacheErrorBuildingCache(t *test func TestBuildRegistrationEntriesCache(t *testing.T) { ctx := context.Background() + clk := clock.NewMock(t) ds := fakedatastore.New(t) agentID, err := spiffeid.FromString("spiffe://example.org/myagent") @@ -160,7 +163,7 @@ func TestBuildRegistrationEntriesCache(t *testing.T) { } { tt := tt t.Run(tt.name, func(t *testing.T) { - cache := authorizedentries.NewCache() + cache := authorizedentries.NewCache(clk) lastRegistrationEntryEventID, err := buildRegistrationEntriesCache(ctx, ds, cache, tt.pageSize) if tt.err != "" { require.Equal(t, uint(0), lastRegistrationEntryEventID) @@ -187,3 +190,69 @@ func TestBuildRegistrationEntriesCache(t *testing.T) { }) } } + +func TestRunUpdateCacheTaskPrunesExpiredAgents(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + log, hook := test.NewNullLogger() + log.SetLevel(logrus.DebugLevel) + clk := clock.NewMock(t) + ds := fakedatastore.New(t) + + ef, err := NewAuthorizedEntryFetcherWithEventsBasedCache(ctx, log, clk, ds, defaultCacheReloadInterval, defaultPruneEventsOlderThan) + require.NoError(t, err) + require.NotNil(t, ef) + + agentID, err := spiffeid.FromString("spiffe://example.org/myagent") + require.NoError(t, err) + + // Start Update Task + updateCacheTaskErr := make(chan error) + go func() { + updateCacheTaskErr <- ef.RunUpdateCacheTask(ctx) + }() + clk.WaitForAfter(time.Second, "waiting for initial task pause") + entries, err := ef.FetchAuthorizedEntries(ctx, agentID) + assert.NoError(t, err) + require.Zero(t, entries) + + // Create Attested Node and Registration Entry + _, err = ds.CreateAttestedNode(ctx, &common.AttestedNode{ + SpiffeId: agentID.String(), + CertNotAfter: clk.Now().Add(6 * time.Second).Unix(), + }) + assert.NoError(t, err) + + _, err = ds.CreateRegistrationEntry(ctx, &common.RegistrationEntry{ + SpiffeId: "spiffe://example.org/workload", + ParentId: agentID.String(), + Selectors: []*common.Selector{ + { + Type: "workload", + Value: "one", + }, + }, + }) + assert.NoError(t, err) + + // Bump clock and rerun UpdateCacheTask + clk.Add(defaultCacheReloadInterval) + clk.WaitForAfter(time.Second, "waiting for task to pause after creating entries") + entries, err = ef.FetchAuthorizedEntries(ctx, agentID) + assert.NoError(t, err) + require.Equal(t, 1, len(entries)) + + // Make sure nothing was pruned yet + for _, entry := range hook.AllEntries() { + require.NotRegexp(t, regexp.MustCompile(`Pruned \d* expired agents from entry cache`), entry.Message) + } + + // Bump clock so entry expires and is pruned + clk.Add(defaultCacheReloadInterval) + clk.WaitForAfter(time.Second, "waiting for task to pause after expiring agent") + assert.Equal(t, "Pruned 1 expired agents from entry cache", hook.LastEntry().Message) + + // Stop the task + cancel() + err = <-updateCacheTaskErr + require.ErrorIs(t, err, context.Canceled) +}