Skip to content

Commit

Permalink
Prune expired nodes from cache
Browse files Browse the repository at this point in the history
Signed-off-by: Faisal Memon <fymemon@yahoo.com>
  • Loading branch information
faisal-memon committed Apr 4, 2024
1 parent 3264a23 commit 026af65
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 25 deletions.
9 changes: 6 additions & 3 deletions pkg/server/authorizedentries/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"sync"
"time"

"github.com/andres-erbsen/clock"
"github.com/google/btree"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/spire-api-sdk/proto/spire/api/types"
Expand Down Expand Up @@ -33,7 +34,8 @@ func (s Selector) String() string {
}

type Cache struct {
mu sync.RWMutex
mu sync.RWMutex
clk clock.Clock

agentsByID *btree.BTreeG[agentRecord]
agentsByExpiresAt *btree.BTreeG[agentRecord]
Expand All @@ -45,8 +47,9 @@ type Cache struct {
entriesByParentID *btree.BTreeG[entryRecord]
}

func NewCache() *Cache {
func NewCache(clk clock.Clock) *Cache {
return &Cache{
clk: clk,
agentsByID: btree.NewG(agentRecordDegree, agentRecordByID),
agentsByExpiresAt: btree.NewG(agentRecordDegree, agentRecordByExpiresAt),
aliasesByEntryID: btree.NewG(aliasRecordDegree, aliasRecordByEntryID),
Expand Down Expand Up @@ -130,7 +133,7 @@ func (c *Cache) RemoveAgent(agentID string) {
}

func (c *Cache) PruneExpiredAgents() int {
now := time.Now().Unix()
now := c.clk.Now().Unix()
pruned := 0

c.mu.Lock()
Expand Down
25 changes: 14 additions & 11 deletions pkg/server/authorizedentries/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/spiffe/spire-api-sdk/proto/spire/api/types"
"github.com/spiffe/spire/pkg/common/idutil"
"github.com/spiffe/spire/pkg/server/api"
"github.com/spiffe/spire/test/clock"
"github.com/spiffe/spire/test/spiretest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -58,7 +59,7 @@ func TestGetAuthorizedEntries(t *testing.T) {
workload := makeWorkload(agent1)
cache := testCache().
withAgent(agent1, sel1).
withEntries(workload).hydrate()
withEntries(workload).hydrate(t)
cache.RemoveEntry(workload.Id)
assertAuthorizedEntries(t, cache, agent1)
})
Expand Down Expand Up @@ -116,7 +117,7 @@ func TestGetAuthorizedEntries(t *testing.T) {
cache := testCache().
withEntries(workloadEntry, aliasEntry).
withAgent(agent1, sel1, sel2).
hydrate()
hydrate(t)

cache.RemoveEntry(aliasEntry.Id)
assertAuthorizedEntries(t, cache, agent1)
Expand All @@ -131,7 +132,7 @@ func TestGetAuthorizedEntries(t *testing.T) {
cache := testCache().
withEntries(workloadEntry, aliasEntry).
withAgent(agent1, sel1, sel2).
hydrate()
hydrate(t)

cache.RemoveAgent(agent1.String())
assertAuthorizedEntries(t, cache, agent1)
Expand All @@ -149,7 +150,7 @@ func TestGetAuthorizedEntries(t *testing.T) {
withExpiredAgent(agent2, time.Hour, sel1, sel2).
withExpiredAgent(agent3, time.Hour*2, sel1, sel2).
withAgent(agent4, sel1, sel2).
hydrate()
hydrate(t)
assertAuthorizedEntries(t, cache, agent1, workloadEntry)
assertAuthorizedEntries(t, cache, agent2, workloadEntry)
assertAuthorizedEntries(t, cache, agent3, workloadEntry)
Expand All @@ -169,8 +170,9 @@ func TestCacheInternalStats(t *testing.T) {
// across various operations. The motivation is to ensure that as the cache
// is updated that we are appropriately inserting and removing records from
// the indexees.
clk := clock.NewMock(t)
t.Run("pristine", func(t *testing.T) {
cache := NewCache()
cache := NewCache(clk)
require.Zero(t, cache.stats())
})

Expand All @@ -182,7 +184,7 @@ func TestCacheInternalStats(t *testing.T) {
entry2b := makeAlias(alias1, sel1, sel2)
entry2b.Id = entry2a.Id

cache := NewCache()
cache := NewCache(clk)
cache.UpdateEntry(entry1)
require.Equal(t, cacheStats{
EntriesByEntryID: 1,
Expand Down Expand Up @@ -218,7 +220,7 @@ func TestCacheInternalStats(t *testing.T) {
})

t.Run("agents", func(t *testing.T) {
cache := NewCache()
cache := NewCache(clk)
cache.UpdateAgent(agent1.String(), now.Add(time.Hour), []*types.Selector{sel1})
require.Equal(t, cacheStats{
AgentsByID: 1,
Expand Down Expand Up @@ -297,8 +299,9 @@ func (a *cacheTest) withExpiredAgent(node spiffeid.ID, expiredBy time.Duration,
return a
}

func (a *cacheTest) hydrate() *Cache {
cache := NewCache()
func (a *cacheTest) hydrate(tb testing.TB) *Cache {
clk := clock.NewMock(tb)
cache := NewCache(clk)
for _, entry := range a.entries {
cache.UpdateEntry(entry)
}
Expand All @@ -310,7 +313,7 @@ func (a *cacheTest) hydrate() *Cache {

func (a *cacheTest) assertAuthorizedEntries(t *testing.T, agent spiffeid.ID, expectEntries ...*types.Entry) {
t.Helper()
assertAuthorizedEntries(t, a.hydrate(), agent, expectEntries...)
assertAuthorizedEntries(t, a.hydrate(t), agent, expectEntries...)
}

func makeAlias(alias spiffeid.ID, selectors ...*types.Selector) *types.Entry {
Expand Down Expand Up @@ -407,7 +410,7 @@ func BenchmarkGetAuthorizedEntriesInMemory(b *testing.B) {
})
}

cache := test.hydrate()
cache := test.hydrate(b)
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.GetAuthorizedEntries(test.pickAgent())
Expand Down
18 changes: 8 additions & 10 deletions pkg/server/endpoints/authorized_entryfetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,14 @@ func (a *AuthorizedEntryFetcherWithEventsBasedCache) RunUpdateCacheTask(ctx cont
select {
case <-ctx.Done():
a.log.Debug("Stopping in-memory entry cache hydrator")
return nil
return ctx.Err()
case <-a.clk.After(a.cacheReloadInterval):
err := a.updateCache(ctx)
if err != nil {
if err := a.updateCache(ctx); err != nil {
a.log.WithError(err).Error("Failed to update entry cache")
}
if pruned := a.cache.PruneExpiredAgents(); pruned > 0 {
a.log.Debugf("Pruned %d expired agents from entry cache", pruned)
}
}
}
}
Expand All @@ -78,7 +80,7 @@ func (a *AuthorizedEntryFetcherWithEventsBasedCache) PruneEventsTask(ctx context
select {
case <-ctx.Done():
a.log.Debug("Stopping event pruner")
return nil
return ctx.Err()
case <-a.clk.After(a.pruneEventsOlderThan / 2):
a.log.Debug("Pruning events")
if err := a.pruneEvents(ctx, a.pruneEventsOlderThan); err != nil {
Expand Down Expand Up @@ -162,6 +164,7 @@ func (a *AuthorizedEntryFetcherWithEventsBasedCache) updateAttestedNodesCache(ct
}
a.lastAttestedNodeEventID = event.EventID

// Node was deleted
if node == nil {
a.cache.RemoveAgent(event.SpiffeID)
continue
Expand All @@ -174,19 +177,14 @@ func (a *AuthorizedEntryFetcherWithEventsBasedCache) updateAttestedNodesCache(ct
node.Selectors = selectors

agentExpiresAt := time.Unix(node.CertNotAfter, 0)
if agentExpiresAt.Before(a.clk.Now()) {
a.cache.RemoveAgent(event.SpiffeID)
continue
}

a.cache.UpdateAgent(node.SpiffeId, agentExpiresAt, api.ProtoFromSelectors(node.Selectors))
}

return nil
}

func buildCache(ctx context.Context, ds datastore.DataStore, clk clock.Clock) (*authorizedentries.Cache, uint, uint, error) {
cache := authorizedentries.NewCache()
cache := authorizedentries.NewCache(clk)

lastRegistrationEntryEventID, err := buildRegistrationEntriesCache(ctx, ds, cache, buildCachePageSize)
if err != nil {
Expand Down
71 changes: 70 additions & 1 deletion pkg/server/endpoints/authorized_entryfetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package endpoints
import (
"context"
"errors"
"regexp"
"sort"
"strconv"
"testing"
"time"

"github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/spire/pkg/common/idutil"
Expand Down Expand Up @@ -114,6 +116,7 @@ func TestNewAuthorizedEntryFetcherWithEventsBasedCacheErrorBuildingCache(t *test

func TestBuildRegistrationEntriesCache(t *testing.T) {
ctx := context.Background()
clk := clock.NewMock(t)
ds := fakedatastore.New(t)

agentID, err := spiffeid.FromString("spiffe://example.org/myagent")
Expand Down Expand Up @@ -160,7 +163,7 @@ func TestBuildRegistrationEntriesCache(t *testing.T) {
} {
tt := tt
t.Run(tt.name, func(t *testing.T) {
cache := authorizedentries.NewCache()
cache := authorizedentries.NewCache(clk)
lastRegistrationEntryEventID, err := buildRegistrationEntriesCache(ctx, ds, cache, tt.pageSize)
if tt.err != "" {
require.Equal(t, uint(0), lastRegistrationEntryEventID)
Expand All @@ -187,3 +190,69 @@ func TestBuildRegistrationEntriesCache(t *testing.T) {
})
}
}

func TestRunUpdateCacheTaskPrunesExpiredAgents(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
log, hook := test.NewNullLogger()
log.SetLevel(logrus.DebugLevel)
clk := clock.NewMock(t)
ds := fakedatastore.New(t)

ef, err := NewAuthorizedEntryFetcherWithEventsBasedCache(ctx, log, clk, ds, defaultCacheReloadInterval, defaultPruneEventsOlderThan)
require.NoError(t, err)
require.NotNil(t, ef)

agentID, err := spiffeid.FromString("spiffe://example.org/myagent")
require.NoError(t, err)

// Start Update Task
updateCacheTaskErr := make(chan error)
go func() {
updateCacheTaskErr <- ef.RunUpdateCacheTask(ctx)
}()
clk.WaitForAfter(time.Second, "waiting for initial task pause")
entries, err := ef.FetchAuthorizedEntries(ctx, agentID)
assert.NoError(t, err)
require.Zero(t, entries)

// Create Attested Node and Registration Entry
_, err = ds.CreateAttestedNode(ctx, &common.AttestedNode{
SpiffeId: agentID.String(),
CertNotAfter: clk.Now().Add(6 * time.Second).Unix(),
})
assert.NoError(t, err)

_, err = ds.CreateRegistrationEntry(ctx, &common.RegistrationEntry{
SpiffeId: "spiffe://example.org/workload",
ParentId: agentID.String(),
Selectors: []*common.Selector{
{
Type: "workload",
Value: "one",
},
},
})
assert.NoError(t, err)

// Bump clock and rerun UpdateCacheTask
clk.Add(defaultCacheReloadInterval)
clk.WaitForAfter(time.Second, "waiting for task to pause after creating entries")
entries, err = ef.FetchAuthorizedEntries(ctx, agentID)
assert.NoError(t, err)
require.Equal(t, 1, len(entries))

// Make sure nothing was pruned yet
for _, entry := range hook.AllEntries() {
require.NotRegexp(t, regexp.MustCompile(`Pruned \d* expired agents from entry cache`), entry.Message)
}

// Bump clock so entry expires and is pruned
clk.Add(defaultCacheReloadInterval)
clk.WaitForAfter(time.Second, "waiting for task to pause after expiring agent")
assert.Equal(t, "Pruned 1 expired agents from entry cache", hook.LastEntry().Message)

// Stop the task
cancel()
err = <-updateCacheTaskErr
require.ErrorIs(t, err, context.Canceled)
}

0 comments on commit 026af65

Please sign in to comment.