diff --git a/pkg/agent/api/delegatedidentity/v1/service_test.go b/pkg/agent/api/delegatedidentity/v1/service_test.go index 0791c256cda..8baf8f20b15 100644 --- a/pkg/agent/api/delegatedidentity/v1/service_test.go +++ b/pkg/agent/api/delegatedidentity/v1/service_test.go @@ -9,7 +9,6 @@ import ( "testing" "time" - "github.com/andres-erbsen/clock" "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/spiffe/go-spiffe/v2/bundle/spiffebundle" @@ -812,5 +811,5 @@ func (m *FakeManager) SubscribeToBundleChanges() *cache.BundleStream { func newTestCache() *cache.Cache { log, _ := test.NewNullLogger() - return cache.New(log, trustDomain1, bundle1, telemetry.Blackhole{}, 0, clock.NewMock()) + return cache.New(log, trustDomain1, bundle1, telemetry.Blackhole{}) } diff --git a/pkg/agent/manager/cache/cache.go b/pkg/agent/manager/cache/cache.go index a0944564710..fb5372d199e 100644 --- a/pkg/agent/manager/cache/cache.go +++ b/pkg/agent/manager/cache/cache.go @@ -1,13 +1,13 @@ package cache import ( + "context" "crypto" "crypto/x509" "sort" "sync" "time" - "github.com/andres-erbsen/clock" "github.com/sirupsen/logrus" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/common/bundleutil" @@ -15,10 +15,6 @@ import ( "github.com/spiffe/spire/proto/spire/common" ) -const ( - DefaultSVIDCacheMaxSize = 1000 -) - type Selectors []*common.Selector type Bundle = bundleutil.Bundle @@ -63,9 +59,9 @@ type X509SVID struct { PrivateKey crypto.Signer } -// Cache caches each registration entry, bundles, and JWT SVIDs for the agent. -// The signed X509-SVIDs for those entries are stored in LRU-like cache. -// It allows subscriptions by (workload) selector sets and notifies subscribers when: +// Cache caches each registration entry, signed X509-SVIDs for those entries, +// bundles, and JWT SVIDs for the agent. It allows subscriptions by (workload) +// selector sets and notifies subscribers when: // // 1) a registration entry related to the selectors: // * is modified @@ -80,22 +76,6 @@ type X509SVID struct { // selector it encounters. Each selector index tracks the subscribers (i.e // workloads) and registration entries that have that selector. // -// The LRU-like SVID cache has configurable size limit and expiry period. -// 1. Size limit of SVID cache is a soft limit. If SVID has a subscriber present then -// that SVID is never removed from cache. -// 2. Least recently used SVIDs are removed from cache only after the cache expiry period has passed. -// This is done to reduce the overall cache churn. -// 3. Last access timestamp for SVID cache entry is updated when a new subscriber is created -// 4. When a new subscriber is created and there is a cache miss -// then subscriber needs to wait for next SVID sync event to receive WorkloadUpdate with newly minted SVID -// -// The advantage of above approach is that if agent has entry count less than cache size -// then all SVIDs are cached at all times. If agent has entry count greater than cache size then -// subscribers will continue to get SVID updates (potential delay for first WorkloadUpdate if cache miss) -// and least used SVIDs will be removed from cache which will save memory usage. -// This allows agent to support environments where the active simultaneous workload count -// is a small percentage of the large number of registrations assigned to the agent. -// // When registration entries are added/updated/removed, the set of relevant // selectors are gathered and the indexes for those selectors are combed for // all relevant subscribers. @@ -123,7 +103,6 @@ type Cache struct { log logrus.FieldLogger trustDomain spiffeid.TrustDomain - clk clock.Clock metrics telemetry.Metrics @@ -135,20 +114,14 @@ type Cache struct { // selectors holds the selector indices, keyed by a selector key selectors map[selector]*selectorIndex - // staleEntries holds stale or new registration entries which require new SVID to be stored in cache + // staleEntries holds stale registration entries staleEntries map[string]bool // bundles holds the trust bundles, keyed by trust domain id (i.e. "spiffe://domain.test") bundles map[spiffeid.TrustDomain]*bundleutil.Bundle - - // svids are stored by entry IDs - svids map[string]*X509SVID - - // svidCacheMaxSize is a soft limit of max number of SVIDs that would be stored in cache - svidCacheMaxSize int } -// StaleEntry holds stale or outdated entries which require new SVID with old SVIDs expiration time (if present) +// StaleEntry holds stale entries with SVIDs expiration time type StaleEntry struct { // Entry stale registration entry Entry *common.RegistrationEntry @@ -156,12 +129,7 @@ type StaleEntry struct { ExpiresAt time.Time } -func New(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundle, metrics telemetry.Metrics, - svidCacheMaxSize int, clk clock.Clock) *Cache { - if svidCacheMaxSize <= 0 { - svidCacheMaxSize = DefaultSVIDCacheMaxSize - } - +func New(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundle, metrics telemetry.Metrics) *Cache { return &Cache{ BundleCache: NewBundleCache(trustDomain, bundle), JWTSVIDCache: NewJWTSVIDCache(), @@ -175,9 +143,6 @@ func New(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundl bundles: map[spiffeid.TrustDomain]*bundleutil.Bundle{ trustDomain: bundle, }, - svids: make(map[string]*X509SVID), - svidCacheMaxSize: svidCacheMaxSize, - clk: clk, } } @@ -189,44 +154,41 @@ func (c *Cache) Identities() []Identity { out := make([]Identity, 0, len(c.records)) for _, record := range c.records { - svid, ok := c.svids[record.entry.EntryId] - if !ok { + if record.svid == nil { // The record does not have an SVID yet and should not be returned // from the cache. continue } - out = append(out, makeIdentity(record, svid)) + out = append(out, makeIdentity(record)) } sortIdentities(out) return out } -func (c *Cache) Entries() []*common.RegistrationEntry { +func (c *Cache) CountSVIDs() int { c.mu.RLock() defer c.mu.RUnlock() - out := make([]*common.RegistrationEntry, 0, len(c.records)) + var records int for _, record := range c.records { - out = append(out, record.entry) + if record.svid == nil { + // The record does not have an SVID yet and should not be returned + // from the cache. + continue + } + records++ } - sortEntriesByID(out) - return out -} - -func (c *Cache) CountSVIDs() int { - c.mu.RLock() - defer c.mu.RUnlock() - return len(c.svids) + return records } -func (c *Cache) MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry { +func (c *Cache) MatchingIdentities(selectors []*common.Selector) []Identity { set, setDone := allocSelectorSet(selectors...) defer setDone() c.mu.RLock() defer c.mu.RUnlock() - return c.matchingEntries(set) + return c.matchingIdentities(set) } func (c *Cache) FetchWorkloadUpdate(selectors []*common.Selector) *WorkloadUpdate { @@ -238,19 +200,8 @@ func (c *Cache) FetchWorkloadUpdate(selectors []*common.Selector) *WorkloadUpdat return c.buildWorkloadUpdate(set) } -// NewSubscriber creates a subscriber for given selector set. -// Separately call Notify for the first time after this method is invoked to receive latest updates. -func (c *Cache) NewSubscriber(selectors []*common.Selector) Subscriber { - c.mu.Lock() - defer c.mu.Unlock() - - sub := newSubscriber(c, selectors) - for s := range sub.set { - c.addSelectorIndexSub(s, sub) - } - // update lastAccessTimestamp of records containing provided selectors - c.updateLastAccessTimestamp(selectors) - return sub +func (c *Cache) SubscribeToWorkloadUpdates(ctx context.Context, selectors Selectors) (Subscriber, error) { + return c.subscribeToWorkloadUpdates(selectors), nil } // UpdateEntries updates the cache with the provided registration entries and bundles and @@ -324,14 +275,11 @@ func (c *Cache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.Regi c.delSelectorIndicesRecord(selRem, record) notifySets = append(notifySets, selRem) delete(c.records, id) - delete(c.svids, id) // Remove stale entry since, registration entry is no longer on cache. delete(c.staleEntries, id) } } - outdatedEntries := make(map[string]struct{}) - // Add/update records for registration entries in the update for _, newEntry := range update.RegistrationEntries { clearSelectorSet(selAdd) @@ -389,9 +337,9 @@ func (c *Cache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.Regi notifySets = append(notifySets, notifySet) } - // Identify stale/outdated entries - if existingEntry != nil && existingEntry.RevisionNumber != newEntry.RevisionNumber { - outdatedEntries[newEntry.EntryId] = struct{}{} + // Invoke the svid checker callback for this record + if checkSVID != nil && checkSVID(existingEntry, newEntry, record.svid) { + c.staleEntries[newEntry.EntryId] = true } // Log all the details of the update to the DEBUG log @@ -420,43 +368,6 @@ func (c *Cache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.Regi } } - // entries with active subscribers which are not cached will be put in staleEntries map; - // irrespective of what svid cache size as we cannot deny identity to a subscriber - activeSubsByEntryID, recordsWithLastAccessTime := c.syncSVIDsWithSubscribers() - extraSize := len(c.svids) - c.svidCacheMaxSize - - // delete svids without subscribers and which have not been accessed since svidCacheExpiryTime - if extraSize > 0 { - // sort recordsWithLastAccessTime - sortByTimestamps(recordsWithLastAccessTime) - - for _, record := range recordsWithLastAccessTime { - if extraSize <= 0 { - // no need to delete SVIDs any further as cache size <= svidCacheMaxSize - break - } - if _, ok := c.svids[record.id]; ok { - if _, exists := activeSubsByEntryID[record.id]; !exists { - // remove svid - c.log.WithField("record_id", record.id). - WithField("record_timestamp", record.timestamp). - Debug("Removing SVID record") - delete(c.svids, record.id) - extraSize-- - } - } - } - } - - // Update all stale svids or svids whose registration entry is outdated - for id, svid := range c.svids { - if _, ok := outdatedEntries[id]; ok || (checkSVID != nil && checkSVID(nil, c.records[id].entry, svid)) { - c.staleEntries[id] = true - } - } - c.log.WithField(telemetry.OutdatedSVIDs, len(outdatedEntries)). - Debug("Updating SVIDs with outdated attributes in cache") - if bundleRemoved || len(bundleChanged) > 0 { c.BundleCache.Update(c.bundles) } @@ -484,7 +395,7 @@ func (c *Cache) UpdateSVIDs(update *UpdateSVIDs) { continue } - c.svids[entryID] = svid + record.svid = svid notifySet.Merge(record.entry.Selectors...) log := c.log.WithFields(logrus.Fields{ telemetry.Entry: record.entry.EntryId, @@ -514,8 +425,8 @@ func (c *Cache) GetStaleEntries() []*StaleEntry { } var expiresAt time.Time - if cachedSvid, ok := c.svids[entryID]; ok { - expiresAt = cachedSvid.Chain[0].NotAfter + if cachedEntry.svid != nil { + expiresAt = cachedEntry.svid.Chain[0].NotAfter } staleEntries = append(staleEntries, &StaleEntry{ @@ -527,97 +438,58 @@ func (c *Cache) GetStaleEntries() []*StaleEntry { return staleEntries } -// SyncSVIDsWithSubscribers will sync svid cache: -// entries with active subscribers which are not cached will be put in staleEntries map -// records which are not cached for remainder of max cache size will also be put in staleEntries map -func (c *Cache) SyncSVIDsWithSubscribers() { - c.mu.Lock() - defer c.mu.Unlock() - - c.syncSVIDsWithSubscribers() -} - -// Notify subscribers of selector set only if all SVIDs for corresponding selector set are cached -// It returns whether all SVIDs are cached or not. -// This method should be retried with backoff to avoid lock contention. -func (c *Cache) Notify(selectors []*common.Selector) bool { +func (c *Cache) MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry { c.mu.RLock() defer c.mu.RUnlock() - set, setFree := allocSelectorSet(selectors...) - defer setFree() - if !c.missingSVIDRecords(set) { - c.notifyBySelectorSet(set) - return true - } - return false -} -func (c *Cache) missingSVIDRecords(set selectorSet) bool { + set, setDone := allocSelectorSet(selectors...) + defer setDone() + records, recordsDone := c.getRecordsForSelectors(set) defer recordsDone() + if len(records) == 0 { + return nil + } + + // Return identities in ascending "entry id" order to maintain a consistent + // ordering. + // TODO: figure out how to determine the "default" identity + out := make([]*common.RegistrationEntry, 0, len(records)) for record := range records { - if _, exists := c.svids[record.entry.EntryId]; !exists { - return true - } + out = append(out, record.entry) } - return false + sortEntriesByID(out) + return out } -func (c *Cache) updateLastAccessTimestamp(selectors []*common.Selector) { - set, setFree := allocSelectorSet(selectors...) - defer setFree() - - records, recordsDone := c.getRecordsForSelectors(set) - defer recordsDone() +func (c *Cache) Entries() []*common.RegistrationEntry { + c.mu.RLock() + defer c.mu.RUnlock() - now := c.clk.Now().UnixMilli() - for record := range records { - // Set lastAccessTimestamp so that svid LRU cache can be cleaned based on this timestamp - record.lastAccessTimestamp = now + out := make([]*common.RegistrationEntry, 0, len(c.records)) + for _, record := range c.records { + out = append(out, record.entry) } + sortEntriesByID(out) + return out } -// entries with active subscribers which are not cached will be put in staleEntries map -// records which are not cached for remainder of max cache size will also be put in staleEntries map -func (c *Cache) syncSVIDsWithSubscribers() (map[string]struct{}, []recordAccessEvent) { - activeSubsByEntryID := make(map[string]struct{}) - lastAccessTimestamps := make([]recordAccessEvent, 0, len(c.records)) +func (c *Cache) SyncSVIDsWithSubscribers() { + c.log.Error("SyncSVIDsWithSubscribers method is not implemented") + return +} - // iterate over all selectors from cached entries and obtain: - // 1. entries that have active subscribers - // 1.1 if those entries don't have corresponding SVID cached then put them in staleEntries - // so that SVID will be cached in next sync - // 2. get lastAccessTimestamp of each entry - for id, record := range c.records { - for _, sel := range record.entry.Selectors { - if index, ok := c.selectors[makeSelector(sel)]; ok && index != nil { - if len(index.subs) > 0 { - if _, ok := c.svids[record.entry.EntryId]; !ok { - c.staleEntries[id] = true - } - activeSubsByEntryID[id] = struct{}{} - break - } - } - } - lastAccessTimestamps = append(lastAccessTimestamps, newRecord(record.lastAccessTimestamp, id)) - } +func (c *Cache) subscribeToWorkloadUpdates(selectors []*common.Selector) Subscriber { + c.mu.Lock() + defer c.mu.Unlock() - remainderSize := c.svidCacheMaxSize - len(c.svids) - // add records which are not cached for remainder of cache size - for id := range c.records { - if len(c.staleEntries) >= remainderSize { - break - } - if _, svidCached := c.svids[id]; !svidCached { - if _, ok := c.staleEntries[id]; !ok { - c.staleEntries[id] = true - } - } + sub := newSubscriber(c, selectors) + for s := range sub.set { + c.addSelectorIndexSub(s, sub) } - - return activeSubsByEntryID, lastAccessTimestamps + c.notify(sub) + return sub } func (c *Cache) updateOrCreateRecord(newEntry *common.RegistrationEntry) (*cacheRecord, *common.RegistrationEntry) { @@ -791,33 +663,12 @@ func (c *Cache) matchingIdentities(set selectorSet) []Identity { // TODO: figure out how to determine the "default" identity out := make([]Identity, 0, len(records)) for record := range records { - if svid, ok := c.svids[record.entry.EntryId]; ok { - out = append(out, makeIdentity(record, svid)) - } + out = append(out, makeIdentity(record)) } sortIdentities(out) return out } -func (c *Cache) matchingEntries(set selectorSet) []*common.RegistrationEntry { - records, recordsDone := c.getRecordsForSelectors(set) - defer recordsDone() - - if len(records) == 0 { - return nil - } - - // Return identities in ascending "entry id" order to maintain a consistent - // ordering. - // TODO: figure out how to determine the "default" identity - out := make([]*common.RegistrationEntry, 0, len(records)) - for record := range records { - out = append(out, record.entry) - } - sortEntriesByID(out) - return out -} - func (c *Cache) buildWorkloadUpdate(set selectorSet) *WorkloadUpdate { w := &WorkloadUpdate{ Bundle: c.bundles[c.trustDomain], @@ -852,13 +703,17 @@ func (c *Cache) buildWorkloadUpdate(set selectorSet) *WorkloadUpdate { } func (c *Cache) getRecordsForSelectors(set selectorSet) (recordSet, func()) { - // Build and dedup a list of candidate entries. Don't check for selector set inclusion yet, since + // Build and dedup a list of candidate entries. Ignore those without an + // SVID but otherwise don't check for selector set inclusion yet, since // that is a more expensive operation and we could easily have duplicate // entries to check. records, recordsDone := allocRecordSet() for selector := range set { if index := c.getSelectorIndexForRead(selector); index != nil { for record := range index.records { + if record.svid == nil { + continue + } records[record] = struct{}{} } } @@ -900,9 +755,9 @@ func (c *Cache) getSelectorIndexForRead(s selector) *selectorIndex { } type cacheRecord struct { - entry *common.RegistrationEntry - subs map[*subscriber]struct{} - lastAccessTimestamp int64 + entry *common.RegistrationEntry + svid *X509SVID + subs map[*subscriber]struct{} } func newCacheRecord() *cacheRecord { @@ -936,31 +791,10 @@ func sortIdentities(identities []Identity) { }) } -func sortEntriesByID(entries []*common.RegistrationEntry) { - sort.Slice(entries, func(a, b int) bool { - return entries[a].EntryId < entries[b].EntryId - }) -} - -func sortByTimestamps(records []recordAccessEvent) { - sort.Slice(records, func(a, b int) bool { - return records[a].timestamp < records[b].timestamp - }) -} - -func makeIdentity(record *cacheRecord, svid *X509SVID) Identity { +func makeIdentity(record *cacheRecord) Identity { return Identity{ Entry: record.entry, - SVID: svid.Chain, - PrivateKey: svid.PrivateKey, + SVID: record.svid.Chain, + PrivateKey: record.svid.PrivateKey, } } - -type recordAccessEvent struct { - timestamp int64 - id string -} - -func newRecord(timestamp int64, id string) recordAccessEvent { - return recordAccessEvent{timestamp: timestamp, id: id} -} diff --git a/pkg/agent/manager/cache/cache_test.go b/pkg/agent/manager/cache/cache_test.go index 1b71ef3f52e..0a048e5c9c4 100644 --- a/pkg/agent/manager/cache/cache_test.go +++ b/pkg/agent/manager/cache/cache_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/andres-erbsen/clock" "github.com/sirupsen/logrus/hooks/test" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/common/bundleutil" @@ -59,7 +58,7 @@ func TestFetchWorkloadUpdate(t *testing.T) { }, workloadUpdate) } -func TestMatchingRegistrationIdentities(t *testing.T) { +func TestMatchingIdentities(t *testing.T) { cache := newTestCache() // populate the cache with FOO and BAR without SVIDS @@ -71,21 +70,19 @@ func TestMatchingRegistrationIdentities(t *testing.T) { } cache.UpdateEntries(updateEntries, nil) - assert.Equal(t, []*common.RegistrationEntry{bar, foo}, - cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) + identities := cache.MatchingIdentities(makeSelectors("A", "B")) + assert.Len(t, identities, 0, "identities should not be returned that don't have SVIDs") - // Update SVIDs and MatchingRegistrationEntries should return both entries updateSVIDs := &UpdateSVIDs{ X509SVIDs: makeX509SVIDs(foo, bar), } cache.UpdateSVIDs(updateSVIDs) - assert.Equal(t, []*common.RegistrationEntry{bar, foo}, - cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) - // Remove SVIDs and MatchingRegistrationEntries should still return both entries - cache.UpdateSVIDs(&UpdateSVIDs{}) - assert.Equal(t, []*common.RegistrationEntry{bar, foo}, - cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) + identities = cache.MatchingIdentities(makeSelectors("A", "B")) + assert.Equal(t, []Identity{ + {Entry: bar}, + {Entry: foo}, + }, identities) } func TestCountSVIDs(t *testing.T) { @@ -140,11 +137,11 @@ func TestAllSubscribersNotifiedOnBundleChange(t *testing.T) { cache := newTestCache() // create some subscribers and assert they get the initial bundle - subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + subA := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer subA.Finish() assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{Bundle: bundleV1}) - subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) + subB := cache.subscribeToWorkloadUpdates(makeSelectors("B")) defer subB.Finish() assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{Bundle: bundleV1}) @@ -171,11 +168,11 @@ func TestSomeSubscribersNotifiedOnFederatedBundleChange(t *testing.T) { }) // subscribe to A and B and assert initial updates are received. - subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + subA := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer subA.Finish() assertAnyWorkloadUpdate(t, subA) - subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) + subB := cache.subscribeToWorkloadUpdates(makeSelectors("B")) defer subB.Finish() assertAnyWorkloadUpdate(t, subB) @@ -234,11 +231,11 @@ func TestSubscribersGetEntriesWithSelectorSubsets(t *testing.T) { cache := newTestCache() // create subscribers for each combination of selectors - subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + subA := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer subA.Finish() - subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) + subB := cache.subscribeToWorkloadUpdates(makeSelectors("B")) defer subB.Finish() - subAB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A", "B")) + subAB := cache.subscribeToWorkloadUpdates(makeSelectors("A", "B")) defer subAB.Finish() // assert all subscribers get the initial update @@ -291,7 +288,7 @@ func TestSubscriberIsNotNotifiedIfNothingChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo), }) - sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -317,7 +314,7 @@ func TestSubscriberNotifiedOnSVIDChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo), }) - sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -332,7 +329,7 @@ func TestSubscriberNotifiedOnSVIDChanges(t *testing.T) { }) } -func TestSubscriberNotificationsOnSelectorChanges(t *testing.T) { +func TestSubcriberNotificationsOnSelectorChanges(t *testing.T) { cache := newTestCache() // initialize the cache with a FOO entry with selector A and an SVID @@ -346,7 +343,7 @@ func TestSubscriberNotificationsOnSelectorChanges(t *testing.T) { }) // create subscribers for A and make sure the initial update has FOO - sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer sub.Finish() assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ Bundle: bundleV1, @@ -383,16 +380,21 @@ func TestSubscriberNotificationsOnSelectorChanges(t *testing.T) { }) } -func TestSubscriberNotifiedWhenEntryDropped(t *testing.T) { +func newTestCache() *Cache { + log, _ := test.NewNullLogger() + return New(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, telemetry.Blackhole{}) +} + +func TestSubcriberNotifiedWhenEntryDropped(t *testing.T) { cache := newTestCache() - subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + subA := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer subA.Finish() assertAnyWorkloadUpdate(t, subA) // subB's job here is to just make sure we don't notify unrelated // subscribers when dropping registration entries - subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) + subB := cache.subscribeToWorkloadUpdates(makeSelectors("B")) defer subB.Finish() assertAnyWorkloadUpdate(t, subB) @@ -426,7 +428,7 @@ func TestSubscriberNotifiedWhenEntryDropped(t *testing.T) { assertNoWorkloadUpdate(t, subB) } -func TestSubscriberOnlyGetsEntriesWithSVID(t *testing.T) { +func TestSubcriberOnlyGetsEntriesWithSVID(t *testing.T) { cache := newTestCache() foo := makeRegistrationEntry("FOO", "A") @@ -436,9 +438,13 @@ func TestSubscriberOnlyGetsEntriesWithSVID(t *testing.T) { } cache.UpdateEntries(updateEntries, nil) - sub := cache.NewSubscriber(makeSelectors("A")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer sub.Finish() - assertNoWorkloadUpdate(t, sub) + + // workload update does not include the identity because it has no SVID. + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + Bundle: bundleV1, + }) // update to include the SVID and now we should get the update cache.UpdateSVIDs(&UpdateSVIDs{ @@ -453,7 +459,7 @@ func TestSubscriberOnlyGetsEntriesWithSVID(t *testing.T) { func TestSubscribersDoNotBlockNotifications(t *testing.T) { cache := newTestCache() - sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer sub.Finish() cache.UpdateEntries(&UpdateEntries{ @@ -483,23 +489,34 @@ func TestCheckSVIDCallback(t *testing.T) { foo := makeRegistrationEntryWithTTL("FOO", 60) + // called once for FOO with no SVID + callCount := 0 cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV2), RegistrationEntries: makeRegistrationEntries(foo), }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { - // should not get invoked - assert.Fail(t, "should not be called as no SVIDs are cached yet") + callCount++ + assert.Equal(t, "FOO", newEntry.EntryId) + + // there is no already existing entry, only the new entry + assert.Nil(t, existingEntry) + assert.Equal(t, foo, newEntry) + assert.Nil(t, svid) + return false }) + assert.Equal(t, 1, callCount) + assert.Empty(t, cache.staleEntries) // called once for FOO with new SVID + callCount = 0 svids := makeX509SVIDs(foo) cache.UpdateSVIDs(&UpdateSVIDs{ X509SVIDs: svids, }) // called once for FOO with existing SVID - callCount := 0 + callCount = 0 cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV2), RegistrationEntries: makeRegistrationEntries(foo), @@ -521,15 +538,22 @@ func TestGetStaleEntries(t *testing.T) { foo := makeRegistrationEntryWithTTL("FOO", 60) - // Create entry but don't mark it stale from checkSVID method; - // it will be marked stale cause it does not have SVID cached + // Create entry but don't mark it stale cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV2), RegistrationEntries: makeRegistrationEntries(foo), }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { return false }) + assert.Empty(t, cache.GetStaleEntries()) + // Update entry and mark it as stale + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + RegistrationEntries: makeRegistrationEntries(foo), + }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { + return true + }) // Assert that the entry is returned as stale. The `ExpiresAt` field should be unset since there is no SVID. expectedEntries := []*StaleEntry{{Entry: cache.records[foo.EntryId].entry}} assert.Equal(t, expectedEntries, cache.GetStaleEntries()) @@ -583,7 +607,7 @@ func TestSubscriberNotNotifiedOnDifferentSVIDChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo, bar), }) - sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -608,7 +632,7 @@ func TestSubscriberNotNotifiedOnOverlappingSVIDChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo, bar), }) - sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A", "B")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A", "B")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -620,181 +644,6 @@ func TestSubscriberNotNotifiedOnOverlappingSVIDChanges(t *testing.T) { assertNoWorkloadUpdate(t, sub) } -func TestSVIDCacheExpiry(t *testing.T) { - clk := clock.NewMock() - cache := newTestCacheWithConfig(10, clk) - - clk.Add(1 * time.Second) - foo := makeRegistrationEntry("FOO", "A") - // validate workload update for foo - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo), - }, nil) - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo), - }) - subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) - assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ - Bundle: bundleV1, - Identities: []Identity{{Entry: foo}}, - }) - subA.Finish() - - // move clk by 1 sec so that SVID access time will be different - clk.Add(1 * time.Second) - bar := makeRegistrationEntry("BAR", "B") - // validate workload update for bar - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo, bar), - }, nil) - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(bar), - }) - - // not closing subscriber immediately - subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) - defer subB.Finish() - assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{ - Bundle: bundleV1, - Identities: []Identity{ - {Entry: bar}, - }, - }) - - // Move clk by 2 seconds - clk.Add(2 * time.Second) - // update total of 12 entries - updateEntries := createUpdateEntries(10, makeBundles(bundleV1)) - updateEntries.RegistrationEntries[foo.EntryId] = foo - updateEntries.RegistrationEntries[bar.EntryId] = bar - - cache.UpdateEntries(updateEntries, nil) - - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDsFromMap(updateEntries.RegistrationEntries), - }) - - for id, entry := range updateEntries.RegistrationEntries { - // create and close subscribers for remaining entries so that svid cache is full - if id != foo.EntryId && id != bar.EntryId { - sub := cache.NewSubscriber(entry.Selectors) - sub.Finish() - } - } - assert.Equal(t, 12, cache.CountSVIDs()) - - cache.UpdateEntries(updateEntries, nil) - assert.Equal(t, 10, cache.CountSVIDs()) - - // foo SVID should be removed from cache as it does not have active subscriber - assert.False(t, cache.Notify(makeSelectors("A"))) - // bar SVID should be cached as it has active subscriber - assert.True(t, cache.Notify(makeSelectors("B"))) - - subA = cache.NewSubscriber(makeSelectors("A")) - defer subA.Finish() - - cache.UpdateEntries(updateEntries, nil) - - // Make sure foo is marked as stale entry which does not have svid cached - require.Len(t, cache.GetStaleEntries(), 1) - assert.Equal(t, foo, cache.GetStaleEntries()[0].Entry) - - assert.Equal(t, 10, cache.CountSVIDs()) -} - -func TestMaxSVIDCacheSize(t *testing.T) { - clk := clock.NewMock() - cache := newTestCacheWithConfig(10, clk) - - // create entries more than maxSvidCacheSize - updateEntries := createUpdateEntries(12, makeBundles(bundleV1)) - cache.UpdateEntries(updateEntries, nil) - - require.Len(t, cache.GetStaleEntries(), 10) - - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDsFromStaleEntries(cache.GetStaleEntries()), - }) - require.Len(t, cache.GetStaleEntries(), 0) - assert.Equal(t, 10, cache.CountSVIDs()) - - // Validate that active subscriber will still get SVID even if SVID count is at maxSvidCacheSize - foo := makeRegistrationEntry("FOO", "A") - updateEntries.RegistrationEntries[foo.EntryId] = foo - - subA := cache.NewSubscriber(foo.Selectors) - defer subA.Finish() - - cache.UpdateEntries(updateEntries, nil) - require.Len(t, cache.GetStaleEntries(), 1) - assert.Equal(t, 10, cache.CountSVIDs()) - - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo), - }) - assert.Equal(t, 11, cache.CountSVIDs()) - require.Len(t, cache.GetStaleEntries(), 0) -} - -func TestSyncSVIDsWithSubscribers(t *testing.T) { - clk := clock.NewMock() - cache := newTestCacheWithConfig(5, clk) - - updateEntries := createUpdateEntries(5, makeBundles(bundleV1)) - cache.UpdateEntries(updateEntries, nil) - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDsFromStaleEntries(cache.GetStaleEntries()), - }) - assert.Equal(t, 5, cache.CountSVIDs()) - - // Update foo but its SVID is not yet cached - foo := makeRegistrationEntry("FOO", "A") - updateEntries.RegistrationEntries[foo.EntryId] = foo - - cache.UpdateEntries(updateEntries, nil) - - // Create a subscriber for foo - subA := cache.NewSubscriber(foo.Selectors) - defer subA.Finish() - require.Len(t, cache.GetStaleEntries(), 0) - - // After SyncSVIDsWithSubscribers foo should be marked as stale, requiring signing - cache.SyncSVIDsWithSubscribers() - require.Len(t, cache.GetStaleEntries(), 1) - assert.Equal(t, []*StaleEntry{{Entry: cache.records[foo.EntryId].entry}}, cache.GetStaleEntries()) - - assert.Equal(t, 5, cache.CountSVIDs()) -} - -func TestNotify(t *testing.T) { - cache := newTestCache() - - foo := makeRegistrationEntry("FOO", "A") - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo), - }, nil) - - assert.False(t, cache.Notify(makeSelectors("A"))) - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo), - }) - assert.True(t, cache.Notify(makeSelectors("A"))) -} - -func TestNewCache(t *testing.T) { - // negative value - cache := newTestCacheWithConfig(-5, clock.NewMock()) - require.Equal(t, DefaultSVIDCacheMaxSize, cache.svidCacheMaxSize) - - // zero value - cache = newTestCacheWithConfig(0, clock.NewMock()) - require.Equal(t, DefaultSVIDCacheMaxSize, cache.svidCacheMaxSize) -} - func BenchmarkCacheGlobalNotification(b *testing.B) { cache := newTestCache() @@ -823,7 +672,7 @@ func BenchmarkCacheGlobalNotification(b *testing.B) { cache.UpdateEntries(updateEntries, nil) for i := 0; i < numWorkloads; i++ { selectors := distinctSelectors(i, selectorsPerWorkload) - cache.NewSubscriber(selectors) + cache.subscribeToWorkloadUpdates(selectors) } runtime.GC() @@ -840,35 +689,40 @@ func BenchmarkCacheGlobalNotification(b *testing.B) { } } -func newTestCache() *Cache { - log, _ := test.NewNullLogger() - return New(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, - telemetry.Blackhole{}, 0, clock.NewMock()) -} +func TestMatchingRegistrationEntries(t *testing.T) { + cache := newTestCache() -func newTestCacheWithConfig(svidCacheMaxSize int, clk clock.Clock) *Cache { - log, _ := test.NewNullLogger() - return New(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, telemetry.Blackhole{}, - svidCacheMaxSize, clk) + // populate the cache with FOO and BAR without SVIDS + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + } + cache.UpdateEntries(updateEntries, nil) + + // Update SVIDs and MatchingRegistrationEntries should return both entries + updateSVIDs := &UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + } + cache.UpdateSVIDs(updateSVIDs) + assert.Equal(t, []*common.RegistrationEntry{bar, foo}, + cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) } -// numEntries should not be more than 12 digits -func createUpdateEntries(numEntries int, bundles map[spiffeid.TrustDomain]*bundleutil.Bundle) *UpdateEntries { +func TestEntries(t *testing.T) { + cache := newTestCache() + + // populate the cache with FOO and BAR without SVIDS + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") updateEntries := &UpdateEntries{ - Bundles: bundles, - RegistrationEntries: make(map[string]*common.RegistrationEntry, numEntries), + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), } + cache.UpdateEntries(updateEntries, nil) - for i := 0; i < numEntries; i++ { - entryID := fmt.Sprintf("00000000-0000-0000-0000-%012d", i) - updateEntries.RegistrationEntries[entryID] = &common.RegistrationEntry{ - EntryId: entryID, - ParentId: "spiffe://domain.test/node", - SpiffeId: fmt.Sprintf("spiffe://domain.test/workload-%d", i), - Selectors: distinctSelectors(i, 1), - } - } - return updateEntries + assert.Equal(t, []*common.RegistrationEntry{bar, foo}, cache.Entries()) } func distinctSelectors(id, n int) []*common.Selector { @@ -926,22 +780,6 @@ func makeX509SVIDs(entries ...*common.RegistrationEntry) map[string]*X509SVID { return out } -func makeX509SVIDsFromMap(entries map[string]*common.RegistrationEntry) map[string]*X509SVID { - out := make(map[string]*X509SVID) - for _, entry := range entries { - out[entry.EntryId] = &X509SVID{} - } - return out -} - -func makeX509SVIDsFromStaleEntries(entries []*StaleEntry) map[string]*X509SVID { - out := make(map[string]*X509SVID) - for _, entry := range entries { - out[entry.Entry.EntryId] = &X509SVID{} - } - return out -} - func makeRegistrationEntry(id string, selectors ...string) *common.RegistrationEntry { return &common.RegistrationEntry{ EntryId: id, @@ -985,9 +823,3 @@ func makeFederatesWith(bundles ...*Bundle) []string { } return out } - -func subscribeToWorkloadUpdatesAndNotify(t *testing.T, cache *Cache, selectors []*common.Selector) Subscriber { - subscriber := cache.NewSubscriber(selectors) - assert.True(t, cache.Notify(selectors)) - return subscriber -} diff --git a/pkg/agent/manager/cache/lru_cache.go b/pkg/agent/manager/cache/lru_cache.go new file mode 100644 index 00000000000..8516ca51d56 --- /dev/null +++ b/pkg/agent/manager/cache/lru_cache.go @@ -0,0 +1,944 @@ +package cache + +import ( + "context" + "sort" + "sync" + "time" + + "github.com/andres-erbsen/clock" + "github.com/sirupsen/logrus" + "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/spiffe/spire/pkg/agent/common/backoff" + "github.com/spiffe/spire/pkg/common/bundleutil" + "github.com/spiffe/spire/pkg/common/telemetry" + "github.com/spiffe/spire/proto/spire/common" +) + +const ( + DefaultSVIDCacheMaxSize = 1000 + SvidSyncInterval = 500 * time.Millisecond +) + +// Cache caches each registration entry, bundles, and JWT SVIDs for the agent. +// The signed X509-SVIDs for those entries are stored in LRU-like cache. +// It allows subscriptions by (workload) selector sets and notifies subscribers when: +// +// 1) a registration entry related to the selectors: +// * is modified +// * has a new X509-SVID signed for it +// * federates with a federated bundle that is updated +// 2) the trust bundle for the agent trust domain is updated +// +// When notified, the subscriber is given a WorkloadUpdate containing +// related identities and trust bundles. +// +// The cache does this efficiently by building an index for each unique +// selector it encounters. Each selector index tracks the subscribers (i.e +// workloads) and registration entries that have that selector. +// +// The LRU-like SVID cache has configurable size limit and expiry period. +// 1. Size limit of SVID cache is a soft limit. If SVID has a subscriber present then +// that SVID is never removed from cache. +// 2. Least recently used SVIDs are removed from cache only after the cache expiry period has passed. +// This is done to reduce the overall cache churn. +// 3. Last access timestamp for SVID cache entry is updated when a new subscriber is created +// 4. When a new subscriber is created and there is a cache miss +// then subscriber needs to wait for next SVID sync event to receive WorkloadUpdate with newly minted SVID +// +// The advantage of above approach is that if agent has entry count less than cache size +// then all SVIDs are cached at all times. If agent has entry count greater than cache size then +// subscribers will continue to get SVID updates (potential delay for first WorkloadUpdate if cache miss) +// and least used SVIDs will be removed from cache which will save memory usage. +// This allows agent to support environments where the active simultaneous workload count +// is a small percentage of the large number of registrations assigned to the agent. +// +// When registration entries are added/updated/removed, the set of relevant +// selectors are gathered and the indexes for those selectors are combed for +// all relevant subscribers. +// +// For each relevant subscriber, the selector index for each selector of the +// subscriber is combed for registration whose selectors are a subset of the +// subscriber selector set. Identities for those entries are added to the +// workload update returned to the subscriber. +// +// NOTE: The cache is intended to be able to handle thousands of workload +// subscriptions, which can involve thousands of certificates, keys, bundles, +// and registration entries, etc. The selector index itself is intended to be +// scalable, but the objects themselves can take a considerable amount of +// memory. For maximal safety, the objects should be cloned both coming in and +// leaving the cache. However, during global updates (e.g. trust bundle is +// updated for the agent trust domain) in particular, cloning all of the +// relevant objects for each subscriber causes HUGE amounts of memory pressure +// which adds non-trivial amounts of latency and causes a giant memory spike +// that could OOM the agent on smaller VMs. For this reason, the cache is +// presumed to own ALL data passing in and out of the cache. Producers and +// consumers MUST NOT mutate the data. +type LRUCache struct { + *BundleCache + *JWTSVIDCache + + log logrus.FieldLogger + trustDomain spiffeid.TrustDomain + clk clock.Clock + + metrics telemetry.Metrics + + mu sync.RWMutex + + // records holds the records for registration entries, keyed by registration entry ID + records map[string]*lruCacheRecord + + // selectors holds the selector indices, keyed by a selector key + selectors map[selector]*selectorsMapIndex + + // staleEntries holds stale or new registration entries which require new SVID to be stored in cache + staleEntries map[string]bool + + // bundles holds the trust bundles, keyed by trust domain id (i.e. "spiffe://domain.test") + bundles map[spiffeid.TrustDomain]*bundleutil.Bundle + + // svids are stored by entry IDs + svids map[string]*X509SVID + + // svidCacheMaxSize is a soft limit of max number of SVIDs that would be stored in cache + svidCacheMaxSize int + subscribeBackoffFn func() backoff.BackOff +} + +func NewLRUCache(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundle, metrics telemetry.Metrics, + svidCacheMaxSize int, clk clock.Clock) *LRUCache { + if svidCacheMaxSize <= 0 { + svidCacheMaxSize = DefaultSVIDCacheMaxSize + } + + return &LRUCache{ + BundleCache: NewBundleCache(trustDomain, bundle), + JWTSVIDCache: NewJWTSVIDCache(), + + log: log, + metrics: metrics, + trustDomain: trustDomain, + records: make(map[string]*lruCacheRecord), + selectors: make(map[selector]*selectorsMapIndex), + staleEntries: make(map[string]bool), + bundles: map[spiffeid.TrustDomain]*bundleutil.Bundle{ + trustDomain: bundle, + }, + svids: make(map[string]*X509SVID), + svidCacheMaxSize: svidCacheMaxSize, + clk: clk, + subscribeBackoffFn: func() backoff.BackOff { + return backoff.NewBackoff(clk, SvidSyncInterval) + }, + } +} + +// Identities is only used by manager tests +// TODO: We should remove this and find a better way +func (c *LRUCache) Identities() []Identity { + c.mu.RLock() + defer c.mu.RUnlock() + + out := make([]Identity, 0, len(c.records)) + for _, record := range c.records { + svid, ok := c.svids[record.entry.EntryId] + if !ok { + // The record does not have an SVID yet and should not be returned + // from the cache. + continue + } + out = append(out, makeNewIdentity(record, svid)) + } + sortIdentities(out) + return out +} + +func (c *LRUCache) Entries() []*common.RegistrationEntry { + c.mu.RLock() + defer c.mu.RUnlock() + + out := make([]*common.RegistrationEntry, 0, len(c.records)) + for _, record := range c.records { + out = append(out, record.entry) + } + sortEntriesByID(out) + return out +} + +func (c *LRUCache) CountSVIDs() int { + c.mu.RLock() + defer c.mu.RUnlock() + + return len(c.svids) +} + +func (c *LRUCache) MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry { + set, setDone := allocSelectorSet(selectors...) + defer setDone() + + c.mu.RLock() + defer c.mu.RUnlock() + return c.matchingEntries(set) +} + +func (c *LRUCache) FetchWorkloadUpdate(selectors []*common.Selector) *WorkloadUpdate { + set, setDone := allocSelectorSet(selectors...) + defer setDone() + + c.mu.RLock() + defer c.mu.RUnlock() + return c.buildWorkloadUpdate(set) +} + +// NewSubscriber creates a subscriber for given selector set. +// Separately call Notify for the first time after this method is invoked to receive latest updates. +func (c *LRUCache) NewSubscriber(selectors []*common.Selector) Subscriber { + c.mu.Lock() + defer c.mu.Unlock() + + sub := newLRUCacheSubscriber(c, selectors) + for s := range sub.set { + c.addSelectorIndexSub(s, sub) + } + // update lastAccessTimestamp of records containing provided selectors + c.updateLastAccessTimestamp(selectors) + return sub +} + +// UpdateEntries updates the cache with the provided registration entries and bundles and +// notifies impacted subscribers. The checkSVID callback, if provided, is used to determine +// if the SVID for the entry is stale, or otherwise in need of rotation. Entries marked stale +// through the checkSVID callback are returned from GetStaleEntries() until the SVID is +// updated through a call to UpdateSVIDs. +func (c *LRUCache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.RegistrationEntry, *common.RegistrationEntry, *X509SVID) bool) { + c.mu.Lock() + defer c.mu.Unlock() + + // Remove bundles that no longer exist. The bundle for the agent trust + // domain should NOT be removed even if not present (which should only be + // the case if there is a bug on the server) since it is necessary to + // authenticate the server. + bundleRemoved := false + for id := range c.bundles { + if _, ok := update.Bundles[id]; !ok && id != c.trustDomain { + bundleRemoved = true + // bundle no longer exists. + c.log.WithField(telemetry.TrustDomainID, id).Debug("Bundle removed") + delete(c.bundles, id) + } + } + + // Update bundles with changes, populating a "changed" set that we can + // check when processing registration entries to know if they need to spawn + // a notification. + bundleChanged := make(map[spiffeid.TrustDomain]bool) + for id, bundle := range update.Bundles { + existing, ok := c.bundles[id] + if !(ok && existing.EqualTo(bundle)) { + if !ok { + c.log.WithField(telemetry.TrustDomainID, id).Debug("Bundle added") + } else { + c.log.WithField(telemetry.TrustDomainID, id).Debug("Bundle updated") + } + bundleChanged[id] = true + c.bundles[id] = bundle + } + } + trustDomainBundleChanged := bundleChanged[c.trustDomain] + + // Allocate sets from the pool to track changes to selectors and + // federatesWith declarations. These sets must be cleared after EACH use + // and returned to their respective pools when done processing the + // updates. + notifySets := make([]selectorSet, 0) + selAdd, selAddDone := allocSelectorSet() + defer selAddDone() + selRem, selRemDone := allocSelectorSet() + defer selRemDone() + fedAdd, fedAddDone := allocStringSet() + defer fedAddDone() + fedRem, fedRemDone := allocStringSet() + defer fedRemDone() + + // Remove records for registration entries that no longer exist + for id, record := range c.records { + if _, ok := update.RegistrationEntries[id]; !ok { + c.log.WithFields(logrus.Fields{ + telemetry.Entry: id, + telemetry.SPIFFEID: record.entry.SpiffeId, + }).Debug("Entry removed") + + // built a set of selectors for the record being removed, drop the + // record for each selector index, and add the entry selectors to + // the notify set. + clearSelectorSet(selRem) + selRem.Merge(record.entry.Selectors...) + c.delSelectorIndicesRecord(selRem, record) + notifySets = append(notifySets, selRem) + delete(c.records, id) + delete(c.svids, id) + // Remove stale entry since, registration entry is no longer on cache. + delete(c.staleEntries, id) + } + } + + outdatedEntries := make(map[string]struct{}) + + // Add/update records for registration entries in the update + for _, newEntry := range update.RegistrationEntries { + clearSelectorSet(selAdd) + clearSelectorSet(selRem) + clearStringSet(fedAdd) + clearStringSet(fedRem) + + record, existingEntry := c.updateOrCreateRecord(newEntry) + + // Calculate the difference in selectors, add/remove the record + // from impacted selector indices, and add the selector diff to the + // notify set. + c.diffSelectors(existingEntry, newEntry, selAdd, selRem) + selectorsChanged := len(selAdd) > 0 || len(selRem) > 0 + c.addSelectorIndicesRecord(selAdd, record) + c.delSelectorIndicesRecord(selRem, record) + + // Determine if there were changes to FederatesWith declarations or + // if any federated bundles related to the entry were updated. + c.diffFederatesWith(existingEntry, newEntry, fedAdd, fedRem) + federatedBundlesChanged := len(fedAdd) > 0 || len(fedRem) > 0 + if !federatedBundlesChanged { + for _, id := range newEntry.FederatesWith { + td, err := spiffeid.TrustDomainFromString(id) + if err != nil { + c.log.WithFields(logrus.Fields{ + telemetry.TrustDomainID: id, + logrus.ErrorKey: err, + }).Warn("Invalid federated trust domain") + continue + } + if bundleChanged[td] { + federatedBundlesChanged = true + break + } + } + } + + // If any selectors or federated bundles were changed, then make + // sure subscribers for the new and extisting entry selector sets + // are notified. + if selectorsChanged { + if existingEntry != nil { + notifySet, selSetDone := allocSelectorSet() + defer selSetDone() + notifySet.Merge(existingEntry.Selectors...) + notifySets = append(notifySets, notifySet) + } + } + + if federatedBundlesChanged || selectorsChanged { + notifySet, selSetDone := allocSelectorSet() + defer selSetDone() + notifySet.Merge(newEntry.Selectors...) + notifySets = append(notifySets, notifySet) + } + + // Identify stale/outdated entries + if existingEntry != nil && existingEntry.RevisionNumber != newEntry.RevisionNumber { + outdatedEntries[newEntry.EntryId] = struct{}{} + } + + // Log all the details of the update to the DEBUG log + if federatedBundlesChanged || selectorsChanged { + log := c.log.WithFields(logrus.Fields{ + telemetry.Entry: newEntry.EntryId, + telemetry.SPIFFEID: newEntry.SpiffeId, + }) + if len(selAdd) > 0 { + log = log.WithField(telemetry.SelectorsAdded, len(selAdd)) + } + if len(selRem) > 0 { + log = log.WithField(telemetry.SelectorsRemoved, len(selRem)) + } + if len(fedAdd) > 0 { + log = log.WithField(telemetry.FederatedAdded, len(fedAdd)) + } + if len(fedRem) > 0 { + log = log.WithField(telemetry.FederatedRemoved, len(fedRem)) + } + if existingEntry != nil { + log.Debug("Entry updated") + } else { + log.Debug("Entry created") + } + } + } + + // entries with active subscribers which are not cached will be put in staleEntries map; + // irrespective of what svid cache size as we cannot deny identity to a subscriber + activeSubsByEntryID, recordsWithLastAccessTime := c.syncSVIDsWithSubscribers() + extraSize := len(c.svids) - c.svidCacheMaxSize + + // delete svids without subscribers and which have not been accessed since svidCacheExpiryTime + if extraSize > 0 { + // sort recordsWithLastAccessTime + sortByTimestamps(recordsWithLastAccessTime) + + for _, record := range recordsWithLastAccessTime { + if extraSize <= 0 { + // no need to delete SVIDs any further as cache size <= svidCacheMaxSize + break + } + if _, ok := c.svids[record.id]; ok { + if _, exists := activeSubsByEntryID[record.id]; !exists { + // remove svid + c.log.WithField("record_id", record.id). + WithField("record_timestamp", record.timestamp). + Debug("Removing SVID record") + delete(c.svids, record.id) + extraSize-- + } + } + } + } + + // Update all stale svids or svids whose registration entry is outdated + for id, svid := range c.svids { + if _, ok := outdatedEntries[id]; ok || (checkSVID != nil && checkSVID(nil, c.records[id].entry, svid)) { + c.staleEntries[id] = true + } + } + c.log.WithField(telemetry.OutdatedSVIDs, len(outdatedEntries)). + Debug("Updating SVIDs with outdated attributes in cache") + + if bundleRemoved || len(bundleChanged) > 0 { + c.BundleCache.Update(c.bundles) + } + + if trustDomainBundleChanged { + c.notifyAll() + } else { + c.notifyBySelectorSet(notifySets...) + } +} + +func (c *LRUCache) UpdateSVIDs(update *UpdateSVIDs) { + c.mu.Lock() + defer c.mu.Unlock() + + // Allocate a set of selectors that + notifySet, selSetDone := allocSelectorSet() + defer selSetDone() + + // Add/update records for registration entries in the update + for entryID, svid := range update.X509SVIDs { + record, existingEntry := c.records[entryID] + if !existingEntry { + c.log.WithField(telemetry.RegistrationID, entryID).Error("Entry not found") + continue + } + + c.svids[entryID] = svid + notifySet.Merge(record.entry.Selectors...) + log := c.log.WithFields(logrus.Fields{ + telemetry.Entry: record.entry.EntryId, + telemetry.SPIFFEID: record.entry.SpiffeId, + }) + log.Debug("SVID updated") + + // Registration entry is updated, remove it from stale map + delete(c.staleEntries, entryID) + c.notifyBySelectorSet(notifySet) + clearSelectorSet(notifySet) + } +} + +// GetStaleEntries obtains a list of stale entries +func (c *LRUCache) GetStaleEntries() []*StaleEntry { + c.mu.Lock() + defer c.mu.Unlock() + + var staleEntries []*StaleEntry + for entryID := range c.staleEntries { + cachedEntry, ok := c.records[entryID] + if !ok { + c.log.WithField(telemetry.RegistrationID, entryID).Debug("Stale marker found for unknown entry. Please fill a bug") + delete(c.staleEntries, entryID) + continue + } + + var expiresAt time.Time + if cachedSvid, ok := c.svids[entryID]; ok { + expiresAt = cachedSvid.Chain[0].NotAfter + } + + staleEntries = append(staleEntries, &StaleEntry{ + Entry: cachedEntry.entry, + ExpiresAt: expiresAt, + }) + } + + return staleEntries +} + +// SyncSVIDsWithSubscribers will sync svid cache: +// entries with active subscribers which are not cached will be put in staleEntries map +// records which are not cached for remainder of max cache size will also be put in staleEntries map +func (c *LRUCache) SyncSVIDsWithSubscribers() { + c.mu.Lock() + defer c.mu.Unlock() + + c.syncSVIDsWithSubscribers() +} + +// Notify subscribers of selector set only if all SVIDs for corresponding selector set are cached +// It returns whether all SVIDs are cached or not. +// This method should be retried with backoff to avoid lock contention. +func (c *LRUCache) Notify(selectors []*common.Selector) bool { + c.mu.RLock() + defer c.mu.RUnlock() + set, setFree := allocSelectorSet(selectors...) + defer setFree() + if !c.missingSVIDRecords(set) { + c.notifyBySelectorSet(set) + return true + } + return false +} + +func (c *LRUCache) SubscribeToWorkloadUpdates(ctx context.Context, selectors Selectors) (Subscriber, error) { + return c.subscribeToWorkloadUpdates(ctx, selectors, nil) +} + +func (c *LRUCache) subscribeToWorkloadUpdates(ctx context.Context, selectors Selectors, notifyCallbackFn func()) (Subscriber, error) { + subscriber := c.NewSubscriber(selectors) + bo := c.subscribeBackoffFn() + // block until all svids are cached and subscriber is notified + for { + // notifyCallbackFn is used for testing + if c.Notify(selectors) { + if notifyCallbackFn != nil { + notifyCallbackFn() + } + return subscriber, nil + } + c.log.WithField(telemetry.Selectors, selectors).Info("Waiting for SVID to get cached") + // used for testing + if notifyCallbackFn != nil { + notifyCallbackFn() + } + + select { + case <-ctx.Done(): + subscriber.Finish() + return nil, ctx.Err() + case <-c.clk.After(bo.NextBackOff()): + } + } +} + +func (c *LRUCache) missingSVIDRecords(set selectorSet) bool { + records, recordsDone := c.getRecordsForSelectors(set) + defer recordsDone() + + for record := range records { + if _, exists := c.svids[record.entry.EntryId]; !exists { + return true + } + } + return false +} + +func (c *LRUCache) updateLastAccessTimestamp(selectors []*common.Selector) { + set, setFree := allocSelectorSet(selectors...) + defer setFree() + + records, recordsDone := c.getRecordsForSelectors(set) + defer recordsDone() + + now := c.clk.Now().UnixMilli() + for record := range records { + // Set lastAccessTimestamp so that svid LRU cache can be cleaned based on this timestamp + record.lastAccessTimestamp = now + } +} + +// entries with active subscribers which are not cached will be put in staleEntries map +// records which are not cached for remainder of max cache size will also be put in staleEntries map +func (c *LRUCache) syncSVIDsWithSubscribers() (map[string]struct{}, []recordAccessEvent) { + activeSubsByEntryID := make(map[string]struct{}) + lastAccessTimestamps := make([]recordAccessEvent, 0, len(c.records)) + + // iterate over all selectors from cached entries and obtain: + // 1. entries that have active subscribers + // 1.1 if those entries don't have corresponding SVID cached then put them in staleEntries + // so that SVID will be cached in next sync + // 2. get lastAccessTimestamp of each entry + for id, record := range c.records { + for _, sel := range record.entry.Selectors { + if index, ok := c.selectors[makeSelector(sel)]; ok && index != nil { + if len(index.subs) > 0 { + if _, ok := c.svids[record.entry.EntryId]; !ok { + c.staleEntries[id] = true + } + activeSubsByEntryID[id] = struct{}{} + break + } + } + } + lastAccessTimestamps = append(lastAccessTimestamps, newRecordAccessEvent(record.lastAccessTimestamp, id)) + } + + remainderSize := c.svidCacheMaxSize - len(c.svids) + // add records which are not cached for remainder of cache size + for id := range c.records { + if len(c.staleEntries) >= remainderSize { + break + } + if _, svidCached := c.svids[id]; !svidCached { + if _, ok := c.staleEntries[id]; !ok { + c.staleEntries[id] = true + } + } + } + + return activeSubsByEntryID, lastAccessTimestamps +} + +func (c *LRUCache) updateOrCreateRecord(newEntry *common.RegistrationEntry) (*lruCacheRecord, *common.RegistrationEntry) { + var existingEntry *common.RegistrationEntry + record, recordExists := c.records[newEntry.EntryId] + if !recordExists { + record = newLRUCacheRecord() + c.records[newEntry.EntryId] = record + } else { + existingEntry = record.entry + } + record.entry = newEntry + return record, existingEntry +} + +func (c *LRUCache) diffSelectors(existingEntry, newEntry *common.RegistrationEntry, added, removed selectorSet) { + // Make a set of all the selectors being added + if newEntry != nil { + added.Merge(newEntry.Selectors...) + } + + // Make a set of all the selectors that are being removed + if existingEntry != nil { + for _, selector := range existingEntry.Selectors { + s := makeSelector(selector) + if _, ok := added[s]; ok { + // selector already exists in entry + delete(added, s) + } else { + // selector has been removed from entry + removed[s] = struct{}{} + } + } + } +} + +func (c *LRUCache) diffFederatesWith(existingEntry, newEntry *common.RegistrationEntry, added, removed stringSet) { + // Make a set of all the selectors being added + if newEntry != nil { + added.Merge(newEntry.FederatesWith...) + } + + // Make a set of all the selectors that are being removed + if existingEntry != nil { + for _, id := range existingEntry.FederatesWith { + if _, ok := added[id]; ok { + // Bundle already exists in entry + delete(added, id) + } else { + // Bundle has been removed from entry + removed[id] = struct{}{} + } + } + } +} + +func (c *LRUCache) addSelectorIndicesRecord(selectors selectorSet, record *lruCacheRecord) { + for selector := range selectors { + c.addSelectorIndexRecord(selector, record) + } +} + +func (c *LRUCache) addSelectorIndexRecord(s selector, record *lruCacheRecord) { + index := c.getSelectorIndexForWrite(s) + index.records[record] = struct{}{} +} + +func (c *LRUCache) delSelectorIndicesRecord(selectors selectorSet, record *lruCacheRecord) { + for selector := range selectors { + c.delSelectorIndexRecord(selector, record) + } +} + +// delSelectorIndexRecord removes the record from the selector index. If +// the selector index is empty afterwards, it is also removed. +func (c *LRUCache) delSelectorIndexRecord(s selector, record *lruCacheRecord) { + index, ok := c.selectors[s] + if ok { + delete(index.records, record) + if index.isEmpty() { + delete(c.selectors, s) + } + } +} + +func (c *LRUCache) addSelectorIndexSub(s selector, sub *lruCacheSubscriber) { + index := c.getSelectorIndexForWrite(s) + index.subs[sub] = struct{}{} +} + +// delSelectorIndexSub removes the subscription from the selector index. If +// the selector index is empty afterwards, it is also removed. +func (c *LRUCache) delSelectorIndexSub(s selector, sub *lruCacheSubscriber) { + index, ok := c.selectors[s] + if ok { + delete(index.subs, sub) + if index.isEmpty() { + delete(c.selectors, s) + } + } +} + +func (c *LRUCache) unsubscribe(sub *lruCacheSubscriber) { + c.mu.Lock() + defer c.mu.Unlock() + for selector := range sub.set { + c.delSelectorIndexSub(selector, sub) + } +} + +func (c *LRUCache) notifyAll() { + subs, subsDone := c.allSubscribers() + defer subsDone() + for sub := range subs { + c.notify(sub) + } +} + +func (c *LRUCache) notifyBySelectorSet(sets ...selectorSet) { + notifiedSubs, notifiedSubsDone := allocLRUCacheSubscriberSet() + defer notifiedSubsDone() + for _, set := range sets { + subs, subsDone := c.getSubscribers(set) + defer subsDone() + for sub := range subs { + if _, notified := notifiedSubs[sub]; !notified && sub.set.SuperSetOf(set) { + c.notify(sub) + notifiedSubs[sub] = struct{}{} + } + } + } +} + +func (c *LRUCache) notify(sub *lruCacheSubscriber) { + update := c.buildWorkloadUpdate(sub.set) + sub.notify(update) +} + +func (c *LRUCache) allSubscribers() (lruCacheSubscriberSet, func()) { + subs, subsDone := allocLRUCacheSubscriberSet() + for _, index := range c.selectors { + for sub := range index.subs { + subs[sub] = struct{}{} + } + } + return subs, subsDone +} + +func (c *LRUCache) getSubscribers(set selectorSet) (lruCacheSubscriberSet, func()) { + subs, subsDone := allocLRUCacheSubscriberSet() + for s := range set { + if index := c.getSelectorIndexForRead(s); index != nil { + for sub := range index.subs { + subs[sub] = struct{}{} + } + } + } + return subs, subsDone +} + +func (c *LRUCache) matchingIdentities(set selectorSet) []Identity { + records, recordsDone := c.getRecordsForSelectors(set) + defer recordsDone() + + if len(records) == 0 { + return nil + } + + // Return identities in ascending "entry id" order to maintain a consistent + // ordering. + // TODO: figure out how to determine the "default" identity + out := make([]Identity, 0, len(records)) + for record := range records { + if svid, ok := c.svids[record.entry.EntryId]; ok { + out = append(out, makeNewIdentity(record, svid)) + } + } + sortIdentities(out) + return out +} + +func (c *LRUCache) matchingEntries(set selectorSet) []*common.RegistrationEntry { + records, recordsDone := c.getRecordsForSelectors(set) + defer recordsDone() + + if len(records) == 0 { + return nil + } + + // Return identities in ascending "entry id" order to maintain a consistent + // ordering. + // TODO: figure out how to determine the "default" identity + out := make([]*common.RegistrationEntry, 0, len(records)) + for record := range records { + out = append(out, record.entry) + } + sortEntriesByID(out) + return out +} + +func (c *LRUCache) buildWorkloadUpdate(set selectorSet) *WorkloadUpdate { + w := &WorkloadUpdate{ + Bundle: c.bundles[c.trustDomain], + FederatedBundles: make(map[spiffeid.TrustDomain]*bundleutil.Bundle), + Identities: c.matchingIdentities(set), + } + + // Add in the bundles the workload is federated with. + for _, identity := range w.Identities { + for _, federatesWith := range identity.Entry.FederatesWith { + td, err := spiffeid.TrustDomainFromString(federatesWith) + if err != nil { + c.log.WithFields(logrus.Fields{ + telemetry.TrustDomainID: federatesWith, + logrus.ErrorKey: err, + }).Warn("Invalid federated trust domain") + continue + } + if federatedBundle := c.bundles[td]; federatedBundle != nil { + w.FederatedBundles[td] = federatedBundle + } else { + c.log.WithFields(logrus.Fields{ + telemetry.RegistrationID: identity.Entry.EntryId, + telemetry.SPIFFEID: identity.Entry.SpiffeId, + telemetry.FederatedBundle: federatesWith, + }).Warn("Federated bundle contents missing") + } + } + } + + return w +} + +func (c *LRUCache) getRecordsForSelectors(set selectorSet) (lruCacheRecordSet, func()) { + // Build and dedup a list of candidate entries. Don't check for selector set inclusion yet, since + // that is a more expensive operation and we could easily have duplicate + // entries to check. + records, recordsDone := allocLRUCacheRecordSet() + for selector := range set { + if index := c.getSelectorIndexForRead(selector); index != nil { + for record := range index.records { + records[record] = struct{}{} + } + } + } + + // Filter out records whose registration entry selectors are not within + // inside the selector set. + for record := range records { + for _, s := range record.entry.Selectors { + if !set.In(s) { + delete(records, record) + } + } + } + return records, recordsDone +} + +// getSelectorIndexForWrite gets the selector index for the selector. If one +// doesn't exist, it is created. Callers must hold the write lock. If the index +// is only being read, then getSelectorIndexForRead should be used instead. +func (c *LRUCache) getSelectorIndexForWrite(s selector) *selectorsMapIndex { + index, ok := c.selectors[s] + if !ok { + index = newSelectorsMapIndex() + c.selectors[s] = index + } + return index +} + +// getSelectorIndexForRead gets the selector index for the selector. If one +// doesn't exist, nil is returned. Callers should hold the read or write lock. +// If the index is being modified, callers should use getSelectorIndexForWrite +// instead. +func (c *LRUCache) getSelectorIndexForRead(s selector) *selectorsMapIndex { + if index, ok := c.selectors[s]; ok { + return index + } + return nil +} + +type lruCacheRecord struct { + entry *common.RegistrationEntry + subs map[*lruCacheSubscriber]struct{} + lastAccessTimestamp int64 +} + +func newLRUCacheRecord() *lruCacheRecord { + return &lruCacheRecord{ + subs: make(map[*lruCacheSubscriber]struct{}), + } +} + +type selectorsMapIndex struct { + // subs holds the subscriptions related to this selector + subs map[*lruCacheSubscriber]struct{} + + // records holds the cache records related to this selector + records map[*lruCacheRecord]struct{} +} + +func (x *selectorsMapIndex) isEmpty() bool { + return len(x.subs) == 0 && len(x.records) == 0 +} + +func newSelectorsMapIndex() *selectorsMapIndex { + return &selectorsMapIndex{ + subs: make(map[*lruCacheSubscriber]struct{}), + records: make(map[*lruCacheRecord]struct{}), + } +} + +func sortEntriesByID(entries []*common.RegistrationEntry) { + sort.Slice(entries, func(a, b int) bool { + return entries[a].EntryId < entries[b].EntryId + }) +} + +func sortByTimestamps(records []recordAccessEvent) { + sort.Slice(records, func(a, b int) bool { + return records[a].timestamp < records[b].timestamp + }) +} + +func makeNewIdentity(record *lruCacheRecord, svid *X509SVID) Identity { + return Identity{ + Entry: record.entry, + SVID: svid.Chain, + PrivateKey: svid.PrivateKey, + } +} + +type recordAccessEvent struct { + timestamp int64 + id string +} + +func newRecordAccessEvent(timestamp int64, id string) recordAccessEvent { + return recordAccessEvent{timestamp: timestamp, id: id} +} diff --git a/pkg/agent/manager/cache/lru_cache_subscriber.go b/pkg/agent/manager/cache/lru_cache_subscriber.go new file mode 100644 index 00000000000..00556f89a9c --- /dev/null +++ b/pkg/agent/manager/cache/lru_cache_subscriber.go @@ -0,0 +1,60 @@ +package cache + +import ( + "sync" + + "github.com/spiffe/spire/proto/spire/common" +) + +type lruCacheSubscriber struct { + cache *LRUCache + set selectorSet + setFree func() + + mu sync.Mutex + c chan *WorkloadUpdate + done bool +} + +func newLRUCacheSubscriber(cache *LRUCache, selectors []*common.Selector) *lruCacheSubscriber { + set, setFree := allocSelectorSet(selectors...) + return &lruCacheSubscriber{ + cache: cache, + set: set, + setFree: setFree, + c: make(chan *WorkloadUpdate, 1), + } +} + +func (s *lruCacheSubscriber) Updates() <-chan *WorkloadUpdate { + return s.c +} + +func (s *lruCacheSubscriber) Finish() { + s.mu.Lock() + done := s.done + if !done { + s.done = true + close(s.c) + } + s.mu.Unlock() + if !done { + s.cache.unsubscribe(s) + s.setFree() + s.set = nil + } +} + +func (s *lruCacheSubscriber) notify(update *WorkloadUpdate) { + s.mu.Lock() + defer s.mu.Unlock() + if s.done { + return + } + + select { + case <-s.c: + default: + } + s.c <- update +} diff --git a/pkg/agent/manager/cache/lru_cache_test.go b/pkg/agent/manager/cache/lru_cache_test.go new file mode 100644 index 00000000000..c270dd3aee9 --- /dev/null +++ b/pkg/agent/manager/cache/lru_cache_test.go @@ -0,0 +1,954 @@ +package cache + +import ( + "context" + "crypto/x509" + "fmt" + "runtime" + "testing" + "time" + + "github.com/andres-erbsen/clock" + "github.com/sirupsen/logrus/hooks/test" + "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/spiffe/spire/pkg/common/bundleutil" + "github.com/spiffe/spire/pkg/common/telemetry" + "github.com/spiffe/spire/proto/spire/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLRUCacheFetchWorkloadUpdate(t *testing.T) { + cache := newTestLRUCache() + // populate the cache with FOO and BAR without SVIDS + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") + bar.FederatesWith = makeFederatesWith(otherBundleV1) + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1, otherBundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + } + cache.UpdateEntries(updateEntries, nil) + + workloadUpdate := cache.FetchWorkloadUpdate(makeSelectors("A", "B")) + assert.Len(t, workloadUpdate.Identities, 0, "identities should not be returned that don't have SVIDs") + + updateSVIDs := &UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + } + cache.UpdateSVIDs(updateSVIDs) + + workloadUpdate = cache.FetchWorkloadUpdate(makeSelectors("A", "B")) + assert.Equal(t, &WorkloadUpdate{ + Bundle: bundleV1, + FederatedBundles: makeBundles(otherBundleV1), + Identities: []Identity{ + {Entry: bar}, + {Entry: foo}, + }, + }, workloadUpdate) +} + +func TestLRUCacheMatchingRegistrationIdentities(t *testing.T) { + cache := newTestLRUCache() + + // populate the cache with FOO and BAR without SVIDS + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + } + cache.UpdateEntries(updateEntries, nil) + + assert.Equal(t, []*common.RegistrationEntry{bar, foo}, + cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) + + // Update SVIDs and MatchingRegistrationEntries should return both entries + updateSVIDs := &UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + } + cache.UpdateSVIDs(updateSVIDs) + assert.Equal(t, []*common.RegistrationEntry{bar, foo}, + cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) + + // Remove SVIDs and MatchingRegistrationEntries should still return both entries + cache.UpdateSVIDs(&UpdateSVIDs{}) + assert.Equal(t, []*common.RegistrationEntry{bar, foo}, + cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) +} + +func TestLRUCacheCountSVIDs(t *testing.T) { + cache := newTestLRUCache() + + // populate the cache with FOO and BAR without SVIDS + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + } + cache.UpdateEntries(updateEntries, nil) + + // No SVIDs expected + require.Equal(t, 0, cache.CountSVIDs()) + + updateSVIDs := &UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + } + cache.UpdateSVIDs(updateSVIDs) + + // Only one SVID expected + require.Equal(t, 1, cache.CountSVIDs()) +} + +func TestLRUCacheBundleChanges(t *testing.T) { + cache := newTestLRUCache() + + bundleStream := cache.SubscribeToBundleChanges() + assert.Equal(t, makeBundles(bundleV1), bundleStream.Value()) + + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1, otherBundleV1), + }, nil) + if assert.True(t, bundleStream.HasNext(), "has new bundle value after adding bundle") { + bundleStream.Next() + assert.Equal(t, makeBundles(bundleV1, otherBundleV1), bundleStream.Value()) + } + + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + }, nil) + + if assert.True(t, bundleStream.HasNext(), "has new bundle value after removing bundle") { + bundleStream.Next() + assert.Equal(t, makeBundles(bundleV1), bundleStream.Value()) + } +} + +func TestLRUCacheAllSubscribersNotifiedOnBundleChange(t *testing.T) { + cache := newTestLRUCache() + + // create some subscribers and assert they get the initial bundle + subA := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer subA.Finish() + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{Bundle: bundleV1}) + + subB := subscribeToWorkloadUpdates(t, cache, makeSelectors("B")) + defer subB.Finish() + assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{Bundle: bundleV1}) + + // update the bundle and assert all subscribers gets the updated bundle + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + }, nil) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{Bundle: bundleV2}) + assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{Bundle: bundleV2}) +} + +func TestLRUCacheSomeSubscribersNotifiedOnFederatedBundleChange(t *testing.T) { + cache := newTestLRUCache() + + // initialize the cache with an entry FOO that has a valid SVID and + // selector "A" + foo := makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + // subscribe to A and B and assert initial updates are received. + subA := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer subA.Finish() + assertAnyWorkloadUpdate(t, subA) + + subB := subscribeToWorkloadUpdates(t, cache, makeSelectors("B")) + defer subB.Finish() + assertAnyWorkloadUpdate(t, subB) + + // add the federated bundle with no registration entries federating with + // it and make sure nobody is notified. + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1, otherBundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + assertNoWorkloadUpdate(t, subA) + assertNoWorkloadUpdate(t, subB) + + // update FOO to federate with otherdomain.test and make sure subA is + // notified but not subB. + foo = makeRegistrationEntry("FOO", "A") + foo.FederatesWith = makeFederatesWith(otherBundleV1) + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1, otherBundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + FederatedBundles: makeBundles(otherBundleV1), + Identities: []Identity{{Entry: foo}}, + }) + assertNoWorkloadUpdate(t, subB) + + // now change the federated bundle and make sure subA gets notified, but + // again, not subB. + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1, otherBundleV2), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + FederatedBundles: makeBundles(otherBundleV2), + Identities: []Identity{{Entry: foo}}, + }) + assertNoWorkloadUpdate(t, subB) + + // now drop the federation and make sure subA is again notified and no + // longer has the federated bundle. + foo = makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1, otherBundleV2), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) + assertNoWorkloadUpdate(t, subB) +} + +func TestLRUCacheSubscribersGetEntriesWithSelectorSubsets(t *testing.T) { + cache := newTestLRUCache() + + // create subscribers for each combination of selectors + subA := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer subA.Finish() + subB := subscribeToWorkloadUpdates(t, cache, makeSelectors("B")) + defer subB.Finish() + subAB := subscribeToWorkloadUpdates(t, cache, makeSelectors("A", "B")) + defer subAB.Finish() + + // assert all subscribers get the initial update + initialUpdate := &WorkloadUpdate{Bundle: bundleV1} + assertWorkloadUpdateEqual(t, subA, initialUpdate) + assertWorkloadUpdateEqual(t, subB, initialUpdate) + assertWorkloadUpdateEqual(t, subAB, initialUpdate) + + // create entry FOO that will target any subscriber with containing (A) + foo := makeRegistrationEntry("FOO", "A") + + // create entry BAR that will target any subscriber with containing (A,C) + bar := makeRegistrationEntry("BAR", "A", "C") + + // update the cache with foo and bar + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + }) + + // subA selector set contains (A), but not (A, C), so it should only get FOO + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) + + // subB selector set does not contain either (A) or (A,C) so it isn't even + // notified. + assertNoWorkloadUpdate(t, subB) + + // subAB selector set contains (A) but not (A, C), so it should get FOO + assertWorkloadUpdateEqual(t, subAB, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) +} + +func TestLRUCacheSubscriberIsNotNotifiedIfNothingChanges(t *testing.T) { + cache := newTestLRUCache() + + foo := makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + sub := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer sub.Finish() + assertAnyWorkloadUpdate(t, sub) + + // Second update is the same (other than X509SVIDs, which, when set, + // always constitute a "change" for the impacted registration entries. + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + + assertNoWorkloadUpdate(t, sub) +} + +func TestLRUCacheSubscriberNotifiedOnSVIDChanges(t *testing.T) { + cache := newTestLRUCache() + + foo := makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + sub := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer sub.Finish() + assertAnyWorkloadUpdate(t, sub) + + // Update SVID + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) +} + +func TestLRUCacheSubscriberNotificationsOnSelectorChanges(t *testing.T) { + cache := newTestLRUCache() + + // initialize the cache with a FOO entry with selector A and an SVID + foo := makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + // create subscribers for A and make sure the initial update has FOO + sub := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer sub.Finish() + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) + + // update FOO to have selectors (A,B) and make sure the subscriber loses + // FOO, since (A,B) is not a subset of the subscriber set (A). + foo = makeRegistrationEntry("FOO", "A", "B") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + Bundle: bundleV1, + }) + + // update FOO to drop B and make sure the subscriber regains FOO + foo = makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) +} + +func TestLRUCacheSubscriberNotifiedWhenEntryDropped(t *testing.T) { + cache := newTestLRUCache() + + subA := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer subA.Finish() + assertAnyWorkloadUpdate(t, subA) + + // subB's job here is to just make sure we don't notify unrelated + // subscribers when dropping registration entries + subB := subscribeToWorkloadUpdates(t, cache, makeSelectors("B")) + defer subB.Finish() + assertAnyWorkloadUpdate(t, subB) + + foo := makeRegistrationEntry("FOO", "A") + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + } + cache.UpdateEntries(updateEntries, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + // make sure subA gets notified with FOO but not subB + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) + assertNoWorkloadUpdate(t, subB) + + updateEntries.RegistrationEntries = nil + cache.UpdateEntries(updateEntries, nil) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + }) + assertNoWorkloadUpdate(t, subB) + + // Make sure trying to update SVIDs of removed entry does not notify + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + assertNoWorkloadUpdate(t, subB) +} + +func TestLRUCacheSubscriberOnlyGetsEntriesWithSVID(t *testing.T) { + cache := newTestLRUCache() + + foo := makeRegistrationEntry("FOO", "A") + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + } + cache.UpdateEntries(updateEntries, nil) + + sub := cache.NewSubscriber(makeSelectors("A")) + defer sub.Finish() + assertNoWorkloadUpdate(t, sub) + + // update to include the SVID and now we should get the update + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) +} + +func TestLRUCacheSubscribersDoNotBlockNotifications(t *testing.T) { + cache := newTestLRUCache() + + sub := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer sub.Finish() + + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + }, nil) + + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV3), + }, nil) + + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + Bundle: bundleV3, + }) +} + +func TestLRUCacheCheckSVIDCallback(t *testing.T) { + cache := newTestLRUCache() + + // no calls because there are no registration entries + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { + assert.Fail(t, "should not be called if there are no registration entries") + + return false + }) + + foo := makeRegistrationEntryWithTTL("FOO", 60) + + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + RegistrationEntries: makeRegistrationEntries(foo), + }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { + // should not get invoked + assert.Fail(t, "should not be called as no SVIDs are cached yet") + return false + }) + + // called once for FOO with new SVID + svids := makeX509SVIDs(foo) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: svids, + }) + + // called once for FOO with existing SVID + callCount := 0 + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + RegistrationEntries: makeRegistrationEntries(foo), + }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { + callCount++ + assert.Equal(t, "FOO", newEntry.EntryId) + if assert.NotNil(t, svid) { + assert.Exactly(t, svids["FOO"], svid) + } + + return true + }) + assert.Equal(t, 1, callCount) + assert.Equal(t, map[string]bool{foo.EntryId: true}, cache.staleEntries) +} + +func TestLRUCacheGetStaleEntries(t *testing.T) { + cache := newTestLRUCache() + + foo := makeRegistrationEntryWithTTL("FOO", 60) + + // Create entry but don't mark it stale from checkSVID method; + // it will be marked stale cause it does not have SVID cached + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + RegistrationEntries: makeRegistrationEntries(foo), + }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { + return false + }) + + // Assert that the entry is returned as stale. The `ExpiresAt` field should be unset since there is no SVID. + expectedEntries := []*StaleEntry{{Entry: cache.records[foo.EntryId].entry}} + assert.Equal(t, expectedEntries, cache.GetStaleEntries()) + + // Update the SVID for the stale entry + svids := make(map[string]*X509SVID) + expiredAt := time.Now() + svids[foo.EntryId] = &X509SVID{ + Chain: []*x509.Certificate{{NotAfter: expiredAt}}, + } + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: svids, + }) + // Assert that updating the SVID removes stale marker from entry + assert.Empty(t, cache.GetStaleEntries()) + + // Update entry again and mark it as stale + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + RegistrationEntries: makeRegistrationEntries(foo), + }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { + return true + }) + + // Assert that the entry again returns as stale. This time the `ExpiresAt` field should be populated with the expiration of the SVID. + expectedEntries = []*StaleEntry{{ + Entry: cache.records[foo.EntryId].entry, + ExpiresAt: expiredAt, + }} + assert.Equal(t, expectedEntries, cache.GetStaleEntries()) + + // Remove registration entry and assert that it is no longer returned as stale + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { + return true + }) + assert.Empty(t, cache.GetStaleEntries()) +} + +func TestLRUCacheSubscriberNotNotifiedOnDifferentSVIDChanges(t *testing.T) { + cache := newTestLRUCache() + + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + }) + + sub := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer sub.Finish() + assertAnyWorkloadUpdate(t, sub) + + // Update SVID + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(bar), + }) + + assertNoWorkloadUpdate(t, sub) +} + +func TestLRUCacheSubscriberNotNotifiedOnOverlappingSVIDChanges(t *testing.T) { + cache := newTestLRUCache() + + foo := makeRegistrationEntry("FOO", "A", "C") + bar := makeRegistrationEntry("FOO", "A", "B") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + }) + + sub := subscribeToWorkloadUpdates(t, cache, makeSelectors("A", "B")) + defer sub.Finish() + assertAnyWorkloadUpdate(t, sub) + + // Update SVID + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + assertNoWorkloadUpdate(t, sub) +} + +func TestLRUCacheSVIDCacheExpiry(t *testing.T) { + clk := clock.NewMock() + cache := newTestLRUCacheWithConfig(10, clk) + + clk.Add(1 * time.Second) + foo := makeRegistrationEntry("FOO", "A") + // validate workload update for foo + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + subA := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) + subA.Finish() + + // move clk by 1 sec so that SVID access time will be different + clk.Add(1 * time.Second) + bar := makeRegistrationEntry("BAR", "B") + // validate workload update for bar + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(bar), + }) + + // not closing subscriber immediately + subB := subscribeToWorkloadUpdates(t, cache, makeSelectors("B")) + defer subB.Finish() + assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{ + {Entry: bar}, + }, + }) + + // Move clk by 2 seconds + clk.Add(2 * time.Second) + // update total of 12 entries + updateEntries := createUpdateEntries(10, makeBundles(bundleV1)) + updateEntries.RegistrationEntries[foo.EntryId] = foo + updateEntries.RegistrationEntries[bar.EntryId] = bar + + cache.UpdateEntries(updateEntries, nil) + + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDsFromMap(updateEntries.RegistrationEntries), + }) + + for id, entry := range updateEntries.RegistrationEntries { + // create and close subscribers for remaining entries so that svid cache is full + if id != foo.EntryId && id != bar.EntryId { + sub := cache.NewSubscriber(entry.Selectors) + sub.Finish() + } + } + assert.Equal(t, 12, cache.CountSVIDs()) + + cache.UpdateEntries(updateEntries, nil) + assert.Equal(t, 10, cache.CountSVIDs()) + + // foo SVID should be removed from cache as it does not have active subscriber + assert.False(t, cache.Notify(makeSelectors("A"))) + // bar SVID should be cached as it has active subscriber + assert.True(t, cache.Notify(makeSelectors("B"))) + + subA = cache.NewSubscriber(makeSelectors("A")) + defer subA.Finish() + + cache.UpdateEntries(updateEntries, nil) + + // Make sure foo is marked as stale entry which does not have svid cached + require.Len(t, cache.GetStaleEntries(), 1) + assert.Equal(t, foo, cache.GetStaleEntries()[0].Entry) + + assert.Equal(t, 10, cache.CountSVIDs()) +} + +func TestLRUCacheMaxSVIDCacheSize(t *testing.T) { + clk := clock.NewMock() + cache := newTestLRUCacheWithConfig(10, clk) + + // create entries more than maxSvidCacheSize + updateEntries := createUpdateEntries(12, makeBundles(bundleV1)) + cache.UpdateEntries(updateEntries, nil) + + require.Len(t, cache.GetStaleEntries(), 10) + + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDsFromStaleEntries(cache.GetStaleEntries()), + }) + require.Len(t, cache.GetStaleEntries(), 0) + assert.Equal(t, 10, cache.CountSVIDs()) + + // Validate that active subscriber will still get SVID even if SVID count is at maxSvidCacheSize + foo := makeRegistrationEntry("FOO", "A") + updateEntries.RegistrationEntries[foo.EntryId] = foo + + subA := cache.NewSubscriber(foo.Selectors) + defer subA.Finish() + + cache.UpdateEntries(updateEntries, nil) + require.Len(t, cache.GetStaleEntries(), 1) + assert.Equal(t, 10, cache.CountSVIDs()) + + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + assert.Equal(t, 11, cache.CountSVIDs()) + require.Len(t, cache.GetStaleEntries(), 0) +} + +func TestSyncSVIDsWithSubscribers(t *testing.T) { + clk := clock.NewMock() + cache := newTestLRUCacheWithConfig(5, clk) + + updateEntries := createUpdateEntries(5, makeBundles(bundleV1)) + cache.UpdateEntries(updateEntries, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDsFromStaleEntries(cache.GetStaleEntries()), + }) + assert.Equal(t, 5, cache.CountSVIDs()) + + // Update foo but its SVID is not yet cached + foo := makeRegistrationEntry("FOO", "A") + updateEntries.RegistrationEntries[foo.EntryId] = foo + + cache.UpdateEntries(updateEntries, nil) + + // Create a subscriber for foo + subA := cache.NewSubscriber(foo.Selectors) + defer subA.Finish() + require.Len(t, cache.GetStaleEntries(), 0) + + // After SyncSVIDsWithSubscribers foo should be marked as stale, requiring signing + cache.SyncSVIDsWithSubscribers() + require.Len(t, cache.GetStaleEntries(), 1) + assert.Equal(t, []*StaleEntry{{Entry: cache.records[foo.EntryId].entry}}, cache.GetStaleEntries()) + + assert.Equal(t, 5, cache.CountSVIDs()) +} + +func TestNotify(t *testing.T) { + cache := newTestLRUCache() + + foo := makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + + assert.False(t, cache.Notify(makeSelectors("A"))) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + assert.True(t, cache.Notify(makeSelectors("A"))) +} + +func TestSubscribeToLRUCacheChanges(t *testing.T) { + clk := clock.NewMock() + cache := newTestLRUCacheWithConfig(1, clk) + + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + }, nil) + + sub1WaitCh := make(chan struct{}, 1) + sub1ErrCh := make(chan error, 1) + go func() { + sub1, err := cache.subscribeToWorkloadUpdates(context.Background(), foo.Selectors, func() { + sub1WaitCh <- struct{}{} + }) + if err != nil { + sub1ErrCh <- err + return + } + + defer sub1.Finish() + u1 := <-sub1.Updates() + if len(u1.Identities) != 1 { + sub1ErrCh <- fmt.Errorf("expected 1 SVID, got: %d", len(u1.Identities)) + return + } + sub1ErrCh <- nil + }() + + sub2WaitCh := make(chan struct{}, 1) + sub2ErrCh := make(chan error, 1) + go func() { + sub2, err := cache.subscribeToWorkloadUpdates(context.Background(), bar.Selectors, func() { + sub2WaitCh <- struct{}{} + }) + if err != nil { + sub2ErrCh <- err + return + } + + defer sub2.Finish() + u2 := <-sub2.Updates() + if len(u2.Identities) != 1 { + sub1ErrCh <- fmt.Errorf("expected 1 SVID, got: %d", len(u2.Identities)) + return + } + sub2ErrCh <- nil + }() + + <-sub1WaitCh + <-sub2WaitCh + cache.SyncSVIDsWithSubscribers() + + assert.Len(t, cache.GetStaleEntries(), 2) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + }) + assert.Equal(t, 2, cache.CountSVIDs()) + + clk.Add(SvidSyncInterval * 2) + + sub1Err := <-sub1ErrCh + assert.NoError(t, sub1Err, "subscriber 1 error") + + sub2Err := <-sub2ErrCh + assert.NoError(t, sub2Err, "subscriber 2 error") +} + +func TestNewLRUCache(t *testing.T) { + // negative value + cache := newTestLRUCacheWithConfig(-5, clock.NewMock()) + require.Equal(t, DefaultSVIDCacheMaxSize, cache.svidCacheMaxSize) + + // zero value + cache = newTestLRUCacheWithConfig(0, clock.NewMock()) + require.Equal(t, DefaultSVIDCacheMaxSize, cache.svidCacheMaxSize) +} + +func BenchmarkLRUCacheGlobalNotification(b *testing.B) { + cache := newTestLRUCache() + + const numEntries = 1000 + const numWorkloads = 1000 + const selectorsPerEntry = 3 + const selectorsPerWorkload = 10 + + // build a set of 1000 registration entries with distinct selectors + bundlesV1 := makeBundles(bundleV1) + bundlesV2 := makeBundles(bundleV2) + updateEntries := &UpdateEntries{ + Bundles: bundlesV1, + RegistrationEntries: make(map[string]*common.RegistrationEntry, numEntries), + } + for i := 0; i < numEntries; i++ { + entryID := fmt.Sprintf("00000000-0000-0000-0000-%012d", i) + updateEntries.RegistrationEntries[entryID] = &common.RegistrationEntry{ + EntryId: entryID, + ParentId: "spiffe://domain.test/node", + SpiffeId: fmt.Sprintf("spiffe://domain.test/workload-%d", i), + Selectors: distinctSelectors(i, selectorsPerEntry), + } + } + + cache.UpdateEntries(updateEntries, nil) + for i := 0; i < numWorkloads; i++ { + selectors := distinctSelectors(i, selectorsPerWorkload) + cache.NewSubscriber(selectors) + } + + runtime.GC() + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + if i%2 == 0 { + updateEntries.Bundles = bundlesV2 + } else { + updateEntries.Bundles = bundlesV1 + } + cache.UpdateEntries(updateEntries, nil) + } +} + +func newTestLRUCache() *LRUCache { + log, _ := test.NewNullLogger() + return NewLRUCache(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, + telemetry.Blackhole{}, 0, clock.NewMock()) +} + +func newTestLRUCacheWithConfig(svidCacheMaxSize int, clk clock.Clock) *LRUCache { + log, _ := test.NewNullLogger() + return NewLRUCache(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, telemetry.Blackhole{}, + svidCacheMaxSize, clk) +} + +// numEntries should not be more than 12 digits +func createUpdateEntries(numEntries int, bundles map[spiffeid.TrustDomain]*bundleutil.Bundle) *UpdateEntries { + updateEntries := &UpdateEntries{ + Bundles: bundles, + RegistrationEntries: make(map[string]*common.RegistrationEntry, numEntries), + } + + for i := 0; i < numEntries; i++ { + entryID := fmt.Sprintf("00000000-0000-0000-0000-%012d", i) + updateEntries.RegistrationEntries[entryID] = &common.RegistrationEntry{ + EntryId: entryID, + ParentId: "spiffe://domain.test/node", + SpiffeId: fmt.Sprintf("spiffe://domain.test/workload-%d", i), + Selectors: distinctSelectors(i, 1), + } + } + return updateEntries +} + +func makeX509SVIDsFromMap(entries map[string]*common.RegistrationEntry) map[string]*X509SVID { + out := make(map[string]*X509SVID) + for _, entry := range entries { + out[entry.EntryId] = &X509SVID{} + } + return out +} + +func makeX509SVIDsFromStaleEntries(entries []*StaleEntry) map[string]*X509SVID { + out := make(map[string]*X509SVID) + for _, entry := range entries { + out[entry.Entry.EntryId] = &X509SVID{} + } + return out +} + +func subscribeToWorkloadUpdates(t *testing.T, cache *LRUCache, selectors []*common.Selector) Subscriber { + subscriber, err := cache.subscribeToWorkloadUpdates(context.Background(), selectors, nil) + assert.NoError(t, err) + return subscriber +} diff --git a/pkg/agent/manager/cache/sets.go b/pkg/agent/manager/cache/sets.go index 6c9e1701bb2..c7cc0d68950 100644 --- a/pkg/agent/manager/cache/sets.go +++ b/pkg/agent/manager/cache/sets.go @@ -30,6 +30,18 @@ var ( return make(recordSet) }, } + + lruCacheRecordSetPool = sync.Pool{ + New: func() interface{} { + return make(lruCacheRecordSet) + }, + } + + lruCacheSubscriberSetPool = sync.Pool{ + New: func() interface{} { + return make(lruCacheSubscriberSet) + }, + } ) // unique set of strings, allocated from a pool @@ -149,3 +161,37 @@ func clearRecordSet(set recordSet) { delete(set, k) } } + +// unique set of LRU cache records, allocated from a pool +type lruCacheRecordSet map[*lruCacheRecord]struct{} + +func allocLRUCacheRecordSet() (lruCacheRecordSet, func()) { + set := lruCacheRecordSetPool.Get().(lruCacheRecordSet) + return set, func() { + clearLRUCacheRecordSet(set) + lruCacheRecordSetPool.Put(set) + } +} + +func clearLRUCacheRecordSet(set lruCacheRecordSet) { + for k := range set { + delete(set, k) + } +} + +// unique set of LRU cache subscribers, allocated from a pool +type lruCacheSubscriberSet map[*lruCacheSubscriber]struct{} + +func allocLRUCacheSubscriberSet() (lruCacheSubscriberSet, func()) { + set := lruCacheSubscriberSetPool.Get().(lruCacheSubscriberSet) + return set, func() { + clearLRUCacheSubscriberSet(set) + lruCacheSubscriberSetPool.Put(set) + } +} + +func clearLRUCacheSubscriberSet(set lruCacheSubscriberSet) { + for k := range set { + delete(set, k) + } +} diff --git a/pkg/agent/manager/config.go b/pkg/agent/manager/config.go index 353cdbab7d0..b192a1ef925 100644 --- a/pkg/agent/manager/config.go +++ b/pkg/agent/manager/config.go @@ -58,8 +58,15 @@ func newManager(c *Config) *manager { c.Clk = clock.New() } - cache := cache.New(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, - c.Metrics, c.SVIDCacheMaxSize, c.Clk) + var x509SVIDCache Cache + if c.SVIDCacheMaxSize > 0 { + // use LRU cache implementation + x509SVIDCache = cache.NewLRUCache(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, + c.Metrics, c.SVIDCacheMaxSize, c.Clk) + } else { + x509SVIDCache = cache.New(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, + c.Metrics) + } rotCfg := &svid.RotatorConfig{ SVIDKeyManager: keymanager.ForSVID(c.Catalog.GetKeyManager()), @@ -67,7 +74,7 @@ func newManager(c *Config) *manager { Metrics: c.Metrics, SVID: c.SVID, SVIDKey: c.SVIDKey, - BundleStream: cache.SubscribeToBundleChanges(), + BundleStream: x509SVIDCache.SubscribeToBundleChanges(), ServerAddr: c.ServerAddr, TrustDomain: c.TrustDomain, Interval: c.RotationInterval, @@ -76,7 +83,7 @@ func newManager(c *Config) *manager { svidRotator, client := svid.NewRotator(rotCfg) m := &manager{ - cache: cache, + cache: x509SVIDCache, c: c, mtx: new(sync.RWMutex), svid: svidRotator, diff --git a/pkg/agent/manager/manager.go b/pkg/agent/manager/manager.go index fc427f22dd3..775623acbf1 100644 --- a/pkg/agent/manager/manager.go +++ b/pkg/agent/manager/manager.go @@ -25,8 +25,6 @@ import ( "github.com/spiffe/spire/proto/spire/common" ) -const svidSyncInterval = 500 * time.Millisecond - // Manager provides cache management functionalities for agents. type Manager interface { // Initialize initializes the manager. @@ -78,13 +76,52 @@ type Manager interface { GetBundle() *cache.Bundle } +// Cache stores each registration entry, signed X509-SVIDs for those entries, +// bundles, and JWT SVIDs for the agent. +type Cache interface { + SVIDCache + + // Bundle gets latest cached bundle + Bundle() *bundleutil.Bundle + + // SyncSVIDsWithSubscribers syncs SVID cache + SyncSVIDsWithSubscribers() + + // SubscribeToWorkloadUpdates creates a subscriber for given selector set. + SubscribeToWorkloadUpdates(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) + + // SubscribeToBundleChanges creates a stream for providing bundle changes + SubscribeToBundleChanges() *cache.BundleStream + + // MatchingRegistrationEntries with given selectors + MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry + + // CountSVIDs in cache stored + CountSVIDs() int + + // FetchWorkloadUpdate for giveb selectors + FetchWorkloadUpdate(selectors []*common.Selector) *cache.WorkloadUpdate + + // GetJWTSVID provides JWTSVID + GetJWTSVID(id spiffeid.ID, audience []string) (*client.JWTSVID, bool) + + // SetJWTSVID adds JWTSVID to cache + SetJWTSVID(id spiffeid.ID, audience []string, svid *client.JWTSVID) + + // Entries get all registration entries + Entries() []*common.RegistrationEntry + + // Identities get all identities in cache + Identities() []cache.Identity +} + type manager struct { c *Config // Fields protected by mtx mutex. mtx *sync.RWMutex - cache *cache.Cache + cache Cache svid svid.Rotator storage storage.Storage @@ -93,7 +130,6 @@ type manager struct { // fetch attempt synchronizeBackoff backoff.BackOff svidSyncBackoff backoff.BackOff - subscribeBackoffFn func() backoff.BackOff client client.Client @@ -111,12 +147,7 @@ func (m *manager) Initialize(ctx context.Context) error { m.storeBundle(m.cache.Bundle()) m.synchronizeBackoff = backoff.NewBackoff(m.clk, m.c.SyncInterval) - m.svidSyncBackoff = backoff.NewBackoff(m.clk, svidSyncInterval) - if m.subscribeBackoffFn == nil { - m.subscribeBackoffFn = func() backoff.BackOff { - return backoff.NewBackoff(m.clk, svidSyncInterval) - } - } + m.svidSyncBackoff = backoff.NewBackoff(m.clk, cache.SvidSyncInterval) err := m.synchronize(ctx) if nodeutil.ShouldAgentReattest(err) { @@ -151,30 +182,7 @@ func (m *manager) Run(ctx context.Context) error { } func (m *manager) SubscribeToCacheChanges(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) { - return m.subscribeToCacheChanges(ctx, selectors, nil) -} - -func (m *manager) subscribeToCacheChanges(ctx context.Context, selectors cache.Selectors, notifyCallbackFn func()) (cache.Subscriber, error) { - subscriber := m.cache.NewSubscriber(selectors) - bo := m.subscribeBackoffFn() - // block until all svids are cached and subscriber is notified - for { - svidsInCache := m.cache.Notify(selectors) - // used for testing - if notifyCallbackFn != nil { - notifyCallbackFn() - } - if svidsInCache { - return subscriber, nil - } - m.c.Log.WithField(telemetry.Selectors, selectors).Info("Waiting for SVID to get cached") - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-m.clk.After(bo.NextBackOff()): - } - } + return m.cache.SubscribeToWorkloadUpdates(ctx, selectors) } func (m *manager) SubscribeToSVIDChanges() observer.Stream { diff --git a/pkg/agent/manager/manager_test.go b/pkg/agent/manager/manager_test.go index 4e4d5c3512e..49a7a0c9a77 100644 --- a/pkg/agent/manager/manager_test.go +++ b/pkg/agent/manager/manager_test.go @@ -14,7 +14,6 @@ import ( "testing" "time" - backoff "github.com/cenkalti/backoff/v3" testlog "github.com/sirupsen/logrus/hooks/test" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/go-spiffe/v2/svid/x509svid" @@ -73,16 +72,15 @@ func TestInitializationFailure(t *testing.T) { cat.SetKeyManager(km) c := &Config{ - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - Metrics: &telemetry.Blackhole{}, - TrustDomain: trustDomain, - Storage: openStorage(t, dir), - Clk: clk, - Catalog: cat, - SVIDCacheMaxSize: 1, - SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + Metrics: &telemetry.Blackhole{}, + TrustDomain: trustDomain, + Storage: openStorage(t, dir), + Clk: clk, + Catalog: cat, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m := newManager(c) require.Error(t, m.Initialize(context.Background())) @@ -101,16 +99,15 @@ func TestStoreBundleOnStartup(t *testing.T) { sto := openStorage(t, dir) c := &Config{ - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - Metrics: &telemetry.Blackhole{}, - TrustDomain: trustDomain, - Storage: sto, - Bundle: bundleutil.BundleFromRootCA(trustDomain, ca), - Clk: clk, - Catalog: cat, - SVIDCacheMaxSize: 1, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + Metrics: &telemetry.Blackhole{}, + TrustDomain: trustDomain, + Storage: sto, + Bundle: bundleutil.BundleFromRootCA(trustDomain, ca), + Clk: clk, + Catalog: cat, } m := newManager(c) @@ -150,15 +147,14 @@ func TestStoreSVIDOnStartup(t *testing.T) { sto := openStorage(t, dir) c := &Config{ - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - Metrics: &telemetry.Blackhole{}, - TrustDomain: trustDomain, - Storage: sto, - Clk: clk, - Catalog: cat, - SVIDCacheMaxSize: 1, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + Metrics: &telemetry.Blackhole{}, + TrustDomain: trustDomain, + Storage: sto, + Clk: clk, + Catalog: cat, } if _, err := sto.LoadSVID(); !errors.Is(err, storage.ErrNotCached) { @@ -403,7 +399,6 @@ func TestSVIDRotation(t *testing.T) { SyncInterval: 1 * time.Hour, Clk: clk, WorkloadKeyType: workloadkey.ECP256, - SVIDCacheMaxSize: 1, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } @@ -521,11 +516,6 @@ func TestSynchronization(t *testing.T) { m := newManager(c) - if err := m.Initialize(context.Background()); err != nil { - t.Fatal(err) - } - require.Equal(t, clk.Now(), m.GetLastSync()) - sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{ {Type: "unix", Value: "uid:1111"}, {Type: "spiffe_id", Value: joinTokenID.String()}, @@ -533,6 +523,11 @@ func TestSynchronization(t *testing.T) { require.NoError(t, err) defer sub.Finish() + if err := m.Initialize(context.Background()); err != nil { + t.Fatal(err) + } + require.Equal(t, clk.Now(), m.GetLastSync()) + // Before synchronization identitiesBefore := identitiesByEntryID(m.cache.Identities()) if len(identitiesBefore) != 3 { @@ -657,19 +652,18 @@ func TestSynchronizationClearsStaleCacheEntries(t *testing.T) { cat.SetKeyManager(km) c := &Config{ - ServerAddr: api.addr, - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - TrustDomain: trustDomain, - Storage: openStorage(t, dir), - Bundle: api.bundle, - Metrics: &telemetry.Blackhole{}, - Clk: clk, - Catalog: cat, - WorkloadKeyType: workloadkey.ECP256, - SVIDCacheMaxSize: 1, - SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + Storage: openStorage(t, dir), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + Clk: clk, + Catalog: cat, + WorkloadKeyType: workloadkey.ECP256, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m := newManager(c) @@ -731,19 +725,18 @@ func TestSynchronizationUpdatesRegistrationEntries(t *testing.T) { cat.SetKeyManager(km) c := &Config{ - ServerAddr: api.addr, - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - TrustDomain: trustDomain, - Storage: openStorage(t, dir), - Bundle: api.bundle, - Metrics: &telemetry.Blackhole{}, - Clk: clk, - Catalog: cat, - WorkloadKeyType: workloadkey.ECP256, - SVIDCacheMaxSize: 1, - SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + Storage: openStorage(t, dir), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + Clk: clk, + Catalog: cat, + WorkloadKeyType: workloadkey.ECP256, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m := newManager(c) @@ -825,21 +818,21 @@ func TestSubscribersGetUpToDateBundle(t *testing.T) { }) } -func TestSyncSVIDs(t *testing.T) { +func TestSynchronizationWithLRUCache(t *testing.T) { dir := spiretest.TempDir(t) km := fakeagentkeymanager.New(t, dir) clk := clock.NewMock(t) + ttl := 3 api := newMockAPI(t, &mockAPIConfig{ km: km, - getAuthorizedEntries: func(h *mockAPI, count int32, req *entryv1.GetAuthorizedEntriesRequest) (*entryv1.GetAuthorizedEntriesResponse, error) { + getAuthorizedEntries: func(*mockAPI, int32, *entryv1.GetAuthorizedEntriesRequest) (*entryv1.GetAuthorizedEntriesResponse, error) { return makeGetAuthorizedEntriesResponse(t, "resp1", "resp2"), nil }, - batchNewX509SVIDEntries: func(h *mockAPI, count int32) []*common.RegistrationEntry { - h.rotateCA() + batchNewX509SVIDEntries: func(*mockAPI, int32) []*common.RegistrationEntry { return makeBatchNewX509SVIDEntries("resp1", "resp2") }, - svidTTL: 200, + svidTTL: ttl, clk: clk, }) @@ -856,106 +849,218 @@ func TestSyncSVIDs(t *testing.T) { Storage: openStorage(t, dir), Bundle: api.bundle, Metrics: &telemetry.Blackhole{}, - RotationInterval: 1 * time.Hour, - SyncInterval: 1 * time.Hour, - SVIDCacheMaxSize: 1, + RotationInterval: time.Hour, + SyncInterval: time.Hour, Clk: clk, Catalog: cat, WorkloadKeyType: workloadkey.ECP256, + SVIDCacheMaxSize: 10, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m := newManager(c) - m.subscribeBackoffFn = func() backoff.BackOff { - return backoff.NewConstantBackOff(svidSyncInterval) + + if err := m.Initialize(context.Background()); err != nil { + t.Fatal(err) } + require.Equal(t, clk.Now(), m.GetLastSync()) - err := m.Initialize(context.Background()) + sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{ + {Type: "unix", Value: "uid:1111"}, + {Type: "spiffe_id", Value: joinTokenID.String()}, + }) require.NoError(t, err) + defer sub.Finish() - // After Initialize, just 1 SVID should be cached - assert.Equal(t, 1, m.CountSVIDs()) - ctx := context.Background() + // Before synchronization + identitiesBefore := identitiesByEntryID(m.cache.Identities()) + if len(identitiesBefore) != 3 { + t.Fatalf("3 cached identities were expected; got %d", len(identitiesBefore)) + } - // Validate the update received by subscribers - // Spawn subscriber 1 in new goroutine to allow SVID sync to run in parallel - sub1WaitCh := make(chan struct{}, 1) - sub1ErrCh := make(chan error, 1) - go func() { - sub1, err := m.subscribeToCacheChanges(ctx, cache.Selectors{{Type: "unix", Value: "uid:1111"}}, func() { - sub1WaitCh <- struct{}{} - }) - if err != nil { - sub1ErrCh <- err - return - } + // This is the initial update based on the selector set + u := <-sub.Updates() + if len(u.Identities) != 3 { + t.Fatalf("expected 3 identities, got: %d", len(u.Identities)) + } - defer sub1.Finish() - u1 := <-sub1.Updates() + if len(u.Bundle.RootCAs()) != 1 { + t.Fatal("expected 1 bundle root CA") + } - if len(u1.Identities) != 2 { - sub1ErrCh <- fmt.Errorf("expected 2 SVIDs, got: %d", len(u1.Identities)) - return - } - if !u1.Bundle.EqualTo(c.Bundle) { - sub1ErrCh <- errors.New("bundles were expected to be equal") - return + if !u.Bundle.EqualTo(api.bundle) { + t.Fatal("received bundle should be equals to the server bundle") + } + + for key, eu := range identitiesByEntryID(u.Identities) { + eb, ok := identitiesBefore[key] + if !ok { + t.Fatalf("an update was received for an inexistent entry on the cache with EntryId=%v", key) } + require.Equal(t, eb, eu, "identity received does not match identity on cache") + } - sub1ErrCh <- nil - }() + require.Equal(t, clk.Now(), m.GetLastSync()) - // Spawn subscriber 2 in new goroutine to allow SVID sync to run in parallel - sub2WaitCh := make(chan struct{}, 1) - sub2ErrCh := make(chan error, 1) - go func() { - sub2, err := m.subscribeToCacheChanges(ctx, cache.Selectors{{Type: "spiffe_id", Value: "spiffe://example.org/spire/agent/join_token/abcd"}}, func() { - sub2WaitCh <- struct{}{} - }) - if err != nil { - sub2ErrCh <- err - return - } + // SVIDs expire after 3 seconds, so we shouldn't expect any updates after + // 1 second has elapsed. + clk.Add(time.Second) + require.NoError(t, m.synchronize(context.Background())) + select { + case <-sub.Updates(): + t.Fatal("update unexpected after 1 second") + default: + } + + // After advancing another second, the SVIDs should have been refreshed, + // since the half-time has been exceeded. + clk.Add(time.Second) + require.NoError(t, m.synchronize(context.Background())) + select { + case u = <-sub.Updates(): + default: + t.Fatal("update expected after 2 seconds") + } - defer sub2.Finish() - u2 := <-sub2.Updates() + // Make sure the update contains the updated entries and that the cache + // has a consistent view. + identitiesAfter := identitiesByEntryID(m.cache.Identities()) + if len(identitiesAfter) != 3 { + t.Fatalf("expected 3 identities, got: %d", len(identitiesAfter)) + } - if len(u2.Identities) != 1 { - sub2ErrCh <- fmt.Errorf("expected 1 SVID, got: %d", len(u2.Identities)) - return + for key, eb := range identitiesBefore { + ea, ok := identitiesAfter[key] + if !ok { + t.Fatalf("expected identity with EntryId=%v after synchronization", key) } - if !u2.Bundle.EqualTo(c.Bundle) { - sub2ErrCh <- errors.New("bundles were expected to be equal") - return + require.NotEqual(t, eb, ea, "there is at least one identity that was not refreshed: %v", ea) + } + + if len(u.Identities) != 3 { + t.Fatalf("expected 3 identities, got: %d", len(u.Identities)) + } + + if len(u.Bundle.RootCAs()) != 1 { + t.Fatal("expected 1 bundle root CA") + } + + if !u.Bundle.EqualTo(api.bundle) { + t.Fatal("received bundle should be equals to the server bundle") + } + + for key, eu := range identitiesByEntryID(u.Identities) { + ea, ok := identitiesAfter[key] + if !ok { + t.Fatalf("an update was received for an inexistent entry on the cache with EntryId=%v", key) } + require.Equal(t, eu, ea, "entry received does not match entry on cache") + } - sub2ErrCh <- nil - }() + require.Equal(t, clk.Now(), m.GetLastSync()) +} + +func TestSyncSVIDsWithLRUCache(t *testing.T) { + dir := spiretest.TempDir(t) + km := fakeagentkeymanager.New(t, dir) + + clk := clock.NewMock(t) + api := newMockAPI(t, &mockAPIConfig{ + km: km, + getAuthorizedEntries: func(h *mockAPI, count int32, _ *entryv1.GetAuthorizedEntriesRequest) (*entryv1.GetAuthorizedEntriesResponse, error) { + switch count { + case 1: + return makeGetAuthorizedEntriesResponse(t, "resp2"), nil + case 2: + return makeGetAuthorizedEntriesResponse(t, "resp2"), nil + default: + return nil, fmt.Errorf("unexpected getAuthorizedEntries call count: %d", count) + } + }, + batchNewX509SVIDEntries: func(h *mockAPI, count int32) []*common.RegistrationEntry { + switch count { + case 1: + return makeBatchNewX509SVIDEntries("resp2") + case 2: + return makeBatchNewX509SVIDEntries("resp2") + default: + return nil + } + }, + svidTTL: 3, + clk: clk, + }) + + baseSVID, baseSVIDKey := api.newSVID(joinTokenID, 1*time.Hour) + cat := fakeagentcatalog.New() + cat.SetKeyManager(km) + + c := &Config{ + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + Storage: openStorage(t, dir), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + Clk: clk, + Catalog: cat, + WorkloadKeyType: workloadkey.ECP256, + SVIDCacheMaxSize: 1, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + } - // Wait until subscribers have been created - <-sub1WaitCh - <-sub2WaitCh + m := newManager(c) - // Sync SVIDs to populate cache - svidSyncErr := m.syncSVIDs(ctx) - require.NoError(t, svidSyncErr, "syncSVIDs method failed") + if err := m.Initialize(context.Background()); err != nil { + t.Fatal(err) + } - // Advance clock so subscribers can check for latest SVIDs in cache - clk.Add(svidSyncInterval) + ctx, cancel := context.WithCancel(context.Background()) + subErrCh := make(chan error, 1) + go func(ctx context.Context) { + sub, err := m.SubscribeToCacheChanges(ctx, cache.Selectors{ + {Type: "unix", Value: "uid:1111"}, + }) + if err != nil { + subErrCh <- err + return + } + defer sub.Finish() + subErrCh <- nil + }(ctx) + + syncErrCh := make(chan error, 1) + // run svid sync + go func(ctx context.Context) { + if err := m.runSyncSVIDs(ctx); err != nil { + syncErrCh <- err + } + syncErrCh <- nil + }(ctx) - sub1Err := <-sub1ErrCh - assert.NoError(t, sub1Err, "subscriber 1 error") + // keep clk moving so that subscriber keeps looking for svid + go func(ctx context.Context) { + for { + clk.Add(cache.SvidSyncInterval) + if ctx.Err() != nil { + return + } + } + }(ctx) - sub2Err := <-sub2ErrCh - assert.NoError(t, sub2Err, "subscriber 2 error") + subErr := <-subErrCh + assert.NoError(t, subErr, "subscriber error") - // All 3 SVIDs should be cached - assert.Equal(t, 3, m.CountSVIDs()) + // ensure 2 SVIDs corresponding to selectors are cached. + assert.Equal(t, 2, m.cache.CountSVIDs()) - assert.NoError(t, m.synchronize(ctx)) + // cancel the ctx to stop go routines + cancel() - // Make sure svid count is SVIDCacheMaxSize and non-active SVIDs are deleted from cache - assert.Equal(t, 1, m.CountSVIDs()) + syncErr := <-syncErrCh + assert.NoError(t, syncErr, "svid sync error") } func TestSurvivesCARotation(t *testing.T) { @@ -1003,9 +1108,6 @@ func TestSurvivesCARotation(t *testing.T) { } m := newManager(c) - m.subscribeBackoffFn = func() backoff.BackOff { - return backoff.NewConstantBackOff(svidSyncInterval) - } sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{{Type: "unix", Value: "uid:1111"}}) require.NoError(t, err) @@ -1052,19 +1154,18 @@ func TestFetchJWTSVID(t *testing.T) { baseSVID, baseSVIDKey := api.newSVID(joinTokenID, 1*time.Hour) c := &Config{ - ServerAddr: api.addr, - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - TrustDomain: trustDomain, - Storage: openStorage(t, dir), - Bundle: api.bundle, - Metrics: &telemetry.Blackhole{}, - Catalog: cat, - Clk: clk, - WorkloadKeyType: workloadkey.ECP256, - SVIDCacheMaxSize: 1, - SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + Storage: openStorage(t, dir), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + Catalog: cat, + Clk: clk, + WorkloadKeyType: workloadkey.ECP256, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m := newManager(c) @@ -1176,19 +1277,18 @@ func TestStorableSVIDsSync(t *testing.T) { cat.SetKeyManager(fakeagentkeymanager.New(t, dir)) c := &Config{ - ServerAddr: api.addr, - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - TrustDomain: trustDomain, - Storage: openStorage(t, dir), - Bundle: api.bundle, - Metrics: &telemetry.Blackhole{}, - Clk: clk, - Catalog: cat, - WorkloadKeyType: workloadkey.ECP256, - SVIDCacheMaxSize: 1, - SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + Storage: openStorage(t, dir), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + Clk: clk, + Catalog: cat, + WorkloadKeyType: workloadkey.ECP256, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m, closer := initializeAndRunNewManager(t, c) diff --git a/pkg/agent/manager/sync.go b/pkg/agent/manager/sync.go index 56a9a633b99..8660dd50cfd 100644 --- a/pkg/agent/manager/sync.go +++ b/pkg/agent/manager/sync.go @@ -25,7 +25,7 @@ type csrRequest struct { CurrentSVIDExpiresAt time.Time } -type Cache interface { +type SVIDCache interface { // UpdateEntries updates entries on cache UpdateEntries(update *cache.UpdateEntries, checkSVID func(*common.RegistrationEntry, *common.RegistrationEntry, *cache.X509SVID) bool) @@ -37,10 +37,13 @@ type Cache interface { } func (m *manager) syncSVIDs(ctx context.Context) (err error) { - m.cache.SyncSVIDsWithSubscribers() - staleEntries := m.cache.GetStaleEntries() - if len(staleEntries) > 0 { - return m.updateSVIDs(ctx, staleEntries, m.cache) + // perform syncSVIDs only if using LRU cache + if m.c.SVIDCacheMaxSize > 0 { + m.cache.SyncSVIDsWithSubscribers() + staleEntries := m.cache.GetStaleEntries() + if len(staleEntries) > 0 { + return m.updateSVIDs(ctx, staleEntries, m.cache) + } } return nil } @@ -66,7 +69,7 @@ func (m *manager) synchronize(ctx context.Context) (err error) { return nil } -func (m *manager) updateCache(ctx context.Context, update *cache.UpdateEntries, log logrus.FieldLogger, cacheType string, c Cache) error { +func (m *manager) updateCache(ctx context.Context, update *cache.UpdateEntries, log logrus.FieldLogger, cacheType string, c SVIDCache) error { // update the cache and build a list of CSRs that need to be processed // in this interval. // @@ -117,7 +120,7 @@ func (m *manager) updateCache(ctx context.Context, update *cache.UpdateEntries, return nil } -func (m *manager) updateSVIDs(ctx context.Context, entries []*cache.StaleEntry, c Cache) error { +func (m *manager) updateSVIDs(ctx context.Context, entries []*cache.StaleEntry, c SVIDCache) error { var csrs []csrRequest for _, entry := range entries { // we've exceeded the CSR limit, don't make any more CSRs