From 24a19c811cf815ec73405dc2678f145853292e95 Mon Sep 17 00:00:00 2001 From: Prasad Borole Date: Tue, 21 Jun 2022 17:04:28 -0700 Subject: [PATCH 01/19] Implement LRU cache for storing SVIDs in SPIRE Agent Signed-off-by: Prasad Borole --- cmd/spire-agent/cli/api/common.go | 2 +- cmd/spire-agent/cli/run/run.go | 15 +- cmd/spire-agent/cli/run/run_test.go | 44 ++ pkg/agent/agent.go | 24 +- pkg/agent/api/delegatedidentity/v1/service.go | 18 +- .../api/delegatedidentity/v1/service_test.go | 15 +- pkg/agent/config.go | 6 + pkg/agent/endpoints/workload/handler.go | 12 +- pkg/agent/endpoints/workload/handler_test.go | 8 +- pkg/agent/manager/cache/cache.go | 303 +++++++++-- pkg/agent/manager/cache/cache_test.go | 325 +++++++++--- pkg/agent/manager/config.go | 31 +- pkg/agent/manager/manager.go | 54 +- pkg/agent/manager/manager_test.go | 474 ++++++++++++++---- pkg/agent/manager/sync.go | 23 +- test/integration/common | 19 + test/integration/setup/debugagent/main.go | 28 +- test/integration/suites/fetch-svids/00-setup | 6 + .../suites/fetch-svids/01-start-server | 3 + .../suites/fetch-svids/02-bootstrap-agent | 5 + .../suites/fetch-svids/03-start-agent | 3 + .../04-create-registration-entries | 18 + .../suites/fetch-svids/05-fetch-svids | 14 + .../06-create-registration-entries | 19 + .../suites/fetch-svids/07-fetch-svids | 28 ++ test/integration/suites/fetch-svids/README.md | 5 + .../suites/fetch-svids/conf/agent/agent.conf | 32 ++ .../fetch-svids/conf/server/server.conf | 26 + .../suites/fetch-svids/docker-compose.yaml | 15 + test/integration/suites/fetch-svids/teardown | 6 + 30 files changed, 1321 insertions(+), 260 deletions(-) create mode 100755 test/integration/suites/fetch-svids/00-setup create mode 100755 test/integration/suites/fetch-svids/01-start-server create mode 100755 test/integration/suites/fetch-svids/02-bootstrap-agent create mode 100755 test/integration/suites/fetch-svids/03-start-agent create mode 100755 test/integration/suites/fetch-svids/04-create-registration-entries create mode 100755 test/integration/suites/fetch-svids/05-fetch-svids create mode 100755 test/integration/suites/fetch-svids/06-create-registration-entries create mode 100755 test/integration/suites/fetch-svids/07-fetch-svids create mode 100644 test/integration/suites/fetch-svids/README.md create mode 100644 test/integration/suites/fetch-svids/conf/agent/agent.conf create mode 100644 test/integration/suites/fetch-svids/conf/server/server.conf create mode 100644 test/integration/suites/fetch-svids/docker-compose.yaml create mode 100755 test/integration/suites/fetch-svids/teardown diff --git a/cmd/spire-agent/cli/api/common.go b/cmd/spire-agent/cli/api/common.go index b586f2bd48..380e963590 100644 --- a/cmd/spire-agent/cli/api/common.go +++ b/cmd/spire-agent/cli/api/common.go @@ -73,7 +73,7 @@ func adaptCommand(env *cli.Env, clientsMaker workloadClientMaker, cmd command) * clientsMaker: clientsMaker, cmd: cmd, env: env, - timeout: cli.DurationFlag(time.Second), + timeout: cli.DurationFlag(2 * time.Second), } fs := flag.NewFlagSet(cmd.name(), flag.ContinueOnError) diff --git a/cmd/spire-agent/cli/run/run.go b/cmd/spire-agent/cli/run/run.go index 5245d5ee7b..6d7e1453d8 100644 --- a/cmd/spire-agent/cli/run/run.go +++ b/cmd/spire-agent/cli/run/run.go @@ -97,7 +97,9 @@ type experimentalConfig struct { SyncInterval string `hcl:"sync_interval"` TCPSocketPort int `hcl:"tcp_socket_port"` - UnusedKeys []string `hcl:",unusedKeys"` + UnusedKeys []string `hcl:",unusedKeys"` + MaxSvidCacheSize int `hcl:"max_svid_cache_size"` + SVIDCacheExpiryPeriod string `hcl:"svid_cache_expiry_interval"` } type Command struct { @@ -400,6 +402,17 @@ func NewAgentConfig(c *Config, logOptions []log.Option, allowUnknownConfig bool) } } + if c.Agent.Experimental.MaxSvidCacheSize != 0 { + ac.MaxSvidCacheSize = c.Agent.Experimental.MaxSvidCacheSize + } + if c.Agent.Experimental.SVIDCacheExpiryPeriod != "" { + var err error + ac.SVIDCacheExpiryPeriod, err = time.ParseDuration(c.Agent.Experimental.SVIDCacheExpiryPeriod) + if err != nil { + return nil, fmt.Errorf("could not parse svid cache expiry interval: %w", err) + } + } + serverHostPort := net.JoinHostPort(c.Agent.ServerAddress, strconv.Itoa(c.Agent.ServerPort)) ac.ServerAddress = fmt.Sprintf("dns:///%s", serverHostPort) diff --git a/cmd/spire-agent/cli/run/run_test.go b/cmd/spire-agent/cli/run/run_test.go index 1819e27bc6..17484fc46a 100644 --- a/cmd/spire-agent/cli/run/run_test.go +++ b/cmd/spire-agent/cli/run/run_test.go @@ -689,6 +689,50 @@ func TestNewAgentConfig(t *testing.T) { require.Nil(t, c) }, }, + { + msg: "svid_cache_expiry_interval parses a duration", + input: func(c *Config) { + c.Agent.Experimental.SVIDCacheExpiryPeriod = "1s50ms" + }, + test: func(t *testing.T, c *agent.Config) { + require.EqualValues(t, 1050000000, c.SVIDCacheExpiryPeriod) + }, + }, + { + msg: "invalid svid_cache_expiry_interval returns an error", + expectError: true, + input: func(c *Config) { + c.Agent.Experimental.SVIDCacheExpiryPeriod = "moo" + }, + test: func(t *testing.T, c *agent.Config) { + require.Nil(t, c) + }, + }, + { + msg: "svid_cache_expiry_interval is not set", + input: func(c *Config) { + }, + test: func(t *testing.T, c *agent.Config) { + require.EqualValues(t, 0, c.SVIDCacheExpiryPeriod) + }, + }, + { + msg: "max_svid_cache_size is set", + input: func(c *Config) { + c.Agent.Experimental.MaxSvidCacheSize = 100 + }, + test: func(t *testing.T, c *agent.Config) { + require.EqualValues(t, 100, c.MaxSvidCacheSize) + }, + }, + { + msg: "max_svid_cache_size is not set", + input: func(c *Config) { + }, + test: func(t *testing.T, c *agent.Config) { + require.EqualValues(t, 0, c.MaxSvidCacheSize) + }, + }, { msg: "admin_socket_path not provided", input: func(c *Config) { diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index a6abf363cf..f10f19a87c 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -200,17 +200,19 @@ func (a *Agent) attest(ctx context.Context, cat catalog.Catalog, metrics telemet func (a *Agent) newManager(ctx context.Context, cat catalog.Catalog, metrics telemetry.Metrics, as *node_attestor.AttestationResult, cache *storecache.Cache) (manager.Manager, error) { config := &manager.Config{ SVID: as.SVID, - SVIDKey: as.Key, - Bundle: as.Bundle, - Catalog: cat, - TrustDomain: a.c.TrustDomain, - ServerAddr: a.c.ServerAddress, - Log: a.c.Log.WithField(telemetry.SubsystemName, telemetry.Manager), - Metrics: metrics, - BundleCachePath: a.bundleCachePath(), - SVIDCachePath: a.agentSVIDPath(), - SyncInterval: a.c.SyncInterval, - SVIDStoreCache: cache, + SVIDKey: as.Key, + Bundle: as.Bundle, + Catalog: cat, + TrustDomain: a.c.TrustDomain, + ServerAddr: a.c.ServerAddress, + Log: a.c.Log.WithField(telemetry.SubsystemName, telemetry.Manager), + Metrics: metrics, + BundleCachePath: a.bundleCachePath(), + SVIDCachePath: a.agentSVIDPath(), + SyncInterval: a.c.SyncInterval, + MaxSvidCacheSize: a.c.MaxSvidCacheSize, + SVIDCacheExpiryPeriod: a.c.SVIDCacheExpiryPeriod, + SVIDStoreCache: cache, } mgr := manager.New(config) diff --git a/pkg/agent/api/delegatedidentity/v1/service.go b/pkg/agent/api/delegatedidentity/v1/service.go index 87a2020988..ef9b9f2b04 100644 --- a/pkg/agent/api/delegatedidentity/v1/service.go +++ b/pkg/agent/api/delegatedidentity/v1/service.go @@ -82,16 +82,16 @@ func (s *Service) isCallerAuthorized(ctx context.Context, log logrus.FieldLogger } } - identities := s.manager.MatchingIdentities(callerSelectors) - numRegisteredIDs := len(identities) + entries := s.manager.MatchingRegistrationEntries(callerSelectors) + numRegisteredIDs := len(entries) if numRegisteredIDs == 0 { log.Error("no identity issued") return nil, status.Error(codes.PermissionDenied, "no identity issued") } - for _, identity := range identities { - if _, ok := s.authorizedDelegates[identity.Entry.SpiffeId]; ok { + for _, entry := range entries { + if _, ok := s.authorizedDelegates[entry.SpiffeId]; ok { return callerSelectors, nil } } @@ -99,7 +99,7 @@ func (s *Service) isCallerAuthorized(ctx context.Context, log logrus.FieldLogger // caller has identity associeted with but none is authorized log.WithFields(logrus.Fields{ "num_registered_ids": numRegisteredIDs, - "default_id": identities[0].Entry.SpiffeId, + "default_id": entries[0].SpiffeId, }).Error("Permission denied; caller not configured as an authorized delegate.") return nil, status.Error(codes.PermissionDenied, "caller not configured as an authorized delegate") @@ -268,11 +268,11 @@ func (s *Service) FetchJWTSVIDs(ctx context.Context, req *delegatedidentityv1.Fe } var spiffeIDs []spiffeid.ID - identities := s.manager.MatchingIdentities(selectors) - for _, identity := range identities { - spiffeID, err := spiffeid.FromString(identity.Entry.SpiffeId) + entries := s.manager.MatchingRegistrationEntries(selectors) + for _, entry := range entries { + spiffeID, err := spiffeid.FromString(entry.SpiffeId) if err != nil { - log.WithField(telemetry.SPIFFEID, identity.Entry.SpiffeId).WithError(err).Error("Invalid requested SPIFFE ID") + log.WithField(telemetry.SPIFFEID, entry.SpiffeId).WithError(err).Error("Invalid requested SPIFFE ID") return nil, status.Errorf(codes.InvalidArgument, "invalid requested SPIFFE ID: %v", err) } diff --git a/pkg/agent/api/delegatedidentity/v1/service_test.go b/pkg/agent/api/delegatedidentity/v1/service_test.go index fe7d6e73ff..78d727390f 100644 --- a/pkg/agent/api/delegatedidentity/v1/service_test.go +++ b/pkg/agent/api/delegatedidentity/v1/service_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/andres-erbsen/clock" "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/spiffe/go-spiffe/v2/bundle/spiffebundle" @@ -653,10 +654,6 @@ func (fa FakeAttestor) Attest(ctx context.Context) ([]*common.Selector, error) { return fa.selectors, fa.err } -func (m *FakeManager) MatchingIdentities(selectors []*common.Selector) []cache.Identity { - return m.identities -} - type FakeManager struct { manager.Manager @@ -692,6 +689,14 @@ func (m *FakeManager) FetchJWTSVID(ctx context.Context, spiffeID spiffeid.ID, au }, nil } +func (m *FakeManager) MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry { + out := make([]*common.RegistrationEntry, 0, len(m.identities)) + for _, identity := range m.identities { + out = append(out, identity.Entry) + } + return out +} + type fakeSubscriber struct { m *FakeManager ch chan *cache.WorkloadUpdate @@ -794,5 +799,5 @@ func (m *FakeManager) SubscribeToBundleChanges() *cache.BundleStream { func newTestCache() *cache.Cache { log, _ := test.NewNullLogger() - return cache.New(log, trustDomain1, bundle1, telemetry.Blackhole{}) + return cache.New(log, trustDomain1, bundle1, telemetry.Blackhole{}, 0, 0, clock.NewMock()) } diff --git a/pkg/agent/config.go b/pkg/agent/config.go index 2586796682..2b7a61b2c3 100644 --- a/pkg/agent/config.go +++ b/pkg/agent/config.go @@ -52,6 +52,12 @@ type Config struct { // SyncInterval controls how often the agent sync synchronizer waits SyncInterval time.Duration + // MaxSvidCacheSize is a soft limit of max number of SVIDs that would be stored in cache + MaxSvidCacheSize int + + // SVIDCacheExpiryPeriod is a period after which svids that don't have subscribers will be removed from cache + SVIDCacheExpiryPeriod time.Duration + // Trust domain and associated CA bundle TrustDomain spiffeid.TrustDomain TrustBundle []*x509.Certificate diff --git a/pkg/agent/endpoints/workload/handler.go b/pkg/agent/endpoints/workload/handler.go index 0392f0a3da..7b72554029 100644 --- a/pkg/agent/endpoints/workload/handler.go +++ b/pkg/agent/endpoints/workload/handler.go @@ -30,7 +30,7 @@ import ( type Manager interface { SubscribeToCacheChanges(cache.Selectors) cache.Subscriber - MatchingIdentities([]*common.Selector) []cache.Identity + MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry FetchJWTSVID(ctx context.Context, spiffeID spiffeid.ID, audience []string) (*client.JWTSVID, error) FetchWorkloadUpdate([]*common.Selector) *cache.WorkloadUpdate } @@ -84,15 +84,15 @@ func (h *Handler) FetchJWTSVID(ctx context.Context, req *workload.JWTSVIDRequest log = log.WithField(telemetry.Registered, true) - identities := h.c.Manager.MatchingIdentities(selectors) - for _, identity := range identities { - if req.SpiffeId != "" && identity.Entry.SpiffeId != req.SpiffeId { + entries := h.c.Manager.MatchingRegistrationEntries(selectors) + for _, entry := range entries { + if req.SpiffeId != "" && entry.SpiffeId != req.SpiffeId { continue } - spiffeID, err := spiffeid.FromString(identity.Entry.SpiffeId) + spiffeID, err := spiffeid.FromString(entry.SpiffeId) if err != nil { - log.WithField(telemetry.SPIFFEID, identity.Entry.SpiffeId).WithError(err).Error("Invalid requested SPIFFE ID") + log.WithField(telemetry.SPIFFEID, entry.SpiffeId).WithError(err).Error("Invalid requested SPIFFE ID") return nil, status.Errorf(codes.InvalidArgument, "invalid requested SPIFFE ID: %v", err) } diff --git a/pkg/agent/endpoints/workload/handler_test.go b/pkg/agent/endpoints/workload/handler_test.go index 42acbe354c..1417c17ae8 100644 --- a/pkg/agent/endpoints/workload/handler_test.go +++ b/pkg/agent/endpoints/workload/handler_test.go @@ -1014,8 +1014,12 @@ type FakeManager struct { err error } -func (m *FakeManager) MatchingIdentities(selectors []*common.Selector) []cache.Identity { - return m.identities +func (m *FakeManager) MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry { + out := make([]*common.RegistrationEntry, 0, len(m.identities)) + for _, identity := range m.identities { + out = append(out, identity.Entry) + } + return out } func (m *FakeManager) FetchJWTSVID(ctx context.Context, spiffeID spiffeid.ID, audience []string) (*client.JWTSVID, error) { diff --git a/pkg/agent/manager/cache/cache.go b/pkg/agent/manager/cache/cache.go index 37345f9575..427bcc12b5 100644 --- a/pkg/agent/manager/cache/cache.go +++ b/pkg/agent/manager/cache/cache.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "github.com/andres-erbsen/clock" "github.com/sirupsen/logrus" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/common/bundleutil" @@ -14,6 +15,11 @@ import ( "github.com/spiffe/spire/proto/spire/common" ) +const ( + DefaultMaxSvidCacheSize = 1000 + DefaultSVIDCacheExpiryPeriod = 1 * time.Hour +) + type Selectors []*common.Selector type Bundle = bundleutil.Bundle @@ -58,9 +64,9 @@ type X509SVID struct { PrivateKey crypto.Signer } -// Cache caches each registration entry, signed X509-SVIDs for those entries, -// bundles, and JWT SVIDs for the agent. It allows subscriptions by (workload) -// selector sets and notifies subscribers when: +// Cache caches each registration entry, bundles, and JWT SVIDs for the agent. +// The signed X509-SVIDs for those entries are stored in LRU-like cache. +// It allows subscriptions by (workload) selector sets and notifies subscribers when: // // 1) a registration entry related to the selectors: // * is modified @@ -75,6 +81,22 @@ type X509SVID struct { // selector it encounters. Each selector index tracks the subscribers (i.e // workloads) and registration entries that have that selector. // +// The LRU-like SVID cache has configurable size limit and expiry period. +// 1. Size limit of SVID cache is a soft limit which means if SVID has a subscriber present then +// that SVID is never removed from cache. +// 2. Least recently used SVIDs are removed from cache only after the cache expiry period has passed. +// This is done to reduce the overall cache churn. +// 3. Last access timestamp for SVID cache entry is updated when a new subscriber is created +// 4. When a new subscriber is created and if there is a cache miss +// then subscriber needs to wait for next SVID sync event to receive WorkloadUpdate with newly minted SVID +// +// The advantage of above approach is that if agent has entry count less than cache size +// then all SVIDs are cached at all times. If agent has entry count greater than cache size then +// subscribers will continue to get SVID updates (potential delay for first WorkloadUpdate if cache miss) +// and least used SVIDs will be removed from cache which will save memory usage. +// It will allow agent to support large number of registrations. +// +// // When registration entries are added/updated/removed, the set of relevant // selectors are gathered and the indexes for those selectors are combed for // all relevant subscribers. @@ -102,6 +124,7 @@ type Cache struct { log logrus.FieldLogger trustDomain spiffeid.TrustDomain + clk clock.Clock metrics telemetry.Metrics @@ -113,14 +136,23 @@ type Cache struct { // selectors holds the selector indices, keyed by a selector key selectors map[selector]*selectorIndex - // staleEntries holds stale registration entries + // staleEntries holds stale or new registration entries which require new SVID to be stored in cache staleEntries map[string]bool // bundles holds the trust bundles, keyed by trust domain id (i.e. "spiffe://domain.test") bundles map[spiffeid.TrustDomain]*bundleutil.Bundle + + // svids are stored by entry IDs + svids map[string]*X509SVID + + // maxSVIDCacheSize is a soft limit of max number of SVIDs that would be stored in cache + maxSvidCacheSize int + + // svidCacheExpiryPeriod is a period after which svids that don't have subscribers will be removed from cache + svidCacheExpiryPeriod time.Duration } -// StaleEntry holds stale entries with SVIDs expiration time +// StaleEntry holds stale or outdated entries which require new SVID with old SVIDs expiration time (if present) type StaleEntry struct { // Entry stale registration entry Entry *common.RegistrationEntry @@ -128,7 +160,16 @@ type StaleEntry struct { ExpiresAt time.Time } -func New(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundle, metrics telemetry.Metrics) *Cache { +func New(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundle, metrics telemetry.Metrics, + maxSvidCacheSize int, svidCacheExpiryPeriod time.Duration, clk clock.Clock) *Cache { + if maxSvidCacheSize == 0 { + maxSvidCacheSize = DefaultMaxSvidCacheSize + } + + if svidCacheExpiryPeriod == 0 { + svidCacheExpiryPeriod = DefaultSVIDCacheExpiryPeriod + } + return &Cache{ BundleCache: NewBundleCache(trustDomain, bundle), JWTSVIDCache: NewJWTSVIDCache(), @@ -142,6 +183,10 @@ func New(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundl bundles: map[spiffeid.TrustDomain]*bundleutil.Bundle{ trustDomain: bundle, }, + svids: make(map[string]*X509SVID), + maxSvidCacheSize: maxSvidCacheSize, + svidCacheExpiryPeriod: svidCacheExpiryPeriod, + clk: clk, } } @@ -153,41 +198,44 @@ func (c *Cache) Identities() []Identity { out := make([]Identity, 0, len(c.records)) for _, record := range c.records { - if record.svid == nil { + svid, ok := c.svids[record.entry.EntryId] + if !ok { // The record does not have an SVID yet and should not be returned // from the cache. continue } - out = append(out, makeIdentity(record)) + out = append(out, makeIdentity(record, svid)) } sortIdentities(out) return out } -func (c *Cache) CountSVIDs() int { +func (c *Cache) Entries() []*common.RegistrationEntry { c.mu.RLock() defer c.mu.RUnlock() - var records int + out := make([]*common.RegistrationEntry, 0, len(c.records)) for _, record := range c.records { - if record.svid == nil { - // The record does not have an SVID yet and should not be returned - // from the cache. - continue - } - records++ + out = append(out, record.entry) } + sortEntries(out) + return out +} + +func (c *Cache) CountSVIDs() int { + c.mu.RLock() + defer c.mu.RUnlock() - return records + return len(c.svids) } -func (c *Cache) MatchingIdentities(selectors []*common.Selector) []Identity { +func (c *Cache) MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry { set, setDone := allocSelectorSet(selectors...) defer setDone() c.mu.RLock() defer c.mu.RUnlock() - return c.matchingIdentities(set) + return c.matchingEntries(set) } func (c *Cache) FetchWorkloadUpdate(selectors []*common.Selector) *WorkloadUpdate { @@ -199,6 +247,8 @@ func (c *Cache) FetchWorkloadUpdate(selectors []*common.Selector) *WorkloadUpdat return c.buildWorkloadUpdate(set) } +// SubscribeToWorkloadUpdates creates a subscriber for given selector set. +// Separately call Notify for the first time after this method is invoked to receive latest updates. func (c *Cache) SubscribeToWorkloadUpdates(selectors []*common.Selector) Subscriber { c.mu.Lock() defer c.mu.Unlock() @@ -207,7 +257,8 @@ func (c *Cache) SubscribeToWorkloadUpdates(selectors []*common.Selector) Subscri for s := range sub.set { c.addSelectorIndexSub(s, sub) } - c.notify(sub) + // update lastAccessTimestamp of records containing provided selectors + c.updateLastAccessTimestamp(selectors) return sub } @@ -282,11 +333,14 @@ func (c *Cache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.Regi c.delSelectorIndicesRecord(selRem, record) notifySets = append(notifySets, selRem) delete(c.records, id) + delete(c.svids, id) // Remove stale entry since, registration entry is no longer on cache. delete(c.staleEntries, id) } } + outdatedEntries := make(map[string]struct{}) + // Add/update records for registration entries in the update for _, newEntry := range update.RegistrationEntries { clearSelectorSet(selAdd) @@ -344,9 +398,9 @@ func (c *Cache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.Regi notifySets = append(notifySets, notifySet) } - // Invoke the svid checker callback for this record - if checkSVID != nil && checkSVID(existingEntry, newEntry, record.svid) { - c.staleEntries[newEntry.EntryId] = true + // Identify stale/outdated entries + if existingEntry != nil && existingEntry.RevisionNumber != newEntry.RevisionNumber { + outdatedEntries[newEntry.EntryId] = struct{}{} } // Log all the details of the update to the DEBUG log @@ -375,6 +429,43 @@ func (c *Cache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.Regi } } + // entries with active subscribers which are not cached will be put in staleEntries map + activeSubs, recordsWithLastAccessTime := c.syncSVIDs() + extraSize := len(c.svids) - c.maxSvidCacheSize + + // delete svids without subscribers and which have not been accessed since svidCacheExpiryTime + if extraSize > 0 { + // sort recordsWithLastAccessTime + sortTimestamps(recordsWithLastAccessTime) + now := c.clk.Now() + svidCacheExpiryTime := now.Add(-1 * c.svidCacheExpiryPeriod).UnixMilli() + for _, record := range recordsWithLastAccessTime { + if extraSize <= 0 { + // no need to delete SVIDs any further as cache size <= maxSvidCacheSize + break + } + if _, ok := c.svids[record.id]; ok { + if _, exists := activeSubs[record.id]; !exists { + // remove svid if it has not been accessed since svidCacheExpiryTime + if record.timestamp < svidCacheExpiryTime { + c.log.WithField("record_id", record.id). + WithField("record_timestamp", record.timestamp). + Debug("Removing SVID record") + delete(c.svids, record.id) + extraSize-- + } + } + } + } + } + + // Update all stale svids or svids whose registration entry is outdated + for id, svid := range c.svids { + if _, ok := outdatedEntries[id]; ok || (checkSVID != nil && checkSVID(nil, c.records[id].entry, svid)) { + c.staleEntries[id] = true + } + } + if bundleRemoved || len(bundleChanged) > 0 { c.BundleCache.Update(c.bundles) } @@ -402,7 +493,7 @@ func (c *Cache) UpdateSVIDs(update *UpdateSVIDs) { continue } - record.svid = svid + c.svids[entryID] = svid notifySet.Merge(record.entry.Selectors...) log := c.log.WithFields(logrus.Fields{ telemetry.Entry: record.entry.EntryId, @@ -432,8 +523,8 @@ func (c *Cache) GetStaleEntries() []*StaleEntry { } var expiresAt time.Time - if cachedEntry.svid != nil { - expiresAt = cachedEntry.svid.Chain[0].NotAfter + if cachedSvid, ok := c.svids[entryID]; ok { + expiresAt = cachedSvid.Chain[0].NotAfter } staleEntries = append(staleEntries, &StaleEntry{ @@ -445,6 +536,104 @@ func (c *Cache) GetStaleEntries() []*StaleEntry { return staleEntries } +// SyncSVIDsWithSubscribers will sync svid cache: +// entries with active subscribers which are not cached will be put in staleEntries map +// records which are not cached for remainder of max cache size will also be put in staleEntries map +func (c *Cache) SyncSVIDsWithSubscribers() { + c.mu.Lock() + defer c.mu.Unlock() + + c.syncSVIDs() +} + +// Notify subscribers of selector set only if all SVIDs for corresponding selector set are cached +// It returns whether all SVIDs are cached or not. +// This method should be retried with backoff to avoid lock contention. +func (c *Cache) Notify(selectors []*common.Selector) bool { + c.mu.Lock() + defer c.mu.Unlock() + if len(c.missingSVIDRecords(selectors)) == 0 { + set, setFree := allocSelectorSet(selectors...) + defer setFree() + c.notifyBySelectorSet(set) + return true + } + return false +} + +func (c *Cache) missingSVIDRecords(selectors []*common.Selector) []*StaleEntry { + set, setFree := allocSelectorSet(selectors...) + defer setFree() + + records, recordsDone := c.getRecordsForSelectors(set) + defer recordsDone() + + if len(records) == 0 { + return nil + } + out := make([]*StaleEntry, 0, len(records)) + for record := range records { + if _, ok := c.svids[record.entry.EntryId]; !ok { + out = append(out, &StaleEntry{ + Entry: record.entry, + }) + } + } + return out +} + +func (c *Cache) updateLastAccessTimestamp(selectors []*common.Selector) { + set, setFree := allocSelectorSet(selectors...) + defer setFree() + + records, recordsDone := c.getRecordsForSelectors(set) + defer recordsDone() + + for record := range records { + // Set lastAccessTimestamp so that svid LRU cache can be cleaned based on this timestamp + record.lastAccessTimestamp = c.clk.Now().UnixMilli() + } +} + +// entries with active subscribers which are not cached will be put in staleEntries map +// records which are not cached for remainder of max cache size will also be put in staleEntries map +func (c *Cache) syncSVIDs() (map[string]struct{}, []record) { + activeSubs := make(map[string]struct{}) + lastAccessTimestamps := make([]record, len(c.records)) + + i := 0 + for id, record := range c.records { + for _, sel := range record.entry.Selectors { + if index, ok := c.selectors[makeSelector(sel)]; ok && index != nil { + if len(index.subs) > 0 { + if _, ok := c.svids[record.entry.EntryId]; !ok { + c.staleEntries[id] = true + } + activeSubs[id] = struct{}{} + break + } + } + } + lastAccessTimestamps[i] = newRecord(record.lastAccessTimestamp, id) + i++ + } + + remainderSize := c.maxSvidCacheSize - len(c.svids) + // add records which are not cached for remainder of cache size + for id, _ := range c.records { + if len(c.staleEntries) >= remainderSize { + break + } + if _, ok := c.svids[id]; !ok { + if _, ok := c.staleEntries[id]; !ok { + c.staleEntries[id] = true + } + } + } + + return activeSubs, lastAccessTimestamps +} + func (c *Cache) updateOrCreateRecord(newEntry *common.RegistrationEntry) (*cacheRecord, *common.RegistrationEntry) { var existingEntry *common.RegistrationEntry record, recordExists := c.records[newEntry.EntryId] @@ -616,12 +805,33 @@ func (c *Cache) matchingIdentities(set selectorSet) []Identity { // TODO: figure out how to determine the "default" identity out := make([]Identity, 0, len(records)) for record := range records { - out = append(out, makeIdentity(record)) + if svid, ok := c.svids[record.entry.EntryId]; ok { + out = append(out, makeIdentity(record, svid)) + } } sortIdentities(out) return out } +func (c *Cache) matchingEntries(set selectorSet) []*common.RegistrationEntry { + records, recordsDone := c.getRecordsForSelectors(set) + defer recordsDone() + + if len(records) == 0 { + return nil + } + + // Return identities in ascending "entry id" order to maintain a consistent + // ordering. + // TODO: figure out how to determine the "default" identity + out := make([]*common.RegistrationEntry, 0, len(records)) + for record := range records { + out = append(out, record.entry) + } + sortEntries(out) + return out +} + func (c *Cache) buildWorkloadUpdate(set selectorSet) *WorkloadUpdate { w := &WorkloadUpdate{ Bundle: c.bundles[c.trustDomain], @@ -656,17 +866,13 @@ func (c *Cache) buildWorkloadUpdate(set selectorSet) *WorkloadUpdate { } func (c *Cache) getRecordsForSelectors(set selectorSet) (recordSet, func()) { - // Build and dedup a list of candidate entries. Ignore those without an - // SVID but otherwise don't check for selector set inclusion yet, since + // Build and dedup a list of candidate entries. Don't check for selector set inclusion yet, since // that is a more expensive operation and we could easily have duplicate // entries to check. records, recordsDone := allocRecordSet() for selector := range set { if index := c.getSelectorIndexForRead(selector); index != nil { for record := range index.records { - if record.svid == nil { - continue - } records[record] = struct{}{} } } @@ -708,9 +914,9 @@ func (c *Cache) getSelectorIndexForRead(s selector) *selectorIndex { } type cacheRecord struct { - entry *common.RegistrationEntry - svid *X509SVID - subs map[*subscriber]struct{} + entry *common.RegistrationEntry + subs map[*subscriber]struct{} + lastAccessTimestamp int64 } func newCacheRecord() *cacheRecord { @@ -744,10 +950,31 @@ func sortIdentities(identities []Identity) { }) } -func makeIdentity(record *cacheRecord) Identity { +func sortEntries(entries []*common.RegistrationEntry) { + sort.Slice(entries, func(a, b int) bool { + return entries[a].EntryId < entries[b].EntryId + }) +} + +func sortTimestamps(records []record) { + sort.Slice(records, func(a, b int) bool { + return records[a].timestamp < records[b].timestamp + }) +} + +func makeIdentity(record *cacheRecord, svid *X509SVID) Identity { return Identity{ Entry: record.entry, - SVID: record.svid.Chain, - PrivateKey: record.svid.PrivateKey, + SVID: svid.Chain, + PrivateKey: svid.PrivateKey, } } + +type record struct { + timestamp int64 + id string +} + +func newRecord(timestamp int64, id string) record { + return record{timestamp: timestamp, id: id} +} diff --git a/pkg/agent/manager/cache/cache_test.go b/pkg/agent/manager/cache/cache_test.go index e7101b260d..9849b5b360 100644 --- a/pkg/agent/manager/cache/cache_test.go +++ b/pkg/agent/manager/cache/cache_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/andres-erbsen/clock" "github.com/sirupsen/logrus/hooks/test" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/common/bundleutil" @@ -58,7 +59,7 @@ func TestFetchWorkloadUpdate(t *testing.T) { }, workloadUpdate) } -func TestMatchingIdentities(t *testing.T) { +func TestMatchingRegistrationIdentities(t *testing.T) { cache := newTestCache() // populate the cache with FOO and BAR without SVIDS @@ -70,19 +71,21 @@ func TestMatchingIdentities(t *testing.T) { } cache.UpdateEntries(updateEntries, nil) - identities := cache.MatchingIdentities(makeSelectors("A", "B")) - assert.Len(t, identities, 0, "identities should not be returned that don't have SVIDs") + assert.Equal(t, []*common.RegistrationEntry{bar, foo}, + cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) + // Update SVIDs and MatchingRegistrationEntries should return both entries updateSVIDs := &UpdateSVIDs{ X509SVIDs: makeX509SVIDs(foo, bar), } cache.UpdateSVIDs(updateSVIDs) + assert.Equal(t, []*common.RegistrationEntry{bar, foo}, + cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) - identities = cache.MatchingIdentities(makeSelectors("A", "B")) - assert.Equal(t, []Identity{ - {Entry: bar}, - {Entry: foo}, - }, identities) + // Remove SVIDs and MatchingRegistrationEntries should still return both entries + cache.UpdateSVIDs(&UpdateSVIDs{}) + assert.Equal(t, []*common.RegistrationEntry{bar, foo}, + cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) } func TestCountSVIDs(t *testing.T) { @@ -137,11 +140,11 @@ func TestAllSubscribersNotifiedOnBundleChange(t *testing.T) { cache := newTestCache() // create some subscribers and assert they get the initial bundle - subA := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) defer subA.Finish() assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{Bundle: bundleV1}) - subB := cache.SubscribeToWorkloadUpdates(makeSelectors("B")) + subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) defer subB.Finish() assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{Bundle: bundleV1}) @@ -168,11 +171,11 @@ func TestSomeSubscribersNotifiedOnFederatedBundleChange(t *testing.T) { }) // subscribe to A and B and assert initial updates are received. - subA := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) defer subA.Finish() assertAnyWorkloadUpdate(t, subA) - subB := cache.SubscribeToWorkloadUpdates(makeSelectors("B")) + subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) defer subB.Finish() assertAnyWorkloadUpdate(t, subB) @@ -231,11 +234,11 @@ func TestSubscribersGetEntriesWithSelectorSubsets(t *testing.T) { cache := newTestCache() // create subscribers for each combination of selectors - subA := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) defer subA.Finish() - subB := cache.SubscribeToWorkloadUpdates(makeSelectors("B")) + subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) defer subB.Finish() - subAB := cache.SubscribeToWorkloadUpdates(makeSelectors("A", "B")) + subAB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A", "B")) defer subAB.Finish() // assert all subscribers get the initial update @@ -288,7 +291,7 @@ func TestSubscriberIsNotNotifiedIfNothingChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo), }) - sub := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -314,7 +317,7 @@ func TestSubscriberNotifiedOnSVIDChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo), }) - sub := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -329,7 +332,7 @@ func TestSubscriberNotifiedOnSVIDChanges(t *testing.T) { }) } -func TestSubcriberNotificationsOnSelectorChanges(t *testing.T) { +func TestSubscriberNotificationsOnSelectorChanges(t *testing.T) { cache := newTestCache() // initialize the cache with a FOO entry with selector A and an SVID @@ -343,7 +346,7 @@ func TestSubcriberNotificationsOnSelectorChanges(t *testing.T) { }) // create subscribers for A and make sure the initial update has FOO - sub := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) defer sub.Finish() assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ Bundle: bundleV1, @@ -380,21 +383,16 @@ func TestSubcriberNotificationsOnSelectorChanges(t *testing.T) { }) } -func newTestCache() *Cache { - log, _ := test.NewNullLogger() - return New(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, telemetry.Blackhole{}) -} - -func TestSubcriberNotifiedWhenEntryDropped(t *testing.T) { +func TestSubscriberNotifiedWhenEntryDropped(t *testing.T) { cache := newTestCache() - subA := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) defer subA.Finish() assertAnyWorkloadUpdate(t, subA) // subB's job here is to just make sure we don't notify unrelated // subscribers when dropping registration entries - subB := cache.SubscribeToWorkloadUpdates(makeSelectors("B")) + subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) defer subB.Finish() assertAnyWorkloadUpdate(t, subB) @@ -428,7 +426,7 @@ func TestSubcriberNotifiedWhenEntryDropped(t *testing.T) { assertNoWorkloadUpdate(t, subB) } -func TestSubcriberOnlyGetsEntriesWithSVID(t *testing.T) { +func TestSubscriberOnlyGetsEntriesWithSVID(t *testing.T) { cache := newTestCache() foo := makeRegistrationEntry("FOO", "A") @@ -440,11 +438,7 @@ func TestSubcriberOnlyGetsEntriesWithSVID(t *testing.T) { sub := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) defer sub.Finish() - - // workload update does not include the identity because it has no SVID. - assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ - Bundle: bundleV1, - }) + assertNoWorkloadUpdate(t, sub) // update to include the SVID and now we should get the update cache.UpdateSVIDs(&UpdateSVIDs{ @@ -459,7 +453,7 @@ func TestSubcriberOnlyGetsEntriesWithSVID(t *testing.T) { func TestSubscribersDoNotBlockNotifications(t *testing.T) { cache := newTestCache() - sub := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) defer sub.Finish() cache.UpdateEntries(&UpdateEntries{ @@ -489,34 +483,23 @@ func TestCheckSVIDCallback(t *testing.T) { foo := makeRegistrationEntryWithTTL("FOO", 60) - // called once for FOO with no SVID - callCount := 0 cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV2), RegistrationEntries: makeRegistrationEntries(foo), }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { - callCount++ - assert.Equal(t, "FOO", newEntry.EntryId) - - // there is no already existing entry, only the new entry - assert.Nil(t, existingEntry) - assert.Equal(t, foo, newEntry) - assert.Nil(t, svid) - + // should not get invoked + assert.Fail(t, "should not be called as no SVIDs are cached yet") return false }) - assert.Equal(t, 1, callCount) - assert.Empty(t, cache.staleEntries) // called once for FOO with new SVID - callCount = 0 svids := makeX509SVIDs(foo) cache.UpdateSVIDs(&UpdateSVIDs{ X509SVIDs: svids, }) // called once for FOO with existing SVID - callCount = 0 + callCount := 0 cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV2), RegistrationEntries: makeRegistrationEntries(foo), @@ -537,30 +520,23 @@ func TestGetStaleEntries(t *testing.T) { cache := newTestCache() foo := makeRegistrationEntryWithTTL("FOO", 60) + expiredAt := time.Now() - // Create entry but don't mark it stale + // Create entry but don't mark it stale from checkSVID method; + // it will be marked stale cause it does not have SVID cached cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV2), RegistrationEntries: makeRegistrationEntries(foo), }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { return false }) - assert.Empty(t, cache.GetStaleEntries()) - // Update entry and mark it as stale - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV2), - RegistrationEntries: makeRegistrationEntries(foo), - }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { - return true - }) // Assert that the entry is returned as stale. The `ExpiresAt` field should be unset since there is no SVID. expectedEntries := []*StaleEntry{{Entry: cache.records[foo.EntryId].entry}} assert.Equal(t, expectedEntries, cache.GetStaleEntries()) // Update the SVID for the stale entry svids := make(map[string]*X509SVID) - expiredAt := time.Now() svids[foo.EntryId] = &X509SVID{ Chain: []*x509.Certificate{{NotAfter: expiredAt}}, } @@ -607,7 +583,7 @@ func TestSubscriberNotNotifiedOnDifferentSVIDChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo, bar), }) - sub := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -632,7 +608,7 @@ func TestSubscriberNotNotifiedOnOverlappingSVIDChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo, bar), }) - sub := cache.SubscribeToWorkloadUpdates(makeSelectors("A", "B")) + sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A", "B")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -644,6 +620,183 @@ func TestSubscriberNotNotifiedOnOverlappingSVIDChanges(t *testing.T) { assertNoWorkloadUpdate(t, sub) } +func TestSVIDCacheExpiry(t *testing.T) { + clk := clock.NewMock() + cache := newTestCacheWithConfig(10, 1*time.Minute, clk) + + clk.Add(1 * time.Second) + foo := makeRegistrationEntry("FOO", "A") + // validate workload update for foo + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) + subA.Finish() + + // move clk by 1 sec so that SVID access time will be different + clk.Add(1 * time.Second) + bar := makeRegistrationEntry("BAR", "B") + // validate workload update for bar + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(bar), + }) + + // not closing subscriber immediately + subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) + defer subB.Finish() + assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{ + {Entry: bar}, + }, + }) + + // Move clk by a second + clk.Add(2 * time.Second) + // update total of 12 entries + updateEntries := createUpdateEntries(10, makeBundles(bundleV1)) + updateEntries.RegistrationEntries[foo.EntryId] = foo + updateEntries.RegistrationEntries[bar.EntryId] = bar + + cache.UpdateEntries(updateEntries, nil) + + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDsFromMap(updateEntries.RegistrationEntries), + }) + + for id, entry := range updateEntries.RegistrationEntries { + // create and close subscribers for remaining entries so that svid cache is full + if id != foo.EntryId && id != bar.EntryId { + sub := cache.SubscribeToWorkloadUpdates(entry.Selectors) + sub.Finish() + } + } + + // Move clk by 58 sec so that a minute has passed since last foo was accessed + // svid for foo should be deleted + clk.Add(58 * time.Second) + cache.UpdateEntries(updateEntries, nil) + + subA = cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + defer subA.Finish() + //cache.NotifyBySelectorSet(makeSelectors("A")) + assert.False(t, cache.Notify(makeSelectors("A"))) + assert.Equal(t, 11, cache.CountSVIDs()) + + // move clk by another minute and update entries + clk.Add(1 * time.Minute) + cache.UpdateEntries(updateEntries, nil) + + // Make sure foo is marked as stale entry which does not have svid cached + require.Len(t, cache.GetStaleEntries(), 1) + assert.Equal(t, foo, cache.GetStaleEntries()[0].Entry) + + // bar should not be removed from cache as it has another active subscriber + subB2 := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) + defer subB2.Finish() + assertWorkloadUpdateEqual(t, subB2, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{ + {Entry: bar}, + }, + }) + + // ensure SVIDs without active subscribers are still cached for remainder of cache size + assert.Equal(t, 10, cache.CountSVIDs()) +} + +func TestMaxSVIDCacheSize(t *testing.T) { + clk := clock.NewMock() + cache := newTestCacheWithConfig(10, 1*time.Minute, clk) + + // create entries more than maxSvidCacheSize + updateEntries := createUpdateEntries(12, makeBundles(bundleV1)) + cache.UpdateEntries(updateEntries, nil) + + require.Len(t, cache.GetStaleEntries(), 10) + + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDsFromStaleEntries(cache.GetStaleEntries()), + }) + require.Len(t, cache.GetStaleEntries(), 0) + assert.Equal(t, 10, cache.CountSVIDs()) + + // Validate that active subscriber will still get SVID even if SVID count is at maxSvidCacheSize + foo := makeRegistrationEntry("FOO", "A") + updateEntries.RegistrationEntries[foo.EntryId] = foo + + subA := cache.SubscribeToWorkloadUpdates(foo.Selectors) + defer subA.Finish() + + cache.UpdateEntries(updateEntries, nil) + require.Len(t, cache.GetStaleEntries(), 1) + assert.Equal(t, 10, cache.CountSVIDs()) + + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + assert.Equal(t, 11, cache.CountSVIDs()) + require.Len(t, cache.GetStaleEntries(), 0) +} + +func TestSyncSVIDsWithSubscribers(t *testing.T) { + clk := clock.NewMock() + cache := newTestCacheWithConfig(5, 1*time.Minute, clk) + + updateEntries := createUpdateEntries(5, makeBundles(bundleV1)) + cache.UpdateEntries(updateEntries, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDsFromStaleEntries(cache.GetStaleEntries()), + }) + assert.Equal(t, 5, cache.CountSVIDs()) + + // Update foo but its SVID is not yet cached + foo := makeRegistrationEntry("FOO", "A") + updateEntries.RegistrationEntries[foo.EntryId] = foo + + cache.UpdateEntries(updateEntries, nil) + + // Create a subscriber for foo + subA := cache.SubscribeToWorkloadUpdates(foo.Selectors) + defer subA.Finish() + require.Len(t, cache.GetStaleEntries(), 0) + + // After SyncSVIDsWithSubscribers foo should be marked as stale + cache.SyncSVIDsWithSubscribers() + require.Len(t, cache.GetStaleEntries(), 1) + assert.Equal(t, []*StaleEntry{{Entry: cache.records[foo.EntryId].entry}}, cache.GetStaleEntries()) + + assert.Equal(t, 5, cache.CountSVIDs()) +} + +func TestNotify(t *testing.T) { + cache := newTestCache() + + foo := makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + + assert.False(t, cache.Notify(makeSelectors("A"))) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + assert.True(t, cache.Notify(makeSelectors("A"))) +} + func BenchmarkCacheGlobalNotification(b *testing.B) { cache := newTestCache() @@ -689,6 +842,36 @@ func BenchmarkCacheGlobalNotification(b *testing.B) { } } +func newTestCache() *Cache { + log, _ := test.NewNullLogger() + return New(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, + telemetry.Blackhole{}, 0, 0, clock.NewMock()) +} + +func newTestCacheWithConfig(maxSvidCacheSize int, svidCacheExpiryPeriod time.Duration, clk clock.Clock) *Cache { + log, _ := test.NewNullLogger() + return New(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, telemetry.Blackhole{}, + maxSvidCacheSize, svidCacheExpiryPeriod, clk) +} + +func createUpdateEntries(numEntries int, bundles map[spiffeid.TrustDomain]*bundleutil.Bundle) *UpdateEntries { + updateEntries := &UpdateEntries{ + Bundles: bundles, + RegistrationEntries: make(map[string]*common.RegistrationEntry, numEntries), + } + + for i := 0; i < numEntries; i++ { + entryID := fmt.Sprintf("00000000-0000-0000-0000-%012d", i) + updateEntries.RegistrationEntries[entryID] = &common.RegistrationEntry{ + EntryId: entryID, + ParentId: "spiffe://domain.test/node", + SpiffeId: fmt.Sprintf("spiffe://domain.test/workload-%d", i), + Selectors: distinctSelectors(i, 1), + } + } + return updateEntries +} + func distinctSelectors(id, n int) []*common.Selector { out := make([]*common.Selector, 0, n) for i := 0; i < n; i++ { @@ -744,6 +927,22 @@ func makeX509SVIDs(entries ...*common.RegistrationEntry) map[string]*X509SVID { return out } +func makeX509SVIDsFromMap(entries map[string]*common.RegistrationEntry) map[string]*X509SVID { + out := make(map[string]*X509SVID) + for _, entry := range entries { + out[entry.EntryId] = &X509SVID{} + } + return out +} + +func makeX509SVIDsFromStaleEntries(entries []*StaleEntry) map[string]*X509SVID { + out := make(map[string]*X509SVID) + for _, entry := range entries { + out[entry.Entry.EntryId] = &X509SVID{} + } + return out +} + func makeRegistrationEntry(id string, selectors ...string) *common.RegistrationEntry { return &common.RegistrationEntry{ EntryId: id, @@ -787,3 +986,9 @@ func makeFederatesWith(bundles ...*Bundle) []string { } return out } + +func subscribeToWorkloadUpdatesAndNotify(t *testing.T, cache *Cache, selectors []*common.Selector) Subscriber { + subscriber := cache.SubscribeToWorkloadUpdates(selectors) + assert.True(t, cache.Notify(selectors)) + return subscriber +} diff --git a/pkg/agent/manager/config.go b/pkg/agent/manager/config.go index fa1d93714a..53208ebb3c 100644 --- a/pkg/agent/manager/config.go +++ b/pkg/agent/manager/config.go @@ -19,19 +19,21 @@ import ( // Config holds a cache manager configuration type Config struct { // Agent SVID and key resulting from successful attestation. - SVID []*x509.Certificate - SVIDKey keymanager.Key - Bundle *cache.Bundle - Catalog catalog.Catalog - TrustDomain spiffeid.TrustDomain - Log logrus.FieldLogger - Metrics telemetry.Metrics - ServerAddr string - SVIDCachePath string - BundleCachePath string - SyncInterval time.Duration - RotationInterval time.Duration - SVIDStoreCache *storecache.Cache + SVID []*x509.Certificate + SVIDKey keymanager.Key + Bundle *cache.Bundle + Catalog catalog.Catalog + TrustDomain spiffeid.TrustDomain + Log logrus.FieldLogger + Metrics telemetry.Metrics + ServerAddr string + SVIDCachePath string + BundleCachePath string + SyncInterval time.Duration + RotationInterval time.Duration + SVIDStoreCache *storecache.Cache + MaxSvidCacheSize int + SVIDCacheExpiryPeriod time.Duration // Clk is the clock the manager will use to get time Clk clock.Clock @@ -55,7 +57,8 @@ func newManager(c *Config) *manager { c.Clk = clock.New() } - cache := cache.New(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, c.Metrics) + cache := cache.New(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, + c.Metrics, c.MaxSvidCacheSize, c.SVIDCacheExpiryPeriod, c.Clk) rotCfg := &svid.RotatorConfig{ SVIDKeyManager: keymanager.ForSVID(c.Catalog.GetKeyManager()), diff --git a/pkg/agent/manager/manager.go b/pkg/agent/manager/manager.go index 98dcddb19e..1d0f60a2e8 100644 --- a/pkg/agent/manager/manager.go +++ b/pkg/agent/manager/manager.go @@ -24,6 +24,8 @@ import ( "github.com/spiffe/spire/proto/spire/common" ) +const svidSyncInterval = 500 * time.Millisecond + // Cache Manager errors var ( ErrNotCached = errors.New("not cached") @@ -59,9 +61,9 @@ type Manager interface { // SetRotationFinishedHook sets a hook that will be called when a rotation finished SetRotationFinishedHook(func()) - // MatchingIdentities returns all of the cached identities whose - // registration entry selectors are a subset of the passed selectors. - MatchingIdentities(selectors []*common.Selector) []cache.Identity + // MatchingRegistrationEntries returns all of the cached registration entries whose + // selectors are a subset of the passed selectors. + MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry // FetchWorkloadUpdates gets the latest workload update for the selectors FetchWorkloadUpdate(selectors []*common.Selector) *cache.WorkloadUpdate @@ -94,7 +96,8 @@ type manager struct { // backoff calculator for fetch interval, backing off if error is returned on // fetch attempt - backoff backoff.BackOff + backoff backoff.BackOff + svidSyncBackoff backoff.BackOff client client.Client @@ -112,6 +115,7 @@ func (m *manager) Initialize(ctx context.Context) error { m.storeBundle(m.cache.Bundle()) m.backoff = backoff.NewBackoff(m.clk, m.c.SyncInterval) + m.svidSyncBackoff = backoff.NewBackoff(m.clk, svidSyncInterval) err := m.synchronize(ctx) if nodeutil.ShouldAgentReattest(err) { @@ -126,6 +130,7 @@ func (m *manager) Run(ctx context.Context) error { err := util.RunTasks(ctx, m.runSynchronizer, + m.runSyncSVIDs, m.runSVIDObserver, m.runBundleObserver, m.svid.Run) @@ -145,7 +150,17 @@ func (m *manager) Run(ctx context.Context) error { } func (m *manager) SubscribeToCacheChanges(selectors cache.Selectors) cache.Subscriber { - return m.cache.SubscribeToWorkloadUpdates(selectors) + subscriber := m.cache.SubscribeToWorkloadUpdates(selectors) + backoff := backoff.NewBackoff(m.clk, svidSyncInterval) + // block until all svids are cached and subscriber is notified + for { + if m.cache.Notify(selectors) { + return subscriber + } + select { + case <-m.clk.After(backoff.NextBackOff()): + } + } } func (m *manager) SubscribeToSVIDChanges() observer.Stream { @@ -168,8 +183,8 @@ func (m *manager) SetRotationFinishedHook(f func()) { m.svid.SetRotationFinishedHook(f) } -func (m *manager) MatchingIdentities(selectors []*common.Selector) []cache.Identity { - return m.cache.MatchingIdentities(selectors) +func (m *manager) MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry { + return m.cache.MatchingRegistrationEntries(selectors) } func (m *manager) CountSVIDs() int { @@ -211,9 +226,9 @@ func (m *manager) FetchJWTSVID(ctx context.Context, spiffeID spiffeid.ID, audien } func (m *manager) getEntryID(spiffeID string) string { - for _, identity := range m.cache.Identities() { - if identity.Entry.SpiffeId == spiffeID { - return identity.Entry.EntryId + for _, entry := range m.cache.Entries() { + if entry.SpiffeId == spiffeID { + return entry.EntryId } } return "" @@ -241,6 +256,25 @@ func (m *manager) runSynchronizer(ctx context.Context) error { } } +func (m *manager) runSyncSVIDs(ctx context.Context) error { + for { + select { + case <-m.clk.After(m.svidSyncBackoff.NextBackOff()): + case <-ctx.Done(): + return nil + } + + err := m.syncSVIDs(ctx) + switch { + case err != nil: + // Just log the error and wait for next synchronization + m.c.Log.WithError(err).Error("SVID sync failed") + default: + m.svidSyncBackoff.Reset() + } + } +} + func (m *manager) setLastSync() { m.mtx.Lock() defer m.mtx.Unlock() diff --git a/pkg/agent/manager/manager_test.go b/pkg/agent/manager/manager_test.go index e7b94b8b52..53279b54e6 100644 --- a/pkg/agent/manager/manager_test.go +++ b/pkg/agent/manager/manager_test.go @@ -71,16 +71,17 @@ func TestInitializationFailure(t *testing.T) { cat.SetKeyManager(km) c := &Config{ - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - Metrics: &telemetry.Blackhole{}, - TrustDomain: trustDomain, - SVIDCachePath: path.Join(dir, "svid.der"), - BundleCachePath: path.Join(dir, "bundle.der"), - Clk: clk, - Catalog: cat, - SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + Metrics: &telemetry.Blackhole{}, + TrustDomain: trustDomain, + SVIDCachePath: path.Join(dir, "svid.der"), + BundleCachePath: path.Join(dir, "bundle.der"), + Clk: clk, + Catalog: cat, + MaxSvidCacheSize: 1, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m := newManager(c) require.Error(t, m.Initialize(context.Background())) @@ -97,16 +98,17 @@ func TestStoreBundleOnStartup(t *testing.T) { cat.SetKeyManager(km) c := &Config{ - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - Metrics: &telemetry.Blackhole{}, - TrustDomain: trustDomain, - SVIDCachePath: path.Join(dir, "svid.der"), - BundleCachePath: path.Join(dir, "bundle.der"), - Bundle: bundleutil.BundleFromRootCA(trustDomain, ca), - Clk: clk, - Catalog: cat, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + Metrics: &telemetry.Blackhole{}, + TrustDomain: trustDomain, + SVIDCachePath: path.Join(dir, "svid.der"), + BundleCachePath: path.Join(dir, "bundle.der"), + Bundle: bundleutil.BundleFromRootCA(trustDomain, ca), + Clk: clk, + Catalog: cat, + MaxSvidCacheSize: 1, } m := newManager(c) @@ -144,15 +146,16 @@ func TestStoreSVIDOnStartup(t *testing.T) { cat.SetKeyManager(km) c := &Config{ - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - Metrics: &telemetry.Blackhole{}, - TrustDomain: trustDomain, - SVIDCachePath: path.Join(dir, "svid.der"), - BundleCachePath: path.Join(dir, "bundle.der"), - Clk: clk, - Catalog: cat, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + Metrics: &telemetry.Blackhole{}, + TrustDomain: trustDomain, + SVIDCachePath: path.Join(dir, "svid.der"), + BundleCachePath: path.Join(dir, "bundle.der"), + Clk: clk, + Catalog: cat, + MaxSvidCacheSize: 1, } _, err := ReadSVID(c.SVIDCachePath) @@ -228,9 +231,9 @@ func TestHappyPathWithoutSyncNorRotation(t *testing.T) { t.Fatal("PrivateKey is not equals to configured one") } - matches := m.MatchingIdentities(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + matches := m.MatchingRegistrationEntries(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) if len(matches) != 2 { - t.Fatal("expected 2 identities") + t.Fatal("expected 2 registration entries") } // Verify bundle @@ -244,7 +247,7 @@ func TestHappyPathWithoutSyncNorRotation(t *testing.T) { compareRegistrationEntries(t, regEntriesMap["resp2"], - []*common.RegistrationEntry{matches[0].Entry, matches[1].Entry}) + []*common.RegistrationEntry{matches[0], matches[1]}) util.RunWithTimeout(t, 5*time.Second, func() { sub := m.SubscribeToCacheChanges(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) @@ -307,6 +310,7 @@ func TestSVIDRotation(t *testing.T) { RotationInterval: baseTTLSeconds / 2, SyncInterval: 1 * time.Hour, Clk: clk, + MaxSvidCacheSize: 1, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } @@ -415,6 +419,7 @@ func TestSynchronization(t *testing.T) { SyncInterval: time.Hour, Clk: clk, Catalog: cat, + MaxSvidCacheSize: 1, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } @@ -555,18 +560,19 @@ func TestSynchronizationClearsStaleCacheEntries(t *testing.T) { cat.SetKeyManager(km) c := &Config{ - ServerAddr: api.addr, - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - TrustDomain: trustDomain, - SVIDCachePath: path.Join(dir, "svid.der"), - BundleCachePath: path.Join(dir, "bundle.der"), - Bundle: api.bundle, - Metrics: &telemetry.Blackhole{}, - Clk: clk, - Catalog: cat, - SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + SVIDCachePath: path.Join(dir, "svid.der"), + BundleCachePath: path.Join(dir, "bundle.der"), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + Clk: clk, + Catalog: cat, + MaxSvidCacheSize: 1, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m := newManager(c) @@ -579,7 +585,7 @@ func TestSynchronizationClearsStaleCacheEntries(t *testing.T) { // entries. compareRegistrationEntries(t, append(regEntriesMap["resp1"], regEntriesMap["resp2"]...), - regEntriesFromIdentities(m.cache.Identities())) + m.cache.Entries()) // manually synchronize again if err := m.synchronize(context.Background()); err != nil { @@ -589,7 +595,7 @@ func TestSynchronizationClearsStaleCacheEntries(t *testing.T) { // now the cache should have entries from resp2 removed compareRegistrationEntries(t, regEntriesMap["resp1"], - regEntriesFromIdentities(m.cache.Identities())) + m.cache.Entries()) } func TestSynchronizationUpdatesRegistrationEntries(t *testing.T) { @@ -628,18 +634,19 @@ func TestSynchronizationUpdatesRegistrationEntries(t *testing.T) { cat.SetKeyManager(km) c := &Config{ - ServerAddr: api.addr, - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - TrustDomain: trustDomain, - SVIDCachePath: path.Join(dir, "svid.der"), - BundleCachePath: path.Join(dir, "bundle.der"), - Bundle: api.bundle, - Metrics: &telemetry.Blackhole{}, - Clk: clk, - Catalog: cat, - SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + SVIDCachePath: path.Join(dir, "svid.der"), + BundleCachePath: path.Join(dir, "bundle.der"), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + Clk: clk, + Catalog: cat, + MaxSvidCacheSize: 1, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m := newManager(c) @@ -651,7 +658,7 @@ func TestSynchronizationUpdatesRegistrationEntries(t *testing.T) { // after initialization, the cache should contain resp2 entries compareRegistrationEntries(t, regEntriesMap["resp2"], - regEntriesFromIdentities(m.cache.Identities())) + m.cache.Entries()) // manually synchronize again if err := m.synchronize(context.Background()); err != nil { @@ -661,7 +668,7 @@ func TestSynchronizationUpdatesRegistrationEntries(t *testing.T) { // now the cache should have the updated entries from resp3 compareRegistrationEntries(t, regEntriesMap["resp3"], - regEntriesFromIdentities(m.cache.Identities())) + m.cache.Entries()) } func TestSubscribersGetUpToDateBundle(t *testing.T) { @@ -699,6 +706,7 @@ func TestSubscribersGetUpToDateBundle(t *testing.T) { RotationInterval: 1 * time.Hour, SyncInterval: 1 * time.Hour, Clk: clk, + MaxSvidCacheSize: 1, Catalog: cat, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } @@ -721,6 +729,256 @@ func TestSubscribersGetUpToDateBundle(t *testing.T) { }) } +func TestSynchronizationClearsExpiredSVIDCache(t *testing.T) { + dir := spiretest.TempDir(t) + km := fakeagentkeymanager.New(t, dir) + + clk := clock.NewMock(t) + api := newMockAPI(t, &mockAPIConfig{ + km: km, + getAuthorizedEntries: func(h *mockAPI, count int32, req *entryv1.GetAuthorizedEntriesRequest) (*entryv1.GetAuthorizedEntriesResponse, error) { + return makeGetAuthorizedEntriesResponse(t, "resp1", "resp2"), nil + }, + batchNewX509SVIDEntries: func(h *mockAPI, count int32) []*common.RegistrationEntry { + h.rotateCA() + return makeBatchNewX509SVIDEntries("resp1", "resp2") + }, + svidTTL: 200, + clk: clk, + }) + + baseSVID, baseSVIDKey := api.newSVID(joinTokenID, 1*time.Hour) + cat := fakeagentcatalog.New() + cat.SetKeyManager(km) + + c := &Config{ + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + SVIDCachePath: path.Join(dir, "svid.der"), + BundleCachePath: path.Join(dir, "bundle.der"), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + RotationInterval: 1 * time.Hour, + SyncInterval: 1 * time.Hour, + MaxSvidCacheSize: 1, + SVIDCacheExpiryPeriod: 5 * time.Second, + Clk: clk, + Catalog: cat, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + } + + clk.Add(1 * time.Second) + + m := newManager(c) + if err := m.Initialize(context.Background()); err != nil { + t.Fatal(err) + } + + // After Initialize, just 1 SVID should be cached + require.Equal(t, 1, m.CountSVIDs()) + waitCh := make(chan struct{}) + + closer := runSVIDSync(t, waitCh, m) + defer closer() + + // Keep clk moving so that each subscriber gets SVID after SVID sync + clkCloser := moveClkAfterInterval(clk, 100*time.Millisecond, svidSyncInterval, waitCh) + + sub1 := m.SubscribeToCacheChanges(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + + sub2 := m.SubscribeToCacheChanges( + cache.Selectors{{Type: "spiffe_id", Value: "spiffe://example.org/spire/agent/join_token/abcd"}}) + + sub1.Finish() + sub2.Finish() + close(waitCh) + clkCloser() + + // All 3 SVIDs should be cached + require.Equal(t, 3, m.CountSVIDs()) + + // Move clock so that svid cache is expired + clk.Add(6 * time.Second) + + require.NoError(t, m.synchronize(context.Background())) + + // Make sure svid count is MaxSvidCacheSize and remaining SVIDs are deleted from cache + require.Equal(t, 1, m.CountSVIDs()) +} + +func TestSyncSVIDs(t *testing.T) { + dir := spiretest.TempDir(t) + km := fakeagentkeymanager.New(t, dir) + + clk := clock.NewMock(t) + api := newMockAPI(t, &mockAPIConfig{ + km: km, + getAuthorizedEntries: func(h *mockAPI, count int32, req *entryv1.GetAuthorizedEntriesRequest) (*entryv1.GetAuthorizedEntriesResponse, error) { + return makeGetAuthorizedEntriesResponse(t, "resp1", "resp2"), nil + }, + batchNewX509SVIDEntries: func(h *mockAPI, count int32) []*common.RegistrationEntry { + h.rotateCA() + return makeBatchNewX509SVIDEntries("resp1", "resp2") + }, + svidTTL: 200, + clk: clk, + }) + + baseSVID, baseSVIDKey := api.newSVID(joinTokenID, 1*time.Hour) + cat := fakeagentcatalog.New() + cat.SetKeyManager(km) + + c := &Config{ + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + SVIDCachePath: path.Join(dir, "svid.der"), + BundleCachePath: path.Join(dir, "bundle.der"), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + RotationInterval: 1 * time.Hour, + SyncInterval: 1 * time.Hour, + MaxSvidCacheSize: 1, + SVIDCacheExpiryPeriod: 5 * time.Second, + Clk: clk, + Catalog: cat, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + } + + clk.Add(1 * time.Second) + + m := newManager(c) + closer := initializeAndRunManager(t, m) + defer closer() + + // After Initialize, just 1 SVID should be cached + require.Equal(t, 1, m.CountSVIDs()) + waitCh := make(chan struct{}) + + // Keep clk moving so that each subscriber gets SVID after SVID sync + clkCloser := moveClkAfterInterval(clk, 100*time.Millisecond, svidSyncInterval, waitCh) + defer clkCloser() + + sub1 := m.SubscribeToCacheChanges(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + defer sub1.Finish() + + sub2 := m.SubscribeToCacheChanges( + cache.Selectors{{Type: "spiffe_id", Value: "spiffe://example.org/spire/agent/join_token/abcd"}}) + defer sub2.Finish() + + close(waitCh) + + // All 3 SVIDs should be cached + require.Equal(t, 3, m.CountSVIDs()) +} + +func TestSubscribersWaitForSVID(t *testing.T) { + dir := spiretest.TempDir(t) + km := fakeagentkeymanager.New(t, dir) + + clk := clock.NewMock(t) + api := newMockAPI(t, &mockAPIConfig{ + km: km, + getAuthorizedEntries: func(h *mockAPI, count int32, req *entryv1.GetAuthorizedEntriesRequest) (*entryv1.GetAuthorizedEntriesResponse, error) { + return makeGetAuthorizedEntriesResponse(t, "resp1", "resp2"), nil + }, + batchNewX509SVIDEntries: func(h *mockAPI, count int32) []*common.RegistrationEntry { + h.rotateCA() + return makeBatchNewX509SVIDEntries("resp1", "resp2") + }, + svidTTL: 200, + clk: clk, + }) + + baseSVID, baseSVIDKey := api.newSVID(joinTokenID, 1*time.Hour) + cat := fakeagentcatalog.New() + cat.SetKeyManager(km) + + c := &Config{ + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + SVIDCachePath: path.Join(dir, "svid.der"), + BundleCachePath: path.Join(dir, "bundle.der"), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + RotationInterval: 1 * time.Hour, + SyncInterval: 1 * time.Hour, + MaxSvidCacheSize: 1, + Clk: clk, + Catalog: cat, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + } + + m := newManager(c) + + if err := m.Initialize(context.Background()); err != nil { + t.Fatal(err) + } + + // After Initialize, just 1 SVID should be cached + require.Equal(t, 1, m.CountSVIDs()) + + waitCh := make(chan struct{}) + + closer := runSVIDSync(t, waitCh, m) + defer closer() + + // Keep clk moving so that each subscriber gets SVID after SVID sync + clkCloser := moveClkAfterInterval(clk, 100*time.Millisecond, svidSyncInterval, waitCh) + defer clkCloser() + + go func() { + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + sub1 := m.SubscribeToCacheChanges(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + defer sub1.Finish() + u := <-sub1.Updates() + if len(u.Identities) != 2 { + t.Fatalf("expected 2 SVIDs, got: %d", len(u.Identities)) + } + if !u.Bundle.EqualTo(c.Bundle) { + t.Fatal("bundles were expected to be equal") + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + sub2 := m.SubscribeToCacheChanges( + cache.Selectors{{Type: "spiffe_id", Value: "spiffe://example.org/spire/agent/join_token/abcd"}}) + defer sub2.Finish() + u := <-sub2.Updates() + if len(u.Identities) != 1 { + t.Fatalf("expected 1 SVID, got: %d", len(u.Identities)) + } + if !u.Bundle.EqualTo(c.Bundle) { + t.Fatal("bundles were expected to be equal") + } + }() + + wg.Wait() + close(waitCh) + }() + + select { + case <-waitCh: + case <-time.After(5 * time.Second): + t.Fatalf("subscriber update wait timed out") + } + + require.Equal(t, 3, m.CountSVIDs()) +} + func TestSurvivesCARotation(t *testing.T) { dir := spiretest.TempDir(t) km := fakeagentkeymanager.New(t, dir) @@ -762,6 +1020,7 @@ func TestSurvivesCARotation(t *testing.T) { SyncInterval: syncInterval, Clk: clk, Catalog: cat, + MaxSvidCacheSize: 1, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } @@ -811,18 +1070,19 @@ func TestFetchJWTSVID(t *testing.T) { baseSVID, baseSVIDKey := api.newSVID(joinTokenID, 1*time.Hour) c := &Config{ - ServerAddr: api.addr, - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - TrustDomain: trustDomain, - SVIDCachePath: path.Join(dir, "svid.der"), - BundleCachePath: path.Join(dir, "bundle.der"), - Bundle: api.bundle, - Metrics: &telemetry.Blackhole{}, - Catalog: cat, - Clk: clk, - SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + SVIDCachePath: path.Join(dir, "svid.der"), + BundleCachePath: path.Join(dir, "bundle.der"), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + Catalog: cat, + Clk: clk, + MaxSvidCacheSize: 1, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m := newManager(c) @@ -935,18 +1195,19 @@ func TestStorableSVIDsSync(t *testing.T) { cat.SetKeyManager(fakeagentkeymanager.New(t, dir)) c := &Config{ - ServerAddr: api.addr, - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - TrustDomain: trustDomain, - SVIDCachePath: path.Join(dir, "svid.der"), - BundleCachePath: path.Join(dir, "bundle.der"), - Bundle: api.bundle, - Metrics: &telemetry.Blackhole{}, - Clk: clk, - Catalog: cat, - SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + SVIDCachePath: path.Join(dir, "svid.der"), + BundleCachePath: path.Join(dir, "bundle.der"), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + Clk: clk, + Catalog: cat, + MaxSvidCacheSize: 1, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m, closer := initializeAndRunNewManager(t, c) @@ -984,6 +1245,44 @@ func TestStorableSVIDsSync(t *testing.T) { validateResponse(records, entries) } +func runSVIDSync(t *testing.T, waitCh chan struct{}, m *manager) (closer func()) { + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-waitCh: + return + case <-m.clk.After(svidSyncInterval): + require.NoError(t, m.syncSVIDs(context.Background())) + } + } + }() + return func() { + wg.Wait() + } +} + +func moveClkAfterInterval(clk *clock.Mock, interval, period time.Duration, waitCh chan struct{}) (closer func()) { + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-waitCh: + return + case <-time.After(interval): + clk.Add(period) + } + } + }() + return func() { + wg.Wait() + } +} + func makeGetAuthorizedEntriesResponse(t *testing.T, respKeys ...string) *entryv1.GetAuthorizedEntriesResponse { var entries []*types.Entry for _, respKey := range respKeys { @@ -1032,13 +1331,6 @@ func identitiesByEntryID(ces []cache.Identity) (result map[string]cache.Identity return result } -func regEntriesFromIdentities(ces []cache.Identity) (result []*common.RegistrationEntry) { - for _, ce := range ces { - result = append(result, ce.Entry) - } - return result -} - func compareRegistrationEntries(t *testing.T, expected, actual []*common.RegistrationEntry) { if len(expected) != len(actual) { t.Fatalf("entries count doesn't match, expected: %d, got: %d", len(expected), len(actual)) diff --git a/pkg/agent/manager/sync.go b/pkg/agent/manager/sync.go index 040432912d..655924ce69 100644 --- a/pkg/agent/manager/sync.go +++ b/pkg/agent/manager/sync.go @@ -37,6 +37,11 @@ type Cache interface { GetStaleEntries() []*cache.StaleEntry } +func (m *manager) syncSVIDs(ctx context.Context) (err error) { + m.cache.SyncSVIDsWithSubscribers() + return m.updateSVIDs(ctx, m.cache.GetStaleEntries(), m.cache) +} + // synchronize fetches the authorized entries from the server, updates the // cache, and fetches missing/expiring SVIDs. func (m *manager) synchronize(ctx context.Context) (err error) { @@ -63,7 +68,6 @@ func (m *manager) updateCache(ctx context.Context, update *cache.UpdateEntries, // in this interval. // // the values in `update` now belong to the cache. DO NOT MODIFY. - var csrs []csrRequest var expiring int var outdated int c.UpdateEntries(update, func(existingEntry, newEntry *common.RegistrationEntry, svid *cache.X509SVID) bool { @@ -105,16 +109,24 @@ func (m *manager) updateCache(ctx context.Context, update *cache.UpdateEntries, telemetry.Count: len(staleEntries), telemetry.Limit: limits.SignLimitPerIP, }).Debug("Renewing stale entries") - for _, staleEntry := range staleEntries { + return m.updateSVIDs(ctx, staleEntries, c) + } + return nil +} + +func (m *manager) updateSVIDs(ctx context.Context, entries []*cache.StaleEntry, c Cache) error { + var csrs []csrRequest + if len(entries) > 0 { + for _, entry := range entries { // we've exceeded the CSR limit, don't make any more CSRs if len(csrs) >= limits.SignLimitPerIP { break } csrs = append(csrs, csrRequest{ - EntryID: staleEntry.Entry.EntryId, - SpiffeID: staleEntry.Entry.SpiffeId, - CurrentSVIDExpiresAt: staleEntry.ExpiresAt, + EntryID: entry.Entry.EntryId, + SpiffeID: entry.Entry.SpiffeId, + CurrentSVIDExpiresAt: entry.ExpiresAt, }) } @@ -125,7 +137,6 @@ func (m *manager) updateCache(ctx context.Context, update *cache.UpdateEntries, // the values in `update` now belong to the cache. DO NOT MODIFY. c.UpdateSVIDs(update) } - return nil } diff --git a/test/integration/common b/test/integration/common index 89c7bb2f4c..114ad4914c 100644 --- a/test/integration/common +++ b/test/integration/common @@ -83,6 +83,25 @@ check-synced-entry() { fail-now "timed out waiting for agent to sync down entry" } +check-svid-count() { + MAXCHECKS=50 + CHECKINTERVAL=1 + + for ((i=0;i<=MAXCHECKS;i++)); do + if (( $i>=$MAXCHECKS )); then + fail-now "svid count validation failed" + fi + log-info "check svid count on agent debug endpoint ($(($i+1)) of $MAXCHECKS max)..." + COUNT=`docker-compose exec -T $1 /opt/spire/conf/agent/debugclient -testCase "printDebugPage" | jq '."svidsCount"'` + log-info "svidsCount: ${COUNT}" + if [ "$COUNT" -eq "$2" ]; then + log-info "SVID count of $COUNT from cache matches the expected count of $2" + break + fi + sleep "${CHECKINTERVAL}" + done +} + build-mashup-image() { ENVOY_VERSION=$1 ENVOY_IMAGE_TAG="${ENVOY_VERSION}-latest" diff --git a/test/integration/setup/debugagent/main.go b/test/integration/setup/debugagent/main.go index bff310d966..67d10ab634 100644 --- a/test/integration/setup/debugagent/main.go +++ b/test/integration/setup/debugagent/main.go @@ -38,6 +38,8 @@ func run() error { var err error switch *testCaseFlag { + case "printDebugPage": + err = printDebugPage(ctx) case "agentEndpoints": err = agentEndpoints(ctx) case "serverWithWorkload": @@ -52,26 +54,40 @@ func run() error { } func agentEndpoints(ctx context.Context) error { + s,err := retrieveDebugPage(ctx) + if err == nil { + log.Printf("Debug info: %s", string(s)) + } + return nil +} + +func printDebugPage(ctx context.Context) error { + s,err := retrieveDebugPage(ctx) + if err == nil { + fmt.Println(s) + } + return nil +} + +func retrieveDebugPage(ctx context.Context) (string,error) { conn, err := grpc.Dial(*socketPathFlag, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { - return fmt.Errorf("failed to connect server: %w", err) + return "", fmt.Errorf("failed to connect server: %w", err) } defer conn.Close() client := agent_debugv1.NewDebugClient(conn) resp, err := client.GetInfo(ctx, &agent_debugv1.GetInfoRequest{}) if err != nil { - return fmt.Errorf("failed to get info: %w", err) + return "", fmt.Errorf("failed to get info: %w", err) } m := protojson.MarshalOptions{Indent: " "} s, err := m.Marshal(resp) if err != nil { - return fmt.Errorf("failed to parse proto: %w", err) + return "", fmt.Errorf("failed to parse proto: %w", err) } - - log.Printf("Debug info: %s", string(s)) - return nil + return string(s),nil } func serverWithWorkload(ctx context.Context) error { diff --git a/test/integration/suites/fetch-svids/00-setup b/test/integration/suites/fetch-svids/00-setup new file mode 100755 index 0000000000..c1fb18218e --- /dev/null +++ b/test/integration/suites/fetch-svids/00-setup @@ -0,0 +1,6 @@ +#!/bin/bash + +"${ROOTDIR}/setup/x509pop/setup.sh" conf/server conf/agent + +"${ROOTDIR}/setup/debugserver/build.sh" "${RUNDIR}/conf/server/debugclient" +"${ROOTDIR}/setup/debugagent/build.sh" "${RUNDIR}/conf/agent/debugclient" diff --git a/test/integration/suites/fetch-svids/01-start-server b/test/integration/suites/fetch-svids/01-start-server new file mode 100755 index 0000000000..a3e999b264 --- /dev/null +++ b/test/integration/suites/fetch-svids/01-start-server @@ -0,0 +1,3 @@ +#!/bin/bash + +docker-up spire-server diff --git a/test/integration/suites/fetch-svids/02-bootstrap-agent b/test/integration/suites/fetch-svids/02-bootstrap-agent new file mode 100755 index 0000000000..405147f2fd --- /dev/null +++ b/test/integration/suites/fetch-svids/02-bootstrap-agent @@ -0,0 +1,5 @@ +#!/bin/bash + +log-debug "bootstrapping agent..." +docker-compose exec -T spire-server \ + /opt/spire/bin/spire-server bundle show > conf/agent/bootstrap.crt diff --git a/test/integration/suites/fetch-svids/03-start-agent b/test/integration/suites/fetch-svids/03-start-agent new file mode 100755 index 0000000000..ac36d05f0d --- /dev/null +++ b/test/integration/suites/fetch-svids/03-start-agent @@ -0,0 +1,3 @@ +#!/bin/bash + +docker-up spire-agent diff --git a/test/integration/suites/fetch-svids/04-create-registration-entries b/test/integration/suites/fetch-svids/04-create-registration-entries new file mode 100755 index 0000000000..6a79e4b6ad --- /dev/null +++ b/test/integration/suites/fetch-svids/04-create-registration-entries @@ -0,0 +1,18 @@ +#!/bin/bash + +SIZE=12 + +# Create entries for uid 1001 +for ((m=1; m<=$SIZE;m++)); do + log-debug "creating registration entry: $m" + docker-compose exec -T spire-server \ + /opt/spire/bin/spire-server entry create \ + -parentID "spiffe://domain.test/spire/agent/x509pop/$(fingerprint conf/agent/agent.crt.pem)" \ + -spiffeID "spiffe://domain.test/workload-$m" \ + -selector "unix:uid:1001" \ + -ttl 0 & +done + +for ((m=1; m<=$SIZE;m++)); do + check-synced-entry "spire-agent" "spiffe://domain.test/workload-$m" +done diff --git a/test/integration/suites/fetch-svids/05-fetch-svids b/test/integration/suites/fetch-svids/05-fetch-svids new file mode 100755 index 0000000000..646579ea02 --- /dev/null +++ b/test/integration/suites/fetch-svids/05-fetch-svids @@ -0,0 +1,14 @@ +#!/bin/bash + +ENTRYCOUNT=12 +CACHESIZE=8 + +docker-compose exec -u 1001 -T spire-agent \ + /opt/spire/bin/spire-agent api fetch x509 \ + -socketPath /opt/spire/sockets/workload_api.sock || fail-now "x509-SVID check failed" + +# Call agent debug endpoints and check if svid count is equal number of entries registered +check-svid-count "spire-agent" $ENTRYCOUNT + +# Call agent debug endpoints and check if svids from cache are cleaned up after expiry +check-svid-count "spire-agent" $CACHESIZE diff --git a/test/integration/suites/fetch-svids/06-create-registration-entries b/test/integration/suites/fetch-svids/06-create-registration-entries new file mode 100755 index 0000000000..abacd71a36 --- /dev/null +++ b/test/integration/suites/fetch-svids/06-create-registration-entries @@ -0,0 +1,19 @@ +#!/bin/bash + +SIZE=5 + +# Create entries for uid 1002 +for ((m=1; m<=$SIZE;m++)); do + log-debug "creating regular registration entry...($m)" + docker-compose exec -T spire-server \ + /opt/spire/bin/spire-server entry create \ + -parentID "spiffe://domain.test/spire/agent/x509pop/$(fingerprint conf/agent/agent.crt.pem)" \ + -spiffeID "spiffe://domain.test/workload/$m" \ + -selector "unix:uid:1002" \ + -ttl 0 & +done + +for ((m=1; m<=$SIZE;m++)); do + check-synced-entry "spire-agent" "spiffe://domain.test/workload/$m" + ((m++)) +done diff --git a/test/integration/suites/fetch-svids/07-fetch-svids b/test/integration/suites/fetch-svids/07-fetch-svids new file mode 100755 index 0000000000..afc484e260 --- /dev/null +++ b/test/integration/suites/fetch-svids/07-fetch-svids @@ -0,0 +1,28 @@ +#!/bin/bash + +CACHESIZE=8 + +docker-compose exec -u 1002 -T spire-agent \ + /opt/spire/bin/spire-agent api fetch x509 \ + -socketPath /opt/spire/sockets/workload_api.sock || fail-now "x509-SVID check failed" + +# Call agent debug endpoints and check if svid count is equal to cache size limit +check-svid-count "spire-agent" $CACHESIZE + +# introduce some delay between two fetch calls so that we can validate cache cleanup of svids from first fetch. +sleep 5 + +docker-compose exec -u 1001 -T spire-agent \ + /opt/spire/bin/spire-agent api fetch x509 \ + -socketPath /opt/spire/sockets/workload_api.sock || fail-now "x509-SVID check failed" + +# Call agent debug endpoints and check if svid count is equal to 17 (registration entry count for uids 1001 and 1002) +check-svid-count "spire-agent" 17 + +# Call agent debug endpoints and check if svid count is equal to 12 +# 17(svid-count) - 5(svids from entries with uuid 1002 will be removed first after svid_cache_expiry_interval) = 12 +check-svid-count "spire-agent" 12 + +# Call agent debug endpoints and check if svid count is equal to 8 +# 12 - 8(cache size) = 4 extra svids from entries with uuid 1001 will be removed after svid_cache_expiry_interval +check-svid-count "spire-agent" $CACHESIZE diff --git a/test/integration/suites/fetch-svids/README.md b/test/integration/suites/fetch-svids/README.md new file mode 100644 index 0000000000..509c2d0390 --- /dev/null +++ b/test/integration/suites/fetch-svids/README.md @@ -0,0 +1,5 @@ +# Fetch x509 SVID Suite + +## Description + +This suite validates svid cache operations from spire-agent cache. diff --git a/test/integration/suites/fetch-svids/conf/agent/agent.conf b/test/integration/suites/fetch-svids/conf/agent/agent.conf new file mode 100644 index 0000000000..c2ec900236 --- /dev/null +++ b/test/integration/suites/fetch-svids/conf/agent/agent.conf @@ -0,0 +1,32 @@ +agent { + data_dir = "/opt/spire/data/agent" + log_level = "DEBUG" + server_address = "spire-server" + server_port = "8081" + socket_path = "/opt/spire/sockets/workload_api.sock" + trust_bundle_path = "/opt/spire/conf/agent/bootstrap.crt" + trust_domain = "domain.test" + admin_socket_path = "/opt/debug.sock" + experimental { + max_svid_cache_size = 8 + svid_cache_expiry_interval = "30s" + } +} + +plugins { + NodeAttestor "x509pop" { + plugin_data { + private_key_path = "/opt/spire/conf/agent/agent.key.pem" + certificate_path = "/opt/spire/conf/agent/agent.crt.pem" + } + } + KeyManager "disk" { + plugin_data { + directory = "/opt/spire/data/agent" + } + } + WorkloadAttestor "unix" { + plugin_data { + } + } +} diff --git a/test/integration/suites/fetch-svids/conf/server/server.conf b/test/integration/suites/fetch-svids/conf/server/server.conf new file mode 100644 index 0000000000..a8f18c0680 --- /dev/null +++ b/test/integration/suites/fetch-svids/conf/server/server.conf @@ -0,0 +1,26 @@ +server { + bind_address = "0.0.0.0" + bind_port = "8081" + trust_domain = "domain.test" + data_dir = "/opt/spire/data/server" + log_level = "DEBUG" + ca_ttl = "1h" + default_svid_ttl = "10m" +} + +plugins { + DataStore "sql" { + plugin_data { + database_type = "sqlite3" + connection_string = "/opt/spire/data/server/datastore.sqlite3" + } + } + NodeAttestor "x509pop" { + plugin_data { + ca_bundle_path = "/opt/spire/conf/server/agent-cacert.pem" + } + } + KeyManager "memory" { + plugin_data = {} + } +} diff --git a/test/integration/suites/fetch-svids/docker-compose.yaml b/test/integration/suites/fetch-svids/docker-compose.yaml new file mode 100644 index 0000000000..0e67183c23 --- /dev/null +++ b/test/integration/suites/fetch-svids/docker-compose.yaml @@ -0,0 +1,15 @@ +version: '3' +services: + spire-server: + image: spire-server:latest-local + hostname: spire-server + volumes: + - ./conf/server:/opt/spire/conf/server + command: ["-config", "/opt/spire/conf/server/server.conf"] + spire-agent: + image: spire-agent:latest-local + hostname: spire-agent + depends_on: ["spire-server"] + volumes: + - ./conf/agent:/opt/spire/conf/agent + command: ["-config", "/opt/spire/conf/agent/agent.conf"] diff --git a/test/integration/suites/fetch-svids/teardown b/test/integration/suites/fetch-svids/teardown new file mode 100755 index 0000000000..9953dcd3f9 --- /dev/null +++ b/test/integration/suites/fetch-svids/teardown @@ -0,0 +1,6 @@ +#!/bin/bash + +if [ -z "$SUCCESS" ]; then + docker-compose logs +fi +docker-down From 709de6b6135c306b190f24249024b760ecec44ee Mon Sep 17 00:00:00 2001 From: Prasad Borole Date: Tue, 28 Jun 2022 16:49:06 -0700 Subject: [PATCH 02/19] Address comments: update comments, new unit tests and var renames Signed-off-by: Prasad Borole --- cmd/spire-agent/cli/api/common.go | 4 +- cmd/spire-agent/cli/run/run.go | 9 +- cmd/spire-agent/cli/run/run_test.go | 29 +++ pkg/agent/api/delegatedidentity/v1/service.go | 8 +- pkg/agent/manager/cache/cache.go | 17 +- pkg/agent/manager/cache/cache_test.go | 242 +++++++++--------- pkg/agent/manager/manager.go | 12 +- test/integration/setup/debugagent/main.go | 9 +- 8 files changed, 192 insertions(+), 138 deletions(-) diff --git a/cmd/spire-agent/cli/api/common.go b/cmd/spire-agent/cli/api/common.go index 380e963590..4b74ed9b64 100644 --- a/cmd/spire-agent/cli/api/common.go +++ b/cmd/spire-agent/cli/api/common.go @@ -15,6 +15,8 @@ import ( "google.golang.org/grpc/metadata" ) +const commandTimeout = 5 * time.Second + type workloadClient struct { workload.SpiffeWorkloadAPIClient timeout time.Duration @@ -73,7 +75,7 @@ func adaptCommand(env *cli.Env, clientsMaker workloadClientMaker, cmd command) * clientsMaker: clientsMaker, cmd: cmd, env: env, - timeout: cli.DurationFlag(2 * time.Second), + timeout: cli.DurationFlag(commandTimeout), } fs := flag.NewFlagSet(cmd.name(), flag.ContinueOnError) diff --git a/cmd/spire-agent/cli/run/run.go b/cmd/spire-agent/cli/run/run.go index 6d7e1453d8..a57f44f426 100644 --- a/cmd/spire-agent/cli/run/run.go +++ b/cmd/spire-agent/cli/run/run.go @@ -402,15 +402,20 @@ func NewAgentConfig(c *Config, logOptions []log.Option, allowUnknownConfig bool) } } - if c.Agent.Experimental.MaxSvidCacheSize != 0 { - ac.MaxSvidCacheSize = c.Agent.Experimental.MaxSvidCacheSize + if c.Agent.Experimental.MaxSvidCacheSize < 0 { + return nil, fmt.Errorf("max_svid_cache_size should not be negative") } + ac.MaxSvidCacheSize = c.Agent.Experimental.MaxSvidCacheSize + if c.Agent.Experimental.SVIDCacheExpiryPeriod != "" { var err error ac.SVIDCacheExpiryPeriod, err = time.ParseDuration(c.Agent.Experimental.SVIDCacheExpiryPeriod) if err != nil { return nil, fmt.Errorf("could not parse svid cache expiry interval: %w", err) } + if ac.SVIDCacheExpiryPeriod < 0 { + return nil, fmt.Errorf("svid_cache_expiry_interval should not be negative") + } } serverHostPort := net.JoinHostPort(c.Agent.ServerAddress, strconv.Itoa(c.Agent.ServerPort)) diff --git a/cmd/spire-agent/cli/run/run_test.go b/cmd/spire-agent/cli/run/run_test.go index 17484fc46a..5acb48db2c 100644 --- a/cmd/spire-agent/cli/run/run_test.go +++ b/cmd/spire-agent/cli/run/run_test.go @@ -698,6 +698,16 @@ func TestNewAgentConfig(t *testing.T) { require.EqualValues(t, 1050000000, c.SVIDCacheExpiryPeriod) }, }, + { + msg: "svid_cache_expiry_interval is negative", + expectError: true, + input: func(c *Config) { + c.Agent.Experimental.SVIDCacheExpiryPeriod = "-1s50ms" + }, + test: func(t *testing.T, c *agent.Config) { + require.Nil(t, c) + }, + }, { msg: "invalid svid_cache_expiry_interval returns an error", expectError: true, @@ -733,6 +743,25 @@ func TestNewAgentConfig(t *testing.T) { require.EqualValues(t, 0, c.MaxSvidCacheSize) }, }, + { + msg: "max_svid_cache_size is zero", + input: func(c *Config) { + c.Agent.Experimental.MaxSvidCacheSize = 0 + }, + test: func(t *testing.T, c *agent.Config) { + require.EqualValues(t, 0, c.MaxSvidCacheSize) + }, + }, + { + msg: "max_svid_cache_size is negative", + expectError: true, + input: func(c *Config) { + c.Agent.Experimental.MaxSvidCacheSize = -10 + }, + test: func(t *testing.T, c *agent.Config) { + require.Nil(t, c) + }, + }, { msg: "admin_socket_path not provided", input: func(c *Config) { diff --git a/pkg/agent/api/delegatedidentity/v1/service.go b/pkg/agent/api/delegatedidentity/v1/service.go index ef9b9f2b04..887babef8a 100644 --- a/pkg/agent/api/delegatedidentity/v1/service.go +++ b/pkg/agent/api/delegatedidentity/v1/service.go @@ -83,9 +83,9 @@ func (s *Service) isCallerAuthorized(ctx context.Context, log logrus.FieldLogger } entries := s.manager.MatchingRegistrationEntries(callerSelectors) - numRegisteredIDs := len(entries) + numRegisteredEntries := len(entries) - if numRegisteredIDs == 0 { + if numRegisteredEntries == 0 { log.Error("no identity issued") return nil, status.Error(codes.PermissionDenied, "no identity issued") } @@ -98,8 +98,8 @@ func (s *Service) isCallerAuthorized(ctx context.Context, log logrus.FieldLogger // caller has identity associeted with but none is authorized log.WithFields(logrus.Fields{ - "num_registered_ids": numRegisteredIDs, - "default_id": entries[0].SpiffeId, + "num_registered_entries": numRegisteredEntries, + "default_id": entries[0].SpiffeId, }).Error("Permission denied; caller not configured as an authorized delegate.") return nil, status.Error(codes.PermissionDenied, "caller not configured as an authorized delegate") diff --git a/pkg/agent/manager/cache/cache.go b/pkg/agent/manager/cache/cache.go index 427bcc12b5..968bbbdc88 100644 --- a/pkg/agent/manager/cache/cache.go +++ b/pkg/agent/manager/cache/cache.go @@ -82,20 +82,20 @@ type X509SVID struct { // workloads) and registration entries that have that selector. // // The LRU-like SVID cache has configurable size limit and expiry period. -// 1. Size limit of SVID cache is a soft limit which means if SVID has a subscriber present then +// 1. Size limit of SVID cache is a soft limit. If SVID has a subscriber present then // that SVID is never removed from cache. // 2. Least recently used SVIDs are removed from cache only after the cache expiry period has passed. // This is done to reduce the overall cache churn. // 3. Last access timestamp for SVID cache entry is updated when a new subscriber is created -// 4. When a new subscriber is created and if there is a cache miss +// 4. When a new subscriber is created and there is a cache miss // then subscriber needs to wait for next SVID sync event to receive WorkloadUpdate with newly minted SVID // // The advantage of above approach is that if agent has entry count less than cache size // then all SVIDs are cached at all times. If agent has entry count greater than cache size then // subscribers will continue to get SVID updates (potential delay for first WorkloadUpdate if cache miss) // and least used SVIDs will be removed from cache which will save memory usage. -// It will allow agent to support large number of registrations. -// +// This allows agent to support environments where the active simultaneous workload count +// is a small percentage of the large number of registrations assigned to the agent. // // When registration entries are added/updated/removed, the set of relevant // selectors are gathered and the indexes for those selectors are combed for @@ -162,11 +162,11 @@ type StaleEntry struct { func New(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundle, metrics telemetry.Metrics, maxSvidCacheSize int, svidCacheExpiryPeriod time.Duration, clk clock.Clock) *Cache { - if maxSvidCacheSize == 0 { + if maxSvidCacheSize <= 0 { maxSvidCacheSize = DefaultMaxSvidCacheSize } - if svidCacheExpiryPeriod == 0 { + if svidCacheExpiryPeriod <= 0 { svidCacheExpiryPeriod = DefaultSVIDCacheExpiryPeriod } @@ -602,6 +602,11 @@ func (c *Cache) syncSVIDs() (map[string]struct{}, []record) { lastAccessTimestamps := make([]record, len(c.records)) i := 0 + // iterate over all selectors from cached entries and obtain: + // 1. entries that have active subscribers + // 1.1 if those entries don't have corresponding SVID cached then put them in staleEntries + // so that SVID will be cached in next sync + // 2. get lastAccessTimestamp of each entry for id, record := range c.records { for _, sel := range record.entry.Selectors { if index, ok := c.selectors[makeSelector(sel)]; ok && index != nil { diff --git a/pkg/agent/manager/cache/cache_test.go b/pkg/agent/manager/cache/cache_test.go index 9849b5b360..bcbb942093 100644 --- a/pkg/agent/manager/cache/cache_test.go +++ b/pkg/agent/manager/cache/cache_test.go @@ -31,8 +31,8 @@ var ( func TestFetchWorkloadUpdate(t *testing.T) { cache := newTestCache() // populate the cache with FOO and BAR without SVIDS - foo := makeRegistrationEntry("FOO", "A") - bar := makeRegistrationEntry("BAR", "B") + foo := makeRegistrationEntry("FOO", "foo") + bar := makeRegistrationEntry("BAR", "bar") bar.FederatesWith = makeFederatesWith(otherBundleV1) updateEntries := &UpdateEntries{ Bundles: makeBundles(bundleV1, otherBundleV1), @@ -40,7 +40,7 @@ func TestFetchWorkloadUpdate(t *testing.T) { } cache.UpdateEntries(updateEntries, nil) - workloadUpdate := cache.FetchWorkloadUpdate(makeSelectors("A", "B")) + workloadUpdate := cache.FetchWorkloadUpdate(makeSelectors("foo", "bar")) assert.Len(t, workloadUpdate.Identities, 0, "identities should not be returned that don't have SVIDs") updateSVIDs := &UpdateSVIDs{ @@ -48,7 +48,7 @@ func TestFetchWorkloadUpdate(t *testing.T) { } cache.UpdateSVIDs(updateSVIDs) - workloadUpdate = cache.FetchWorkloadUpdate(makeSelectors("A", "B")) + workloadUpdate = cache.FetchWorkloadUpdate(makeSelectors("foo", "bar")) assert.Equal(t, &WorkloadUpdate{ Bundle: bundleV1, FederatedBundles: makeBundles(otherBundleV1), @@ -63,8 +63,8 @@ func TestMatchingRegistrationIdentities(t *testing.T) { cache := newTestCache() // populate the cache with FOO and BAR without SVIDS - foo := makeRegistrationEntry("FOO", "A") - bar := makeRegistrationEntry("BAR", "B") + foo := makeRegistrationEntry("FOO", "foo") + bar := makeRegistrationEntry("BAR", "bar") updateEntries := &UpdateEntries{ Bundles: makeBundles(bundleV1), RegistrationEntries: makeRegistrationEntries(foo, bar), @@ -72,7 +72,7 @@ func TestMatchingRegistrationIdentities(t *testing.T) { cache.UpdateEntries(updateEntries, nil) assert.Equal(t, []*common.RegistrationEntry{bar, foo}, - cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) + cache.MatchingRegistrationEntries(makeSelectors("foo", "bar"))) // Update SVIDs and MatchingRegistrationEntries should return both entries updateSVIDs := &UpdateSVIDs{ @@ -80,20 +80,20 @@ func TestMatchingRegistrationIdentities(t *testing.T) { } cache.UpdateSVIDs(updateSVIDs) assert.Equal(t, []*common.RegistrationEntry{bar, foo}, - cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) + cache.MatchingRegistrationEntries(makeSelectors("foo", "bar"))) // Remove SVIDs and MatchingRegistrationEntries should still return both entries cache.UpdateSVIDs(&UpdateSVIDs{}) assert.Equal(t, []*common.RegistrationEntry{bar, foo}, - cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) + cache.MatchingRegistrationEntries(makeSelectors("foo", "bar"))) } func TestCountSVIDs(t *testing.T) { cache := newTestCache() // populate the cache with FOO and BAR without SVIDS - foo := makeRegistrationEntry("FOO", "A") - bar := makeRegistrationEntry("BAR", "B") + foo := makeRegistrationEntry("FOO", "foo") + bar := makeRegistrationEntry("BAR", "bar") updateEntries := &UpdateEntries{ Bundles: makeBundles(bundleV1), RegistrationEntries: makeRegistrationEntries(foo, bar), @@ -140,28 +140,28 @@ func TestAllSubscribersNotifiedOnBundleChange(t *testing.T) { cache := newTestCache() // create some subscribers and assert they get the initial bundle - subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) - defer subA.Finish() - assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{Bundle: bundleV1}) + subFoo := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("foo")) + defer subFoo.Finish() + assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{Bundle: bundleV1}) - subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) - defer subB.Finish() - assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{Bundle: bundleV1}) + subBar := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("bar")) + defer subBar.Finish() + assertWorkloadUpdateEqual(t, subBar, &WorkloadUpdate{Bundle: bundleV1}) // update the bundle and assert all subscribers gets the updated bundle cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV2), }, nil) - assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{Bundle: bundleV2}) - assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{Bundle: bundleV2}) + assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{Bundle: bundleV2}) + assertWorkloadUpdateEqual(t, subBar, &WorkloadUpdate{Bundle: bundleV2}) } func TestSomeSubscribersNotifiedOnFederatedBundleChange(t *testing.T) { cache := newTestCache() // initialize the cache with an entry FOO that has a valid SVID and - // selector "A" - foo := makeRegistrationEntry("FOO", "A") + // selector "foo" + foo := makeRegistrationEntry("FOO", "foo") cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV1), RegistrationEntries: makeRegistrationEntries(foo), @@ -171,13 +171,13 @@ func TestSomeSubscribersNotifiedOnFederatedBundleChange(t *testing.T) { }) // subscribe to A and B and assert initial updates are received. - subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) - defer subA.Finish() - assertAnyWorkloadUpdate(t, subA) + subFoo := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("foo")) + defer subFoo.Finish() + assertAnyWorkloadUpdate(t, subFoo) - subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) - defer subB.Finish() - assertAnyWorkloadUpdate(t, subB) + subBar := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("bar")) + defer subBar.Finish() + assertAnyWorkloadUpdate(t, subBar) // add the federated bundle with no registration entries federating with // it and make sure nobody is notified. @@ -185,73 +185,73 @@ func TestSomeSubscribersNotifiedOnFederatedBundleChange(t *testing.T) { Bundles: makeBundles(bundleV1, otherBundleV1), RegistrationEntries: makeRegistrationEntries(foo), }, nil) - assertNoWorkloadUpdate(t, subA) - assertNoWorkloadUpdate(t, subB) + assertNoWorkloadUpdate(t, subFoo) + assertNoWorkloadUpdate(t, subBar) - // update FOO to federate with otherdomain.test and make sure subA is - // notified but not subB. - foo = makeRegistrationEntry("FOO", "A") + // update FOO to federate with otherdomain.test and make sure subFoo is + // notified but not subBar. + foo = makeRegistrationEntry("FOO", "foo") foo.FederatesWith = makeFederatesWith(otherBundleV1) cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV1, otherBundleV1), RegistrationEntries: makeRegistrationEntries(foo), }, nil) - assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{ Bundle: bundleV1, FederatedBundles: makeBundles(otherBundleV1), Identities: []Identity{{Entry: foo}}, }) - assertNoWorkloadUpdate(t, subB) + assertNoWorkloadUpdate(t, subBar) - // now change the federated bundle and make sure subA gets notified, but - // again, not subB. + // now change the federated bundle and make sure subFoo gets notified, but + // again, not subBar. cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV1, otherBundleV2), RegistrationEntries: makeRegistrationEntries(foo), }, nil) - assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{ Bundle: bundleV1, FederatedBundles: makeBundles(otherBundleV2), Identities: []Identity{{Entry: foo}}, }) - assertNoWorkloadUpdate(t, subB) + assertNoWorkloadUpdate(t, subBar) - // now drop the federation and make sure subA is again notified and no + // now drop the federation and make sure subFoo is again notified and no // longer has the federated bundle. - foo = makeRegistrationEntry("FOO", "A") + foo = makeRegistrationEntry("FOO", "foo") cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV1, otherBundleV2), RegistrationEntries: makeRegistrationEntries(foo), }, nil) - assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{ Bundle: bundleV1, Identities: []Identity{{Entry: foo}}, }) - assertNoWorkloadUpdate(t, subB) + assertNoWorkloadUpdate(t, subBar) } func TestSubscribersGetEntriesWithSelectorSubsets(t *testing.T) { cache := newTestCache() // create subscribers for each combination of selectors - subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) - defer subA.Finish() - subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) - defer subB.Finish() - subAB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A", "B")) - defer subAB.Finish() + subFoo := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("foo")) + defer subFoo.Finish() + subBar := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("bar")) + defer subBar.Finish() + subFooBar := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("foo", "bar")) + defer subFooBar.Finish() // assert all subscribers get the initial update initialUpdate := &WorkloadUpdate{Bundle: bundleV1} - assertWorkloadUpdateEqual(t, subA, initialUpdate) - assertWorkloadUpdateEqual(t, subB, initialUpdate) - assertWorkloadUpdateEqual(t, subAB, initialUpdate) + assertWorkloadUpdateEqual(t, subFoo, initialUpdate) + assertWorkloadUpdateEqual(t, subBar, initialUpdate) + assertWorkloadUpdateEqual(t, subFooBar, initialUpdate) - // create entry FOO that will target any subscriber with containing (A) - foo := makeRegistrationEntry("FOO", "A") + // create entry FOO that will target any subscriber with containing (foo) + foo := makeRegistrationEntry("FOO", "foo") - // create entry BAR that will target any subscriber with containing (A,C) - bar := makeRegistrationEntry("BAR", "A", "C") + // create entry BAR that will target any subscriber with containing (foo,baz) + bar := makeRegistrationEntry("BAR", "foo", "baz") // update the cache with foo and bar cache.UpdateEntries(&UpdateEntries{ @@ -262,18 +262,18 @@ func TestSubscribersGetEntriesWithSelectorSubsets(t *testing.T) { X509SVIDs: makeX509SVIDs(foo, bar), }) - // subA selector set contains (A), but not (A, C), so it should only get FOO - assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + // subFoo selector set contains (foo), but not (foo, baz), so it should only get FOO + assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{ Bundle: bundleV1, Identities: []Identity{{Entry: foo}}, }) - // subB selector set does not contain either (A) or (A,C) so it isn't even + // subBar selector set does not contain either (foo) or (foo,baz) so it isn't even // notified. - assertNoWorkloadUpdate(t, subB) + assertNoWorkloadUpdate(t, subBar) - // subAB selector set contains (A) but not (A, C), so it should get FOO - assertWorkloadUpdateEqual(t, subAB, &WorkloadUpdate{ + // subFooBar selector set contains (foo) but not (foo, baz), so it should get FOO + assertWorkloadUpdateEqual(t, subFooBar, &WorkloadUpdate{ Bundle: bundleV1, Identities: []Identity{{Entry: foo}}, }) @@ -282,7 +282,7 @@ func TestSubscribersGetEntriesWithSelectorSubsets(t *testing.T) { func TestSubscriberIsNotNotifiedIfNothingChanges(t *testing.T) { cache := newTestCache() - foo := makeRegistrationEntry("FOO", "A") + foo := makeRegistrationEntry("FOO", "foo") cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV1), RegistrationEntries: makeRegistrationEntries(foo), @@ -291,7 +291,7 @@ func TestSubscriberIsNotNotifiedIfNothingChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo), }) - sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("foo")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -308,7 +308,7 @@ func TestSubscriberIsNotNotifiedIfNothingChanges(t *testing.T) { func TestSubscriberNotifiedOnSVIDChanges(t *testing.T) { cache := newTestCache() - foo := makeRegistrationEntry("FOO", "A") + foo := makeRegistrationEntry("FOO", "foo") cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV1), RegistrationEntries: makeRegistrationEntries(foo), @@ -317,7 +317,7 @@ func TestSubscriberNotifiedOnSVIDChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo), }) - sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("foo")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -386,17 +386,17 @@ func TestSubscriberNotificationsOnSelectorChanges(t *testing.T) { func TestSubscriberNotifiedWhenEntryDropped(t *testing.T) { cache := newTestCache() - subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) - defer subA.Finish() - assertAnyWorkloadUpdate(t, subA) + subFoo := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("foo")) + defer subFoo.Finish() + assertAnyWorkloadUpdate(t, subFoo) - // subB's job here is to just make sure we don't notify unrelated + // subBar's job here is to just make sure we don't notify unrelated // subscribers when dropping registration entries - subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) - defer subB.Finish() - assertAnyWorkloadUpdate(t, subB) + subBar := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("bar")) + defer subBar.Finish() + assertAnyWorkloadUpdate(t, subBar) - foo := makeRegistrationEntry("FOO", "A") + foo := makeRegistrationEntry("FOO", "foo") updateEntries := &UpdateEntries{ Bundles: makeBundles(bundleV1), RegistrationEntries: makeRegistrationEntries(foo), @@ -405,46 +405,46 @@ func TestSubscriberNotifiedWhenEntryDropped(t *testing.T) { cache.UpdateSVIDs(&UpdateSVIDs{ X509SVIDs: makeX509SVIDs(foo), }) - // make sure subA gets notified with FOO but not subB - assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + // make sure subFoo gets notified with FOO but not subBar + assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{ Bundle: bundleV1, Identities: []Identity{{Entry: foo}}, }) - assertNoWorkloadUpdate(t, subB) + assertNoWorkloadUpdate(t, subBar) updateEntries.RegistrationEntries = nil cache.UpdateEntries(updateEntries, nil) - assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{ Bundle: bundleV1, }) - assertNoWorkloadUpdate(t, subB) + assertNoWorkloadUpdate(t, subBar) // Make sure trying to update SVIDs of removed entry does not notify cache.UpdateSVIDs(&UpdateSVIDs{ X509SVIDs: makeX509SVIDs(foo), }) - assertNoWorkloadUpdate(t, subB) + assertNoWorkloadUpdate(t, subBar) } func TestSubscriberOnlyGetsEntriesWithSVID(t *testing.T) { cache := newTestCache() - foo := makeRegistrationEntry("FOO", "A") + foo := makeRegistrationEntry("FOO", "foo") updateEntries := &UpdateEntries{ Bundles: makeBundles(bundleV1), RegistrationEntries: makeRegistrationEntries(foo), } cache.UpdateEntries(updateEntries, nil) - sub := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) - defer sub.Finish() - assertNoWorkloadUpdate(t, sub) + subFoo := cache.SubscribeToWorkloadUpdates(makeSelectors("foo")) + defer subFoo.Finish() + assertNoWorkloadUpdate(t, subFoo) // update to include the SVID and now we should get the update cache.UpdateSVIDs(&UpdateSVIDs{ X509SVIDs: makeX509SVIDs(foo), }) - assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{ Bundle: bundleV1, Identities: []Identity{{Entry: foo}}, }) @@ -453,8 +453,8 @@ func TestSubscriberOnlyGetsEntriesWithSVID(t *testing.T) { func TestSubscribersDoNotBlockNotifications(t *testing.T) { cache := newTestCache() - sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) - defer sub.Finish() + subFoo := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("foo")) + defer subFoo.Finish() cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV2), @@ -464,7 +464,7 @@ func TestSubscribersDoNotBlockNotifications(t *testing.T) { Bundles: makeBundles(bundleV3), }, nil) - assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{ Bundle: bundleV3, }) } @@ -573,8 +573,8 @@ func TestGetStaleEntries(t *testing.T) { func TestSubscriberNotNotifiedOnDifferentSVIDChanges(t *testing.T) { cache := newTestCache() - foo := makeRegistrationEntry("FOO", "A") - bar := makeRegistrationEntry("BAR", "B") + foo := makeRegistrationEntry("FOO", "foo") + bar := makeRegistrationEntry("BAR", "bar") cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV1), RegistrationEntries: makeRegistrationEntries(foo, bar), @@ -583,7 +583,7 @@ func TestSubscriberNotNotifiedOnDifferentSVIDChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo, bar), }) - sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("foo")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -625,7 +625,7 @@ func TestSVIDCacheExpiry(t *testing.T) { cache := newTestCacheWithConfig(10, 1*time.Minute, clk) clk.Add(1 * time.Second) - foo := makeRegistrationEntry("FOO", "A") + foo := makeRegistrationEntry("FOO", "foo") // validate workload update for foo cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV1), @@ -634,16 +634,16 @@ func TestSVIDCacheExpiry(t *testing.T) { cache.UpdateSVIDs(&UpdateSVIDs{ X509SVIDs: makeX509SVIDs(foo), }) - subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) - assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + subFoo := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("foo")) + assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{ Bundle: bundleV1, Identities: []Identity{{Entry: foo}}, }) - subA.Finish() + subFoo.Finish() // move clk by 1 sec so that SVID access time will be different clk.Add(1 * time.Second) - bar := makeRegistrationEntry("BAR", "B") + bar := makeRegistrationEntry("BAR", "bar") // validate workload update for bar cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV1), @@ -654,9 +654,9 @@ func TestSVIDCacheExpiry(t *testing.T) { }) // not closing subscriber immediately - subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) - defer subB.Finish() - assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{ + subBar := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("bar")) + defer subBar.Finish() + assertWorkloadUpdateEqual(t, subBar, &WorkloadUpdate{ Bundle: bundleV1, Identities: []Identity{ {Entry: bar}, @@ -689,10 +689,9 @@ func TestSVIDCacheExpiry(t *testing.T) { clk.Add(58 * time.Second) cache.UpdateEntries(updateEntries, nil) - subA = cache.SubscribeToWorkloadUpdates(makeSelectors("A")) - defer subA.Finish() - //cache.NotifyBySelectorSet(makeSelectors("A")) - assert.False(t, cache.Notify(makeSelectors("A"))) + subFoo = cache.SubscribeToWorkloadUpdates(makeSelectors("foo")) + defer subFoo.Finish() + assert.False(t, cache.Notify(makeSelectors("foo"))) assert.Equal(t, 11, cache.CountSVIDs()) // move clk by another minute and update entries @@ -704,9 +703,9 @@ func TestSVIDCacheExpiry(t *testing.T) { assert.Equal(t, foo, cache.GetStaleEntries()[0].Entry) // bar should not be removed from cache as it has another active subscriber - subB2 := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) - defer subB2.Finish() - assertWorkloadUpdateEqual(t, subB2, &WorkloadUpdate{ + subBar2 := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("bar")) + defer subBar2.Finish() + assertWorkloadUpdateEqual(t, subBar2, &WorkloadUpdate{ Bundle: bundleV1, Identities: []Identity{ {Entry: bar}, @@ -734,11 +733,11 @@ func TestMaxSVIDCacheSize(t *testing.T) { assert.Equal(t, 10, cache.CountSVIDs()) // Validate that active subscriber will still get SVID even if SVID count is at maxSvidCacheSize - foo := makeRegistrationEntry("FOO", "A") + foo := makeRegistrationEntry("FOO", "foo") updateEntries.RegistrationEntries[foo.EntryId] = foo - subA := cache.SubscribeToWorkloadUpdates(foo.Selectors) - defer subA.Finish() + subFoo := cache.SubscribeToWorkloadUpdates(foo.Selectors) + defer subFoo.Finish() cache.UpdateEntries(updateEntries, nil) require.Len(t, cache.GetStaleEntries(), 1) @@ -763,17 +762,17 @@ func TestSyncSVIDsWithSubscribers(t *testing.T) { assert.Equal(t, 5, cache.CountSVIDs()) // Update foo but its SVID is not yet cached - foo := makeRegistrationEntry("FOO", "A") + foo := makeRegistrationEntry("FOO", "foo") updateEntries.RegistrationEntries[foo.EntryId] = foo cache.UpdateEntries(updateEntries, nil) // Create a subscriber for foo - subA := cache.SubscribeToWorkloadUpdates(foo.Selectors) - defer subA.Finish() + subFoo := cache.SubscribeToWorkloadUpdates(foo.Selectors) + defer subFoo.Finish() require.Len(t, cache.GetStaleEntries(), 0) - // After SyncSVIDsWithSubscribers foo should be marked as stale + // After SyncSVIDsWithSubscribers foo should be marked as stale, requiring signing cache.SyncSVIDsWithSubscribers() require.Len(t, cache.GetStaleEntries(), 1) assert.Equal(t, []*StaleEntry{{Entry: cache.records[foo.EntryId].entry}}, cache.GetStaleEntries()) @@ -784,17 +783,29 @@ func TestSyncSVIDsWithSubscribers(t *testing.T) { func TestNotify(t *testing.T) { cache := newTestCache() - foo := makeRegistrationEntry("FOO", "A") + foo := makeRegistrationEntry("FOO", "foo") cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV1), RegistrationEntries: makeRegistrationEntries(foo), }, nil) - assert.False(t, cache.Notify(makeSelectors("A"))) + assert.False(t, cache.Notify(makeSelectors("foo"))) cache.UpdateSVIDs(&UpdateSVIDs{ X509SVIDs: makeX509SVIDs(foo), }) - assert.True(t, cache.Notify(makeSelectors("A"))) + assert.True(t, cache.Notify(makeSelectors("foo"))) +} + +func TestNewCache(t *testing.T) { + // negative values + cache := newTestCacheWithConfig(-5, -5, clock.NewMock()) + require.Equal(t, DefaultMaxSvidCacheSize, cache.maxSvidCacheSize) + require.Equal(t, DefaultSVIDCacheExpiryPeriod, cache.svidCacheExpiryPeriod) + + // zero values + cache = newTestCacheWithConfig(0, 0, clock.NewMock()) + require.Equal(t, DefaultMaxSvidCacheSize, cache.maxSvidCacheSize) + require.Equal(t, DefaultSVIDCacheExpiryPeriod, cache.svidCacheExpiryPeriod) } func BenchmarkCacheGlobalNotification(b *testing.B) { @@ -854,6 +865,7 @@ func newTestCacheWithConfig(maxSvidCacheSize int, svidCacheExpiryPeriod time.Dur maxSvidCacheSize, svidCacheExpiryPeriod, clk) } +// numEntries should not be more than 12 digits func createUpdateEntries(numEntries int, bundles map[spiffeid.TrustDomain]*bundleutil.Bundle) *UpdateEntries { updateEntries := &UpdateEntries{ Bundles: bundles, diff --git a/pkg/agent/manager/manager.go b/pkg/agent/manager/manager.go index 1d0f60a2e8..56070c18b9 100644 --- a/pkg/agent/manager/manager.go +++ b/pkg/agent/manager/manager.go @@ -94,10 +94,10 @@ type manager struct { svidCachePath string bundleCachePath string - // backoff calculator for fetch interval, backing off if error is returned on + // synchronizeBackoff calculator for fetch interval, backing off if error is returned on // fetch attempt - backoff backoff.BackOff - svidSyncBackoff backoff.BackOff + synchronizeBackoff backoff.BackOff + svidSyncBackoff backoff.BackOff client client.Client @@ -114,7 +114,7 @@ func (m *manager) Initialize(ctx context.Context) error { m.storeSVID(m.svid.State().SVID) m.storeBundle(m.cache.Bundle()) - m.backoff = backoff.NewBackoff(m.clk, m.c.SyncInterval) + m.synchronizeBackoff = backoff.NewBackoff(m.clk, m.c.SyncInterval) m.svidSyncBackoff = backoff.NewBackoff(m.clk, svidSyncInterval) err := m.synchronize(ctx) @@ -237,7 +237,7 @@ func (m *manager) getEntryID(spiffeID string) string { func (m *manager) runSynchronizer(ctx context.Context) error { for { select { - case <-m.clk.After(m.backoff.NextBackOff()): + case <-m.clk.After(m.synchronizeBackoff.NextBackOff()): case <-ctx.Done(): return nil } @@ -251,7 +251,7 @@ func (m *manager) runSynchronizer(ctx context.Context) error { // Just log the error and wait for next synchronization m.c.Log.WithError(err).Error("Synchronize failed") default: - m.backoff.Reset() + m.synchronizeBackoff.Reset() } } } diff --git a/test/integration/setup/debugagent/main.go b/test/integration/setup/debugagent/main.go index 67d10ab634..283119a08b 100644 --- a/test/integration/setup/debugagent/main.go +++ b/test/integration/setup/debugagent/main.go @@ -54,22 +54,23 @@ func run() error { } func agentEndpoints(ctx context.Context) error { - s,err := retrieveDebugPage(ctx) + s, err := retrieveDebugPage(ctx) if err == nil { log.Printf("Debug info: %s", string(s)) } return nil } +// printDebugPage allows integration tests to easily parse debug page with jq func printDebugPage(ctx context.Context) error { - s,err := retrieveDebugPage(ctx) + s, err := retrieveDebugPage(ctx) if err == nil { fmt.Println(s) } return nil } -func retrieveDebugPage(ctx context.Context) (string,error) { +func retrieveDebugPage(ctx context.Context) (string, error) { conn, err := grpc.Dial(*socketPathFlag, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return "", fmt.Errorf("failed to connect server: %w", err) @@ -87,7 +88,7 @@ func retrieveDebugPage(ctx context.Context) (string,error) { if err != nil { return "", fmt.Errorf("failed to parse proto: %w", err) } - return string(s),nil + return string(s), nil } func serverWithWorkload(ctx context.Context) error { From b48bf9d2de719c1b783b7dbcb24f2807133eb57c Mon Sep 17 00:00:00 2001 From: Prasad Borole Date: Fri, 22 Jul 2022 15:25:42 -0700 Subject: [PATCH 03/19] Removing svid_cache_expiry_interval configuration Signed-off-by: Prasad Borole --- cmd/spire-agent/cli/run/run.go | 22 +- cmd/spire-agent/cli/run/run_test.go | 57 +--- pkg/agent/agent.go | 27 +- .../api/delegatedidentity/v1/service_test.go | 2 +- pkg/agent/config.go | 7 +- pkg/agent/manager/cache/cache.go | 92 +++--- pkg/agent/manager/cache/cache_test.go | 273 ++++++++--------- pkg/agent/manager/config.go | 31 +- pkg/agent/manager/manager.go | 6 +- pkg/agent/manager/manager_test.go | 284 ++++-------------- pkg/agent/manager/sync.go | 41 +-- test/integration/common | 19 +- test/integration/setup/debugagent/main.go | 10 +- .../04-create-registration-entries | 2 +- .../suites/fetch-svids/05-fetch-svids | 17 +- .../06-create-registration-entries | 2 +- .../suites/fetch-svids/07-fetch-svids | 35 ++- .../suites/fetch-svids/conf/agent/agent.conf | 3 +- 18 files changed, 344 insertions(+), 586 deletions(-) diff --git a/cmd/spire-agent/cli/run/run.go b/cmd/spire-agent/cli/run/run.go index a57f44f426..5870d3de02 100644 --- a/cmd/spire-agent/cli/run/run.go +++ b/cmd/spire-agent/cli/run/run.go @@ -97,9 +97,8 @@ type experimentalConfig struct { SyncInterval string `hcl:"sync_interval"` TCPSocketPort int `hcl:"tcp_socket_port"` - UnusedKeys []string `hcl:",unusedKeys"` - MaxSvidCacheSize int `hcl:"max_svid_cache_size"` - SVIDCacheExpiryPeriod string `hcl:"svid_cache_expiry_interval"` + UnusedKeys []string `hcl:",unusedKeys"` + X509SVIDCacheMaxSize int `hcl:"x509_svid_cache_max_size"` } type Command struct { @@ -402,21 +401,10 @@ func NewAgentConfig(c *Config, logOptions []log.Option, allowUnknownConfig bool) } } - if c.Agent.Experimental.MaxSvidCacheSize < 0 { - return nil, fmt.Errorf("max_svid_cache_size should not be negative") - } - ac.MaxSvidCacheSize = c.Agent.Experimental.MaxSvidCacheSize - - if c.Agent.Experimental.SVIDCacheExpiryPeriod != "" { - var err error - ac.SVIDCacheExpiryPeriod, err = time.ParseDuration(c.Agent.Experimental.SVIDCacheExpiryPeriod) - if err != nil { - return nil, fmt.Errorf("could not parse svid cache expiry interval: %w", err) - } - if ac.SVIDCacheExpiryPeriod < 0 { - return nil, fmt.Errorf("svid_cache_expiry_interval should not be negative") - } + if c.Agent.Experimental.X509SVIDCacheMaxSize < 0 { + return nil, errors.New("x509_svid_cache_max_size should not be negative") } + ac.X509SVIDCacheMaxSize = c.Agent.Experimental.X509SVIDCacheMaxSize serverHostPort := net.JoinHostPort(c.Agent.ServerAddress, strconv.Itoa(c.Agent.ServerPort)) ac.ServerAddress = fmt.Sprintf("dns:///%s", serverHostPort) diff --git a/cmd/spire-agent/cli/run/run_test.go b/cmd/spire-agent/cli/run/run_test.go index 5acb48db2c..3562ec8306 100644 --- a/cmd/spire-agent/cli/run/run_test.go +++ b/cmd/spire-agent/cli/run/run_test.go @@ -690,73 +690,36 @@ func TestNewAgentConfig(t *testing.T) { }, }, { - msg: "svid_cache_expiry_interval parses a duration", + msg: "x509_svid_cache_max_size is set", input: func(c *Config) { - c.Agent.Experimental.SVIDCacheExpiryPeriod = "1s50ms" + c.Agent.Experimental.X509SVIDCacheMaxSize = 100 }, test: func(t *testing.T, c *agent.Config) { - require.EqualValues(t, 1050000000, c.SVIDCacheExpiryPeriod) + require.EqualValues(t, 100, c.X509SVIDCacheMaxSize) }, }, { - msg: "svid_cache_expiry_interval is negative", - expectError: true, - input: func(c *Config) { - c.Agent.Experimental.SVIDCacheExpiryPeriod = "-1s50ms" - }, - test: func(t *testing.T, c *agent.Config) { - require.Nil(t, c) - }, - }, - { - msg: "invalid svid_cache_expiry_interval returns an error", - expectError: true, - input: func(c *Config) { - c.Agent.Experimental.SVIDCacheExpiryPeriod = "moo" - }, - test: func(t *testing.T, c *agent.Config) { - require.Nil(t, c) - }, - }, - { - msg: "svid_cache_expiry_interval is not set", - input: func(c *Config) { - }, - test: func(t *testing.T, c *agent.Config) { - require.EqualValues(t, 0, c.SVIDCacheExpiryPeriod) - }, - }, - { - msg: "max_svid_cache_size is set", - input: func(c *Config) { - c.Agent.Experimental.MaxSvidCacheSize = 100 - }, - test: func(t *testing.T, c *agent.Config) { - require.EqualValues(t, 100, c.MaxSvidCacheSize) - }, - }, - { - msg: "max_svid_cache_size is not set", + msg: "x509_svid_cache_max_size is not set", input: func(c *Config) { }, test: func(t *testing.T, c *agent.Config) { - require.EqualValues(t, 0, c.MaxSvidCacheSize) + require.EqualValues(t, 0, c.X509SVIDCacheMaxSize) }, }, { - msg: "max_svid_cache_size is zero", + msg: "x509_svid_cache_max_size is zero", input: func(c *Config) { - c.Agent.Experimental.MaxSvidCacheSize = 0 + c.Agent.Experimental.X509SVIDCacheMaxSize = 0 }, test: func(t *testing.T, c *agent.Config) { - require.EqualValues(t, 0, c.MaxSvidCacheSize) + require.EqualValues(t, 0, c.X509SVIDCacheMaxSize) }, }, { - msg: "max_svid_cache_size is negative", + msg: "x509_svid_cache_max_size is negative", expectError: true, input: func(c *Config) { - c.Agent.Experimental.MaxSvidCacheSize = -10 + c.Agent.Experimental.X509SVIDCacheMaxSize = -10 }, test: func(t *testing.T, c *agent.Config) { require.Nil(t, c) diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index f10f19a87c..c81294fcbf 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -199,20 +199,19 @@ func (a *Agent) attest(ctx context.Context, cat catalog.Catalog, metrics telemet func (a *Agent) newManager(ctx context.Context, cat catalog.Catalog, metrics telemetry.Metrics, as *node_attestor.AttestationResult, cache *storecache.Cache) (manager.Manager, error) { config := &manager.Config{ - SVID: as.SVID, - SVIDKey: as.Key, - Bundle: as.Bundle, - Catalog: cat, - TrustDomain: a.c.TrustDomain, - ServerAddr: a.c.ServerAddress, - Log: a.c.Log.WithField(telemetry.SubsystemName, telemetry.Manager), - Metrics: metrics, - BundleCachePath: a.bundleCachePath(), - SVIDCachePath: a.agentSVIDPath(), - SyncInterval: a.c.SyncInterval, - MaxSvidCacheSize: a.c.MaxSvidCacheSize, - SVIDCacheExpiryPeriod: a.c.SVIDCacheExpiryPeriod, - SVIDStoreCache: cache, + SVID: as.SVID, + SVIDKey: as.Key, + Bundle: as.Bundle, + Catalog: cat, + TrustDomain: a.c.TrustDomain, + ServerAddr: a.c.ServerAddress, + Log: a.c.Log.WithField(telemetry.SubsystemName, telemetry.Manager), + Metrics: metrics, + BundleCachePath: a.bundleCachePath(), + SVIDCachePath: a.agentSVIDPath(), + SyncInterval: a.c.SyncInterval, + SVIDCacheMaxSize: a.c.X509SVIDCacheMaxSize, + SVIDStoreCache: cache, } mgr := manager.New(config) diff --git a/pkg/agent/api/delegatedidentity/v1/service_test.go b/pkg/agent/api/delegatedidentity/v1/service_test.go index 78d727390f..662dac0d69 100644 --- a/pkg/agent/api/delegatedidentity/v1/service_test.go +++ b/pkg/agent/api/delegatedidentity/v1/service_test.go @@ -799,5 +799,5 @@ func (m *FakeManager) SubscribeToBundleChanges() *cache.BundleStream { func newTestCache() *cache.Cache { log, _ := test.NewNullLogger() - return cache.New(log, trustDomain1, bundle1, telemetry.Blackhole{}, 0, 0, clock.NewMock()) + return cache.New(log, trustDomain1, bundle1, telemetry.Blackhole{}, 0, clock.NewMock()) } diff --git a/pkg/agent/config.go b/pkg/agent/config.go index 2b7a61b2c3..8eae409208 100644 --- a/pkg/agent/config.go +++ b/pkg/agent/config.go @@ -52,11 +52,8 @@ type Config struct { // SyncInterval controls how often the agent sync synchronizer waits SyncInterval time.Duration - // MaxSvidCacheSize is a soft limit of max number of SVIDs that would be stored in cache - MaxSvidCacheSize int - - // SVIDCacheExpiryPeriod is a period after which svids that don't have subscribers will be removed from cache - SVIDCacheExpiryPeriod time.Duration + // X509SVIDCacheMaxSize is a soft limit of max number of SVIDs that would be stored in cache + X509SVIDCacheMaxSize int // Trust domain and associated CA bundle TrustDomain spiffeid.TrustDomain diff --git a/pkg/agent/manager/cache/cache.go b/pkg/agent/manager/cache/cache.go index 968bbbdc88..a1228d8fb6 100644 --- a/pkg/agent/manager/cache/cache.go +++ b/pkg/agent/manager/cache/cache.go @@ -16,8 +16,7 @@ import ( ) const ( - DefaultMaxSvidCacheSize = 1000 - DefaultSVIDCacheExpiryPeriod = 1 * time.Hour + DefaultSVIDCacheMaxSize = 1000 ) type Selectors []*common.Selector @@ -145,11 +144,8 @@ type Cache struct { // svids are stored by entry IDs svids map[string]*X509SVID - // maxSVIDCacheSize is a soft limit of max number of SVIDs that would be stored in cache - maxSvidCacheSize int - - // svidCacheExpiryPeriod is a period after which svids that don't have subscribers will be removed from cache - svidCacheExpiryPeriod time.Duration + // svidCacheMaxSize is a soft limit of max number of SVIDs that would be stored in cache + svidCacheMaxSize int } // StaleEntry holds stale or outdated entries which require new SVID with old SVIDs expiration time (if present) @@ -161,13 +157,9 @@ type StaleEntry struct { } func New(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundle, metrics telemetry.Metrics, - maxSvidCacheSize int, svidCacheExpiryPeriod time.Duration, clk clock.Clock) *Cache { - if maxSvidCacheSize <= 0 { - maxSvidCacheSize = DefaultMaxSvidCacheSize - } - - if svidCacheExpiryPeriod <= 0 { - svidCacheExpiryPeriod = DefaultSVIDCacheExpiryPeriod + svidCacheMaxSize int, clk clock.Clock) *Cache { + if svidCacheMaxSize <= 0 { + svidCacheMaxSize = DefaultSVIDCacheMaxSize } return &Cache{ @@ -183,10 +175,9 @@ func New(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundl bundles: map[spiffeid.TrustDomain]*bundleutil.Bundle{ trustDomain: bundle, }, - svids: make(map[string]*X509SVID), - maxSvidCacheSize: maxSvidCacheSize, - svidCacheExpiryPeriod: svidCacheExpiryPeriod, - clk: clk, + svids: make(map[string]*X509SVID), + svidCacheMaxSize: svidCacheMaxSize, + clk: clk, } } @@ -431,29 +422,26 @@ func (c *Cache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.Regi // entries with active subscribers which are not cached will be put in staleEntries map activeSubs, recordsWithLastAccessTime := c.syncSVIDs() - extraSize := len(c.svids) - c.maxSvidCacheSize + extraSize := len(c.svids) - c.svidCacheMaxSize // delete svids without subscribers and which have not been accessed since svidCacheExpiryTime if extraSize > 0 { // sort recordsWithLastAccessTime sortTimestamps(recordsWithLastAccessTime) - now := c.clk.Now() - svidCacheExpiryTime := now.Add(-1 * c.svidCacheExpiryPeriod).UnixMilli() + for _, record := range recordsWithLastAccessTime { if extraSize <= 0 { - // no need to delete SVIDs any further as cache size <= maxSvidCacheSize + // no need to delete SVIDs any further as cache size <= svidCacheMaxSize break } if _, ok := c.svids[record.id]; ok { if _, exists := activeSubs[record.id]; !exists { - // remove svid if it has not been accessed since svidCacheExpiryTime - if record.timestamp < svidCacheExpiryTime { - c.log.WithField("record_id", record.id). - WithField("record_timestamp", record.timestamp). - Debug("Removing SVID record") - delete(c.svids, record.id) - extraSize-- - } + // remove svid + c.log.WithField("record_id", record.id). + WithField("record_timestamp", record.timestamp). + Debug("Removing SVID record") + delete(c.svids, record.id) + extraSize-- } } } @@ -550,21 +538,18 @@ func (c *Cache) SyncSVIDsWithSubscribers() { // It returns whether all SVIDs are cached or not. // This method should be retried with backoff to avoid lock contention. func (c *Cache) Notify(selectors []*common.Selector) bool { - c.mu.Lock() - defer c.mu.Unlock() - if len(c.missingSVIDRecords(selectors)) == 0 { - set, setFree := allocSelectorSet(selectors...) - defer setFree() + c.mu.RLock() + defer c.mu.RUnlock() + set, setFree := allocSelectorSet(selectors...) + defer setFree() + if len(c.missingSVIDRecords(set)) == 0 { c.notifyBySelectorSet(set) return true } return false } -func (c *Cache) missingSVIDRecords(selectors []*common.Selector) []*StaleEntry { - set, setFree := allocSelectorSet(selectors...) - defer setFree() - +func (c *Cache) missingSVIDRecords(set selectorSet) []*StaleEntry { records, recordsDone := c.getRecordsForSelectors(set) defer recordsDone() @@ -589,17 +574,18 @@ func (c *Cache) updateLastAccessTimestamp(selectors []*common.Selector) { records, recordsDone := c.getRecordsForSelectors(set) defer recordsDone() + now := c.clk.Now().UnixMilli() for record := range records { // Set lastAccessTimestamp so that svid LRU cache can be cleaned based on this timestamp - record.lastAccessTimestamp = c.clk.Now().UnixMilli() + record.lastAccessTimestamp = now } } // entries with active subscribers which are not cached will be put in staleEntries map // records which are not cached for remainder of max cache size will also be put in staleEntries map -func (c *Cache) syncSVIDs() (map[string]struct{}, []record) { - activeSubs := make(map[string]struct{}) - lastAccessTimestamps := make([]record, len(c.records)) +func (c *Cache) syncSVIDs() (map[string]struct{}, []recordAccessEvent) { + activeSubsByEntryID := make(map[string]struct{}) + lastAccessTimestamps := make([]recordAccessEvent, 0, len(c.records)) i := 0 // iterate over all selectors from cached entries and obtain: @@ -614,29 +600,29 @@ func (c *Cache) syncSVIDs() (map[string]struct{}, []record) { if _, ok := c.svids[record.entry.EntryId]; !ok { c.staleEntries[id] = true } - activeSubs[id] = struct{}{} + activeSubsByEntryID[id] = struct{}{} break } } } - lastAccessTimestamps[i] = newRecord(record.lastAccessTimestamp, id) + lastAccessTimestamps = append(lastAccessTimestamps, newRecord(record.lastAccessTimestamp, id)) i++ } - remainderSize := c.maxSvidCacheSize - len(c.svids) + remainderSize := c.svidCacheMaxSize - len(c.svids) // add records which are not cached for remainder of cache size - for id, _ := range c.records { + for id := range c.records { if len(c.staleEntries) >= remainderSize { break } - if _, ok := c.svids[id]; !ok { + if _, svidCached := c.svids[id]; !svidCached { if _, ok := c.staleEntries[id]; !ok { c.staleEntries[id] = true } } } - return activeSubs, lastAccessTimestamps + return activeSubsByEntryID, lastAccessTimestamps } func (c *Cache) updateOrCreateRecord(newEntry *common.RegistrationEntry) (*cacheRecord, *common.RegistrationEntry) { @@ -961,7 +947,7 @@ func sortEntries(entries []*common.RegistrationEntry) { }) } -func sortTimestamps(records []record) { +func sortTimestamps(records []recordAccessEvent) { sort.Slice(records, func(a, b int) bool { return records[a].timestamp < records[b].timestamp }) @@ -975,11 +961,11 @@ func makeIdentity(record *cacheRecord, svid *X509SVID) Identity { } } -type record struct { +type recordAccessEvent struct { timestamp int64 id string } -func newRecord(timestamp int64, id string) record { - return record{timestamp: timestamp, id: id} +func newRecord(timestamp int64, id string) recordAccessEvent { + return recordAccessEvent{timestamp: timestamp, id: id} } diff --git a/pkg/agent/manager/cache/cache_test.go b/pkg/agent/manager/cache/cache_test.go index bcbb942093..49839a8155 100644 --- a/pkg/agent/manager/cache/cache_test.go +++ b/pkg/agent/manager/cache/cache_test.go @@ -31,8 +31,8 @@ var ( func TestFetchWorkloadUpdate(t *testing.T) { cache := newTestCache() // populate the cache with FOO and BAR without SVIDS - foo := makeRegistrationEntry("FOO", "foo") - bar := makeRegistrationEntry("BAR", "bar") + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") bar.FederatesWith = makeFederatesWith(otherBundleV1) updateEntries := &UpdateEntries{ Bundles: makeBundles(bundleV1, otherBundleV1), @@ -40,7 +40,7 @@ func TestFetchWorkloadUpdate(t *testing.T) { } cache.UpdateEntries(updateEntries, nil) - workloadUpdate := cache.FetchWorkloadUpdate(makeSelectors("foo", "bar")) + workloadUpdate := cache.FetchWorkloadUpdate(makeSelectors("A", "B")) assert.Len(t, workloadUpdate.Identities, 0, "identities should not be returned that don't have SVIDs") updateSVIDs := &UpdateSVIDs{ @@ -48,7 +48,7 @@ func TestFetchWorkloadUpdate(t *testing.T) { } cache.UpdateSVIDs(updateSVIDs) - workloadUpdate = cache.FetchWorkloadUpdate(makeSelectors("foo", "bar")) + workloadUpdate = cache.FetchWorkloadUpdate(makeSelectors("A", "B")) assert.Equal(t, &WorkloadUpdate{ Bundle: bundleV1, FederatedBundles: makeBundles(otherBundleV1), @@ -63,8 +63,8 @@ func TestMatchingRegistrationIdentities(t *testing.T) { cache := newTestCache() // populate the cache with FOO and BAR without SVIDS - foo := makeRegistrationEntry("FOO", "foo") - bar := makeRegistrationEntry("BAR", "bar") + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") updateEntries := &UpdateEntries{ Bundles: makeBundles(bundleV1), RegistrationEntries: makeRegistrationEntries(foo, bar), @@ -72,7 +72,7 @@ func TestMatchingRegistrationIdentities(t *testing.T) { cache.UpdateEntries(updateEntries, nil) assert.Equal(t, []*common.RegistrationEntry{bar, foo}, - cache.MatchingRegistrationEntries(makeSelectors("foo", "bar"))) + cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) // Update SVIDs and MatchingRegistrationEntries should return both entries updateSVIDs := &UpdateSVIDs{ @@ -80,20 +80,20 @@ func TestMatchingRegistrationIdentities(t *testing.T) { } cache.UpdateSVIDs(updateSVIDs) assert.Equal(t, []*common.RegistrationEntry{bar, foo}, - cache.MatchingRegistrationEntries(makeSelectors("foo", "bar"))) + cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) // Remove SVIDs and MatchingRegistrationEntries should still return both entries cache.UpdateSVIDs(&UpdateSVIDs{}) assert.Equal(t, []*common.RegistrationEntry{bar, foo}, - cache.MatchingRegistrationEntries(makeSelectors("foo", "bar"))) + cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) } func TestCountSVIDs(t *testing.T) { cache := newTestCache() // populate the cache with FOO and BAR without SVIDS - foo := makeRegistrationEntry("FOO", "foo") - bar := makeRegistrationEntry("BAR", "bar") + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") updateEntries := &UpdateEntries{ Bundles: makeBundles(bundleV1), RegistrationEntries: makeRegistrationEntries(foo, bar), @@ -140,28 +140,28 @@ func TestAllSubscribersNotifiedOnBundleChange(t *testing.T) { cache := newTestCache() // create some subscribers and assert they get the initial bundle - subFoo := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("foo")) - defer subFoo.Finish() - assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{Bundle: bundleV1}) + subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + defer subA.Finish() + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{Bundle: bundleV1}) - subBar := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("bar")) - defer subBar.Finish() - assertWorkloadUpdateEqual(t, subBar, &WorkloadUpdate{Bundle: bundleV1}) + subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) + defer subB.Finish() + assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{Bundle: bundleV1}) // update the bundle and assert all subscribers gets the updated bundle cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV2), }, nil) - assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{Bundle: bundleV2}) - assertWorkloadUpdateEqual(t, subBar, &WorkloadUpdate{Bundle: bundleV2}) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{Bundle: bundleV2}) + assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{Bundle: bundleV2}) } func TestSomeSubscribersNotifiedOnFederatedBundleChange(t *testing.T) { cache := newTestCache() // initialize the cache with an entry FOO that has a valid SVID and - // selector "foo" - foo := makeRegistrationEntry("FOO", "foo") + // selector "A" + foo := makeRegistrationEntry("FOO", "A") cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV1), RegistrationEntries: makeRegistrationEntries(foo), @@ -171,13 +171,13 @@ func TestSomeSubscribersNotifiedOnFederatedBundleChange(t *testing.T) { }) // subscribe to A and B and assert initial updates are received. - subFoo := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("foo")) - defer subFoo.Finish() - assertAnyWorkloadUpdate(t, subFoo) + subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + defer subA.Finish() + assertAnyWorkloadUpdate(t, subA) - subBar := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("bar")) - defer subBar.Finish() - assertAnyWorkloadUpdate(t, subBar) + subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) + defer subB.Finish() + assertAnyWorkloadUpdate(t, subB) // add the federated bundle with no registration entries federating with // it and make sure nobody is notified. @@ -185,73 +185,73 @@ func TestSomeSubscribersNotifiedOnFederatedBundleChange(t *testing.T) { Bundles: makeBundles(bundleV1, otherBundleV1), RegistrationEntries: makeRegistrationEntries(foo), }, nil) - assertNoWorkloadUpdate(t, subFoo) - assertNoWorkloadUpdate(t, subBar) + assertNoWorkloadUpdate(t, subA) + assertNoWorkloadUpdate(t, subB) - // update FOO to federate with otherdomain.test and make sure subFoo is - // notified but not subBar. - foo = makeRegistrationEntry("FOO", "foo") + // update FOO to federate with otherdomain.test and make sure subA is + // notified but not subB. + foo = makeRegistrationEntry("FOO", "A") foo.FederatesWith = makeFederatesWith(otherBundleV1) cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV1, otherBundleV1), RegistrationEntries: makeRegistrationEntries(foo), }, nil) - assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{ + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ Bundle: bundleV1, FederatedBundles: makeBundles(otherBundleV1), Identities: []Identity{{Entry: foo}}, }) - assertNoWorkloadUpdate(t, subBar) + assertNoWorkloadUpdate(t, subB) - // now change the federated bundle and make sure subFoo gets notified, but - // again, not subBar. + // now change the federated bundle and make sure subA gets notified, but + // again, not subB. cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV1, otherBundleV2), RegistrationEntries: makeRegistrationEntries(foo), }, nil) - assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{ + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ Bundle: bundleV1, FederatedBundles: makeBundles(otherBundleV2), Identities: []Identity{{Entry: foo}}, }) - assertNoWorkloadUpdate(t, subBar) + assertNoWorkloadUpdate(t, subB) - // now drop the federation and make sure subFoo is again notified and no + // now drop the federation and make sure subA is again notified and no // longer has the federated bundle. - foo = makeRegistrationEntry("FOO", "foo") + foo = makeRegistrationEntry("FOO", "A") cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV1, otherBundleV2), RegistrationEntries: makeRegistrationEntries(foo), }, nil) - assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{ + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ Bundle: bundleV1, Identities: []Identity{{Entry: foo}}, }) - assertNoWorkloadUpdate(t, subBar) + assertNoWorkloadUpdate(t, subB) } func TestSubscribersGetEntriesWithSelectorSubsets(t *testing.T) { cache := newTestCache() // create subscribers for each combination of selectors - subFoo := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("foo")) - defer subFoo.Finish() - subBar := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("bar")) - defer subBar.Finish() - subFooBar := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("foo", "bar")) - defer subFooBar.Finish() + subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + defer subA.Finish() + subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) + defer subB.Finish() + subAB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A", "B")) + defer subAB.Finish() // assert all subscribers get the initial update initialUpdate := &WorkloadUpdate{Bundle: bundleV1} - assertWorkloadUpdateEqual(t, subFoo, initialUpdate) - assertWorkloadUpdateEqual(t, subBar, initialUpdate) - assertWorkloadUpdateEqual(t, subFooBar, initialUpdate) + assertWorkloadUpdateEqual(t, subA, initialUpdate) + assertWorkloadUpdateEqual(t, subB, initialUpdate) + assertWorkloadUpdateEqual(t, subAB, initialUpdate) - // create entry FOO that will target any subscriber with containing (foo) - foo := makeRegistrationEntry("FOO", "foo") + // create entry FOO that will target any subscriber with containing (A) + foo := makeRegistrationEntry("FOO", "A") - // create entry BAR that will target any subscriber with containing (foo,baz) - bar := makeRegistrationEntry("BAR", "foo", "baz") + // create entry BAR that will target any subscriber with containing (A,C) + bar := makeRegistrationEntry("BAR", "A", "C") // update the cache with foo and bar cache.UpdateEntries(&UpdateEntries{ @@ -262,18 +262,18 @@ func TestSubscribersGetEntriesWithSelectorSubsets(t *testing.T) { X509SVIDs: makeX509SVIDs(foo, bar), }) - // subFoo selector set contains (foo), but not (foo, baz), so it should only get FOO - assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{ + // subA selector set contains (A), but not (A, C), so it should only get FOO + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ Bundle: bundleV1, Identities: []Identity{{Entry: foo}}, }) - // subBar selector set does not contain either (foo) or (foo,baz) so it isn't even + // subB selector set does not contain either (A) or (A,C) so it isn't even // notified. - assertNoWorkloadUpdate(t, subBar) + assertNoWorkloadUpdate(t, subB) - // subFooBar selector set contains (foo) but not (foo, baz), so it should get FOO - assertWorkloadUpdateEqual(t, subFooBar, &WorkloadUpdate{ + // subAB selector set contains (A) but not (A, C), so it should get FOO + assertWorkloadUpdateEqual(t, subAB, &WorkloadUpdate{ Bundle: bundleV1, Identities: []Identity{{Entry: foo}}, }) @@ -282,7 +282,7 @@ func TestSubscribersGetEntriesWithSelectorSubsets(t *testing.T) { func TestSubscriberIsNotNotifiedIfNothingChanges(t *testing.T) { cache := newTestCache() - foo := makeRegistrationEntry("FOO", "foo") + foo := makeRegistrationEntry("FOO", "A") cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV1), RegistrationEntries: makeRegistrationEntries(foo), @@ -291,7 +291,7 @@ func TestSubscriberIsNotNotifiedIfNothingChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo), }) - sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("foo")) + sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -308,7 +308,7 @@ func TestSubscriberIsNotNotifiedIfNothingChanges(t *testing.T) { func TestSubscriberNotifiedOnSVIDChanges(t *testing.T) { cache := newTestCache() - foo := makeRegistrationEntry("FOO", "foo") + foo := makeRegistrationEntry("FOO", "A") cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV1), RegistrationEntries: makeRegistrationEntries(foo), @@ -317,7 +317,7 @@ func TestSubscriberNotifiedOnSVIDChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo), }) - sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("foo")) + sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -386,17 +386,17 @@ func TestSubscriberNotificationsOnSelectorChanges(t *testing.T) { func TestSubscriberNotifiedWhenEntryDropped(t *testing.T) { cache := newTestCache() - subFoo := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("foo")) - defer subFoo.Finish() - assertAnyWorkloadUpdate(t, subFoo) + subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + defer subA.Finish() + assertAnyWorkloadUpdate(t, subA) - // subBar's job here is to just make sure we don't notify unrelated + // subB's job here is to just make sure we don't notify unrelated // subscribers when dropping registration entries - subBar := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("bar")) - defer subBar.Finish() - assertAnyWorkloadUpdate(t, subBar) + subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) + defer subB.Finish() + assertAnyWorkloadUpdate(t, subB) - foo := makeRegistrationEntry("FOO", "foo") + foo := makeRegistrationEntry("FOO", "A") updateEntries := &UpdateEntries{ Bundles: makeBundles(bundleV1), RegistrationEntries: makeRegistrationEntries(foo), @@ -405,46 +405,46 @@ func TestSubscriberNotifiedWhenEntryDropped(t *testing.T) { cache.UpdateSVIDs(&UpdateSVIDs{ X509SVIDs: makeX509SVIDs(foo), }) - // make sure subFoo gets notified with FOO but not subBar - assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{ + // make sure subA gets notified with FOO but not subB + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ Bundle: bundleV1, Identities: []Identity{{Entry: foo}}, }) - assertNoWorkloadUpdate(t, subBar) + assertNoWorkloadUpdate(t, subB) updateEntries.RegistrationEntries = nil cache.UpdateEntries(updateEntries, nil) - assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{ + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ Bundle: bundleV1, }) - assertNoWorkloadUpdate(t, subBar) + assertNoWorkloadUpdate(t, subB) // Make sure trying to update SVIDs of removed entry does not notify cache.UpdateSVIDs(&UpdateSVIDs{ X509SVIDs: makeX509SVIDs(foo), }) - assertNoWorkloadUpdate(t, subBar) + assertNoWorkloadUpdate(t, subB) } func TestSubscriberOnlyGetsEntriesWithSVID(t *testing.T) { cache := newTestCache() - foo := makeRegistrationEntry("FOO", "foo") + foo := makeRegistrationEntry("FOO", "A") updateEntries := &UpdateEntries{ Bundles: makeBundles(bundleV1), RegistrationEntries: makeRegistrationEntries(foo), } cache.UpdateEntries(updateEntries, nil) - subFoo := cache.SubscribeToWorkloadUpdates(makeSelectors("foo")) - defer subFoo.Finish() - assertNoWorkloadUpdate(t, subFoo) + sub := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + defer sub.Finish() + assertNoWorkloadUpdate(t, sub) // update to include the SVID and now we should get the update cache.UpdateSVIDs(&UpdateSVIDs{ X509SVIDs: makeX509SVIDs(foo), }) - assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{ + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ Bundle: bundleV1, Identities: []Identity{{Entry: foo}}, }) @@ -453,8 +453,8 @@ func TestSubscriberOnlyGetsEntriesWithSVID(t *testing.T) { func TestSubscribersDoNotBlockNotifications(t *testing.T) { cache := newTestCache() - subFoo := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("foo")) - defer subFoo.Finish() + sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + defer sub.Finish() cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV2), @@ -464,7 +464,7 @@ func TestSubscribersDoNotBlockNotifications(t *testing.T) { Bundles: makeBundles(bundleV3), }, nil) - assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{ + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ Bundle: bundleV3, }) } @@ -520,7 +520,6 @@ func TestGetStaleEntries(t *testing.T) { cache := newTestCache() foo := makeRegistrationEntryWithTTL("FOO", 60) - expiredAt := time.Now() // Create entry but don't mark it stale from checkSVID method; // it will be marked stale cause it does not have SVID cached @@ -537,6 +536,7 @@ func TestGetStaleEntries(t *testing.T) { // Update the SVID for the stale entry svids := make(map[string]*X509SVID) + expiredAt := time.Now() svids[foo.EntryId] = &X509SVID{ Chain: []*x509.Certificate{{NotAfter: expiredAt}}, } @@ -573,8 +573,8 @@ func TestGetStaleEntries(t *testing.T) { func TestSubscriberNotNotifiedOnDifferentSVIDChanges(t *testing.T) { cache := newTestCache() - foo := makeRegistrationEntry("FOO", "foo") - bar := makeRegistrationEntry("BAR", "bar") + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV1), RegistrationEntries: makeRegistrationEntries(foo, bar), @@ -583,7 +583,7 @@ func TestSubscriberNotNotifiedOnDifferentSVIDChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo, bar), }) - sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("foo")) + sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -622,10 +622,10 @@ func TestSubscriberNotNotifiedOnOverlappingSVIDChanges(t *testing.T) { func TestSVIDCacheExpiry(t *testing.T) { clk := clock.NewMock() - cache := newTestCacheWithConfig(10, 1*time.Minute, clk) + cache := newTestCacheWithConfig(10, clk) clk.Add(1 * time.Second) - foo := makeRegistrationEntry("FOO", "foo") + foo := makeRegistrationEntry("FOO", "A") // validate workload update for foo cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV1), @@ -634,16 +634,16 @@ func TestSVIDCacheExpiry(t *testing.T) { cache.UpdateSVIDs(&UpdateSVIDs{ X509SVIDs: makeX509SVIDs(foo), }) - subFoo := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("foo")) - assertWorkloadUpdateEqual(t, subFoo, &WorkloadUpdate{ + subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ Bundle: bundleV1, Identities: []Identity{{Entry: foo}}, }) - subFoo.Finish() + subA.Finish() // move clk by 1 sec so that SVID access time will be different clk.Add(1 * time.Second) - bar := makeRegistrationEntry("BAR", "bar") + bar := makeRegistrationEntry("BAR", "B") // validate workload update for bar cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV1), @@ -654,16 +654,16 @@ func TestSVIDCacheExpiry(t *testing.T) { }) // not closing subscriber immediately - subBar := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("bar")) - defer subBar.Finish() - assertWorkloadUpdateEqual(t, subBar, &WorkloadUpdate{ + subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) + defer subB.Finish() + assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{ Bundle: bundleV1, Identities: []Identity{ {Entry: bar}, }, }) - // Move clk by a second + // Move clk by 2 seconds clk.Add(2 * time.Second) // update total of 12 entries updateEntries := createUpdateEntries(10, makeBundles(bundleV1)) @@ -683,42 +683,31 @@ func TestSVIDCacheExpiry(t *testing.T) { sub.Finish() } } + assert.Equal(t, 12, cache.CountSVIDs()) - // Move clk by 58 sec so that a minute has passed since last foo was accessed - // svid for foo should be deleted - clk.Add(58 * time.Second) cache.UpdateEntries(updateEntries, nil) + assert.Equal(t, 10, cache.CountSVIDs()) - subFoo = cache.SubscribeToWorkloadUpdates(makeSelectors("foo")) - defer subFoo.Finish() - assert.False(t, cache.Notify(makeSelectors("foo"))) - assert.Equal(t, 11, cache.CountSVIDs()) + // foo SVID should be removed from cache as it does not have active subscriber + assert.False(t, cache.Notify(makeSelectors("A"))) + // bar SVID should be cached as it has active subscriber + assert.True(t, cache.Notify(makeSelectors("B"))) + + subA = cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + defer subA.Finish() - // move clk by another minute and update entries - clk.Add(1 * time.Minute) cache.UpdateEntries(updateEntries, nil) // Make sure foo is marked as stale entry which does not have svid cached require.Len(t, cache.GetStaleEntries(), 1) assert.Equal(t, foo, cache.GetStaleEntries()[0].Entry) - // bar should not be removed from cache as it has another active subscriber - subBar2 := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("bar")) - defer subBar2.Finish() - assertWorkloadUpdateEqual(t, subBar2, &WorkloadUpdate{ - Bundle: bundleV1, - Identities: []Identity{ - {Entry: bar}, - }, - }) - - // ensure SVIDs without active subscribers are still cached for remainder of cache size assert.Equal(t, 10, cache.CountSVIDs()) } func TestMaxSVIDCacheSize(t *testing.T) { clk := clock.NewMock() - cache := newTestCacheWithConfig(10, 1*time.Minute, clk) + cache := newTestCacheWithConfig(10, clk) // create entries more than maxSvidCacheSize updateEntries := createUpdateEntries(12, makeBundles(bundleV1)) @@ -733,11 +722,11 @@ func TestMaxSVIDCacheSize(t *testing.T) { assert.Equal(t, 10, cache.CountSVIDs()) // Validate that active subscriber will still get SVID even if SVID count is at maxSvidCacheSize - foo := makeRegistrationEntry("FOO", "foo") + foo := makeRegistrationEntry("FOO", "A") updateEntries.RegistrationEntries[foo.EntryId] = foo - subFoo := cache.SubscribeToWorkloadUpdates(foo.Selectors) - defer subFoo.Finish() + subA := cache.SubscribeToWorkloadUpdates(foo.Selectors) + defer subA.Finish() cache.UpdateEntries(updateEntries, nil) require.Len(t, cache.GetStaleEntries(), 1) @@ -752,7 +741,7 @@ func TestMaxSVIDCacheSize(t *testing.T) { func TestSyncSVIDsWithSubscribers(t *testing.T) { clk := clock.NewMock() - cache := newTestCacheWithConfig(5, 1*time.Minute, clk) + cache := newTestCacheWithConfig(5, clk) updateEntries := createUpdateEntries(5, makeBundles(bundleV1)) cache.UpdateEntries(updateEntries, nil) @@ -762,14 +751,14 @@ func TestSyncSVIDsWithSubscribers(t *testing.T) { assert.Equal(t, 5, cache.CountSVIDs()) // Update foo but its SVID is not yet cached - foo := makeRegistrationEntry("FOO", "foo") + foo := makeRegistrationEntry("FOO", "A") updateEntries.RegistrationEntries[foo.EntryId] = foo cache.UpdateEntries(updateEntries, nil) // Create a subscriber for foo - subFoo := cache.SubscribeToWorkloadUpdates(foo.Selectors) - defer subFoo.Finish() + subA := cache.SubscribeToWorkloadUpdates(foo.Selectors) + defer subA.Finish() require.Len(t, cache.GetStaleEntries(), 0) // After SyncSVIDsWithSubscribers foo should be marked as stale, requiring signing @@ -783,29 +772,27 @@ func TestSyncSVIDsWithSubscribers(t *testing.T) { func TestNotify(t *testing.T) { cache := newTestCache() - foo := makeRegistrationEntry("FOO", "foo") + foo := makeRegistrationEntry("FOO", "A") cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV1), RegistrationEntries: makeRegistrationEntries(foo), }, nil) - assert.False(t, cache.Notify(makeSelectors("foo"))) + assert.False(t, cache.Notify(makeSelectors("A"))) cache.UpdateSVIDs(&UpdateSVIDs{ X509SVIDs: makeX509SVIDs(foo), }) - assert.True(t, cache.Notify(makeSelectors("foo"))) + assert.True(t, cache.Notify(makeSelectors("A"))) } func TestNewCache(t *testing.T) { - // negative values - cache := newTestCacheWithConfig(-5, -5, clock.NewMock()) - require.Equal(t, DefaultMaxSvidCacheSize, cache.maxSvidCacheSize) - require.Equal(t, DefaultSVIDCacheExpiryPeriod, cache.svidCacheExpiryPeriod) + // negative value + cache := newTestCacheWithConfig(-5, clock.NewMock()) + require.Equal(t, DefaultSVIDCacheMaxSize, cache.svidCacheMaxSize) - // zero values - cache = newTestCacheWithConfig(0, 0, clock.NewMock()) - require.Equal(t, DefaultMaxSvidCacheSize, cache.maxSvidCacheSize) - require.Equal(t, DefaultSVIDCacheExpiryPeriod, cache.svidCacheExpiryPeriod) + // zero value + cache = newTestCacheWithConfig(0, clock.NewMock()) + require.Equal(t, DefaultSVIDCacheMaxSize, cache.svidCacheMaxSize) } func BenchmarkCacheGlobalNotification(b *testing.B) { @@ -856,13 +843,13 @@ func BenchmarkCacheGlobalNotification(b *testing.B) { func newTestCache() *Cache { log, _ := test.NewNullLogger() return New(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, - telemetry.Blackhole{}, 0, 0, clock.NewMock()) + telemetry.Blackhole{}, 0, clock.NewMock()) } -func newTestCacheWithConfig(maxSvidCacheSize int, svidCacheExpiryPeriod time.Duration, clk clock.Clock) *Cache { +func newTestCacheWithConfig(svidCacheMaxSize int, clk clock.Clock) *Cache { log, _ := test.NewNullLogger() return New(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, telemetry.Blackhole{}, - maxSvidCacheSize, svidCacheExpiryPeriod, clk) + svidCacheMaxSize, clk) } // numEntries should not be more than 12 digits diff --git a/pkg/agent/manager/config.go b/pkg/agent/manager/config.go index 53208ebb3c..f33392d6b2 100644 --- a/pkg/agent/manager/config.go +++ b/pkg/agent/manager/config.go @@ -19,21 +19,20 @@ import ( // Config holds a cache manager configuration type Config struct { // Agent SVID and key resulting from successful attestation. - SVID []*x509.Certificate - SVIDKey keymanager.Key - Bundle *cache.Bundle - Catalog catalog.Catalog - TrustDomain spiffeid.TrustDomain - Log logrus.FieldLogger - Metrics telemetry.Metrics - ServerAddr string - SVIDCachePath string - BundleCachePath string - SyncInterval time.Duration - RotationInterval time.Duration - SVIDStoreCache *storecache.Cache - MaxSvidCacheSize int - SVIDCacheExpiryPeriod time.Duration + SVID []*x509.Certificate + SVIDKey keymanager.Key + Bundle *cache.Bundle + Catalog catalog.Catalog + TrustDomain spiffeid.TrustDomain + Log logrus.FieldLogger + Metrics telemetry.Metrics + ServerAddr string + SVIDCachePath string + BundleCachePath string + SyncInterval time.Duration + RotationInterval time.Duration + SVIDStoreCache *storecache.Cache + SVIDCacheMaxSize int // Clk is the clock the manager will use to get time Clk clock.Clock @@ -58,7 +57,7 @@ func newManager(c *Config) *manager { } cache := cache.New(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, - c.Metrics, c.MaxSvidCacheSize, c.SVIDCacheExpiryPeriod, c.Clk) + c.Metrics, c.SVIDCacheMaxSize, c.Clk) rotCfg := &svid.RotatorConfig{ SVIDKeyManager: keymanager.ForSVID(c.Catalog.GetKeyManager()), diff --git a/pkg/agent/manager/manager.go b/pkg/agent/manager/manager.go index 56070c18b9..7620de3a8b 100644 --- a/pkg/agent/manager/manager.go +++ b/pkg/agent/manager/manager.go @@ -157,9 +157,9 @@ func (m *manager) SubscribeToCacheChanges(selectors cache.Selectors) cache.Subsc if m.cache.Notify(selectors) { return subscriber } - select { - case <-m.clk.After(backoff.NextBackOff()): - } + m.c.Log.WithField(telemetry.Selectors, selectors).Info("Waiting for SVID to get cached") + + <-m.clk.After(backoff.NextBackOff()) } } diff --git a/pkg/agent/manager/manager_test.go b/pkg/agent/manager/manager_test.go index 53279b54e6..ca317132b4 100644 --- a/pkg/agent/manager/manager_test.go +++ b/pkg/agent/manager/manager_test.go @@ -80,7 +80,7 @@ func TestInitializationFailure(t *testing.T) { BundleCachePath: path.Join(dir, "bundle.der"), Clk: clk, Catalog: cat, - MaxSvidCacheSize: 1, + SVIDCacheMaxSize: 1, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m := newManager(c) @@ -108,7 +108,7 @@ func TestStoreBundleOnStartup(t *testing.T) { Bundle: bundleutil.BundleFromRootCA(trustDomain, ca), Clk: clk, Catalog: cat, - MaxSvidCacheSize: 1, + SVIDCacheMaxSize: 1, } m := newManager(c) @@ -155,7 +155,7 @@ func TestStoreSVIDOnStartup(t *testing.T) { BundleCachePath: path.Join(dir, "bundle.der"), Clk: clk, Catalog: cat, - MaxSvidCacheSize: 1, + SVIDCacheMaxSize: 1, } _, err := ReadSVID(c.SVIDCachePath) @@ -310,7 +310,7 @@ func TestSVIDRotation(t *testing.T) { RotationInterval: baseTTLSeconds / 2, SyncInterval: 1 * time.Hour, Clk: clk, - MaxSvidCacheSize: 1, + SVIDCacheMaxSize: 1, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } @@ -419,7 +419,7 @@ func TestSynchronization(t *testing.T) { SyncInterval: time.Hour, Clk: clk, Catalog: cat, - MaxSvidCacheSize: 1, + SVIDCacheMaxSize: 1, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } @@ -571,7 +571,7 @@ func TestSynchronizationClearsStaleCacheEntries(t *testing.T) { Metrics: &telemetry.Blackhole{}, Clk: clk, Catalog: cat, - MaxSvidCacheSize: 1, + SVIDCacheMaxSize: 1, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } @@ -645,7 +645,7 @@ func TestSynchronizationUpdatesRegistrationEntries(t *testing.T) { Metrics: &telemetry.Blackhole{}, Clk: clk, Catalog: cat, - MaxSvidCacheSize: 1, + SVIDCacheMaxSize: 1, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } @@ -706,7 +706,7 @@ func TestSubscribersGetUpToDateBundle(t *testing.T) { RotationInterval: 1 * time.Hour, SyncInterval: 1 * time.Hour, Clk: clk, - MaxSvidCacheSize: 1, + SVIDCacheMaxSize: 1, Catalog: cat, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } @@ -729,86 +729,6 @@ func TestSubscribersGetUpToDateBundle(t *testing.T) { }) } -func TestSynchronizationClearsExpiredSVIDCache(t *testing.T) { - dir := spiretest.TempDir(t) - km := fakeagentkeymanager.New(t, dir) - - clk := clock.NewMock(t) - api := newMockAPI(t, &mockAPIConfig{ - km: km, - getAuthorizedEntries: func(h *mockAPI, count int32, req *entryv1.GetAuthorizedEntriesRequest) (*entryv1.GetAuthorizedEntriesResponse, error) { - return makeGetAuthorizedEntriesResponse(t, "resp1", "resp2"), nil - }, - batchNewX509SVIDEntries: func(h *mockAPI, count int32) []*common.RegistrationEntry { - h.rotateCA() - return makeBatchNewX509SVIDEntries("resp1", "resp2") - }, - svidTTL: 200, - clk: clk, - }) - - baseSVID, baseSVIDKey := api.newSVID(joinTokenID, 1*time.Hour) - cat := fakeagentcatalog.New() - cat.SetKeyManager(km) - - c := &Config{ - ServerAddr: api.addr, - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - TrustDomain: trustDomain, - SVIDCachePath: path.Join(dir, "svid.der"), - BundleCachePath: path.Join(dir, "bundle.der"), - Bundle: api.bundle, - Metrics: &telemetry.Blackhole{}, - RotationInterval: 1 * time.Hour, - SyncInterval: 1 * time.Hour, - MaxSvidCacheSize: 1, - SVIDCacheExpiryPeriod: 5 * time.Second, - Clk: clk, - Catalog: cat, - SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), - } - - clk.Add(1 * time.Second) - - m := newManager(c) - if err := m.Initialize(context.Background()); err != nil { - t.Fatal(err) - } - - // After Initialize, just 1 SVID should be cached - require.Equal(t, 1, m.CountSVIDs()) - waitCh := make(chan struct{}) - - closer := runSVIDSync(t, waitCh, m) - defer closer() - - // Keep clk moving so that each subscriber gets SVID after SVID sync - clkCloser := moveClkAfterInterval(clk, 100*time.Millisecond, svidSyncInterval, waitCh) - - sub1 := m.SubscribeToCacheChanges(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) - - sub2 := m.SubscribeToCacheChanges( - cache.Selectors{{Type: "spiffe_id", Value: "spiffe://example.org/spire/agent/join_token/abcd"}}) - - sub1.Finish() - sub2.Finish() - close(waitCh) - clkCloser() - - // All 3 SVIDs should be cached - require.Equal(t, 3, m.CountSVIDs()) - - // Move clock so that svid cache is expired - clk.Add(6 * time.Second) - - require.NoError(t, m.synchronize(context.Background())) - - // Make sure svid count is MaxSvidCacheSize and remaining SVIDs are deleted from cache - require.Equal(t, 1, m.CountSVIDs()) -} - func TestSyncSVIDs(t *testing.T) { dir := spiretest.TempDir(t) km := fakeagentkeymanager.New(t, dir) @@ -831,74 +751,6 @@ func TestSyncSVIDs(t *testing.T) { cat := fakeagentcatalog.New() cat.SetKeyManager(km) - c := &Config{ - ServerAddr: api.addr, - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - TrustDomain: trustDomain, - SVIDCachePath: path.Join(dir, "svid.der"), - BundleCachePath: path.Join(dir, "bundle.der"), - Bundle: api.bundle, - Metrics: &telemetry.Blackhole{}, - RotationInterval: 1 * time.Hour, - SyncInterval: 1 * time.Hour, - MaxSvidCacheSize: 1, - SVIDCacheExpiryPeriod: 5 * time.Second, - Clk: clk, - Catalog: cat, - SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), - } - - clk.Add(1 * time.Second) - - m := newManager(c) - closer := initializeAndRunManager(t, m) - defer closer() - - // After Initialize, just 1 SVID should be cached - require.Equal(t, 1, m.CountSVIDs()) - waitCh := make(chan struct{}) - - // Keep clk moving so that each subscriber gets SVID after SVID sync - clkCloser := moveClkAfterInterval(clk, 100*time.Millisecond, svidSyncInterval, waitCh) - defer clkCloser() - - sub1 := m.SubscribeToCacheChanges(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) - defer sub1.Finish() - - sub2 := m.SubscribeToCacheChanges( - cache.Selectors{{Type: "spiffe_id", Value: "spiffe://example.org/spire/agent/join_token/abcd"}}) - defer sub2.Finish() - - close(waitCh) - - // All 3 SVIDs should be cached - require.Equal(t, 3, m.CountSVIDs()) -} - -func TestSubscribersWaitForSVID(t *testing.T) { - dir := spiretest.TempDir(t) - km := fakeagentkeymanager.New(t, dir) - - clk := clock.NewMock(t) - api := newMockAPI(t, &mockAPIConfig{ - km: km, - getAuthorizedEntries: func(h *mockAPI, count int32, req *entryv1.GetAuthorizedEntriesRequest) (*entryv1.GetAuthorizedEntriesResponse, error) { - return makeGetAuthorizedEntriesResponse(t, "resp1", "resp2"), nil - }, - batchNewX509SVIDEntries: func(h *mockAPI, count int32) []*common.RegistrationEntry { - h.rotateCA() - return makeBatchNewX509SVIDEntries("resp1", "resp2") - }, - svidTTL: 200, - clk: clk, - }) - - baseSVID, baseSVIDKey := api.newSVID(joinTokenID, 1*time.Hour) - cat := fakeagentcatalog.New() - cat.SetKeyManager(km) - c := &Config{ ServerAddr: api.addr, SVID: baseSVID, @@ -911,72 +763,64 @@ func TestSubscribersWaitForSVID(t *testing.T) { Metrics: &telemetry.Blackhole{}, RotationInterval: 1 * time.Hour, SyncInterval: 1 * time.Hour, - MaxSvidCacheSize: 1, + SVIDCacheMaxSize: 1, Clk: clk, Catalog: cat, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m := newManager(c) - if err := m.Initialize(context.Background()); err != nil { t.Fatal(err) } // After Initialize, just 1 SVID should be cached - require.Equal(t, 1, m.CountSVIDs()) - + assert.Equal(t, 1, m.CountSVIDs()) waitCh := make(chan struct{}) + errCh := make(chan error) - closer := runSVIDSync(t, waitCh, m) + // Run svidSync in separate routine and advance clock. + // It allows SubscribeToCacheChanges to keep checking for SVID in cache as clk advances. + closer := runSVIDSync(waitCh, clk, 50*time.Millisecond, m, errCh) defer closer() - // Keep clk moving so that each subscriber gets SVID after SVID sync - clkCloser := moveClkAfterInterval(clk, 100*time.Millisecond, svidSyncInterval, waitCh) - defer clkCloser() + sub1 := m.SubscribeToCacheChanges(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + // Validate the update received by subscribers + u1 := <-sub1.Updates() + if len(u1.Identities) != 2 { + t.Fatalf("expected 2 SVIDs, got: %d", len(u1.Identities)) + } + if !u1.Bundle.EqualTo(c.Bundle) { + t.Fatal("bundles were expected to be equal") + } - go func() { - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - sub1 := m.SubscribeToCacheChanges(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) - defer sub1.Finish() - u := <-sub1.Updates() - if len(u.Identities) != 2 { - t.Fatalf("expected 2 SVIDs, got: %d", len(u.Identities)) - } - if !u.Bundle.EqualTo(c.Bundle) { - t.Fatal("bundles were expected to be equal") - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - sub2 := m.SubscribeToCacheChanges( - cache.Selectors{{Type: "spiffe_id", Value: "spiffe://example.org/spire/agent/join_token/abcd"}}) - defer sub2.Finish() - u := <-sub2.Updates() - if len(u.Identities) != 1 { - t.Fatalf("expected 1 SVID, got: %d", len(u.Identities)) - } - if !u.Bundle.EqualTo(c.Bundle) { - t.Fatal("bundles were expected to be equal") - } - }() + sub2 := m.SubscribeToCacheChanges( + cache.Selectors{{Type: "spiffe_id", Value: "spiffe://example.org/spire/agent/join_token/abcd"}}) + // Validate the update received by subscribers + u2 := <-sub2.Updates() + if len(u2.Identities) != 1 { + t.Fatalf("expected 1 SVID, got: %d", len(u2.Identities)) + } + if !u2.Bundle.EqualTo(c.Bundle) { + t.Fatal("bundles were expected to be equal") + } - wg.Wait() - close(waitCh) - }() + sub1.Finish() + sub2.Finish() + close(waitCh) - select { - case <-waitCh: - case <-time.After(5 * time.Second): - t.Fatalf("subscriber update wait timed out") + err := <-errCh + if err != nil { + t.Fatalf("syncSVIDs method failed with error %v", err) } - require.Equal(t, 3, m.CountSVIDs()) + // All 3 SVIDs should be cached + assert.Equal(t, 3, m.CountSVIDs()) + + assert.NoError(t, m.synchronize(context.Background())) + + // Make sure svid count is SVIDCacheMaxSize and non-active SVIDs are deleted from cache + assert.Equal(t, 1, m.CountSVIDs()) } func TestSurvivesCARotation(t *testing.T) { @@ -1020,7 +864,7 @@ func TestSurvivesCARotation(t *testing.T) { SyncInterval: syncInterval, Clk: clk, Catalog: cat, - MaxSvidCacheSize: 1, + SVIDCacheMaxSize: 1, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } @@ -1081,7 +925,7 @@ func TestFetchJWTSVID(t *testing.T) { Metrics: &telemetry.Blackhole{}, Catalog: cat, Clk: clk, - MaxSvidCacheSize: 1, + SVIDCacheMaxSize: 1, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } @@ -1206,7 +1050,7 @@ func TestStorableSVIDsSync(t *testing.T) { Metrics: &telemetry.Blackhole{}, Clk: clk, Catalog: cat, - MaxSvidCacheSize: 1, + SVIDCacheMaxSize: 1, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } @@ -1245,26 +1089,8 @@ func TestStorableSVIDsSync(t *testing.T) { validateResponse(records, entries) } -func runSVIDSync(t *testing.T, waitCh chan struct{}, m *manager) (closer func()) { - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - for { - select { - case <-waitCh: - return - case <-m.clk.After(svidSyncInterval): - require.NoError(t, m.syncSVIDs(context.Background())) - } - } - }() - return func() { - wg.Wait() - } -} - -func moveClkAfterInterval(clk *clock.Mock, interval, period time.Duration, waitCh chan struct{}) (closer func()) { +func runSVIDSync(waitCh chan struct{}, clk *clock.Mock, interval time.Duration, m *manager, + errCh chan<- error) (closer func()) { var wg sync.WaitGroup wg.Add(1) go func() { @@ -1272,9 +1098,15 @@ func moveClkAfterInterval(clk *clock.Mock, interval, period time.Duration, waitC for { select { case <-waitCh: + errCh <- nil return case <-time.After(interval): - clk.Add(period) + clk.Add(interval) + err := m.syncSVIDs(context.Background()) + if err != nil { + errCh <- err + return + } } } }() diff --git a/pkg/agent/manager/sync.go b/pkg/agent/manager/sync.go index 655924ce69..f58fc54d88 100644 --- a/pkg/agent/manager/sync.go +++ b/pkg/agent/manager/sync.go @@ -39,7 +39,11 @@ type Cache interface { func (m *manager) syncSVIDs(ctx context.Context) (err error) { m.cache.SyncSVIDsWithSubscribers() - return m.updateSVIDs(ctx, m.cache.GetStaleEntries(), m.cache) + staleEntries := m.cache.GetStaleEntries() + if len(staleEntries) > 0 { + return m.updateSVIDs(ctx, staleEntries, m.cache) + } + return nil } // synchronize fetches the authorized entries from the server, updates the @@ -116,27 +120,26 @@ func (m *manager) updateCache(ctx context.Context, update *cache.UpdateEntries, func (m *manager) updateSVIDs(ctx context.Context, entries []*cache.StaleEntry, c Cache) error { var csrs []csrRequest - if len(entries) > 0 { - for _, entry := range entries { - // we've exceeded the CSR limit, don't make any more CSRs - if len(csrs) >= limits.SignLimitPerIP { - break - } - - csrs = append(csrs, csrRequest{ - EntryID: entry.Entry.EntryId, - SpiffeID: entry.Entry.SpiffeId, - CurrentSVIDExpiresAt: entry.ExpiresAt, - }) + for _, entry := range entries { + // we've exceeded the CSR limit, don't make any more CSRs + if len(csrs) >= limits.SignLimitPerIP { + break } - update, err := m.fetchSVIDs(ctx, csrs) - if err != nil { - return err - } - // the values in `update` now belong to the cache. DO NOT MODIFY. - c.UpdateSVIDs(update) + csrs = append(csrs, csrRequest{ + EntryID: entry.Entry.EntryId, + SpiffeID: entry.Entry.SpiffeId, + CurrentSVIDExpiresAt: entry.ExpiresAt, + }) } + + update, err := m.fetchSVIDs(ctx, csrs) + if err != nil { + return err + } + // the values in `update` now belong to the cache. DO NOT MODIFY. + c.UpdateSVIDs(update) + return nil } diff --git a/test/integration/common b/test/integration/common index 114ad4914c..858d2970a2 100644 --- a/test/integration/common +++ b/test/integration/common @@ -83,23 +83,24 @@ check-synced-entry() { fail-now "timed out waiting for agent to sync down entry" } -check-svid-count() { +check-x509-svid-count() { MAXCHECKS=50 CHECKINTERVAL=1 - for ((i=0;i<=MAXCHECKS;i++)); do - if (( $i>=$MAXCHECKS )); then - fail-now "svid count validation failed" - fi - log-info "check svid count on agent debug endpoint ($(($i+1)) of $MAXCHECKS max)..." - COUNT=`docker-compose exec -T $1 /opt/spire/conf/agent/debugclient -testCase "printDebugPage" | jq '."svidsCount"'` - log-info "svidsCount: ${COUNT}" + for ((i=1;i<=MAXCHECKS;i++)); do + log-info "check X.509-SVID count on agent debug endpoint ($(($i)) of $MAXCHECKS max)..." + COUNT=`docker-compose exec -T $1 /opt/spire/conf/agent/debugclient -testCase "printDebugPage" | jq '.svidsCount'` + log-info "X.509-SVID Count: ${COUNT}" if [ "$COUNT" -eq "$2" ]; then - log-info "SVID count of $COUNT from cache matches the expected count of $2" + log-info "X.509-SVID count of $COUNT from cache matches the expected count of $2" break fi sleep "${CHECKINTERVAL}" done + + if (( $i>$MAXCHECKS )); then + fail-now "X.509-SVID count validation failed" + fi } build-mashup-image() { diff --git a/test/integration/setup/debugagent/main.go b/test/integration/setup/debugagent/main.go index 283119a08b..aa420040e0 100644 --- a/test/integration/setup/debugagent/main.go +++ b/test/integration/setup/debugagent/main.go @@ -55,18 +55,20 @@ func run() error { func agentEndpoints(ctx context.Context) error { s, err := retrieveDebugPage(ctx) - if err == nil { - log.Printf("Debug info: %s", string(s)) + if err != nil { + return err } + log.Printf("Debug info: %s", s) return nil } // printDebugPage allows integration tests to easily parse debug page with jq func printDebugPage(ctx context.Context) error { s, err := retrieveDebugPage(ctx) - if err == nil { - fmt.Println(s) + if err != nil { + return err } + fmt.Println(s) return nil } diff --git a/test/integration/suites/fetch-svids/04-create-registration-entries b/test/integration/suites/fetch-svids/04-create-registration-entries index 6a79e4b6ad..d992df7947 100755 --- a/test/integration/suites/fetch-svids/04-create-registration-entries +++ b/test/integration/suites/fetch-svids/04-create-registration-entries @@ -1,6 +1,6 @@ #!/bin/bash -SIZE=12 +SIZE=10 # Create entries for uid 1001 for ((m=1; m<=$SIZE;m++)); do diff --git a/test/integration/suites/fetch-svids/05-fetch-svids b/test/integration/suites/fetch-svids/05-fetch-svids index 646579ea02..5c9f706d1c 100755 --- a/test/integration/suites/fetch-svids/05-fetch-svids +++ b/test/integration/suites/fetch-svids/05-fetch-svids @@ -1,14 +1,17 @@ #!/bin/bash -ENTRYCOUNT=12 +ENTRYCOUNT=10 CACHESIZE=8 -docker-compose exec -u 1001 -T spire-agent \ +X509SVIDCOUNT=`docker-compose exec -u 1001 -T spire-agent \ /opt/spire/bin/spire-agent api fetch x509 \ - -socketPath /opt/spire/sockets/workload_api.sock || fail-now "x509-SVID check failed" + -socketPath /opt/spire/sockets/workload_api.sock | grep -i "spiffe://domain.test" | wc -l || fail-now "X.509-SVID check failed"` -# Call agent debug endpoints and check if svid count is equal number of entries registered -check-svid-count "spire-agent" $ENTRYCOUNT +if [ "$X509SVIDCOUNT" -ne "$ENTRYCOUNT" ]; then + fail-now "X.509-SVID check failed. Expected $ENTRYCOUNT X.509-SVIDs but received $X509SVIDCOUNT for uid 1001"; +else + echo "Expected $ENTRYCOUNT X.509-SVIDs and received $X509SVIDCOUNT for uid 1001"; +fi -# Call agent debug endpoints and check if svids from cache are cleaned up after expiry -check-svid-count "spire-agent" $CACHESIZE +# Call agent debug endpoints and check if extra svids from cache are cleaned up +check-x509-svid-count "spire-agent" $CACHESIZE diff --git a/test/integration/suites/fetch-svids/06-create-registration-entries b/test/integration/suites/fetch-svids/06-create-registration-entries index abacd71a36..0f6365e95d 100755 --- a/test/integration/suites/fetch-svids/06-create-registration-entries +++ b/test/integration/suites/fetch-svids/06-create-registration-entries @@ -1,6 +1,6 @@ #!/bin/bash -SIZE=5 +SIZE=10 # Create entries for uid 1002 for ((m=1; m<=$SIZE;m++)); do diff --git a/test/integration/suites/fetch-svids/07-fetch-svids b/test/integration/suites/fetch-svids/07-fetch-svids index afc484e260..ba55bfe2ce 100755 --- a/test/integration/suites/fetch-svids/07-fetch-svids +++ b/test/integration/suites/fetch-svids/07-fetch-svids @@ -1,28 +1,27 @@ #!/bin/bash CACHESIZE=8 +ENTRYCOUNT=10 -docker-compose exec -u 1002 -T spire-agent \ +X509SVIDCOUNT=`docker-compose exec -u 1002 -T spire-agent \ /opt/spire/bin/spire-agent api fetch x509 \ - -socketPath /opt/spire/sockets/workload_api.sock || fail-now "x509-SVID check failed" + -socketPath /opt/spire/sockets/workload_api.sock | grep -i "spiffe://domain.test" | wc -l || fail-now "X.509-SVID check failed"` -# Call agent debug endpoints and check if svid count is equal to cache size limit -check-svid-count "spire-agent" $CACHESIZE +if [ "$X509SVIDCOUNT" -ne "$ENTRYCOUNT" ]; then + fail-now "X.509-SVID check failed. Expected $ENTRYCOUNT X.509-SVIDs but received $X509SVIDCOUNT for uid 1002"; +else + echo "Expected $ENTRYCOUNT X.509-SVIDs and received $X509SVIDCOUNT for uid 1002"; +fi -# introduce some delay between two fetch calls so that we can validate cache cleanup of svids from first fetch. -sleep 5 - -docker-compose exec -u 1001 -T spire-agent \ +X509SVIDCOUNT=`docker-compose exec -u 1001 -T spire-agent \ /opt/spire/bin/spire-agent api fetch x509 \ - -socketPath /opt/spire/sockets/workload_api.sock || fail-now "x509-SVID check failed" - -# Call agent debug endpoints and check if svid count is equal to 17 (registration entry count for uids 1001 and 1002) -check-svid-count "spire-agent" 17 + -socketPath /opt/spire/sockets/workload_api.sock | grep -i "spiffe://domain.test" | wc -l || fail-now "X.509-SVID check failed"` -# Call agent debug endpoints and check if svid count is equal to 12 -# 17(svid-count) - 5(svids from entries with uuid 1002 will be removed first after svid_cache_expiry_interval) = 12 -check-svid-count "spire-agent" 12 +if [ "$X509SVIDCOUNT" -ne "$ENTRYCOUNT" ]; then + fail-now "X.509-SVID check failed. Expected $ENTRYCOUNT X.509-SVIDs but received $X509SVIDCOUNT for uid 1001"; +else + echo "Expected $ENTRYCOUNT X.509-SVIDs and received $X509SVIDCOUNT for uid 1001"; +fi -# Call agent debug endpoints and check if svid count is equal to 8 -# 12 - 8(cache size) = 4 extra svids from entries with uuid 1001 will be removed after svid_cache_expiry_interval -check-svid-count "spire-agent" $CACHESIZE +# Call agent debug endpoints and check if extra svids from cache are cleaned up +check-x509-svid-count "spire-agent" $CACHESIZE diff --git a/test/integration/suites/fetch-svids/conf/agent/agent.conf b/test/integration/suites/fetch-svids/conf/agent/agent.conf index c2ec900236..bdbc803a95 100644 --- a/test/integration/suites/fetch-svids/conf/agent/agent.conf +++ b/test/integration/suites/fetch-svids/conf/agent/agent.conf @@ -8,8 +8,7 @@ agent { trust_domain = "domain.test" admin_socket_path = "/opt/debug.sock" experimental { - max_svid_cache_size = 8 - svid_cache_expiry_interval = "30s" + x509_svid_cache_max_size = 8 } } From 487fc58e2e316e3d0ae4a292e63d95bdb8acf8db Mon Sep 17 00:00:00 2001 From: Prasad Borole Date: Fri, 22 Jul 2022 16:10:43 -0700 Subject: [PATCH 04/19] Updating integ tests Signed-off-by: Prasad Borole --- pkg/agent/manager/cache/cache.go | 6 ++---- .../suites/fetch-svids/04-create-registration-entries | 4 ++-- .../suites/fetch-svids/06-create-registration-entries | 7 +++---- test/integration/suites/fetch-svids/README.md | 4 ++-- 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/pkg/agent/manager/cache/cache.go b/pkg/agent/manager/cache/cache.go index a1228d8fb6..c33365c51b 100644 --- a/pkg/agent/manager/cache/cache.go +++ b/pkg/agent/manager/cache/cache.go @@ -421,7 +421,7 @@ func (c *Cache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.Regi } // entries with active subscribers which are not cached will be put in staleEntries map - activeSubs, recordsWithLastAccessTime := c.syncSVIDs() + activeSubsByEntryID, recordsWithLastAccessTime := c.syncSVIDs() extraSize := len(c.svids) - c.svidCacheMaxSize // delete svids without subscribers and which have not been accessed since svidCacheExpiryTime @@ -435,7 +435,7 @@ func (c *Cache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.Regi break } if _, ok := c.svids[record.id]; ok { - if _, exists := activeSubs[record.id]; !exists { + if _, exists := activeSubsByEntryID[record.id]; !exists { // remove svid c.log.WithField("record_id", record.id). WithField("record_timestamp", record.timestamp). @@ -587,7 +587,6 @@ func (c *Cache) syncSVIDs() (map[string]struct{}, []recordAccessEvent) { activeSubsByEntryID := make(map[string]struct{}) lastAccessTimestamps := make([]recordAccessEvent, 0, len(c.records)) - i := 0 // iterate over all selectors from cached entries and obtain: // 1. entries that have active subscribers // 1.1 if those entries don't have corresponding SVID cached then put them in staleEntries @@ -606,7 +605,6 @@ func (c *Cache) syncSVIDs() (map[string]struct{}, []recordAccessEvent) { } } lastAccessTimestamps = append(lastAccessTimestamps, newRecord(record.lastAccessTimestamp, id)) - i++ } remainderSize := c.svidCacheMaxSize - len(c.svids) diff --git a/test/integration/suites/fetch-svids/04-create-registration-entries b/test/integration/suites/fetch-svids/04-create-registration-entries index d992df7947..1866777122 100755 --- a/test/integration/suites/fetch-svids/04-create-registration-entries +++ b/test/integration/suites/fetch-svids/04-create-registration-entries @@ -3,7 +3,7 @@ SIZE=10 # Create entries for uid 1001 -for ((m=1; m<=$SIZE;m++)); do +for ((m=1;m<=$SIZE;m++)); do log-debug "creating registration entry: $m" docker-compose exec -T spire-server \ /opt/spire/bin/spire-server entry create \ @@ -13,6 +13,6 @@ for ((m=1; m<=$SIZE;m++)); do -ttl 0 & done -for ((m=1; m<=$SIZE;m++)); do +for ((m=1;m<=$SIZE;m++)); do check-synced-entry "spire-agent" "spiffe://domain.test/workload-$m" done diff --git a/test/integration/suites/fetch-svids/06-create-registration-entries b/test/integration/suites/fetch-svids/06-create-registration-entries index 0f6365e95d..f93ae19418 100755 --- a/test/integration/suites/fetch-svids/06-create-registration-entries +++ b/test/integration/suites/fetch-svids/06-create-registration-entries @@ -3,8 +3,8 @@ SIZE=10 # Create entries for uid 1002 -for ((m=1; m<=$SIZE;m++)); do - log-debug "creating regular registration entry...($m)" +for ((m=1;m<=$SIZE;m++)); do + log-debug "creating registration entry...($m)" docker-compose exec -T spire-server \ /opt/spire/bin/spire-server entry create \ -parentID "spiffe://domain.test/spire/agent/x509pop/$(fingerprint conf/agent/agent.crt.pem)" \ @@ -13,7 +13,6 @@ for ((m=1; m<=$SIZE;m++)); do -ttl 0 & done -for ((m=1; m<=$SIZE;m++)); do +for ((m=1;m<=$SIZE;m++)); do check-synced-entry "spire-agent" "spiffe://domain.test/workload/$m" - ((m++)) done diff --git a/test/integration/suites/fetch-svids/README.md b/test/integration/suites/fetch-svids/README.md index 509c2d0390..896ed8deeb 100644 --- a/test/integration/suites/fetch-svids/README.md +++ b/test/integration/suites/fetch-svids/README.md @@ -1,5 +1,5 @@ -# Fetch x509 SVID Suite +# Fetch x509-SVID Suite ## Description -This suite validates svid cache operations from spire-agent cache. +This suite validates X.509-SVID cache operations. From 5f1fee7abebf099c4bf19a4d85f8c5e11e0a9978 Mon Sep 17 00:00:00 2001 From: Prasad Borole Date: Fri, 22 Jul 2022 16:52:38 -0700 Subject: [PATCH 05/19] Renamed integ test suite Signed-off-by: Prasad Borole --- .../suites/{fetch-svids => fetch-x509-svids}/00-setup | 0 .../suites/{fetch-svids => fetch-x509-svids}/01-start-server | 0 .../suites/{fetch-svids => fetch-x509-svids}/02-bootstrap-agent | 0 .../suites/{fetch-svids => fetch-x509-svids}/03-start-agent | 0 .../04-create-registration-entries | 0 .../05-fetch-svids => fetch-x509-svids/05-fetch-x509-svids} | 2 +- .../06-create-registration-entries | 0 .../07-fetch-svids => fetch-x509-svids/07-fetch-x509-svids} | 2 +- .../suites/{fetch-svids => fetch-x509-svids}/README.md | 0 .../{fetch-svids => fetch-x509-svids}/conf/agent/agent.conf | 0 .../{fetch-svids => fetch-x509-svids}/conf/server/server.conf | 0 .../{fetch-svids => fetch-x509-svids}/docker-compose.yaml | 0 .../suites/{fetch-svids => fetch-x509-svids}/teardown | 0 13 files changed, 2 insertions(+), 2 deletions(-) rename test/integration/suites/{fetch-svids => fetch-x509-svids}/00-setup (100%) rename test/integration/suites/{fetch-svids => fetch-x509-svids}/01-start-server (100%) rename test/integration/suites/{fetch-svids => fetch-x509-svids}/02-bootstrap-agent (100%) rename test/integration/suites/{fetch-svids => fetch-x509-svids}/03-start-agent (100%) rename test/integration/suites/{fetch-svids => fetch-x509-svids}/04-create-registration-entries (100%) rename test/integration/suites/{fetch-svids/05-fetch-svids => fetch-x509-svids/05-fetch-x509-svids} (87%) rename test/integration/suites/{fetch-svids => fetch-x509-svids}/06-create-registration-entries (100%) rename test/integration/suites/{fetch-svids/07-fetch-svids => fetch-x509-svids/07-fetch-x509-svids} (92%) rename test/integration/suites/{fetch-svids => fetch-x509-svids}/README.md (100%) rename test/integration/suites/{fetch-svids => fetch-x509-svids}/conf/agent/agent.conf (100%) rename test/integration/suites/{fetch-svids => fetch-x509-svids}/conf/server/server.conf (100%) rename test/integration/suites/{fetch-svids => fetch-x509-svids}/docker-compose.yaml (100%) rename test/integration/suites/{fetch-svids => fetch-x509-svids}/teardown (100%) diff --git a/test/integration/suites/fetch-svids/00-setup b/test/integration/suites/fetch-x509-svids/00-setup similarity index 100% rename from test/integration/suites/fetch-svids/00-setup rename to test/integration/suites/fetch-x509-svids/00-setup diff --git a/test/integration/suites/fetch-svids/01-start-server b/test/integration/suites/fetch-x509-svids/01-start-server similarity index 100% rename from test/integration/suites/fetch-svids/01-start-server rename to test/integration/suites/fetch-x509-svids/01-start-server diff --git a/test/integration/suites/fetch-svids/02-bootstrap-agent b/test/integration/suites/fetch-x509-svids/02-bootstrap-agent similarity index 100% rename from test/integration/suites/fetch-svids/02-bootstrap-agent rename to test/integration/suites/fetch-x509-svids/02-bootstrap-agent diff --git a/test/integration/suites/fetch-svids/03-start-agent b/test/integration/suites/fetch-x509-svids/03-start-agent similarity index 100% rename from test/integration/suites/fetch-svids/03-start-agent rename to test/integration/suites/fetch-x509-svids/03-start-agent diff --git a/test/integration/suites/fetch-svids/04-create-registration-entries b/test/integration/suites/fetch-x509-svids/04-create-registration-entries similarity index 100% rename from test/integration/suites/fetch-svids/04-create-registration-entries rename to test/integration/suites/fetch-x509-svids/04-create-registration-entries diff --git a/test/integration/suites/fetch-svids/05-fetch-svids b/test/integration/suites/fetch-x509-svids/05-fetch-x509-svids similarity index 87% rename from test/integration/suites/fetch-svids/05-fetch-svids rename to test/integration/suites/fetch-x509-svids/05-fetch-x509-svids index 5c9f706d1c..b4ccf76ab2 100755 --- a/test/integration/suites/fetch-svids/05-fetch-svids +++ b/test/integration/suites/fetch-x509-svids/05-fetch-x509-svids @@ -13,5 +13,5 @@ else echo "Expected $ENTRYCOUNT X.509-SVIDs and received $X509SVIDCOUNT for uid 1001"; fi -# Call agent debug endpoints and check if extra svids from cache are cleaned up +# Call agent debug endpoints and check if extra X.509-SVIDs from cache are cleaned up check-x509-svid-count "spire-agent" $CACHESIZE diff --git a/test/integration/suites/fetch-svids/06-create-registration-entries b/test/integration/suites/fetch-x509-svids/06-create-registration-entries similarity index 100% rename from test/integration/suites/fetch-svids/06-create-registration-entries rename to test/integration/suites/fetch-x509-svids/06-create-registration-entries diff --git a/test/integration/suites/fetch-svids/07-fetch-svids b/test/integration/suites/fetch-x509-svids/07-fetch-x509-svids similarity index 92% rename from test/integration/suites/fetch-svids/07-fetch-svids rename to test/integration/suites/fetch-x509-svids/07-fetch-x509-svids index ba55bfe2ce..d29e54095e 100755 --- a/test/integration/suites/fetch-svids/07-fetch-svids +++ b/test/integration/suites/fetch-x509-svids/07-fetch-x509-svids @@ -23,5 +23,5 @@ else echo "Expected $ENTRYCOUNT X.509-SVIDs and received $X509SVIDCOUNT for uid 1001"; fi -# Call agent debug endpoints and check if extra svids from cache are cleaned up +# Call agent debug endpoints and check if extra X.509-SVIDs from cache are cleaned up check-x509-svid-count "spire-agent" $CACHESIZE diff --git a/test/integration/suites/fetch-svids/README.md b/test/integration/suites/fetch-x509-svids/README.md similarity index 100% rename from test/integration/suites/fetch-svids/README.md rename to test/integration/suites/fetch-x509-svids/README.md diff --git a/test/integration/suites/fetch-svids/conf/agent/agent.conf b/test/integration/suites/fetch-x509-svids/conf/agent/agent.conf similarity index 100% rename from test/integration/suites/fetch-svids/conf/agent/agent.conf rename to test/integration/suites/fetch-x509-svids/conf/agent/agent.conf diff --git a/test/integration/suites/fetch-svids/conf/server/server.conf b/test/integration/suites/fetch-x509-svids/conf/server/server.conf similarity index 100% rename from test/integration/suites/fetch-svids/conf/server/server.conf rename to test/integration/suites/fetch-x509-svids/conf/server/server.conf diff --git a/test/integration/suites/fetch-svids/docker-compose.yaml b/test/integration/suites/fetch-x509-svids/docker-compose.yaml similarity index 100% rename from test/integration/suites/fetch-svids/docker-compose.yaml rename to test/integration/suites/fetch-x509-svids/docker-compose.yaml diff --git a/test/integration/suites/fetch-svids/teardown b/test/integration/suites/fetch-x509-svids/teardown similarity index 100% rename from test/integration/suites/fetch-svids/teardown rename to test/integration/suites/fetch-x509-svids/teardown From d98654b505a8faa4650b617b4512428f5421c52d Mon Sep 17 00:00:00 2001 From: Prasad Borole Date: Tue, 26 Jul 2022 12:47:16 -0700 Subject: [PATCH 06/19] Updating method signature of SubscribeToCacheChanges to take ctx Signed-off-by: Prasad Borole --- pkg/agent/api/delegatedidentity/v1/service.go | 2 +- .../api/delegatedidentity/v1/service_test.go | 4 ++-- pkg/agent/endpoints/sdsv2/handler.go | 8 +++++-- pkg/agent/endpoints/sdsv2/handler_test.go | 4 ++-- pkg/agent/endpoints/sdsv3/handler.go | 8 +++++-- pkg/agent/endpoints/sdsv3/handler_test.go | 4 ++-- pkg/agent/endpoints/workload/handler.go | 20 ++++++++++++---- pkg/agent/endpoints/workload/handler_test.go | 4 ++-- pkg/agent/manager/manager.go | 12 ++++++---- pkg/agent/manager/manager_test.go | 24 ++++++++++++------- test/integration/common | 2 +- .../fetch-x509-svids/05-fetch-x509-svids | 4 ++-- .../fetch-x509-svids/07-fetch-x509-svids | 8 +++---- 13 files changed, 67 insertions(+), 37 deletions(-) diff --git a/pkg/agent/api/delegatedidentity/v1/service.go b/pkg/agent/api/delegatedidentity/v1/service.go index 887babef8a..bff3510aa5 100644 --- a/pkg/agent/api/delegatedidentity/v1/service.go +++ b/pkg/agent/api/delegatedidentity/v1/service.go @@ -120,7 +120,7 @@ func (s *Service) SubscribeToX509SVIDs(req *delegatedidentityv1.SubscribeToX509S return status.Error(codes.InvalidArgument, "could not parse provided selectors") } - subscriber := s.manager.SubscribeToCacheChanges(selectors) + subscriber, err := s.manager.SubscribeToCacheChanges(ctx, selectors) defer subscriber.Finish() for { diff --git a/pkg/agent/api/delegatedidentity/v1/service_test.go b/pkg/agent/api/delegatedidentity/v1/service_test.go index e503d7194a..923a407106 100644 --- a/pkg/agent/api/delegatedidentity/v1/service_test.go +++ b/pkg/agent/api/delegatedidentity/v1/service_test.go @@ -674,9 +674,9 @@ func (m *FakeManager) subscriberDone() { atomic.AddInt32(&m.subscribers, -1) } -func (m *FakeManager) SubscribeToCacheChanges(selectors cache.Selectors) cache.Subscriber { +func (m *FakeManager) SubscribeToCacheChanges(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) { atomic.AddInt32(&m.subscribers, 1) - return newFakeSubscriber(m, m.updates) + return newFakeSubscriber(m, m.updates), nil } func (m *FakeManager) FetchJWTSVID(ctx context.Context, spiffeID spiffeid.ID, audience []string) (*client.JWTSVID, error) { diff --git a/pkg/agent/endpoints/sdsv2/handler.go b/pkg/agent/endpoints/sdsv2/handler.go index dfff824140..3d96f6ccc1 100644 --- a/pkg/agent/endpoints/sdsv2/handler.go +++ b/pkg/agent/endpoints/sdsv2/handler.go @@ -31,7 +31,7 @@ type Attestor interface { } type Manager interface { - SubscribeToCacheChanges(key cache.Selectors) cache.Subscriber + SubscribeToCacheChanges(ctx context.Context, key cache.Selectors) (cache.Subscriber, error) FetchWorkloadUpdate(selectors []*common.Selector) *cache.WorkloadUpdate } @@ -64,7 +64,11 @@ func (h *Handler) StreamSecrets(stream discovery_v2.SecretDiscoveryService_Strea return err } - sub := h.c.Manager.SubscribeToCacheChanges(selectors) + sub, err := h.c.Manager.SubscribeToCacheChanges(stream.Context(), selectors) + if err != nil { + log.WithError(err).Error("Subscribe to cache changes failed") + return err + } defer sub.Finish() updch := sub.Updates() diff --git a/pkg/agent/endpoints/sdsv2/handler_test.go b/pkg/agent/endpoints/sdsv2/handler_test.go index 4ac1a68712..be9bdb96f6 100644 --- a/pkg/agent/endpoints/sdsv2/handler_test.go +++ b/pkg/agent/endpoints/sdsv2/handler_test.go @@ -552,7 +552,7 @@ func NewFakeManager(t *testing.T) *FakeManager { } } -func (m *FakeManager) SubscribeToCacheChanges(selectors cache.Selectors) cache.Subscriber { +func (m *FakeManager) SubscribeToCacheChanges(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) { require.Equal(m.t, workloadSelectors, selectors) updch := make(chan *cache.WorkloadUpdate, 1) @@ -568,7 +568,7 @@ func (m *FakeManager) SubscribeToCacheChanges(selectors cache.Selectors) cache.S return NewFakeSubscriber(updch, func() { delete(m.subs, key) close(updch) - }) + }), nil } func (m *FakeManager) FetchWorkloadUpdate(selectors []*common.Selector) *cache.WorkloadUpdate { diff --git a/pkg/agent/endpoints/sdsv3/handler.go b/pkg/agent/endpoints/sdsv3/handler.go index 07da317be9..0e98bad562 100644 --- a/pkg/agent/endpoints/sdsv3/handler.go +++ b/pkg/agent/endpoints/sdsv3/handler.go @@ -39,7 +39,7 @@ type Attestor interface { } type Manager interface { - SubscribeToCacheChanges(key cache.Selectors) cache.Subscriber + SubscribeToCacheChanges(ctx context.Context, key cache.Selectors) (cache.Subscriber, error) FetchWorkloadUpdate(selectors []*common.Selector) *cache.WorkloadUpdate } @@ -74,7 +74,11 @@ func (h *Handler) StreamSecrets(stream secret_v3.SecretDiscoveryService_StreamSe return err } - sub := h.c.Manager.SubscribeToCacheChanges(selectors) + sub, err := h.c.Manager.SubscribeToCacheChanges(stream.Context(), selectors) + if err != nil { + log.WithError(err).Error("Subscribe to cache changes failed") + return err + } defer sub.Finish() updch := sub.Updates() diff --git a/pkg/agent/endpoints/sdsv3/handler_test.go b/pkg/agent/endpoints/sdsv3/handler_test.go index 937b307bc6..c390f6dcf6 100644 --- a/pkg/agent/endpoints/sdsv3/handler_test.go +++ b/pkg/agent/endpoints/sdsv3/handler_test.go @@ -1288,7 +1288,7 @@ func NewFakeManager(t *testing.T) *FakeManager { } } -func (m *FakeManager) SubscribeToCacheChanges(selectors cache.Selectors) cache.Subscriber { +func (m *FakeManager) SubscribeToCacheChanges(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) { require.Equal(m.t, workloadSelectors, selectors) updch := make(chan *cache.WorkloadUpdate, 1) @@ -1304,7 +1304,7 @@ func (m *FakeManager) SubscribeToCacheChanges(selectors cache.Selectors) cache.S return NewFakeSubscriber(updch, func() { delete(m.subs, key) close(updch) - }) + }), nil } func (m *FakeManager) FetchWorkloadUpdate(selectors []*common.Selector) *cache.WorkloadUpdate { diff --git a/pkg/agent/endpoints/workload/handler.go b/pkg/agent/endpoints/workload/handler.go index abc17f8614..81ab3e5570 100644 --- a/pkg/agent/endpoints/workload/handler.go +++ b/pkg/agent/endpoints/workload/handler.go @@ -30,7 +30,7 @@ import ( ) type Manager interface { - SubscribeToCacheChanges(cache.Selectors) cache.Subscriber + SubscribeToCacheChanges(ctx context.Context, key cache.Selectors) (cache.Subscriber, error) MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry FetchJWTSVID(ctx context.Context, spiffeID spiffeid.ID, audience []string) (*client.JWTSVID, error) FetchWorkloadUpdate([]*common.Selector) *cache.WorkloadUpdate @@ -138,7 +138,11 @@ func (h *Handler) FetchJWTBundles(req *workload.JWTBundlesRequest, stream worklo return err } - subscriber := h.c.Manager.SubscribeToCacheChanges(selectors) + subscriber, err := h.c.Manager.SubscribeToCacheChanges(ctx, selectors) + if err != nil { + log.WithError(err).Error("Subscribe to cache changes failed") + return err + } defer subscriber.Finish() var previousResp *workload.JWTBundlesResponse @@ -224,7 +228,11 @@ func (h *Handler) FetchX509SVID(_ *workload.X509SVIDRequest, stream workload.Spi return err } - subscriber := h.c.Manager.SubscribeToCacheChanges(selectors) + subscriber, err := h.c.Manager.SubscribeToCacheChanges(ctx, selectors) + if err != nil { + log.WithError(err).Error("Subscribe to cache changes failed") + return err + } defer subscriber.Finish() for { @@ -250,7 +258,11 @@ func (h *Handler) FetchX509Bundles(_ *workload.X509BundlesRequest, stream worklo return err } - subscriber := h.c.Manager.SubscribeToCacheChanges(selectors) + subscriber, err := h.c.Manager.SubscribeToCacheChanges(ctx, selectors) + if err != nil { + log.WithError(err).Error("Subscribe to cache changes failed") + return err + } defer subscriber.Finish() var previousResp *workload.X509BundlesResponse diff --git a/pkg/agent/endpoints/workload/handler_test.go b/pkg/agent/endpoints/workload/handler_test.go index 7d68a38bfb..0c9a520424 100644 --- a/pkg/agent/endpoints/workload/handler_test.go +++ b/pkg/agent/endpoints/workload/handler_test.go @@ -1318,9 +1318,9 @@ func (m *FakeManager) FetchJWTSVID(ctx context.Context, spiffeID spiffeid.ID, au }, nil } -func (m *FakeManager) SubscribeToCacheChanges(selectors cache.Selectors) cache.Subscriber { +func (m *FakeManager) SubscribeToCacheChanges(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) { atomic.AddInt32(&m.subscribers, 1) - return newFakeSubscriber(m, m.updates) + return newFakeSubscriber(m, m.updates), nil } func (m *FakeManager) FetchWorkloadUpdate(selectors []*common.Selector) *cache.WorkloadUpdate { diff --git a/pkg/agent/manager/manager.go b/pkg/agent/manager/manager.go index 67a6853e93..25a39fba37 100644 --- a/pkg/agent/manager/manager.go +++ b/pkg/agent/manager/manager.go @@ -37,7 +37,7 @@ type Manager interface { // SubscribeToCacheChanges returns a Subscriber on which cache entry updates are sent // for a particular set of selectors. - SubscribeToCacheChanges(key cache.Selectors) cache.Subscriber + SubscribeToCacheChanges(ctx context.Context, key cache.Selectors) (cache.Subscriber, error) // SubscribeToSVIDChanges returns a new observer.Stream on which svid.State instances are received // each time an SVID rotation finishes. @@ -144,17 +144,21 @@ func (m *manager) Run(ctx context.Context) error { } } -func (m *manager) SubscribeToCacheChanges(selectors cache.Selectors) cache.Subscriber { +func (m *manager) SubscribeToCacheChanges(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) { subscriber := m.cache.SubscribeToWorkloadUpdates(selectors) backoff := backoff.NewBackoff(m.clk, svidSyncInterval) // block until all svids are cached and subscriber is notified for { if m.cache.Notify(selectors) { - return subscriber + return subscriber, nil } m.c.Log.WithField(telemetry.Selectors, selectors).Info("Waiting for SVID to get cached") - <-m.clk.After(backoff.NextBackOff()) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-m.clk.After(backoff.NextBackOff()): + } } } diff --git a/pkg/agent/manager/manager_test.go b/pkg/agent/manager/manager_test.go index 169e9b2dee..f5717fd543 100644 --- a/pkg/agent/manager/manager_test.go +++ b/pkg/agent/manager/manager_test.go @@ -250,7 +250,8 @@ func TestHappyPathWithoutSyncNorRotation(t *testing.T) { []*common.RegistrationEntry{matches[0], matches[1]}) util.RunWithTimeout(t, 5*time.Second, func() { - sub := m.SubscribeToCacheChanges(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + require.NoError(t, err) u := <-sub.Updates() if len(u.Identities) != 2 { @@ -340,7 +341,8 @@ func TestRotationWithRSAKey(t *testing.T) { []*common.RegistrationEntry{matches[0], matches[1]}) util.RunWithTimeout(t, 5*time.Second, func() { - sub := m.SubscribeToCacheChanges(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + require.NoError(t, err) u := <-sub.Updates() if len(u.Identities) != 2 { @@ -519,10 +521,11 @@ func TestSynchronization(t *testing.T) { m := newManager(c) - sub := m.SubscribeToCacheChanges(cache.Selectors{ + sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{ {Type: "unix", Value: "uid:1111"}, {Type: "spiffe_id", Value: joinTokenID.String()}, }) + require.NoError(t, err) defer sub.Finish() if err := m.Initialize(context.Background()); err != nil { @@ -807,8 +810,8 @@ func TestSubscribersGetUpToDateBundle(t *testing.T) { m := newManager(c) - sub := m.SubscribeToCacheChanges(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) - + sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + require.NoError(t, err) defer initializeAndRunManager(t, m)() util.RunWithTimeout(t, 1*time.Second, func() { @@ -878,7 +881,8 @@ func TestSyncSVIDs(t *testing.T) { closer := runSVIDSync(waitCh, clk, 50*time.Millisecond, m, errCh) defer closer() - sub1 := m.SubscribeToCacheChanges(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + sub1, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + require.NoError(t, err) // Validate the update received by subscribers u1 := <-sub1.Updates() if len(u1.Identities) != 2 { @@ -888,9 +892,10 @@ func TestSyncSVIDs(t *testing.T) { t.Fatal("bundles were expected to be equal") } - sub2 := m.SubscribeToCacheChanges( + sub2, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{{Type: "spiffe_id", Value: "spiffe://example.org/spire/agent/join_token/abcd"}}) // Validate the update received by subscribers + require.NoError(t, err) u2 := <-sub2.Updates() if len(u2.Identities) != 1 { t.Fatalf("expected 1 SVID, got: %d", len(u2.Identities)) @@ -903,7 +908,7 @@ func TestSyncSVIDs(t *testing.T) { sub2.Finish() close(waitCh) - err := <-errCh + err = <-errCh if err != nil { t.Fatalf("syncSVIDs method failed with error %v", err) } @@ -964,7 +969,8 @@ func TestSurvivesCARotation(t *testing.T) { m := newManager(c) - sub := m.SubscribeToCacheChanges(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + require.NoError(t, err) // This should be the update received when Subscribe function was called. updates := sub.Updates() initialUpdate := <-updates diff --git a/test/integration/common b/test/integration/common index 858d2970a2..54e7a7651c 100644 --- a/test/integration/common +++ b/test/integration/common @@ -89,7 +89,7 @@ check-x509-svid-count() { for ((i=1;i<=MAXCHECKS;i++)); do log-info "check X.509-SVID count on agent debug endpoint ($(($i)) of $MAXCHECKS max)..." - COUNT=`docker-compose exec -T $1 /opt/spire/conf/agent/debugclient -testCase "printDebugPage" | jq '.svidsCount'` + COUNT=$(docker-compose exec -T $1 /opt/spire/conf/agent/debugclient -testCase "printDebugPage" | jq '.svidsCount') log-info "X.509-SVID Count: ${COUNT}" if [ "$COUNT" -eq "$2" ]; then log-info "X.509-SVID count of $COUNT from cache matches the expected count of $2" diff --git a/test/integration/suites/fetch-x509-svids/05-fetch-x509-svids b/test/integration/suites/fetch-x509-svids/05-fetch-x509-svids index b4ccf76ab2..71cc902860 100755 --- a/test/integration/suites/fetch-x509-svids/05-fetch-x509-svids +++ b/test/integration/suites/fetch-x509-svids/05-fetch-x509-svids @@ -3,9 +3,9 @@ ENTRYCOUNT=10 CACHESIZE=8 -X509SVIDCOUNT=`docker-compose exec -u 1001 -T spire-agent \ +X509SVIDCOUNT=$(docker-compose exec -u 1001 -T spire-agent \ /opt/spire/bin/spire-agent api fetch x509 \ - -socketPath /opt/spire/sockets/workload_api.sock | grep -i "spiffe://domain.test" | wc -l || fail-now "X.509-SVID check failed"` + -socketPath /opt/spire/sockets/workload_api.sock | grep -i "spiffe://domain.test" | wc -l || fail-now "X.509-SVID check failed") if [ "$X509SVIDCOUNT" -ne "$ENTRYCOUNT" ]; then fail-now "X.509-SVID check failed. Expected $ENTRYCOUNT X.509-SVIDs but received $X509SVIDCOUNT for uid 1001"; diff --git a/test/integration/suites/fetch-x509-svids/07-fetch-x509-svids b/test/integration/suites/fetch-x509-svids/07-fetch-x509-svids index d29e54095e..3d0ecafb74 100755 --- a/test/integration/suites/fetch-x509-svids/07-fetch-x509-svids +++ b/test/integration/suites/fetch-x509-svids/07-fetch-x509-svids @@ -3,9 +3,9 @@ CACHESIZE=8 ENTRYCOUNT=10 -X509SVIDCOUNT=`docker-compose exec -u 1002 -T spire-agent \ +X509SVIDCOUNT=$(docker-compose exec -u 1002 -T spire-agent \ /opt/spire/bin/spire-agent api fetch x509 \ - -socketPath /opt/spire/sockets/workload_api.sock | grep -i "spiffe://domain.test" | wc -l || fail-now "X.509-SVID check failed"` + -socketPath /opt/spire/sockets/workload_api.sock | grep -i "spiffe://domain.test" | wc -l || fail-now "X.509-SVID check failed") if [ "$X509SVIDCOUNT" -ne "$ENTRYCOUNT" ]; then fail-now "X.509-SVID check failed. Expected $ENTRYCOUNT X.509-SVIDs but received $X509SVIDCOUNT for uid 1002"; @@ -13,9 +13,9 @@ else echo "Expected $ENTRYCOUNT X.509-SVIDs and received $X509SVIDCOUNT for uid 1002"; fi -X509SVIDCOUNT=`docker-compose exec -u 1001 -T spire-agent \ +X509SVIDCOUNT=$(docker-compose exec -u 1001 -T spire-agent \ /opt/spire/bin/spire-agent api fetch x509 \ - -socketPath /opt/spire/sockets/workload_api.sock | grep -i "spiffe://domain.test" | wc -l || fail-now "X.509-SVID check failed"` + -socketPath /opt/spire/sockets/workload_api.sock | grep -i "spiffe://domain.test" | wc -l || fail-now "X.509-SVID check failed") if [ "$X509SVIDCOUNT" -ne "$ENTRYCOUNT" ]; then fail-now "X.509-SVID check failed. Expected $ENTRYCOUNT X.509-SVIDs but received $X509SVIDCOUNT for uid 1001"; From ec9a133cdb73012ef2d1c3b6f6d1e0577bff2d0f Mon Sep 17 00:00:00 2001 From: Prasad Borole Date: Tue, 26 Jul 2022 13:40:24 -0700 Subject: [PATCH 07/19] fix lint Signed-off-by: Prasad Borole --- pkg/agent/api/delegatedidentity/v1/service.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkg/agent/api/delegatedidentity/v1/service.go b/pkg/agent/api/delegatedidentity/v1/service.go index bff3510aa5..56a94dbae9 100644 --- a/pkg/agent/api/delegatedidentity/v1/service.go +++ b/pkg/agent/api/delegatedidentity/v1/service.go @@ -121,6 +121,10 @@ func (s *Service) SubscribeToX509SVIDs(req *delegatedidentityv1.SubscribeToX509S } subscriber, err := s.manager.SubscribeToCacheChanges(ctx, selectors) + if err != nil { + log.WithError(err).Error("Subscribe to cache changes failed") + return err + } defer subscriber.Finish() for { From 34dff281f1bff9084f5860a145b58e2423bf4ce0 Mon Sep 17 00:00:00 2001 From: Ryan Turner Date: Thu, 28 Jul 2022 04:50:41 +0000 Subject: [PATCH 08/19] Remove dependence on timing in unit test Signed-off-by: Ryan Turner --- pkg/agent/manager/manager.go | 20 ++++- pkg/agent/manager/manager_test.go | 138 ++++++++++++++++-------------- 2 files changed, 92 insertions(+), 66 deletions(-) diff --git a/pkg/agent/manager/manager.go b/pkg/agent/manager/manager.go index 25a39fba37..aac7e8d319 100644 --- a/pkg/agent/manager/manager.go +++ b/pkg/agent/manager/manager.go @@ -93,6 +93,7 @@ type manager struct { // fetch attempt synchronizeBackoff backoff.BackOff svidSyncBackoff backoff.BackOff + subscribeBackoffFn func() backoff.BackOff client client.Client @@ -111,6 +112,11 @@ func (m *manager) Initialize(ctx context.Context) error { m.synchronizeBackoff = backoff.NewBackoff(m.clk, m.c.SyncInterval) m.svidSyncBackoff = backoff.NewBackoff(m.clk, svidSyncInterval) + if m.subscribeBackoffFn == nil { + m.subscribeBackoffFn = func() backoff.BackOff { + return backoff.NewBackoff(m.clk, svidSyncInterval) + } + } err := m.synchronize(ctx) if nodeutil.ShouldAgentReattest(err) { @@ -145,11 +151,19 @@ func (m *manager) Run(ctx context.Context) error { } func (m *manager) SubscribeToCacheChanges(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) { + return m.subscribeToCacheChanges(ctx, selectors, nil) +} + +func (m *manager) subscribeToCacheChanges(ctx context.Context, selectors cache.Selectors, notifyCallbackFn func()) (cache.Subscriber, error) { subscriber := m.cache.SubscribeToWorkloadUpdates(selectors) - backoff := backoff.NewBackoff(m.clk, svidSyncInterval) + bo := m.subscribeBackoffFn() // block until all svids are cached and subscriber is notified for { - if m.cache.Notify(selectors) { + svidsInCache := m.cache.Notify(selectors) + if notifyCallbackFn != nil { + notifyCallbackFn() + } + if svidsInCache { return subscriber, nil } m.c.Log.WithField(telemetry.Selectors, selectors).Info("Waiting for SVID to get cached") @@ -157,7 +171,7 @@ func (m *manager) SubscribeToCacheChanges(ctx context.Context, selectors cache.S select { case <-ctx.Done(): return nil, ctx.Err() - case <-m.clk.After(backoff.NextBackOff()): + case <-m.clk.After(bo.NextBackOff()): } } } diff --git a/pkg/agent/manager/manager_test.go b/pkg/agent/manager/manager_test.go index f5717fd543..5a7720a831 100644 --- a/pkg/agent/manager/manager_test.go +++ b/pkg/agent/manager/manager_test.go @@ -14,6 +14,7 @@ import ( "testing" "time" + backoff "github.com/cenkalti/backoff/v3" testlog "github.com/sirupsen/logrus/hooks/test" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/go-spiffe/v2/svid/x509svid" @@ -867,56 +868,93 @@ func TestSyncSVIDs(t *testing.T) { } m := newManager(c) - if err := m.Initialize(context.Background()); err != nil { - t.Fatal(err) + m.subscribeBackoffFn = func() backoff.BackOff { + return backoff.NewConstantBackOff(svidSyncInterval) } + err := m.Initialize(context.Background()) + require.NoError(t, err) + // After Initialize, just 1 SVID should be cached assert.Equal(t, 1, m.CountSVIDs()) - waitCh := make(chan struct{}) - errCh := make(chan error) + ctx := context.Background() - // Run svidSync in separate routine and advance clock. - // It allows SubscribeToCacheChanges to keep checking for SVID in cache as clk advances. - closer := runSVIDSync(waitCh, clk, 50*time.Millisecond, m, errCh) - defer closer() - - sub1, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{{Type: "unix", Value: "uid:1111"}}) - require.NoError(t, err) // Validate the update received by subscribers - u1 := <-sub1.Updates() - if len(u1.Identities) != 2 { - t.Fatalf("expected 2 SVIDs, got: %d", len(u1.Identities)) - } - if !u1.Bundle.EqualTo(c.Bundle) { - t.Fatal("bundles were expected to be equal") - } + // Spawn subscriber 1 in new goroutine to allow SVID sync to run in parallel + sub1WaitCh := make(chan struct{}, 1) + sub1ErrCh := make(chan error, 1) + go func() { + sub1, err := m.subscribeToCacheChanges(ctx, cache.Selectors{{Type: "unix", Value: "uid:1111"}}, func() { + sub1WaitCh <- struct{}{} + }) + if err != nil { + sub1ErrCh <- err + return + } - sub2, err := m.SubscribeToCacheChanges(context.Background(), - cache.Selectors{{Type: "spiffe_id", Value: "spiffe://example.org/spire/agent/join_token/abcd"}}) - // Validate the update received by subscribers - require.NoError(t, err) - u2 := <-sub2.Updates() - if len(u2.Identities) != 1 { - t.Fatalf("expected 1 SVID, got: %d", len(u2.Identities)) - } - if !u2.Bundle.EqualTo(c.Bundle) { - t.Fatal("bundles were expected to be equal") - } + defer sub1.Finish() + u1 := <-sub1.Updates() - sub1.Finish() - sub2.Finish() - close(waitCh) + if len(u1.Identities) != 2 { + sub1ErrCh <- fmt.Errorf("expected 2 SVIDs, got: %d", len(u1.Identities)) + return + } + if !u1.Bundle.EqualTo(c.Bundle) { + sub1ErrCh <- errors.New("bundles were expected to be equal") + return + } - err = <-errCh - if err != nil { - t.Fatalf("syncSVIDs method failed with error %v", err) - } + sub1ErrCh <- nil + }() + + // Spawn subscriber 2 in new goroutine to allow SVID sync to run in parallel + sub2WaitCh := make(chan struct{}, 1) + sub2ErrCh := make(chan error, 1) + go func() { + sub2, err := m.subscribeToCacheChanges(ctx, cache.Selectors{{Type: "spiffe_id", Value: "spiffe://example.org/spire/agent/join_token/abcd"}}, func() { + sub2WaitCh <- struct{}{} + }) + if err != nil { + sub2ErrCh <- err + return + } + + defer sub2.Finish() + u2 := <-sub2.Updates() + + if len(u2.Identities) != 1 { + sub2ErrCh <- fmt.Errorf("expected 1 SVID, got: %d", len(u2.Identities)) + return + } + if !u2.Bundle.EqualTo(c.Bundle) { + sub2ErrCh <- errors.New("bundles were expected to be equal") + return + } + + sub2ErrCh <- nil + }() + + // Wait until subscribers have been created + <-sub1WaitCh + <-sub2WaitCh + + // Sync SVIDs to populate cache + svidSyncErr := m.syncSVIDs(ctx) + require.NoError(t, svidSyncErr, "syncSVIDs method failed") + + // Advance clock so subscribers can check for latest SVIDs in cache + clk.Add(svidSyncInterval) + + sub1Err := <-sub1ErrCh + assert.NoError(t, sub1Err, "subscriber 1 error") + + sub2Err := <-sub2ErrCh + assert.NoError(t, sub2Err, "subscriber 2 error") // All 3 SVIDs should be cached assert.Equal(t, 3, m.CountSVIDs()) - assert.NoError(t, m.synchronize(context.Background())) + assert.NoError(t, m.synchronize(ctx)) // Make sure svid count is SVIDCacheMaxSize and non-active SVIDs are deleted from cache assert.Equal(t, 1, m.CountSVIDs()) @@ -1188,32 +1226,6 @@ func TestStorableSVIDsSync(t *testing.T) { validateResponse(records, entries) } -func runSVIDSync(waitCh chan struct{}, clk *clock.Mock, interval time.Duration, m *manager, - errCh chan<- error) (closer func()) { - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - for { - select { - case <-waitCh: - errCh <- nil - return - case <-time.After(interval): - clk.Add(interval) - err := m.syncSVIDs(context.Background()) - if err != nil { - errCh <- err - return - } - } - } - }() - return func() { - wg.Wait() - } -} - func makeGetAuthorizedEntriesResponse(t *testing.T, respKeys ...string) *entryv1.GetAuthorizedEntriesResponse { var entries []*types.Entry for _, respKey := range respKeys { From 4fc52a091d8c5d4014de0c8f2bb22d8ce88a27c1 Mon Sep 17 00:00:00 2001 From: Prasad Borole Date: Mon, 1 Aug 2022 11:44:51 -0700 Subject: [PATCH 09/19] fix unit tests and address comments Signed-off-by: Prasad Borole --- .../api/delegatedidentity/v1/service_test.go | 13 ++++ pkg/agent/endpoints/workload/handler_test.go | 60 +++++++++++++++++++ pkg/agent/manager/cache/cache.go | 39 ++++++------ pkg/agent/manager/cache/cache_test.go | 14 ++--- pkg/agent/manager/manager.go | 3 +- pkg/agent/manager/manager_test.go | 18 +++--- .../fetch-x509-svids/05-fetch-x509-svids | 2 +- .../fetch-x509-svids/07-fetch-x509-svids | 4 +- 8 files changed, 111 insertions(+), 42 deletions(-) diff --git a/pkg/agent/api/delegatedidentity/v1/service_test.go b/pkg/agent/api/delegatedidentity/v1/service_test.go index 923a407106..0791c256cd 100644 --- a/pkg/agent/api/delegatedidentity/v1/service_test.go +++ b/pkg/agent/api/delegatedidentity/v1/service_test.go @@ -88,6 +88,16 @@ func TestSubscribeToX509SVIDs(t *testing.T) { expectCode: codes.PermissionDenied, expectMsg: "caller not configured as an authorized delegate", }, + { + testName: "subscribe to cache changes error", + authSpiffeID: []string{"spiffe://example.org/one"}, + identities: []cache.Identity{ + identityFromX509SVID(x509SVID1), + }, + managerErr: errors.New("err"), + expectCode: codes.Unknown, + expectMsg: "err", + }, { testName: "workload update with one identity", authSpiffeID: []string{"spiffe://example.org/one"}, @@ -675,6 +685,9 @@ func (m *FakeManager) subscriberDone() { } func (m *FakeManager) SubscribeToCacheChanges(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) { + if m.err != nil { + return nil, m.err + } atomic.AddInt32(&m.subscribers, 1) return newFakeSubscriber(m, m.updates), nil } diff --git a/pkg/agent/endpoints/workload/handler_test.go b/pkg/agent/endpoints/workload/handler_test.go index 0c9a520424..92cabb973b 100644 --- a/pkg/agent/endpoints/workload/handler_test.go +++ b/pkg/agent/endpoints/workload/handler_test.go @@ -56,6 +56,7 @@ func TestFetchX509SVID(t *testing.T) { name string updates []*cache.WorkloadUpdate attestErr error + managerErr error asPID int expectCode codes.Code expectMsg string @@ -103,6 +104,23 @@ func TestFetchX509SVID(t *testing.T) { }, }, }, + { + name: "subscribe to cache changes error", + managerErr: errors.New("err"), + expectCode: codes.Unknown, + expectMsg: "err", + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Subscribe to cache changes failed", + Data: logrus.Fields{ + "service": "WorkloadAPI", + "method": "FetchX509SVID", + logrus.ErrorKey: "err", + }, + }, + }, + }, { name: "with identity and federated bundles", updates: []*cache.WorkloadUpdate{{ @@ -167,6 +185,7 @@ func TestFetchX509SVID(t *testing.T) { AttestErr: tt.attestErr, ExpectLogs: tt.expectLogs, AsPID: tt.asPID, + ManagerErr: tt.managerErr, } runTest(t, params, func(ctx context.Context, client workloadPB.SpiffeWorkloadAPIClient) { @@ -195,6 +214,7 @@ func TestFetchX509Bundles(t *testing.T) { testName string updates []*cache.WorkloadUpdate attestErr error + managerErr error expectCode codes.Code expectMsg string expectResp *workloadPB.X509BundlesResponse @@ -235,6 +255,23 @@ func TestFetchX509Bundles(t *testing.T) { }, }, }, + { + testName: "subscribe to cache changes error", + managerErr: errors.New("err"), + expectCode: codes.Unknown, + expectMsg: "err", + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Subscribe to cache changes failed", + Data: logrus.Fields{ + "service": "WorkloadAPI", + "method": "FetchX509Bundles", + logrus.ErrorKey: "err", + }, + }, + }, + }, { testName: "cache update unexpectedly missing bundle", updates: []*cache.WorkloadUpdate{ @@ -307,6 +344,7 @@ func TestFetchX509Bundles(t *testing.T) { AttestErr: tt.attestErr, ExpectLogs: tt.expectLogs, AllowUnauthenticatedVerifiers: tt.allowUnauthenticatedVerifiers, + ManagerErr: tt.managerErr, } runTest(t, params, func(ctx context.Context, client workloadPB.SpiffeWorkloadAPIClient) { @@ -665,6 +703,7 @@ func TestFetchJWTBundles(t *testing.T) { name string updates []*cache.WorkloadUpdate attestErr error + managerErr error expectCode codes.Code expectMsg string expectResp *workloadPB.JWTBundlesResponse @@ -705,6 +744,23 @@ func TestFetchJWTBundles(t *testing.T) { }, }, }, + { + name: "subscribe to cache changes error", + managerErr: errors.New("err"), + expectCode: codes.Unknown, + expectMsg: "err", + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Subscribe to cache changes failed", + Data: logrus.Fields{ + "service": "WorkloadAPI", + "method": "FetchJWTBundles", + logrus.ErrorKey: "err", + }, + }, + }, + }, { name: "cache update unexpectedly missing bundle", updates: []*cache.WorkloadUpdate{ @@ -777,6 +833,7 @@ func TestFetchJWTBundles(t *testing.T) { AttestErr: tt.attestErr, ExpectLogs: tt.expectLogs, AllowUnauthenticatedVerifiers: tt.allowUnauthenticatedVerifiers, + ManagerErr: tt.managerErr, } runTest(t, params, func(ctx context.Context, client workloadPB.SpiffeWorkloadAPIClient) { @@ -1319,6 +1376,9 @@ func (m *FakeManager) FetchJWTSVID(ctx context.Context, spiffeID spiffeid.ID, au } func (m *FakeManager) SubscribeToCacheChanges(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) { + if m.err != nil { + return nil, m.err + } atomic.AddInt32(&m.subscribers, 1) return newFakeSubscriber(m, m.updates), nil } diff --git a/pkg/agent/manager/cache/cache.go b/pkg/agent/manager/cache/cache.go index c33365c51b..18036b8f8f 100644 --- a/pkg/agent/manager/cache/cache.go +++ b/pkg/agent/manager/cache/cache.go @@ -209,7 +209,7 @@ func (c *Cache) Entries() []*common.RegistrationEntry { for _, record := range c.records { out = append(out, record.entry) } - sortEntries(out) + sortEntriesByID(out) return out } @@ -238,9 +238,9 @@ func (c *Cache) FetchWorkloadUpdate(selectors []*common.Selector) *WorkloadUpdat return c.buildWorkloadUpdate(set) } -// SubscribeToWorkloadUpdates creates a subscriber for given selector set. +// NewSubscriber creates a subscriber for given selector set. // Separately call Notify for the first time after this method is invoked to receive latest updates. -func (c *Cache) SubscribeToWorkloadUpdates(selectors []*common.Selector) Subscriber { +func (c *Cache) NewSubscriber(selectors []*common.Selector) Subscriber { c.mu.Lock() defer c.mu.Unlock() @@ -420,14 +420,15 @@ func (c *Cache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.Regi } } - // entries with active subscribers which are not cached will be put in staleEntries map - activeSubsByEntryID, recordsWithLastAccessTime := c.syncSVIDs() + // entries with active subscribers which are not cached will be put in staleEntries map; + // irrespective of what svid cache size as we cannot deny identity to a subscriber + activeSubsByEntryID, recordsWithLastAccessTime := c.syncSVIDsWithSubscribers() extraSize := len(c.svids) - c.svidCacheMaxSize // delete svids without subscribers and which have not been accessed since svidCacheExpiryTime if extraSize > 0 { // sort recordsWithLastAccessTime - sortTimestamps(recordsWithLastAccessTime) + sortByTimestamps(recordsWithLastAccessTime) for _, record := range recordsWithLastAccessTime { if extraSize <= 0 { @@ -531,7 +532,7 @@ func (c *Cache) SyncSVIDsWithSubscribers() { c.mu.Lock() defer c.mu.Unlock() - c.syncSVIDs() + c.syncSVIDsWithSubscribers() } // Notify subscribers of selector set only if all SVIDs for corresponding selector set are cached @@ -542,29 +543,23 @@ func (c *Cache) Notify(selectors []*common.Selector) bool { defer c.mu.RUnlock() set, setFree := allocSelectorSet(selectors...) defer setFree() - if len(c.missingSVIDRecords(set)) == 0 { + if !c.missingSVIDRecords(set) { c.notifyBySelectorSet(set) return true } return false } -func (c *Cache) missingSVIDRecords(set selectorSet) []*StaleEntry { +func (c *Cache) missingSVIDRecords(set selectorSet) bool { records, recordsDone := c.getRecordsForSelectors(set) defer recordsDone() - if len(records) == 0 { - return nil - } - out := make([]*StaleEntry, 0, len(records)) for record := range records { - if _, ok := c.svids[record.entry.EntryId]; !ok { - out = append(out, &StaleEntry{ - Entry: record.entry, - }) + if _, exists := c.svids[record.entry.EntryId]; !exists { + return true } } - return out + return false } func (c *Cache) updateLastAccessTimestamp(selectors []*common.Selector) { @@ -583,7 +578,7 @@ func (c *Cache) updateLastAccessTimestamp(selectors []*common.Selector) { // entries with active subscribers which are not cached will be put in staleEntries map // records which are not cached for remainder of max cache size will also be put in staleEntries map -func (c *Cache) syncSVIDs() (map[string]struct{}, []recordAccessEvent) { +func (c *Cache) syncSVIDsWithSubscribers() (map[string]struct{}, []recordAccessEvent) { activeSubsByEntryID := make(map[string]struct{}) lastAccessTimestamps := make([]recordAccessEvent, 0, len(c.records)) @@ -817,7 +812,7 @@ func (c *Cache) matchingEntries(set selectorSet) []*common.RegistrationEntry { for record := range records { out = append(out, record.entry) } - sortEntries(out) + sortEntriesByID(out) return out } @@ -939,13 +934,13 @@ func sortIdentities(identities []Identity) { }) } -func sortEntries(entries []*common.RegistrationEntry) { +func sortEntriesByID(entries []*common.RegistrationEntry) { sort.Slice(entries, func(a, b int) bool { return entries[a].EntryId < entries[b].EntryId }) } -func sortTimestamps(records []recordAccessEvent) { +func sortByTimestamps(records []recordAccessEvent) { sort.Slice(records, func(a, b int) bool { return records[a].timestamp < records[b].timestamp }) diff --git a/pkg/agent/manager/cache/cache_test.go b/pkg/agent/manager/cache/cache_test.go index 49839a8155..1b71ef3f52 100644 --- a/pkg/agent/manager/cache/cache_test.go +++ b/pkg/agent/manager/cache/cache_test.go @@ -436,7 +436,7 @@ func TestSubscriberOnlyGetsEntriesWithSVID(t *testing.T) { } cache.UpdateEntries(updateEntries, nil) - sub := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + sub := cache.NewSubscriber(makeSelectors("A")) defer sub.Finish() assertNoWorkloadUpdate(t, sub) @@ -679,7 +679,7 @@ func TestSVIDCacheExpiry(t *testing.T) { for id, entry := range updateEntries.RegistrationEntries { // create and close subscribers for remaining entries so that svid cache is full if id != foo.EntryId && id != bar.EntryId { - sub := cache.SubscribeToWorkloadUpdates(entry.Selectors) + sub := cache.NewSubscriber(entry.Selectors) sub.Finish() } } @@ -693,7 +693,7 @@ func TestSVIDCacheExpiry(t *testing.T) { // bar SVID should be cached as it has active subscriber assert.True(t, cache.Notify(makeSelectors("B"))) - subA = cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + subA = cache.NewSubscriber(makeSelectors("A")) defer subA.Finish() cache.UpdateEntries(updateEntries, nil) @@ -725,7 +725,7 @@ func TestMaxSVIDCacheSize(t *testing.T) { foo := makeRegistrationEntry("FOO", "A") updateEntries.RegistrationEntries[foo.EntryId] = foo - subA := cache.SubscribeToWorkloadUpdates(foo.Selectors) + subA := cache.NewSubscriber(foo.Selectors) defer subA.Finish() cache.UpdateEntries(updateEntries, nil) @@ -757,7 +757,7 @@ func TestSyncSVIDsWithSubscribers(t *testing.T) { cache.UpdateEntries(updateEntries, nil) // Create a subscriber for foo - subA := cache.SubscribeToWorkloadUpdates(foo.Selectors) + subA := cache.NewSubscriber(foo.Selectors) defer subA.Finish() require.Len(t, cache.GetStaleEntries(), 0) @@ -823,7 +823,7 @@ func BenchmarkCacheGlobalNotification(b *testing.B) { cache.UpdateEntries(updateEntries, nil) for i := 0; i < numWorkloads; i++ { selectors := distinctSelectors(i, selectorsPerWorkload) - cache.SubscribeToWorkloadUpdates(selectors) + cache.NewSubscriber(selectors) } runtime.GC() @@ -987,7 +987,7 @@ func makeFederatesWith(bundles ...*Bundle) []string { } func subscribeToWorkloadUpdatesAndNotify(t *testing.T, cache *Cache, selectors []*common.Selector) Subscriber { - subscriber := cache.SubscribeToWorkloadUpdates(selectors) + subscriber := cache.NewSubscriber(selectors) assert.True(t, cache.Notify(selectors)) return subscriber } diff --git a/pkg/agent/manager/manager.go b/pkg/agent/manager/manager.go index aac7e8d319..fc427f22dd 100644 --- a/pkg/agent/manager/manager.go +++ b/pkg/agent/manager/manager.go @@ -155,11 +155,12 @@ func (m *manager) SubscribeToCacheChanges(ctx context.Context, selectors cache.S } func (m *manager) subscribeToCacheChanges(ctx context.Context, selectors cache.Selectors, notifyCallbackFn func()) (cache.Subscriber, error) { - subscriber := m.cache.SubscribeToWorkloadUpdates(selectors) + subscriber := m.cache.NewSubscriber(selectors) bo := m.subscribeBackoffFn() // block until all svids are cached and subscriber is notified for { svidsInCache := m.cache.Notify(selectors) + // used for testing if notifyCallbackFn != nil { notifyCallbackFn() } diff --git a/pkg/agent/manager/manager_test.go b/pkg/agent/manager/manager_test.go index 5a7720a831..4e4d5c3512 100644 --- a/pkg/agent/manager/manager_test.go +++ b/pkg/agent/manager/manager_test.go @@ -516,12 +516,16 @@ func TestSynchronization(t *testing.T) { Clk: clk, Catalog: cat, WorkloadKeyType: workloadkey.ECP256, - SVIDCacheMaxSize: 1, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m := newManager(c) + if err := m.Initialize(context.Background()); err != nil { + t.Fatal(err) + } + require.Equal(t, clk.Now(), m.GetLastSync()) + sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{ {Type: "unix", Value: "uid:1111"}, {Type: "spiffe_id", Value: joinTokenID.String()}, @@ -529,11 +533,6 @@ func TestSynchronization(t *testing.T) { require.NoError(t, err) defer sub.Finish() - if err := m.Initialize(context.Background()); err != nil { - t.Fatal(err) - } - require.Equal(t, clk.Now(), m.GetLastSync()) - // Before synchronization identitiesBefore := identitiesByEntryID(m.cache.Identities()) if len(identitiesBefore) != 3 { @@ -803,7 +802,6 @@ func TestSubscribersGetUpToDateBundle(t *testing.T) { RotationInterval: 1 * time.Hour, SyncInterval: 1 * time.Hour, Clk: clk, - SVIDCacheMaxSize: 1, Catalog: cat, WorkloadKeyType: workloadkey.ECP256, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), @@ -811,9 +809,9 @@ func TestSubscribersGetUpToDateBundle(t *testing.T) { m := newManager(c) + defer initializeAndRunManager(t, m)() sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{{Type: "unix", Value: "uid:1111"}}) require.NoError(t, err) - defer initializeAndRunManager(t, m)() util.RunWithTimeout(t, 1*time.Second, func() { // Update should contain a new bundle. @@ -1001,11 +999,13 @@ func TestSurvivesCARotation(t *testing.T) { Clk: clk, Catalog: cat, WorkloadKeyType: workloadkey.ECP256, - SVIDCacheMaxSize: 1, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m := newManager(c) + m.subscribeBackoffFn = func() backoff.BackOff { + return backoff.NewConstantBackOff(svidSyncInterval) + } sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{{Type: "unix", Value: "uid:1111"}}) require.NoError(t, err) diff --git a/test/integration/suites/fetch-x509-svids/05-fetch-x509-svids b/test/integration/suites/fetch-x509-svids/05-fetch-x509-svids index 71cc902860..4bb53c55df 100755 --- a/test/integration/suites/fetch-x509-svids/05-fetch-x509-svids +++ b/test/integration/suites/fetch-x509-svids/05-fetch-x509-svids @@ -10,7 +10,7 @@ X509SVIDCOUNT=$(docker-compose exec -u 1001 -T spire-agent \ if [ "$X509SVIDCOUNT" -ne "$ENTRYCOUNT" ]; then fail-now "X.509-SVID check failed. Expected $ENTRYCOUNT X.509-SVIDs but received $X509SVIDCOUNT for uid 1001"; else - echo "Expected $ENTRYCOUNT X.509-SVIDs and received $X509SVIDCOUNT for uid 1001"; + log-info "Expected $ENTRYCOUNT X.509-SVIDs and received $X509SVIDCOUNT for uid 1001"; fi # Call agent debug endpoints and check if extra X.509-SVIDs from cache are cleaned up diff --git a/test/integration/suites/fetch-x509-svids/07-fetch-x509-svids b/test/integration/suites/fetch-x509-svids/07-fetch-x509-svids index 3d0ecafb74..9a46e29602 100755 --- a/test/integration/suites/fetch-x509-svids/07-fetch-x509-svids +++ b/test/integration/suites/fetch-x509-svids/07-fetch-x509-svids @@ -10,7 +10,7 @@ X509SVIDCOUNT=$(docker-compose exec -u 1002 -T spire-agent \ if [ "$X509SVIDCOUNT" -ne "$ENTRYCOUNT" ]; then fail-now "X.509-SVID check failed. Expected $ENTRYCOUNT X.509-SVIDs but received $X509SVIDCOUNT for uid 1002"; else - echo "Expected $ENTRYCOUNT X.509-SVIDs and received $X509SVIDCOUNT for uid 1002"; + log-info "Expected $ENTRYCOUNT X.509-SVIDs and received $X509SVIDCOUNT for uid 1002"; fi X509SVIDCOUNT=$(docker-compose exec -u 1001 -T spire-agent \ @@ -20,7 +20,7 @@ X509SVIDCOUNT=$(docker-compose exec -u 1001 -T spire-agent \ if [ "$X509SVIDCOUNT" -ne "$ENTRYCOUNT" ]; then fail-now "X.509-SVID check failed. Expected $ENTRYCOUNT X.509-SVIDs but received $X509SVIDCOUNT for uid 1001"; else - echo "Expected $ENTRYCOUNT X.509-SVIDs and received $X509SVIDCOUNT for uid 1001"; + log-info "Expected $ENTRYCOUNT X.509-SVIDs and received $X509SVIDCOUNT for uid 1001"; fi # Call agent debug endpoints and check if extra X.509-SVIDs from cache are cleaned up From 38a17132cf2ffd8378735bf85d9a1a916227276d Mon Sep 17 00:00:00 2001 From: Prasad Borole Date: Mon, 1 Aug 2022 15:05:23 -0700 Subject: [PATCH 10/19] Added unit tests and log Signed-off-by: Prasad Borole --- pkg/agent/endpoints/sdsv3/handler_test.go | 34 +++++++++++++++++++++-- pkg/agent/manager/cache/cache.go | 2 ++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/pkg/agent/endpoints/sdsv3/handler_test.go b/pkg/agent/endpoints/sdsv3/handler_test.go index c390f6dcf6..7a3836e2b5 100644 --- a/pkg/agent/endpoints/sdsv3/handler_test.go +++ b/pkg/agent/endpoints/sdsv3/handler_test.go @@ -831,6 +831,22 @@ func TestStreamSecretsBadNonce(t *testing.T) { requireSecrets(t, resp, workloadTLSCertificate2) } +func TestStreamSecretsErrInSubscribeToCacheChanges(t *testing.T) { + err := errors.New("err") + test := setupErrTest(t) + defer test.server.Stop() + + stream, err := test.handler.StreamSecrets(context.Background()) + require.NoError(t, err) + defer func() { + require.NoError(t, stream.CloseSend()) + }() + + resp, err := stream.Recv() + require.Error(t, err) + require.Nil(t, resp) +} + func TestFetchSecrets(t *testing.T) { for _, tt := range []struct { name string @@ -1174,11 +1190,16 @@ func DeltaSecretsTest(t *testing.T) { } func setupTest(t *testing.T) *handlerTest { - return setupTestWithConfig(t, Config{}) + return setupTestWithManager(t, Config{}, NewFakeManager(t)) } -func setupTestWithConfig(t *testing.T, c Config) *handlerTest { +func setupErrTest(t *testing.T) *handlerTest { manager := NewFakeManager(t) + manager.err = errors.New("err") + return setupTestWithManager(t, Config{}, manager) +} + +func setupTestWithManager(t *testing.T, c Config, manager *FakeManager) *handlerTest { defaultConfig := Config{ Manager: manager, Attestor: FakeAttestor(workloadSelectors), @@ -1220,6 +1241,11 @@ func setupTestWithConfig(t *testing.T, c Config) *handlerTest { return test } +func setupTestWithConfig(t *testing.T, c Config) *handlerTest { + manager := NewFakeManager(t) + return setupTestWithManager(t, c, manager) +} + type handlerTest struct { t *testing.T @@ -1279,6 +1305,7 @@ type FakeManager struct { upd *cache.WorkloadUpdate next int subs map[int]chan *cache.WorkloadUpdate + err error } func NewFakeManager(t *testing.T) *FakeManager { @@ -1289,6 +1316,9 @@ func NewFakeManager(t *testing.T) *FakeManager { } func (m *FakeManager) SubscribeToCacheChanges(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) { + if m.err != nil { + return nil, m.err + } require.Equal(m.t, workloadSelectors, selectors) updch := make(chan *cache.WorkloadUpdate, 1) diff --git a/pkg/agent/manager/cache/cache.go b/pkg/agent/manager/cache/cache.go index 18036b8f8f..a094456471 100644 --- a/pkg/agent/manager/cache/cache.go +++ b/pkg/agent/manager/cache/cache.go @@ -454,6 +454,8 @@ func (c *Cache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.Regi c.staleEntries[id] = true } } + c.log.WithField(telemetry.OutdatedSVIDs, len(outdatedEntries)). + Debug("Updating SVIDs with outdated attributes in cache") if bundleRemoved || len(bundleChanged) > 0 { c.BundleCache.Update(c.bundles) From e71f1c96d666e47e1cb3d92269e8adda89b1ad74 Mon Sep 17 00:00:00 2001 From: Prasad Borole Date: Mon, 1 Aug 2022 15:17:18 -0700 Subject: [PATCH 11/19] fix linting Signed-off-by: Prasad Borole --- pkg/agent/endpoints/sdsv3/handler_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pkg/agent/endpoints/sdsv3/handler_test.go b/pkg/agent/endpoints/sdsv3/handler_test.go index 7a3836e2b5..1fd9455b3c 100644 --- a/pkg/agent/endpoints/sdsv3/handler_test.go +++ b/pkg/agent/endpoints/sdsv3/handler_test.go @@ -832,7 +832,6 @@ func TestStreamSecretsBadNonce(t *testing.T) { } func TestStreamSecretsErrInSubscribeToCacheChanges(t *testing.T) { - err := errors.New("err") test := setupErrTest(t) defer test.server.Stop() @@ -1195,7 +1194,7 @@ func setupTest(t *testing.T) *handlerTest { func setupErrTest(t *testing.T) *handlerTest { manager := NewFakeManager(t) - manager.err = errors.New("err") + manager.err = errors.New("bad-error") return setupTestWithManager(t, Config{}, manager) } From 993f1bc0c9667b47152601e6fef6054bea2003cf Mon Sep 17 00:00:00 2001 From: Prasad Borole Date: Wed, 24 Aug 2022 13:45:32 -0700 Subject: [PATCH 12/19] the new LRU cache is enabled if experimental cache size is provided Signed-off-by: Prasad Borole --- .../api/delegatedidentity/v1/service_test.go | 3 +- pkg/agent/manager/cache/cache.go | 316 ++---- pkg/agent/manager/cache/cache_test.go | 346 ++----- pkg/agent/manager/cache/lru_cache.go | 944 +++++++++++++++++ .../manager/cache/lru_cache_subscriber.go | 60 ++ pkg/agent/manager/cache/lru_cache_test.go | 954 ++++++++++++++++++ pkg/agent/manager/cache/sets.go | 46 + pkg/agent/manager/config.go | 15 +- pkg/agent/manager/manager.go | 76 +- pkg/agent/manager/manager_test.go | 432 +++++--- pkg/agent/manager/sync.go | 17 +- 11 files changed, 2498 insertions(+), 711 deletions(-) create mode 100644 pkg/agent/manager/cache/lru_cache.go create mode 100644 pkg/agent/manager/cache/lru_cache_subscriber.go create mode 100644 pkg/agent/manager/cache/lru_cache_test.go diff --git a/pkg/agent/api/delegatedidentity/v1/service_test.go b/pkg/agent/api/delegatedidentity/v1/service_test.go index 0791c256cd..8baf8f20b1 100644 --- a/pkg/agent/api/delegatedidentity/v1/service_test.go +++ b/pkg/agent/api/delegatedidentity/v1/service_test.go @@ -9,7 +9,6 @@ import ( "testing" "time" - "github.com/andres-erbsen/clock" "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/spiffe/go-spiffe/v2/bundle/spiffebundle" @@ -812,5 +811,5 @@ func (m *FakeManager) SubscribeToBundleChanges() *cache.BundleStream { func newTestCache() *cache.Cache { log, _ := test.NewNullLogger() - return cache.New(log, trustDomain1, bundle1, telemetry.Blackhole{}, 0, clock.NewMock()) + return cache.New(log, trustDomain1, bundle1, telemetry.Blackhole{}) } diff --git a/pkg/agent/manager/cache/cache.go b/pkg/agent/manager/cache/cache.go index a094456471..fb5372d199 100644 --- a/pkg/agent/manager/cache/cache.go +++ b/pkg/agent/manager/cache/cache.go @@ -1,13 +1,13 @@ package cache import ( + "context" "crypto" "crypto/x509" "sort" "sync" "time" - "github.com/andres-erbsen/clock" "github.com/sirupsen/logrus" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/common/bundleutil" @@ -15,10 +15,6 @@ import ( "github.com/spiffe/spire/proto/spire/common" ) -const ( - DefaultSVIDCacheMaxSize = 1000 -) - type Selectors []*common.Selector type Bundle = bundleutil.Bundle @@ -63,9 +59,9 @@ type X509SVID struct { PrivateKey crypto.Signer } -// Cache caches each registration entry, bundles, and JWT SVIDs for the agent. -// The signed X509-SVIDs for those entries are stored in LRU-like cache. -// It allows subscriptions by (workload) selector sets and notifies subscribers when: +// Cache caches each registration entry, signed X509-SVIDs for those entries, +// bundles, and JWT SVIDs for the agent. It allows subscriptions by (workload) +// selector sets and notifies subscribers when: // // 1) a registration entry related to the selectors: // * is modified @@ -80,22 +76,6 @@ type X509SVID struct { // selector it encounters. Each selector index tracks the subscribers (i.e // workloads) and registration entries that have that selector. // -// The LRU-like SVID cache has configurable size limit and expiry period. -// 1. Size limit of SVID cache is a soft limit. If SVID has a subscriber present then -// that SVID is never removed from cache. -// 2. Least recently used SVIDs are removed from cache only after the cache expiry period has passed. -// This is done to reduce the overall cache churn. -// 3. Last access timestamp for SVID cache entry is updated when a new subscriber is created -// 4. When a new subscriber is created and there is a cache miss -// then subscriber needs to wait for next SVID sync event to receive WorkloadUpdate with newly minted SVID -// -// The advantage of above approach is that if agent has entry count less than cache size -// then all SVIDs are cached at all times. If agent has entry count greater than cache size then -// subscribers will continue to get SVID updates (potential delay for first WorkloadUpdate if cache miss) -// and least used SVIDs will be removed from cache which will save memory usage. -// This allows agent to support environments where the active simultaneous workload count -// is a small percentage of the large number of registrations assigned to the agent. -// // When registration entries are added/updated/removed, the set of relevant // selectors are gathered and the indexes for those selectors are combed for // all relevant subscribers. @@ -123,7 +103,6 @@ type Cache struct { log logrus.FieldLogger trustDomain spiffeid.TrustDomain - clk clock.Clock metrics telemetry.Metrics @@ -135,20 +114,14 @@ type Cache struct { // selectors holds the selector indices, keyed by a selector key selectors map[selector]*selectorIndex - // staleEntries holds stale or new registration entries which require new SVID to be stored in cache + // staleEntries holds stale registration entries staleEntries map[string]bool // bundles holds the trust bundles, keyed by trust domain id (i.e. "spiffe://domain.test") bundles map[spiffeid.TrustDomain]*bundleutil.Bundle - - // svids are stored by entry IDs - svids map[string]*X509SVID - - // svidCacheMaxSize is a soft limit of max number of SVIDs that would be stored in cache - svidCacheMaxSize int } -// StaleEntry holds stale or outdated entries which require new SVID with old SVIDs expiration time (if present) +// StaleEntry holds stale entries with SVIDs expiration time type StaleEntry struct { // Entry stale registration entry Entry *common.RegistrationEntry @@ -156,12 +129,7 @@ type StaleEntry struct { ExpiresAt time.Time } -func New(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundle, metrics telemetry.Metrics, - svidCacheMaxSize int, clk clock.Clock) *Cache { - if svidCacheMaxSize <= 0 { - svidCacheMaxSize = DefaultSVIDCacheMaxSize - } - +func New(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundle, metrics telemetry.Metrics) *Cache { return &Cache{ BundleCache: NewBundleCache(trustDomain, bundle), JWTSVIDCache: NewJWTSVIDCache(), @@ -175,9 +143,6 @@ func New(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundl bundles: map[spiffeid.TrustDomain]*bundleutil.Bundle{ trustDomain: bundle, }, - svids: make(map[string]*X509SVID), - svidCacheMaxSize: svidCacheMaxSize, - clk: clk, } } @@ -189,44 +154,41 @@ func (c *Cache) Identities() []Identity { out := make([]Identity, 0, len(c.records)) for _, record := range c.records { - svid, ok := c.svids[record.entry.EntryId] - if !ok { + if record.svid == nil { // The record does not have an SVID yet and should not be returned // from the cache. continue } - out = append(out, makeIdentity(record, svid)) + out = append(out, makeIdentity(record)) } sortIdentities(out) return out } -func (c *Cache) Entries() []*common.RegistrationEntry { +func (c *Cache) CountSVIDs() int { c.mu.RLock() defer c.mu.RUnlock() - out := make([]*common.RegistrationEntry, 0, len(c.records)) + var records int for _, record := range c.records { - out = append(out, record.entry) + if record.svid == nil { + // The record does not have an SVID yet and should not be returned + // from the cache. + continue + } + records++ } - sortEntriesByID(out) - return out -} - -func (c *Cache) CountSVIDs() int { - c.mu.RLock() - defer c.mu.RUnlock() - return len(c.svids) + return records } -func (c *Cache) MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry { +func (c *Cache) MatchingIdentities(selectors []*common.Selector) []Identity { set, setDone := allocSelectorSet(selectors...) defer setDone() c.mu.RLock() defer c.mu.RUnlock() - return c.matchingEntries(set) + return c.matchingIdentities(set) } func (c *Cache) FetchWorkloadUpdate(selectors []*common.Selector) *WorkloadUpdate { @@ -238,19 +200,8 @@ func (c *Cache) FetchWorkloadUpdate(selectors []*common.Selector) *WorkloadUpdat return c.buildWorkloadUpdate(set) } -// NewSubscriber creates a subscriber for given selector set. -// Separately call Notify for the first time after this method is invoked to receive latest updates. -func (c *Cache) NewSubscriber(selectors []*common.Selector) Subscriber { - c.mu.Lock() - defer c.mu.Unlock() - - sub := newSubscriber(c, selectors) - for s := range sub.set { - c.addSelectorIndexSub(s, sub) - } - // update lastAccessTimestamp of records containing provided selectors - c.updateLastAccessTimestamp(selectors) - return sub +func (c *Cache) SubscribeToWorkloadUpdates(ctx context.Context, selectors Selectors) (Subscriber, error) { + return c.subscribeToWorkloadUpdates(selectors), nil } // UpdateEntries updates the cache with the provided registration entries and bundles and @@ -324,14 +275,11 @@ func (c *Cache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.Regi c.delSelectorIndicesRecord(selRem, record) notifySets = append(notifySets, selRem) delete(c.records, id) - delete(c.svids, id) // Remove stale entry since, registration entry is no longer on cache. delete(c.staleEntries, id) } } - outdatedEntries := make(map[string]struct{}) - // Add/update records for registration entries in the update for _, newEntry := range update.RegistrationEntries { clearSelectorSet(selAdd) @@ -389,9 +337,9 @@ func (c *Cache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.Regi notifySets = append(notifySets, notifySet) } - // Identify stale/outdated entries - if existingEntry != nil && existingEntry.RevisionNumber != newEntry.RevisionNumber { - outdatedEntries[newEntry.EntryId] = struct{}{} + // Invoke the svid checker callback for this record + if checkSVID != nil && checkSVID(existingEntry, newEntry, record.svid) { + c.staleEntries[newEntry.EntryId] = true } // Log all the details of the update to the DEBUG log @@ -420,43 +368,6 @@ func (c *Cache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.Regi } } - // entries with active subscribers which are not cached will be put in staleEntries map; - // irrespective of what svid cache size as we cannot deny identity to a subscriber - activeSubsByEntryID, recordsWithLastAccessTime := c.syncSVIDsWithSubscribers() - extraSize := len(c.svids) - c.svidCacheMaxSize - - // delete svids without subscribers and which have not been accessed since svidCacheExpiryTime - if extraSize > 0 { - // sort recordsWithLastAccessTime - sortByTimestamps(recordsWithLastAccessTime) - - for _, record := range recordsWithLastAccessTime { - if extraSize <= 0 { - // no need to delete SVIDs any further as cache size <= svidCacheMaxSize - break - } - if _, ok := c.svids[record.id]; ok { - if _, exists := activeSubsByEntryID[record.id]; !exists { - // remove svid - c.log.WithField("record_id", record.id). - WithField("record_timestamp", record.timestamp). - Debug("Removing SVID record") - delete(c.svids, record.id) - extraSize-- - } - } - } - } - - // Update all stale svids or svids whose registration entry is outdated - for id, svid := range c.svids { - if _, ok := outdatedEntries[id]; ok || (checkSVID != nil && checkSVID(nil, c.records[id].entry, svid)) { - c.staleEntries[id] = true - } - } - c.log.WithField(telemetry.OutdatedSVIDs, len(outdatedEntries)). - Debug("Updating SVIDs with outdated attributes in cache") - if bundleRemoved || len(bundleChanged) > 0 { c.BundleCache.Update(c.bundles) } @@ -484,7 +395,7 @@ func (c *Cache) UpdateSVIDs(update *UpdateSVIDs) { continue } - c.svids[entryID] = svid + record.svid = svid notifySet.Merge(record.entry.Selectors...) log := c.log.WithFields(logrus.Fields{ telemetry.Entry: record.entry.EntryId, @@ -514,8 +425,8 @@ func (c *Cache) GetStaleEntries() []*StaleEntry { } var expiresAt time.Time - if cachedSvid, ok := c.svids[entryID]; ok { - expiresAt = cachedSvid.Chain[0].NotAfter + if cachedEntry.svid != nil { + expiresAt = cachedEntry.svid.Chain[0].NotAfter } staleEntries = append(staleEntries, &StaleEntry{ @@ -527,97 +438,58 @@ func (c *Cache) GetStaleEntries() []*StaleEntry { return staleEntries } -// SyncSVIDsWithSubscribers will sync svid cache: -// entries with active subscribers which are not cached will be put in staleEntries map -// records which are not cached for remainder of max cache size will also be put in staleEntries map -func (c *Cache) SyncSVIDsWithSubscribers() { - c.mu.Lock() - defer c.mu.Unlock() - - c.syncSVIDsWithSubscribers() -} - -// Notify subscribers of selector set only if all SVIDs for corresponding selector set are cached -// It returns whether all SVIDs are cached or not. -// This method should be retried with backoff to avoid lock contention. -func (c *Cache) Notify(selectors []*common.Selector) bool { +func (c *Cache) MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry { c.mu.RLock() defer c.mu.RUnlock() - set, setFree := allocSelectorSet(selectors...) - defer setFree() - if !c.missingSVIDRecords(set) { - c.notifyBySelectorSet(set) - return true - } - return false -} -func (c *Cache) missingSVIDRecords(set selectorSet) bool { + set, setDone := allocSelectorSet(selectors...) + defer setDone() + records, recordsDone := c.getRecordsForSelectors(set) defer recordsDone() + if len(records) == 0 { + return nil + } + + // Return identities in ascending "entry id" order to maintain a consistent + // ordering. + // TODO: figure out how to determine the "default" identity + out := make([]*common.RegistrationEntry, 0, len(records)) for record := range records { - if _, exists := c.svids[record.entry.EntryId]; !exists { - return true - } + out = append(out, record.entry) } - return false + sortEntriesByID(out) + return out } -func (c *Cache) updateLastAccessTimestamp(selectors []*common.Selector) { - set, setFree := allocSelectorSet(selectors...) - defer setFree() - - records, recordsDone := c.getRecordsForSelectors(set) - defer recordsDone() +func (c *Cache) Entries() []*common.RegistrationEntry { + c.mu.RLock() + defer c.mu.RUnlock() - now := c.clk.Now().UnixMilli() - for record := range records { - // Set lastAccessTimestamp so that svid LRU cache can be cleaned based on this timestamp - record.lastAccessTimestamp = now + out := make([]*common.RegistrationEntry, 0, len(c.records)) + for _, record := range c.records { + out = append(out, record.entry) } + sortEntriesByID(out) + return out } -// entries with active subscribers which are not cached will be put in staleEntries map -// records which are not cached for remainder of max cache size will also be put in staleEntries map -func (c *Cache) syncSVIDsWithSubscribers() (map[string]struct{}, []recordAccessEvent) { - activeSubsByEntryID := make(map[string]struct{}) - lastAccessTimestamps := make([]recordAccessEvent, 0, len(c.records)) +func (c *Cache) SyncSVIDsWithSubscribers() { + c.log.Error("SyncSVIDsWithSubscribers method is not implemented") + return +} - // iterate over all selectors from cached entries and obtain: - // 1. entries that have active subscribers - // 1.1 if those entries don't have corresponding SVID cached then put them in staleEntries - // so that SVID will be cached in next sync - // 2. get lastAccessTimestamp of each entry - for id, record := range c.records { - for _, sel := range record.entry.Selectors { - if index, ok := c.selectors[makeSelector(sel)]; ok && index != nil { - if len(index.subs) > 0 { - if _, ok := c.svids[record.entry.EntryId]; !ok { - c.staleEntries[id] = true - } - activeSubsByEntryID[id] = struct{}{} - break - } - } - } - lastAccessTimestamps = append(lastAccessTimestamps, newRecord(record.lastAccessTimestamp, id)) - } +func (c *Cache) subscribeToWorkloadUpdates(selectors []*common.Selector) Subscriber { + c.mu.Lock() + defer c.mu.Unlock() - remainderSize := c.svidCacheMaxSize - len(c.svids) - // add records which are not cached for remainder of cache size - for id := range c.records { - if len(c.staleEntries) >= remainderSize { - break - } - if _, svidCached := c.svids[id]; !svidCached { - if _, ok := c.staleEntries[id]; !ok { - c.staleEntries[id] = true - } - } + sub := newSubscriber(c, selectors) + for s := range sub.set { + c.addSelectorIndexSub(s, sub) } - - return activeSubsByEntryID, lastAccessTimestamps + c.notify(sub) + return sub } func (c *Cache) updateOrCreateRecord(newEntry *common.RegistrationEntry) (*cacheRecord, *common.RegistrationEntry) { @@ -791,33 +663,12 @@ func (c *Cache) matchingIdentities(set selectorSet) []Identity { // TODO: figure out how to determine the "default" identity out := make([]Identity, 0, len(records)) for record := range records { - if svid, ok := c.svids[record.entry.EntryId]; ok { - out = append(out, makeIdentity(record, svid)) - } + out = append(out, makeIdentity(record)) } sortIdentities(out) return out } -func (c *Cache) matchingEntries(set selectorSet) []*common.RegistrationEntry { - records, recordsDone := c.getRecordsForSelectors(set) - defer recordsDone() - - if len(records) == 0 { - return nil - } - - // Return identities in ascending "entry id" order to maintain a consistent - // ordering. - // TODO: figure out how to determine the "default" identity - out := make([]*common.RegistrationEntry, 0, len(records)) - for record := range records { - out = append(out, record.entry) - } - sortEntriesByID(out) - return out -} - func (c *Cache) buildWorkloadUpdate(set selectorSet) *WorkloadUpdate { w := &WorkloadUpdate{ Bundle: c.bundles[c.trustDomain], @@ -852,13 +703,17 @@ func (c *Cache) buildWorkloadUpdate(set selectorSet) *WorkloadUpdate { } func (c *Cache) getRecordsForSelectors(set selectorSet) (recordSet, func()) { - // Build and dedup a list of candidate entries. Don't check for selector set inclusion yet, since + // Build and dedup a list of candidate entries. Ignore those without an + // SVID but otherwise don't check for selector set inclusion yet, since // that is a more expensive operation and we could easily have duplicate // entries to check. records, recordsDone := allocRecordSet() for selector := range set { if index := c.getSelectorIndexForRead(selector); index != nil { for record := range index.records { + if record.svid == nil { + continue + } records[record] = struct{}{} } } @@ -900,9 +755,9 @@ func (c *Cache) getSelectorIndexForRead(s selector) *selectorIndex { } type cacheRecord struct { - entry *common.RegistrationEntry - subs map[*subscriber]struct{} - lastAccessTimestamp int64 + entry *common.RegistrationEntry + svid *X509SVID + subs map[*subscriber]struct{} } func newCacheRecord() *cacheRecord { @@ -936,31 +791,10 @@ func sortIdentities(identities []Identity) { }) } -func sortEntriesByID(entries []*common.RegistrationEntry) { - sort.Slice(entries, func(a, b int) bool { - return entries[a].EntryId < entries[b].EntryId - }) -} - -func sortByTimestamps(records []recordAccessEvent) { - sort.Slice(records, func(a, b int) bool { - return records[a].timestamp < records[b].timestamp - }) -} - -func makeIdentity(record *cacheRecord, svid *X509SVID) Identity { +func makeIdentity(record *cacheRecord) Identity { return Identity{ Entry: record.entry, - SVID: svid.Chain, - PrivateKey: svid.PrivateKey, + SVID: record.svid.Chain, + PrivateKey: record.svid.PrivateKey, } } - -type recordAccessEvent struct { - timestamp int64 - id string -} - -func newRecord(timestamp int64, id string) recordAccessEvent { - return recordAccessEvent{timestamp: timestamp, id: id} -} diff --git a/pkg/agent/manager/cache/cache_test.go b/pkg/agent/manager/cache/cache_test.go index 1b71ef3f52..0a048e5c9c 100644 --- a/pkg/agent/manager/cache/cache_test.go +++ b/pkg/agent/manager/cache/cache_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/andres-erbsen/clock" "github.com/sirupsen/logrus/hooks/test" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/common/bundleutil" @@ -59,7 +58,7 @@ func TestFetchWorkloadUpdate(t *testing.T) { }, workloadUpdate) } -func TestMatchingRegistrationIdentities(t *testing.T) { +func TestMatchingIdentities(t *testing.T) { cache := newTestCache() // populate the cache with FOO and BAR without SVIDS @@ -71,21 +70,19 @@ func TestMatchingRegistrationIdentities(t *testing.T) { } cache.UpdateEntries(updateEntries, nil) - assert.Equal(t, []*common.RegistrationEntry{bar, foo}, - cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) + identities := cache.MatchingIdentities(makeSelectors("A", "B")) + assert.Len(t, identities, 0, "identities should not be returned that don't have SVIDs") - // Update SVIDs and MatchingRegistrationEntries should return both entries updateSVIDs := &UpdateSVIDs{ X509SVIDs: makeX509SVIDs(foo, bar), } cache.UpdateSVIDs(updateSVIDs) - assert.Equal(t, []*common.RegistrationEntry{bar, foo}, - cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) - // Remove SVIDs and MatchingRegistrationEntries should still return both entries - cache.UpdateSVIDs(&UpdateSVIDs{}) - assert.Equal(t, []*common.RegistrationEntry{bar, foo}, - cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) + identities = cache.MatchingIdentities(makeSelectors("A", "B")) + assert.Equal(t, []Identity{ + {Entry: bar}, + {Entry: foo}, + }, identities) } func TestCountSVIDs(t *testing.T) { @@ -140,11 +137,11 @@ func TestAllSubscribersNotifiedOnBundleChange(t *testing.T) { cache := newTestCache() // create some subscribers and assert they get the initial bundle - subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + subA := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer subA.Finish() assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{Bundle: bundleV1}) - subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) + subB := cache.subscribeToWorkloadUpdates(makeSelectors("B")) defer subB.Finish() assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{Bundle: bundleV1}) @@ -171,11 +168,11 @@ func TestSomeSubscribersNotifiedOnFederatedBundleChange(t *testing.T) { }) // subscribe to A and B and assert initial updates are received. - subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + subA := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer subA.Finish() assertAnyWorkloadUpdate(t, subA) - subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) + subB := cache.subscribeToWorkloadUpdates(makeSelectors("B")) defer subB.Finish() assertAnyWorkloadUpdate(t, subB) @@ -234,11 +231,11 @@ func TestSubscribersGetEntriesWithSelectorSubsets(t *testing.T) { cache := newTestCache() // create subscribers for each combination of selectors - subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + subA := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer subA.Finish() - subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) + subB := cache.subscribeToWorkloadUpdates(makeSelectors("B")) defer subB.Finish() - subAB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A", "B")) + subAB := cache.subscribeToWorkloadUpdates(makeSelectors("A", "B")) defer subAB.Finish() // assert all subscribers get the initial update @@ -291,7 +288,7 @@ func TestSubscriberIsNotNotifiedIfNothingChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo), }) - sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -317,7 +314,7 @@ func TestSubscriberNotifiedOnSVIDChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo), }) - sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -332,7 +329,7 @@ func TestSubscriberNotifiedOnSVIDChanges(t *testing.T) { }) } -func TestSubscriberNotificationsOnSelectorChanges(t *testing.T) { +func TestSubcriberNotificationsOnSelectorChanges(t *testing.T) { cache := newTestCache() // initialize the cache with a FOO entry with selector A and an SVID @@ -346,7 +343,7 @@ func TestSubscriberNotificationsOnSelectorChanges(t *testing.T) { }) // create subscribers for A and make sure the initial update has FOO - sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer sub.Finish() assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ Bundle: bundleV1, @@ -383,16 +380,21 @@ func TestSubscriberNotificationsOnSelectorChanges(t *testing.T) { }) } -func TestSubscriberNotifiedWhenEntryDropped(t *testing.T) { +func newTestCache() *Cache { + log, _ := test.NewNullLogger() + return New(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, telemetry.Blackhole{}) +} + +func TestSubcriberNotifiedWhenEntryDropped(t *testing.T) { cache := newTestCache() - subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + subA := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer subA.Finish() assertAnyWorkloadUpdate(t, subA) // subB's job here is to just make sure we don't notify unrelated // subscribers when dropping registration entries - subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) + subB := cache.subscribeToWorkloadUpdates(makeSelectors("B")) defer subB.Finish() assertAnyWorkloadUpdate(t, subB) @@ -426,7 +428,7 @@ func TestSubscriberNotifiedWhenEntryDropped(t *testing.T) { assertNoWorkloadUpdate(t, subB) } -func TestSubscriberOnlyGetsEntriesWithSVID(t *testing.T) { +func TestSubcriberOnlyGetsEntriesWithSVID(t *testing.T) { cache := newTestCache() foo := makeRegistrationEntry("FOO", "A") @@ -436,9 +438,13 @@ func TestSubscriberOnlyGetsEntriesWithSVID(t *testing.T) { } cache.UpdateEntries(updateEntries, nil) - sub := cache.NewSubscriber(makeSelectors("A")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer sub.Finish() - assertNoWorkloadUpdate(t, sub) + + // workload update does not include the identity because it has no SVID. + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + Bundle: bundleV1, + }) // update to include the SVID and now we should get the update cache.UpdateSVIDs(&UpdateSVIDs{ @@ -453,7 +459,7 @@ func TestSubscriberOnlyGetsEntriesWithSVID(t *testing.T) { func TestSubscribersDoNotBlockNotifications(t *testing.T) { cache := newTestCache() - sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer sub.Finish() cache.UpdateEntries(&UpdateEntries{ @@ -483,23 +489,34 @@ func TestCheckSVIDCallback(t *testing.T) { foo := makeRegistrationEntryWithTTL("FOO", 60) + // called once for FOO with no SVID + callCount := 0 cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV2), RegistrationEntries: makeRegistrationEntries(foo), }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { - // should not get invoked - assert.Fail(t, "should not be called as no SVIDs are cached yet") + callCount++ + assert.Equal(t, "FOO", newEntry.EntryId) + + // there is no already existing entry, only the new entry + assert.Nil(t, existingEntry) + assert.Equal(t, foo, newEntry) + assert.Nil(t, svid) + return false }) + assert.Equal(t, 1, callCount) + assert.Empty(t, cache.staleEntries) // called once for FOO with new SVID + callCount = 0 svids := makeX509SVIDs(foo) cache.UpdateSVIDs(&UpdateSVIDs{ X509SVIDs: svids, }) // called once for FOO with existing SVID - callCount := 0 + callCount = 0 cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV2), RegistrationEntries: makeRegistrationEntries(foo), @@ -521,15 +538,22 @@ func TestGetStaleEntries(t *testing.T) { foo := makeRegistrationEntryWithTTL("FOO", 60) - // Create entry but don't mark it stale from checkSVID method; - // it will be marked stale cause it does not have SVID cached + // Create entry but don't mark it stale cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV2), RegistrationEntries: makeRegistrationEntries(foo), }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { return false }) + assert.Empty(t, cache.GetStaleEntries()) + // Update entry and mark it as stale + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + RegistrationEntries: makeRegistrationEntries(foo), + }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { + return true + }) // Assert that the entry is returned as stale. The `ExpiresAt` field should be unset since there is no SVID. expectedEntries := []*StaleEntry{{Entry: cache.records[foo.EntryId].entry}} assert.Equal(t, expectedEntries, cache.GetStaleEntries()) @@ -583,7 +607,7 @@ func TestSubscriberNotNotifiedOnDifferentSVIDChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo, bar), }) - sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -608,7 +632,7 @@ func TestSubscriberNotNotifiedOnOverlappingSVIDChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo, bar), }) - sub := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A", "B")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A", "B")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -620,181 +644,6 @@ func TestSubscriberNotNotifiedOnOverlappingSVIDChanges(t *testing.T) { assertNoWorkloadUpdate(t, sub) } -func TestSVIDCacheExpiry(t *testing.T) { - clk := clock.NewMock() - cache := newTestCacheWithConfig(10, clk) - - clk.Add(1 * time.Second) - foo := makeRegistrationEntry("FOO", "A") - // validate workload update for foo - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo), - }, nil) - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo), - }) - subA := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("A")) - assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ - Bundle: bundleV1, - Identities: []Identity{{Entry: foo}}, - }) - subA.Finish() - - // move clk by 1 sec so that SVID access time will be different - clk.Add(1 * time.Second) - bar := makeRegistrationEntry("BAR", "B") - // validate workload update for bar - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo, bar), - }, nil) - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(bar), - }) - - // not closing subscriber immediately - subB := subscribeToWorkloadUpdatesAndNotify(t, cache, makeSelectors("B")) - defer subB.Finish() - assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{ - Bundle: bundleV1, - Identities: []Identity{ - {Entry: bar}, - }, - }) - - // Move clk by 2 seconds - clk.Add(2 * time.Second) - // update total of 12 entries - updateEntries := createUpdateEntries(10, makeBundles(bundleV1)) - updateEntries.RegistrationEntries[foo.EntryId] = foo - updateEntries.RegistrationEntries[bar.EntryId] = bar - - cache.UpdateEntries(updateEntries, nil) - - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDsFromMap(updateEntries.RegistrationEntries), - }) - - for id, entry := range updateEntries.RegistrationEntries { - // create and close subscribers for remaining entries so that svid cache is full - if id != foo.EntryId && id != bar.EntryId { - sub := cache.NewSubscriber(entry.Selectors) - sub.Finish() - } - } - assert.Equal(t, 12, cache.CountSVIDs()) - - cache.UpdateEntries(updateEntries, nil) - assert.Equal(t, 10, cache.CountSVIDs()) - - // foo SVID should be removed from cache as it does not have active subscriber - assert.False(t, cache.Notify(makeSelectors("A"))) - // bar SVID should be cached as it has active subscriber - assert.True(t, cache.Notify(makeSelectors("B"))) - - subA = cache.NewSubscriber(makeSelectors("A")) - defer subA.Finish() - - cache.UpdateEntries(updateEntries, nil) - - // Make sure foo is marked as stale entry which does not have svid cached - require.Len(t, cache.GetStaleEntries(), 1) - assert.Equal(t, foo, cache.GetStaleEntries()[0].Entry) - - assert.Equal(t, 10, cache.CountSVIDs()) -} - -func TestMaxSVIDCacheSize(t *testing.T) { - clk := clock.NewMock() - cache := newTestCacheWithConfig(10, clk) - - // create entries more than maxSvidCacheSize - updateEntries := createUpdateEntries(12, makeBundles(bundleV1)) - cache.UpdateEntries(updateEntries, nil) - - require.Len(t, cache.GetStaleEntries(), 10) - - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDsFromStaleEntries(cache.GetStaleEntries()), - }) - require.Len(t, cache.GetStaleEntries(), 0) - assert.Equal(t, 10, cache.CountSVIDs()) - - // Validate that active subscriber will still get SVID even if SVID count is at maxSvidCacheSize - foo := makeRegistrationEntry("FOO", "A") - updateEntries.RegistrationEntries[foo.EntryId] = foo - - subA := cache.NewSubscriber(foo.Selectors) - defer subA.Finish() - - cache.UpdateEntries(updateEntries, nil) - require.Len(t, cache.GetStaleEntries(), 1) - assert.Equal(t, 10, cache.CountSVIDs()) - - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo), - }) - assert.Equal(t, 11, cache.CountSVIDs()) - require.Len(t, cache.GetStaleEntries(), 0) -} - -func TestSyncSVIDsWithSubscribers(t *testing.T) { - clk := clock.NewMock() - cache := newTestCacheWithConfig(5, clk) - - updateEntries := createUpdateEntries(5, makeBundles(bundleV1)) - cache.UpdateEntries(updateEntries, nil) - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDsFromStaleEntries(cache.GetStaleEntries()), - }) - assert.Equal(t, 5, cache.CountSVIDs()) - - // Update foo but its SVID is not yet cached - foo := makeRegistrationEntry("FOO", "A") - updateEntries.RegistrationEntries[foo.EntryId] = foo - - cache.UpdateEntries(updateEntries, nil) - - // Create a subscriber for foo - subA := cache.NewSubscriber(foo.Selectors) - defer subA.Finish() - require.Len(t, cache.GetStaleEntries(), 0) - - // After SyncSVIDsWithSubscribers foo should be marked as stale, requiring signing - cache.SyncSVIDsWithSubscribers() - require.Len(t, cache.GetStaleEntries(), 1) - assert.Equal(t, []*StaleEntry{{Entry: cache.records[foo.EntryId].entry}}, cache.GetStaleEntries()) - - assert.Equal(t, 5, cache.CountSVIDs()) -} - -func TestNotify(t *testing.T) { - cache := newTestCache() - - foo := makeRegistrationEntry("FOO", "A") - cache.UpdateEntries(&UpdateEntries{ - Bundles: makeBundles(bundleV1), - RegistrationEntries: makeRegistrationEntries(foo), - }, nil) - - assert.False(t, cache.Notify(makeSelectors("A"))) - cache.UpdateSVIDs(&UpdateSVIDs{ - X509SVIDs: makeX509SVIDs(foo), - }) - assert.True(t, cache.Notify(makeSelectors("A"))) -} - -func TestNewCache(t *testing.T) { - // negative value - cache := newTestCacheWithConfig(-5, clock.NewMock()) - require.Equal(t, DefaultSVIDCacheMaxSize, cache.svidCacheMaxSize) - - // zero value - cache = newTestCacheWithConfig(0, clock.NewMock()) - require.Equal(t, DefaultSVIDCacheMaxSize, cache.svidCacheMaxSize) -} - func BenchmarkCacheGlobalNotification(b *testing.B) { cache := newTestCache() @@ -823,7 +672,7 @@ func BenchmarkCacheGlobalNotification(b *testing.B) { cache.UpdateEntries(updateEntries, nil) for i := 0; i < numWorkloads; i++ { selectors := distinctSelectors(i, selectorsPerWorkload) - cache.NewSubscriber(selectors) + cache.subscribeToWorkloadUpdates(selectors) } runtime.GC() @@ -840,35 +689,40 @@ func BenchmarkCacheGlobalNotification(b *testing.B) { } } -func newTestCache() *Cache { - log, _ := test.NewNullLogger() - return New(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, - telemetry.Blackhole{}, 0, clock.NewMock()) -} +func TestMatchingRegistrationEntries(t *testing.T) { + cache := newTestCache() -func newTestCacheWithConfig(svidCacheMaxSize int, clk clock.Clock) *Cache { - log, _ := test.NewNullLogger() - return New(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, telemetry.Blackhole{}, - svidCacheMaxSize, clk) + // populate the cache with FOO and BAR without SVIDS + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + } + cache.UpdateEntries(updateEntries, nil) + + // Update SVIDs and MatchingRegistrationEntries should return both entries + updateSVIDs := &UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + } + cache.UpdateSVIDs(updateSVIDs) + assert.Equal(t, []*common.RegistrationEntry{bar, foo}, + cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) } -// numEntries should not be more than 12 digits -func createUpdateEntries(numEntries int, bundles map[spiffeid.TrustDomain]*bundleutil.Bundle) *UpdateEntries { +func TestEntries(t *testing.T) { + cache := newTestCache() + + // populate the cache with FOO and BAR without SVIDS + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") updateEntries := &UpdateEntries{ - Bundles: bundles, - RegistrationEntries: make(map[string]*common.RegistrationEntry, numEntries), + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), } + cache.UpdateEntries(updateEntries, nil) - for i := 0; i < numEntries; i++ { - entryID := fmt.Sprintf("00000000-0000-0000-0000-%012d", i) - updateEntries.RegistrationEntries[entryID] = &common.RegistrationEntry{ - EntryId: entryID, - ParentId: "spiffe://domain.test/node", - SpiffeId: fmt.Sprintf("spiffe://domain.test/workload-%d", i), - Selectors: distinctSelectors(i, 1), - } - } - return updateEntries + assert.Equal(t, []*common.RegistrationEntry{bar, foo}, cache.Entries()) } func distinctSelectors(id, n int) []*common.Selector { @@ -926,22 +780,6 @@ func makeX509SVIDs(entries ...*common.RegistrationEntry) map[string]*X509SVID { return out } -func makeX509SVIDsFromMap(entries map[string]*common.RegistrationEntry) map[string]*X509SVID { - out := make(map[string]*X509SVID) - for _, entry := range entries { - out[entry.EntryId] = &X509SVID{} - } - return out -} - -func makeX509SVIDsFromStaleEntries(entries []*StaleEntry) map[string]*X509SVID { - out := make(map[string]*X509SVID) - for _, entry := range entries { - out[entry.Entry.EntryId] = &X509SVID{} - } - return out -} - func makeRegistrationEntry(id string, selectors ...string) *common.RegistrationEntry { return &common.RegistrationEntry{ EntryId: id, @@ -985,9 +823,3 @@ func makeFederatesWith(bundles ...*Bundle) []string { } return out } - -func subscribeToWorkloadUpdatesAndNotify(t *testing.T, cache *Cache, selectors []*common.Selector) Subscriber { - subscriber := cache.NewSubscriber(selectors) - assert.True(t, cache.Notify(selectors)) - return subscriber -} diff --git a/pkg/agent/manager/cache/lru_cache.go b/pkg/agent/manager/cache/lru_cache.go new file mode 100644 index 0000000000..8516ca51d5 --- /dev/null +++ b/pkg/agent/manager/cache/lru_cache.go @@ -0,0 +1,944 @@ +package cache + +import ( + "context" + "sort" + "sync" + "time" + + "github.com/andres-erbsen/clock" + "github.com/sirupsen/logrus" + "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/spiffe/spire/pkg/agent/common/backoff" + "github.com/spiffe/spire/pkg/common/bundleutil" + "github.com/spiffe/spire/pkg/common/telemetry" + "github.com/spiffe/spire/proto/spire/common" +) + +const ( + DefaultSVIDCacheMaxSize = 1000 + SvidSyncInterval = 500 * time.Millisecond +) + +// Cache caches each registration entry, bundles, and JWT SVIDs for the agent. +// The signed X509-SVIDs for those entries are stored in LRU-like cache. +// It allows subscriptions by (workload) selector sets and notifies subscribers when: +// +// 1) a registration entry related to the selectors: +// * is modified +// * has a new X509-SVID signed for it +// * federates with a federated bundle that is updated +// 2) the trust bundle for the agent trust domain is updated +// +// When notified, the subscriber is given a WorkloadUpdate containing +// related identities and trust bundles. +// +// The cache does this efficiently by building an index for each unique +// selector it encounters. Each selector index tracks the subscribers (i.e +// workloads) and registration entries that have that selector. +// +// The LRU-like SVID cache has configurable size limit and expiry period. +// 1. Size limit of SVID cache is a soft limit. If SVID has a subscriber present then +// that SVID is never removed from cache. +// 2. Least recently used SVIDs are removed from cache only after the cache expiry period has passed. +// This is done to reduce the overall cache churn. +// 3. Last access timestamp for SVID cache entry is updated when a new subscriber is created +// 4. When a new subscriber is created and there is a cache miss +// then subscriber needs to wait for next SVID sync event to receive WorkloadUpdate with newly minted SVID +// +// The advantage of above approach is that if agent has entry count less than cache size +// then all SVIDs are cached at all times. If agent has entry count greater than cache size then +// subscribers will continue to get SVID updates (potential delay for first WorkloadUpdate if cache miss) +// and least used SVIDs will be removed from cache which will save memory usage. +// This allows agent to support environments where the active simultaneous workload count +// is a small percentage of the large number of registrations assigned to the agent. +// +// When registration entries are added/updated/removed, the set of relevant +// selectors are gathered and the indexes for those selectors are combed for +// all relevant subscribers. +// +// For each relevant subscriber, the selector index for each selector of the +// subscriber is combed for registration whose selectors are a subset of the +// subscriber selector set. Identities for those entries are added to the +// workload update returned to the subscriber. +// +// NOTE: The cache is intended to be able to handle thousands of workload +// subscriptions, which can involve thousands of certificates, keys, bundles, +// and registration entries, etc. The selector index itself is intended to be +// scalable, but the objects themselves can take a considerable amount of +// memory. For maximal safety, the objects should be cloned both coming in and +// leaving the cache. However, during global updates (e.g. trust bundle is +// updated for the agent trust domain) in particular, cloning all of the +// relevant objects for each subscriber causes HUGE amounts of memory pressure +// which adds non-trivial amounts of latency and causes a giant memory spike +// that could OOM the agent on smaller VMs. For this reason, the cache is +// presumed to own ALL data passing in and out of the cache. Producers and +// consumers MUST NOT mutate the data. +type LRUCache struct { + *BundleCache + *JWTSVIDCache + + log logrus.FieldLogger + trustDomain spiffeid.TrustDomain + clk clock.Clock + + metrics telemetry.Metrics + + mu sync.RWMutex + + // records holds the records for registration entries, keyed by registration entry ID + records map[string]*lruCacheRecord + + // selectors holds the selector indices, keyed by a selector key + selectors map[selector]*selectorsMapIndex + + // staleEntries holds stale or new registration entries which require new SVID to be stored in cache + staleEntries map[string]bool + + // bundles holds the trust bundles, keyed by trust domain id (i.e. "spiffe://domain.test") + bundles map[spiffeid.TrustDomain]*bundleutil.Bundle + + // svids are stored by entry IDs + svids map[string]*X509SVID + + // svidCacheMaxSize is a soft limit of max number of SVIDs that would be stored in cache + svidCacheMaxSize int + subscribeBackoffFn func() backoff.BackOff +} + +func NewLRUCache(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundle, metrics telemetry.Metrics, + svidCacheMaxSize int, clk clock.Clock) *LRUCache { + if svidCacheMaxSize <= 0 { + svidCacheMaxSize = DefaultSVIDCacheMaxSize + } + + return &LRUCache{ + BundleCache: NewBundleCache(trustDomain, bundle), + JWTSVIDCache: NewJWTSVIDCache(), + + log: log, + metrics: metrics, + trustDomain: trustDomain, + records: make(map[string]*lruCacheRecord), + selectors: make(map[selector]*selectorsMapIndex), + staleEntries: make(map[string]bool), + bundles: map[spiffeid.TrustDomain]*bundleutil.Bundle{ + trustDomain: bundle, + }, + svids: make(map[string]*X509SVID), + svidCacheMaxSize: svidCacheMaxSize, + clk: clk, + subscribeBackoffFn: func() backoff.BackOff { + return backoff.NewBackoff(clk, SvidSyncInterval) + }, + } +} + +// Identities is only used by manager tests +// TODO: We should remove this and find a better way +func (c *LRUCache) Identities() []Identity { + c.mu.RLock() + defer c.mu.RUnlock() + + out := make([]Identity, 0, len(c.records)) + for _, record := range c.records { + svid, ok := c.svids[record.entry.EntryId] + if !ok { + // The record does not have an SVID yet and should not be returned + // from the cache. + continue + } + out = append(out, makeNewIdentity(record, svid)) + } + sortIdentities(out) + return out +} + +func (c *LRUCache) Entries() []*common.RegistrationEntry { + c.mu.RLock() + defer c.mu.RUnlock() + + out := make([]*common.RegistrationEntry, 0, len(c.records)) + for _, record := range c.records { + out = append(out, record.entry) + } + sortEntriesByID(out) + return out +} + +func (c *LRUCache) CountSVIDs() int { + c.mu.RLock() + defer c.mu.RUnlock() + + return len(c.svids) +} + +func (c *LRUCache) MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry { + set, setDone := allocSelectorSet(selectors...) + defer setDone() + + c.mu.RLock() + defer c.mu.RUnlock() + return c.matchingEntries(set) +} + +func (c *LRUCache) FetchWorkloadUpdate(selectors []*common.Selector) *WorkloadUpdate { + set, setDone := allocSelectorSet(selectors...) + defer setDone() + + c.mu.RLock() + defer c.mu.RUnlock() + return c.buildWorkloadUpdate(set) +} + +// NewSubscriber creates a subscriber for given selector set. +// Separately call Notify for the first time after this method is invoked to receive latest updates. +func (c *LRUCache) NewSubscriber(selectors []*common.Selector) Subscriber { + c.mu.Lock() + defer c.mu.Unlock() + + sub := newLRUCacheSubscriber(c, selectors) + for s := range sub.set { + c.addSelectorIndexSub(s, sub) + } + // update lastAccessTimestamp of records containing provided selectors + c.updateLastAccessTimestamp(selectors) + return sub +} + +// UpdateEntries updates the cache with the provided registration entries and bundles and +// notifies impacted subscribers. The checkSVID callback, if provided, is used to determine +// if the SVID for the entry is stale, or otherwise in need of rotation. Entries marked stale +// through the checkSVID callback are returned from GetStaleEntries() until the SVID is +// updated through a call to UpdateSVIDs. +func (c *LRUCache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.RegistrationEntry, *common.RegistrationEntry, *X509SVID) bool) { + c.mu.Lock() + defer c.mu.Unlock() + + // Remove bundles that no longer exist. The bundle for the agent trust + // domain should NOT be removed even if not present (which should only be + // the case if there is a bug on the server) since it is necessary to + // authenticate the server. + bundleRemoved := false + for id := range c.bundles { + if _, ok := update.Bundles[id]; !ok && id != c.trustDomain { + bundleRemoved = true + // bundle no longer exists. + c.log.WithField(telemetry.TrustDomainID, id).Debug("Bundle removed") + delete(c.bundles, id) + } + } + + // Update bundles with changes, populating a "changed" set that we can + // check when processing registration entries to know if they need to spawn + // a notification. + bundleChanged := make(map[spiffeid.TrustDomain]bool) + for id, bundle := range update.Bundles { + existing, ok := c.bundles[id] + if !(ok && existing.EqualTo(bundle)) { + if !ok { + c.log.WithField(telemetry.TrustDomainID, id).Debug("Bundle added") + } else { + c.log.WithField(telemetry.TrustDomainID, id).Debug("Bundle updated") + } + bundleChanged[id] = true + c.bundles[id] = bundle + } + } + trustDomainBundleChanged := bundleChanged[c.trustDomain] + + // Allocate sets from the pool to track changes to selectors and + // federatesWith declarations. These sets must be cleared after EACH use + // and returned to their respective pools when done processing the + // updates. + notifySets := make([]selectorSet, 0) + selAdd, selAddDone := allocSelectorSet() + defer selAddDone() + selRem, selRemDone := allocSelectorSet() + defer selRemDone() + fedAdd, fedAddDone := allocStringSet() + defer fedAddDone() + fedRem, fedRemDone := allocStringSet() + defer fedRemDone() + + // Remove records for registration entries that no longer exist + for id, record := range c.records { + if _, ok := update.RegistrationEntries[id]; !ok { + c.log.WithFields(logrus.Fields{ + telemetry.Entry: id, + telemetry.SPIFFEID: record.entry.SpiffeId, + }).Debug("Entry removed") + + // built a set of selectors for the record being removed, drop the + // record for each selector index, and add the entry selectors to + // the notify set. + clearSelectorSet(selRem) + selRem.Merge(record.entry.Selectors...) + c.delSelectorIndicesRecord(selRem, record) + notifySets = append(notifySets, selRem) + delete(c.records, id) + delete(c.svids, id) + // Remove stale entry since, registration entry is no longer on cache. + delete(c.staleEntries, id) + } + } + + outdatedEntries := make(map[string]struct{}) + + // Add/update records for registration entries in the update + for _, newEntry := range update.RegistrationEntries { + clearSelectorSet(selAdd) + clearSelectorSet(selRem) + clearStringSet(fedAdd) + clearStringSet(fedRem) + + record, existingEntry := c.updateOrCreateRecord(newEntry) + + // Calculate the difference in selectors, add/remove the record + // from impacted selector indices, and add the selector diff to the + // notify set. + c.diffSelectors(existingEntry, newEntry, selAdd, selRem) + selectorsChanged := len(selAdd) > 0 || len(selRem) > 0 + c.addSelectorIndicesRecord(selAdd, record) + c.delSelectorIndicesRecord(selRem, record) + + // Determine if there were changes to FederatesWith declarations or + // if any federated bundles related to the entry were updated. + c.diffFederatesWith(existingEntry, newEntry, fedAdd, fedRem) + federatedBundlesChanged := len(fedAdd) > 0 || len(fedRem) > 0 + if !federatedBundlesChanged { + for _, id := range newEntry.FederatesWith { + td, err := spiffeid.TrustDomainFromString(id) + if err != nil { + c.log.WithFields(logrus.Fields{ + telemetry.TrustDomainID: id, + logrus.ErrorKey: err, + }).Warn("Invalid federated trust domain") + continue + } + if bundleChanged[td] { + federatedBundlesChanged = true + break + } + } + } + + // If any selectors or federated bundles were changed, then make + // sure subscribers for the new and extisting entry selector sets + // are notified. + if selectorsChanged { + if existingEntry != nil { + notifySet, selSetDone := allocSelectorSet() + defer selSetDone() + notifySet.Merge(existingEntry.Selectors...) + notifySets = append(notifySets, notifySet) + } + } + + if federatedBundlesChanged || selectorsChanged { + notifySet, selSetDone := allocSelectorSet() + defer selSetDone() + notifySet.Merge(newEntry.Selectors...) + notifySets = append(notifySets, notifySet) + } + + // Identify stale/outdated entries + if existingEntry != nil && existingEntry.RevisionNumber != newEntry.RevisionNumber { + outdatedEntries[newEntry.EntryId] = struct{}{} + } + + // Log all the details of the update to the DEBUG log + if federatedBundlesChanged || selectorsChanged { + log := c.log.WithFields(logrus.Fields{ + telemetry.Entry: newEntry.EntryId, + telemetry.SPIFFEID: newEntry.SpiffeId, + }) + if len(selAdd) > 0 { + log = log.WithField(telemetry.SelectorsAdded, len(selAdd)) + } + if len(selRem) > 0 { + log = log.WithField(telemetry.SelectorsRemoved, len(selRem)) + } + if len(fedAdd) > 0 { + log = log.WithField(telemetry.FederatedAdded, len(fedAdd)) + } + if len(fedRem) > 0 { + log = log.WithField(telemetry.FederatedRemoved, len(fedRem)) + } + if existingEntry != nil { + log.Debug("Entry updated") + } else { + log.Debug("Entry created") + } + } + } + + // entries with active subscribers which are not cached will be put in staleEntries map; + // irrespective of what svid cache size as we cannot deny identity to a subscriber + activeSubsByEntryID, recordsWithLastAccessTime := c.syncSVIDsWithSubscribers() + extraSize := len(c.svids) - c.svidCacheMaxSize + + // delete svids without subscribers and which have not been accessed since svidCacheExpiryTime + if extraSize > 0 { + // sort recordsWithLastAccessTime + sortByTimestamps(recordsWithLastAccessTime) + + for _, record := range recordsWithLastAccessTime { + if extraSize <= 0 { + // no need to delete SVIDs any further as cache size <= svidCacheMaxSize + break + } + if _, ok := c.svids[record.id]; ok { + if _, exists := activeSubsByEntryID[record.id]; !exists { + // remove svid + c.log.WithField("record_id", record.id). + WithField("record_timestamp", record.timestamp). + Debug("Removing SVID record") + delete(c.svids, record.id) + extraSize-- + } + } + } + } + + // Update all stale svids or svids whose registration entry is outdated + for id, svid := range c.svids { + if _, ok := outdatedEntries[id]; ok || (checkSVID != nil && checkSVID(nil, c.records[id].entry, svid)) { + c.staleEntries[id] = true + } + } + c.log.WithField(telemetry.OutdatedSVIDs, len(outdatedEntries)). + Debug("Updating SVIDs with outdated attributes in cache") + + if bundleRemoved || len(bundleChanged) > 0 { + c.BundleCache.Update(c.bundles) + } + + if trustDomainBundleChanged { + c.notifyAll() + } else { + c.notifyBySelectorSet(notifySets...) + } +} + +func (c *LRUCache) UpdateSVIDs(update *UpdateSVIDs) { + c.mu.Lock() + defer c.mu.Unlock() + + // Allocate a set of selectors that + notifySet, selSetDone := allocSelectorSet() + defer selSetDone() + + // Add/update records for registration entries in the update + for entryID, svid := range update.X509SVIDs { + record, existingEntry := c.records[entryID] + if !existingEntry { + c.log.WithField(telemetry.RegistrationID, entryID).Error("Entry not found") + continue + } + + c.svids[entryID] = svid + notifySet.Merge(record.entry.Selectors...) + log := c.log.WithFields(logrus.Fields{ + telemetry.Entry: record.entry.EntryId, + telemetry.SPIFFEID: record.entry.SpiffeId, + }) + log.Debug("SVID updated") + + // Registration entry is updated, remove it from stale map + delete(c.staleEntries, entryID) + c.notifyBySelectorSet(notifySet) + clearSelectorSet(notifySet) + } +} + +// GetStaleEntries obtains a list of stale entries +func (c *LRUCache) GetStaleEntries() []*StaleEntry { + c.mu.Lock() + defer c.mu.Unlock() + + var staleEntries []*StaleEntry + for entryID := range c.staleEntries { + cachedEntry, ok := c.records[entryID] + if !ok { + c.log.WithField(telemetry.RegistrationID, entryID).Debug("Stale marker found for unknown entry. Please fill a bug") + delete(c.staleEntries, entryID) + continue + } + + var expiresAt time.Time + if cachedSvid, ok := c.svids[entryID]; ok { + expiresAt = cachedSvid.Chain[0].NotAfter + } + + staleEntries = append(staleEntries, &StaleEntry{ + Entry: cachedEntry.entry, + ExpiresAt: expiresAt, + }) + } + + return staleEntries +} + +// SyncSVIDsWithSubscribers will sync svid cache: +// entries with active subscribers which are not cached will be put in staleEntries map +// records which are not cached for remainder of max cache size will also be put in staleEntries map +func (c *LRUCache) SyncSVIDsWithSubscribers() { + c.mu.Lock() + defer c.mu.Unlock() + + c.syncSVIDsWithSubscribers() +} + +// Notify subscribers of selector set only if all SVIDs for corresponding selector set are cached +// It returns whether all SVIDs are cached or not. +// This method should be retried with backoff to avoid lock contention. +func (c *LRUCache) Notify(selectors []*common.Selector) bool { + c.mu.RLock() + defer c.mu.RUnlock() + set, setFree := allocSelectorSet(selectors...) + defer setFree() + if !c.missingSVIDRecords(set) { + c.notifyBySelectorSet(set) + return true + } + return false +} + +func (c *LRUCache) SubscribeToWorkloadUpdates(ctx context.Context, selectors Selectors) (Subscriber, error) { + return c.subscribeToWorkloadUpdates(ctx, selectors, nil) +} + +func (c *LRUCache) subscribeToWorkloadUpdates(ctx context.Context, selectors Selectors, notifyCallbackFn func()) (Subscriber, error) { + subscriber := c.NewSubscriber(selectors) + bo := c.subscribeBackoffFn() + // block until all svids are cached and subscriber is notified + for { + // notifyCallbackFn is used for testing + if c.Notify(selectors) { + if notifyCallbackFn != nil { + notifyCallbackFn() + } + return subscriber, nil + } + c.log.WithField(telemetry.Selectors, selectors).Info("Waiting for SVID to get cached") + // used for testing + if notifyCallbackFn != nil { + notifyCallbackFn() + } + + select { + case <-ctx.Done(): + subscriber.Finish() + return nil, ctx.Err() + case <-c.clk.After(bo.NextBackOff()): + } + } +} + +func (c *LRUCache) missingSVIDRecords(set selectorSet) bool { + records, recordsDone := c.getRecordsForSelectors(set) + defer recordsDone() + + for record := range records { + if _, exists := c.svids[record.entry.EntryId]; !exists { + return true + } + } + return false +} + +func (c *LRUCache) updateLastAccessTimestamp(selectors []*common.Selector) { + set, setFree := allocSelectorSet(selectors...) + defer setFree() + + records, recordsDone := c.getRecordsForSelectors(set) + defer recordsDone() + + now := c.clk.Now().UnixMilli() + for record := range records { + // Set lastAccessTimestamp so that svid LRU cache can be cleaned based on this timestamp + record.lastAccessTimestamp = now + } +} + +// entries with active subscribers which are not cached will be put in staleEntries map +// records which are not cached for remainder of max cache size will also be put in staleEntries map +func (c *LRUCache) syncSVIDsWithSubscribers() (map[string]struct{}, []recordAccessEvent) { + activeSubsByEntryID := make(map[string]struct{}) + lastAccessTimestamps := make([]recordAccessEvent, 0, len(c.records)) + + // iterate over all selectors from cached entries and obtain: + // 1. entries that have active subscribers + // 1.1 if those entries don't have corresponding SVID cached then put them in staleEntries + // so that SVID will be cached in next sync + // 2. get lastAccessTimestamp of each entry + for id, record := range c.records { + for _, sel := range record.entry.Selectors { + if index, ok := c.selectors[makeSelector(sel)]; ok && index != nil { + if len(index.subs) > 0 { + if _, ok := c.svids[record.entry.EntryId]; !ok { + c.staleEntries[id] = true + } + activeSubsByEntryID[id] = struct{}{} + break + } + } + } + lastAccessTimestamps = append(lastAccessTimestamps, newRecordAccessEvent(record.lastAccessTimestamp, id)) + } + + remainderSize := c.svidCacheMaxSize - len(c.svids) + // add records which are not cached for remainder of cache size + for id := range c.records { + if len(c.staleEntries) >= remainderSize { + break + } + if _, svidCached := c.svids[id]; !svidCached { + if _, ok := c.staleEntries[id]; !ok { + c.staleEntries[id] = true + } + } + } + + return activeSubsByEntryID, lastAccessTimestamps +} + +func (c *LRUCache) updateOrCreateRecord(newEntry *common.RegistrationEntry) (*lruCacheRecord, *common.RegistrationEntry) { + var existingEntry *common.RegistrationEntry + record, recordExists := c.records[newEntry.EntryId] + if !recordExists { + record = newLRUCacheRecord() + c.records[newEntry.EntryId] = record + } else { + existingEntry = record.entry + } + record.entry = newEntry + return record, existingEntry +} + +func (c *LRUCache) diffSelectors(existingEntry, newEntry *common.RegistrationEntry, added, removed selectorSet) { + // Make a set of all the selectors being added + if newEntry != nil { + added.Merge(newEntry.Selectors...) + } + + // Make a set of all the selectors that are being removed + if existingEntry != nil { + for _, selector := range existingEntry.Selectors { + s := makeSelector(selector) + if _, ok := added[s]; ok { + // selector already exists in entry + delete(added, s) + } else { + // selector has been removed from entry + removed[s] = struct{}{} + } + } + } +} + +func (c *LRUCache) diffFederatesWith(existingEntry, newEntry *common.RegistrationEntry, added, removed stringSet) { + // Make a set of all the selectors being added + if newEntry != nil { + added.Merge(newEntry.FederatesWith...) + } + + // Make a set of all the selectors that are being removed + if existingEntry != nil { + for _, id := range existingEntry.FederatesWith { + if _, ok := added[id]; ok { + // Bundle already exists in entry + delete(added, id) + } else { + // Bundle has been removed from entry + removed[id] = struct{}{} + } + } + } +} + +func (c *LRUCache) addSelectorIndicesRecord(selectors selectorSet, record *lruCacheRecord) { + for selector := range selectors { + c.addSelectorIndexRecord(selector, record) + } +} + +func (c *LRUCache) addSelectorIndexRecord(s selector, record *lruCacheRecord) { + index := c.getSelectorIndexForWrite(s) + index.records[record] = struct{}{} +} + +func (c *LRUCache) delSelectorIndicesRecord(selectors selectorSet, record *lruCacheRecord) { + for selector := range selectors { + c.delSelectorIndexRecord(selector, record) + } +} + +// delSelectorIndexRecord removes the record from the selector index. If +// the selector index is empty afterwards, it is also removed. +func (c *LRUCache) delSelectorIndexRecord(s selector, record *lruCacheRecord) { + index, ok := c.selectors[s] + if ok { + delete(index.records, record) + if index.isEmpty() { + delete(c.selectors, s) + } + } +} + +func (c *LRUCache) addSelectorIndexSub(s selector, sub *lruCacheSubscriber) { + index := c.getSelectorIndexForWrite(s) + index.subs[sub] = struct{}{} +} + +// delSelectorIndexSub removes the subscription from the selector index. If +// the selector index is empty afterwards, it is also removed. +func (c *LRUCache) delSelectorIndexSub(s selector, sub *lruCacheSubscriber) { + index, ok := c.selectors[s] + if ok { + delete(index.subs, sub) + if index.isEmpty() { + delete(c.selectors, s) + } + } +} + +func (c *LRUCache) unsubscribe(sub *lruCacheSubscriber) { + c.mu.Lock() + defer c.mu.Unlock() + for selector := range sub.set { + c.delSelectorIndexSub(selector, sub) + } +} + +func (c *LRUCache) notifyAll() { + subs, subsDone := c.allSubscribers() + defer subsDone() + for sub := range subs { + c.notify(sub) + } +} + +func (c *LRUCache) notifyBySelectorSet(sets ...selectorSet) { + notifiedSubs, notifiedSubsDone := allocLRUCacheSubscriberSet() + defer notifiedSubsDone() + for _, set := range sets { + subs, subsDone := c.getSubscribers(set) + defer subsDone() + for sub := range subs { + if _, notified := notifiedSubs[sub]; !notified && sub.set.SuperSetOf(set) { + c.notify(sub) + notifiedSubs[sub] = struct{}{} + } + } + } +} + +func (c *LRUCache) notify(sub *lruCacheSubscriber) { + update := c.buildWorkloadUpdate(sub.set) + sub.notify(update) +} + +func (c *LRUCache) allSubscribers() (lruCacheSubscriberSet, func()) { + subs, subsDone := allocLRUCacheSubscriberSet() + for _, index := range c.selectors { + for sub := range index.subs { + subs[sub] = struct{}{} + } + } + return subs, subsDone +} + +func (c *LRUCache) getSubscribers(set selectorSet) (lruCacheSubscriberSet, func()) { + subs, subsDone := allocLRUCacheSubscriberSet() + for s := range set { + if index := c.getSelectorIndexForRead(s); index != nil { + for sub := range index.subs { + subs[sub] = struct{}{} + } + } + } + return subs, subsDone +} + +func (c *LRUCache) matchingIdentities(set selectorSet) []Identity { + records, recordsDone := c.getRecordsForSelectors(set) + defer recordsDone() + + if len(records) == 0 { + return nil + } + + // Return identities in ascending "entry id" order to maintain a consistent + // ordering. + // TODO: figure out how to determine the "default" identity + out := make([]Identity, 0, len(records)) + for record := range records { + if svid, ok := c.svids[record.entry.EntryId]; ok { + out = append(out, makeNewIdentity(record, svid)) + } + } + sortIdentities(out) + return out +} + +func (c *LRUCache) matchingEntries(set selectorSet) []*common.RegistrationEntry { + records, recordsDone := c.getRecordsForSelectors(set) + defer recordsDone() + + if len(records) == 0 { + return nil + } + + // Return identities in ascending "entry id" order to maintain a consistent + // ordering. + // TODO: figure out how to determine the "default" identity + out := make([]*common.RegistrationEntry, 0, len(records)) + for record := range records { + out = append(out, record.entry) + } + sortEntriesByID(out) + return out +} + +func (c *LRUCache) buildWorkloadUpdate(set selectorSet) *WorkloadUpdate { + w := &WorkloadUpdate{ + Bundle: c.bundles[c.trustDomain], + FederatedBundles: make(map[spiffeid.TrustDomain]*bundleutil.Bundle), + Identities: c.matchingIdentities(set), + } + + // Add in the bundles the workload is federated with. + for _, identity := range w.Identities { + for _, federatesWith := range identity.Entry.FederatesWith { + td, err := spiffeid.TrustDomainFromString(federatesWith) + if err != nil { + c.log.WithFields(logrus.Fields{ + telemetry.TrustDomainID: federatesWith, + logrus.ErrorKey: err, + }).Warn("Invalid federated trust domain") + continue + } + if federatedBundle := c.bundles[td]; federatedBundle != nil { + w.FederatedBundles[td] = federatedBundle + } else { + c.log.WithFields(logrus.Fields{ + telemetry.RegistrationID: identity.Entry.EntryId, + telemetry.SPIFFEID: identity.Entry.SpiffeId, + telemetry.FederatedBundle: federatesWith, + }).Warn("Federated bundle contents missing") + } + } + } + + return w +} + +func (c *LRUCache) getRecordsForSelectors(set selectorSet) (lruCacheRecordSet, func()) { + // Build and dedup a list of candidate entries. Don't check for selector set inclusion yet, since + // that is a more expensive operation and we could easily have duplicate + // entries to check. + records, recordsDone := allocLRUCacheRecordSet() + for selector := range set { + if index := c.getSelectorIndexForRead(selector); index != nil { + for record := range index.records { + records[record] = struct{}{} + } + } + } + + // Filter out records whose registration entry selectors are not within + // inside the selector set. + for record := range records { + for _, s := range record.entry.Selectors { + if !set.In(s) { + delete(records, record) + } + } + } + return records, recordsDone +} + +// getSelectorIndexForWrite gets the selector index for the selector. If one +// doesn't exist, it is created. Callers must hold the write lock. If the index +// is only being read, then getSelectorIndexForRead should be used instead. +func (c *LRUCache) getSelectorIndexForWrite(s selector) *selectorsMapIndex { + index, ok := c.selectors[s] + if !ok { + index = newSelectorsMapIndex() + c.selectors[s] = index + } + return index +} + +// getSelectorIndexForRead gets the selector index for the selector. If one +// doesn't exist, nil is returned. Callers should hold the read or write lock. +// If the index is being modified, callers should use getSelectorIndexForWrite +// instead. +func (c *LRUCache) getSelectorIndexForRead(s selector) *selectorsMapIndex { + if index, ok := c.selectors[s]; ok { + return index + } + return nil +} + +type lruCacheRecord struct { + entry *common.RegistrationEntry + subs map[*lruCacheSubscriber]struct{} + lastAccessTimestamp int64 +} + +func newLRUCacheRecord() *lruCacheRecord { + return &lruCacheRecord{ + subs: make(map[*lruCacheSubscriber]struct{}), + } +} + +type selectorsMapIndex struct { + // subs holds the subscriptions related to this selector + subs map[*lruCacheSubscriber]struct{} + + // records holds the cache records related to this selector + records map[*lruCacheRecord]struct{} +} + +func (x *selectorsMapIndex) isEmpty() bool { + return len(x.subs) == 0 && len(x.records) == 0 +} + +func newSelectorsMapIndex() *selectorsMapIndex { + return &selectorsMapIndex{ + subs: make(map[*lruCacheSubscriber]struct{}), + records: make(map[*lruCacheRecord]struct{}), + } +} + +func sortEntriesByID(entries []*common.RegistrationEntry) { + sort.Slice(entries, func(a, b int) bool { + return entries[a].EntryId < entries[b].EntryId + }) +} + +func sortByTimestamps(records []recordAccessEvent) { + sort.Slice(records, func(a, b int) bool { + return records[a].timestamp < records[b].timestamp + }) +} + +func makeNewIdentity(record *lruCacheRecord, svid *X509SVID) Identity { + return Identity{ + Entry: record.entry, + SVID: svid.Chain, + PrivateKey: svid.PrivateKey, + } +} + +type recordAccessEvent struct { + timestamp int64 + id string +} + +func newRecordAccessEvent(timestamp int64, id string) recordAccessEvent { + return recordAccessEvent{timestamp: timestamp, id: id} +} diff --git a/pkg/agent/manager/cache/lru_cache_subscriber.go b/pkg/agent/manager/cache/lru_cache_subscriber.go new file mode 100644 index 0000000000..00556f89a9 --- /dev/null +++ b/pkg/agent/manager/cache/lru_cache_subscriber.go @@ -0,0 +1,60 @@ +package cache + +import ( + "sync" + + "github.com/spiffe/spire/proto/spire/common" +) + +type lruCacheSubscriber struct { + cache *LRUCache + set selectorSet + setFree func() + + mu sync.Mutex + c chan *WorkloadUpdate + done bool +} + +func newLRUCacheSubscriber(cache *LRUCache, selectors []*common.Selector) *lruCacheSubscriber { + set, setFree := allocSelectorSet(selectors...) + return &lruCacheSubscriber{ + cache: cache, + set: set, + setFree: setFree, + c: make(chan *WorkloadUpdate, 1), + } +} + +func (s *lruCacheSubscriber) Updates() <-chan *WorkloadUpdate { + return s.c +} + +func (s *lruCacheSubscriber) Finish() { + s.mu.Lock() + done := s.done + if !done { + s.done = true + close(s.c) + } + s.mu.Unlock() + if !done { + s.cache.unsubscribe(s) + s.setFree() + s.set = nil + } +} + +func (s *lruCacheSubscriber) notify(update *WorkloadUpdate) { + s.mu.Lock() + defer s.mu.Unlock() + if s.done { + return + } + + select { + case <-s.c: + default: + } + s.c <- update +} diff --git a/pkg/agent/manager/cache/lru_cache_test.go b/pkg/agent/manager/cache/lru_cache_test.go new file mode 100644 index 0000000000..c270dd3aee --- /dev/null +++ b/pkg/agent/manager/cache/lru_cache_test.go @@ -0,0 +1,954 @@ +package cache + +import ( + "context" + "crypto/x509" + "fmt" + "runtime" + "testing" + "time" + + "github.com/andres-erbsen/clock" + "github.com/sirupsen/logrus/hooks/test" + "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/spiffe/spire/pkg/common/bundleutil" + "github.com/spiffe/spire/pkg/common/telemetry" + "github.com/spiffe/spire/proto/spire/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLRUCacheFetchWorkloadUpdate(t *testing.T) { + cache := newTestLRUCache() + // populate the cache with FOO and BAR without SVIDS + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") + bar.FederatesWith = makeFederatesWith(otherBundleV1) + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1, otherBundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + } + cache.UpdateEntries(updateEntries, nil) + + workloadUpdate := cache.FetchWorkloadUpdate(makeSelectors("A", "B")) + assert.Len(t, workloadUpdate.Identities, 0, "identities should not be returned that don't have SVIDs") + + updateSVIDs := &UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + } + cache.UpdateSVIDs(updateSVIDs) + + workloadUpdate = cache.FetchWorkloadUpdate(makeSelectors("A", "B")) + assert.Equal(t, &WorkloadUpdate{ + Bundle: bundleV1, + FederatedBundles: makeBundles(otherBundleV1), + Identities: []Identity{ + {Entry: bar}, + {Entry: foo}, + }, + }, workloadUpdate) +} + +func TestLRUCacheMatchingRegistrationIdentities(t *testing.T) { + cache := newTestLRUCache() + + // populate the cache with FOO and BAR without SVIDS + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + } + cache.UpdateEntries(updateEntries, nil) + + assert.Equal(t, []*common.RegistrationEntry{bar, foo}, + cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) + + // Update SVIDs and MatchingRegistrationEntries should return both entries + updateSVIDs := &UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + } + cache.UpdateSVIDs(updateSVIDs) + assert.Equal(t, []*common.RegistrationEntry{bar, foo}, + cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) + + // Remove SVIDs and MatchingRegistrationEntries should still return both entries + cache.UpdateSVIDs(&UpdateSVIDs{}) + assert.Equal(t, []*common.RegistrationEntry{bar, foo}, + cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) +} + +func TestLRUCacheCountSVIDs(t *testing.T) { + cache := newTestLRUCache() + + // populate the cache with FOO and BAR without SVIDS + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + } + cache.UpdateEntries(updateEntries, nil) + + // No SVIDs expected + require.Equal(t, 0, cache.CountSVIDs()) + + updateSVIDs := &UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + } + cache.UpdateSVIDs(updateSVIDs) + + // Only one SVID expected + require.Equal(t, 1, cache.CountSVIDs()) +} + +func TestLRUCacheBundleChanges(t *testing.T) { + cache := newTestLRUCache() + + bundleStream := cache.SubscribeToBundleChanges() + assert.Equal(t, makeBundles(bundleV1), bundleStream.Value()) + + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1, otherBundleV1), + }, nil) + if assert.True(t, bundleStream.HasNext(), "has new bundle value after adding bundle") { + bundleStream.Next() + assert.Equal(t, makeBundles(bundleV1, otherBundleV1), bundleStream.Value()) + } + + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + }, nil) + + if assert.True(t, bundleStream.HasNext(), "has new bundle value after removing bundle") { + bundleStream.Next() + assert.Equal(t, makeBundles(bundleV1), bundleStream.Value()) + } +} + +func TestLRUCacheAllSubscribersNotifiedOnBundleChange(t *testing.T) { + cache := newTestLRUCache() + + // create some subscribers and assert they get the initial bundle + subA := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer subA.Finish() + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{Bundle: bundleV1}) + + subB := subscribeToWorkloadUpdates(t, cache, makeSelectors("B")) + defer subB.Finish() + assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{Bundle: bundleV1}) + + // update the bundle and assert all subscribers gets the updated bundle + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + }, nil) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{Bundle: bundleV2}) + assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{Bundle: bundleV2}) +} + +func TestLRUCacheSomeSubscribersNotifiedOnFederatedBundleChange(t *testing.T) { + cache := newTestLRUCache() + + // initialize the cache with an entry FOO that has a valid SVID and + // selector "A" + foo := makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + // subscribe to A and B and assert initial updates are received. + subA := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer subA.Finish() + assertAnyWorkloadUpdate(t, subA) + + subB := subscribeToWorkloadUpdates(t, cache, makeSelectors("B")) + defer subB.Finish() + assertAnyWorkloadUpdate(t, subB) + + // add the federated bundle with no registration entries federating with + // it and make sure nobody is notified. + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1, otherBundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + assertNoWorkloadUpdate(t, subA) + assertNoWorkloadUpdate(t, subB) + + // update FOO to federate with otherdomain.test and make sure subA is + // notified but not subB. + foo = makeRegistrationEntry("FOO", "A") + foo.FederatesWith = makeFederatesWith(otherBundleV1) + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1, otherBundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + FederatedBundles: makeBundles(otherBundleV1), + Identities: []Identity{{Entry: foo}}, + }) + assertNoWorkloadUpdate(t, subB) + + // now change the federated bundle and make sure subA gets notified, but + // again, not subB. + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1, otherBundleV2), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + FederatedBundles: makeBundles(otherBundleV2), + Identities: []Identity{{Entry: foo}}, + }) + assertNoWorkloadUpdate(t, subB) + + // now drop the federation and make sure subA is again notified and no + // longer has the federated bundle. + foo = makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1, otherBundleV2), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) + assertNoWorkloadUpdate(t, subB) +} + +func TestLRUCacheSubscribersGetEntriesWithSelectorSubsets(t *testing.T) { + cache := newTestLRUCache() + + // create subscribers for each combination of selectors + subA := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer subA.Finish() + subB := subscribeToWorkloadUpdates(t, cache, makeSelectors("B")) + defer subB.Finish() + subAB := subscribeToWorkloadUpdates(t, cache, makeSelectors("A", "B")) + defer subAB.Finish() + + // assert all subscribers get the initial update + initialUpdate := &WorkloadUpdate{Bundle: bundleV1} + assertWorkloadUpdateEqual(t, subA, initialUpdate) + assertWorkloadUpdateEqual(t, subB, initialUpdate) + assertWorkloadUpdateEqual(t, subAB, initialUpdate) + + // create entry FOO that will target any subscriber with containing (A) + foo := makeRegistrationEntry("FOO", "A") + + // create entry BAR that will target any subscriber with containing (A,C) + bar := makeRegistrationEntry("BAR", "A", "C") + + // update the cache with foo and bar + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + }) + + // subA selector set contains (A), but not (A, C), so it should only get FOO + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) + + // subB selector set does not contain either (A) or (A,C) so it isn't even + // notified. + assertNoWorkloadUpdate(t, subB) + + // subAB selector set contains (A) but not (A, C), so it should get FOO + assertWorkloadUpdateEqual(t, subAB, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) +} + +func TestLRUCacheSubscriberIsNotNotifiedIfNothingChanges(t *testing.T) { + cache := newTestLRUCache() + + foo := makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + sub := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer sub.Finish() + assertAnyWorkloadUpdate(t, sub) + + // Second update is the same (other than X509SVIDs, which, when set, + // always constitute a "change" for the impacted registration entries. + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + + assertNoWorkloadUpdate(t, sub) +} + +func TestLRUCacheSubscriberNotifiedOnSVIDChanges(t *testing.T) { + cache := newTestLRUCache() + + foo := makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + sub := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer sub.Finish() + assertAnyWorkloadUpdate(t, sub) + + // Update SVID + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) +} + +func TestLRUCacheSubscriberNotificationsOnSelectorChanges(t *testing.T) { + cache := newTestLRUCache() + + // initialize the cache with a FOO entry with selector A and an SVID + foo := makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + // create subscribers for A and make sure the initial update has FOO + sub := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer sub.Finish() + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) + + // update FOO to have selectors (A,B) and make sure the subscriber loses + // FOO, since (A,B) is not a subset of the subscriber set (A). + foo = makeRegistrationEntry("FOO", "A", "B") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + Bundle: bundleV1, + }) + + // update FOO to drop B and make sure the subscriber regains FOO + foo = makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) +} + +func TestLRUCacheSubscriberNotifiedWhenEntryDropped(t *testing.T) { + cache := newTestLRUCache() + + subA := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer subA.Finish() + assertAnyWorkloadUpdate(t, subA) + + // subB's job here is to just make sure we don't notify unrelated + // subscribers when dropping registration entries + subB := subscribeToWorkloadUpdates(t, cache, makeSelectors("B")) + defer subB.Finish() + assertAnyWorkloadUpdate(t, subB) + + foo := makeRegistrationEntry("FOO", "A") + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + } + cache.UpdateEntries(updateEntries, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + // make sure subA gets notified with FOO but not subB + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) + assertNoWorkloadUpdate(t, subB) + + updateEntries.RegistrationEntries = nil + cache.UpdateEntries(updateEntries, nil) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + }) + assertNoWorkloadUpdate(t, subB) + + // Make sure trying to update SVIDs of removed entry does not notify + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + assertNoWorkloadUpdate(t, subB) +} + +func TestLRUCacheSubscriberOnlyGetsEntriesWithSVID(t *testing.T) { + cache := newTestLRUCache() + + foo := makeRegistrationEntry("FOO", "A") + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + } + cache.UpdateEntries(updateEntries, nil) + + sub := cache.NewSubscriber(makeSelectors("A")) + defer sub.Finish() + assertNoWorkloadUpdate(t, sub) + + // update to include the SVID and now we should get the update + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) +} + +func TestLRUCacheSubscribersDoNotBlockNotifications(t *testing.T) { + cache := newTestLRUCache() + + sub := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer sub.Finish() + + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + }, nil) + + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV3), + }, nil) + + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + Bundle: bundleV3, + }) +} + +func TestLRUCacheCheckSVIDCallback(t *testing.T) { + cache := newTestLRUCache() + + // no calls because there are no registration entries + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { + assert.Fail(t, "should not be called if there are no registration entries") + + return false + }) + + foo := makeRegistrationEntryWithTTL("FOO", 60) + + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + RegistrationEntries: makeRegistrationEntries(foo), + }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { + // should not get invoked + assert.Fail(t, "should not be called as no SVIDs are cached yet") + return false + }) + + // called once for FOO with new SVID + svids := makeX509SVIDs(foo) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: svids, + }) + + // called once for FOO with existing SVID + callCount := 0 + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + RegistrationEntries: makeRegistrationEntries(foo), + }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { + callCount++ + assert.Equal(t, "FOO", newEntry.EntryId) + if assert.NotNil(t, svid) { + assert.Exactly(t, svids["FOO"], svid) + } + + return true + }) + assert.Equal(t, 1, callCount) + assert.Equal(t, map[string]bool{foo.EntryId: true}, cache.staleEntries) +} + +func TestLRUCacheGetStaleEntries(t *testing.T) { + cache := newTestLRUCache() + + foo := makeRegistrationEntryWithTTL("FOO", 60) + + // Create entry but don't mark it stale from checkSVID method; + // it will be marked stale cause it does not have SVID cached + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + RegistrationEntries: makeRegistrationEntries(foo), + }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { + return false + }) + + // Assert that the entry is returned as stale. The `ExpiresAt` field should be unset since there is no SVID. + expectedEntries := []*StaleEntry{{Entry: cache.records[foo.EntryId].entry}} + assert.Equal(t, expectedEntries, cache.GetStaleEntries()) + + // Update the SVID for the stale entry + svids := make(map[string]*X509SVID) + expiredAt := time.Now() + svids[foo.EntryId] = &X509SVID{ + Chain: []*x509.Certificate{{NotAfter: expiredAt}}, + } + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: svids, + }) + // Assert that updating the SVID removes stale marker from entry + assert.Empty(t, cache.GetStaleEntries()) + + // Update entry again and mark it as stale + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + RegistrationEntries: makeRegistrationEntries(foo), + }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { + return true + }) + + // Assert that the entry again returns as stale. This time the `ExpiresAt` field should be populated with the expiration of the SVID. + expectedEntries = []*StaleEntry{{ + Entry: cache.records[foo.EntryId].entry, + ExpiresAt: expiredAt, + }} + assert.Equal(t, expectedEntries, cache.GetStaleEntries()) + + // Remove registration entry and assert that it is no longer returned as stale + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { + return true + }) + assert.Empty(t, cache.GetStaleEntries()) +} + +func TestLRUCacheSubscriberNotNotifiedOnDifferentSVIDChanges(t *testing.T) { + cache := newTestLRUCache() + + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + }) + + sub := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer sub.Finish() + assertAnyWorkloadUpdate(t, sub) + + // Update SVID + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(bar), + }) + + assertNoWorkloadUpdate(t, sub) +} + +func TestLRUCacheSubscriberNotNotifiedOnOverlappingSVIDChanges(t *testing.T) { + cache := newTestLRUCache() + + foo := makeRegistrationEntry("FOO", "A", "C") + bar := makeRegistrationEntry("FOO", "A", "B") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + }) + + sub := subscribeToWorkloadUpdates(t, cache, makeSelectors("A", "B")) + defer sub.Finish() + assertAnyWorkloadUpdate(t, sub) + + // Update SVID + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + assertNoWorkloadUpdate(t, sub) +} + +func TestLRUCacheSVIDCacheExpiry(t *testing.T) { + clk := clock.NewMock() + cache := newTestLRUCacheWithConfig(10, clk) + + clk.Add(1 * time.Second) + foo := makeRegistrationEntry("FOO", "A") + // validate workload update for foo + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + subA := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) + subA.Finish() + + // move clk by 1 sec so that SVID access time will be different + clk.Add(1 * time.Second) + bar := makeRegistrationEntry("BAR", "B") + // validate workload update for bar + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(bar), + }) + + // not closing subscriber immediately + subB := subscribeToWorkloadUpdates(t, cache, makeSelectors("B")) + defer subB.Finish() + assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{ + {Entry: bar}, + }, + }) + + // Move clk by 2 seconds + clk.Add(2 * time.Second) + // update total of 12 entries + updateEntries := createUpdateEntries(10, makeBundles(bundleV1)) + updateEntries.RegistrationEntries[foo.EntryId] = foo + updateEntries.RegistrationEntries[bar.EntryId] = bar + + cache.UpdateEntries(updateEntries, nil) + + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDsFromMap(updateEntries.RegistrationEntries), + }) + + for id, entry := range updateEntries.RegistrationEntries { + // create and close subscribers for remaining entries so that svid cache is full + if id != foo.EntryId && id != bar.EntryId { + sub := cache.NewSubscriber(entry.Selectors) + sub.Finish() + } + } + assert.Equal(t, 12, cache.CountSVIDs()) + + cache.UpdateEntries(updateEntries, nil) + assert.Equal(t, 10, cache.CountSVIDs()) + + // foo SVID should be removed from cache as it does not have active subscriber + assert.False(t, cache.Notify(makeSelectors("A"))) + // bar SVID should be cached as it has active subscriber + assert.True(t, cache.Notify(makeSelectors("B"))) + + subA = cache.NewSubscriber(makeSelectors("A")) + defer subA.Finish() + + cache.UpdateEntries(updateEntries, nil) + + // Make sure foo is marked as stale entry which does not have svid cached + require.Len(t, cache.GetStaleEntries(), 1) + assert.Equal(t, foo, cache.GetStaleEntries()[0].Entry) + + assert.Equal(t, 10, cache.CountSVIDs()) +} + +func TestLRUCacheMaxSVIDCacheSize(t *testing.T) { + clk := clock.NewMock() + cache := newTestLRUCacheWithConfig(10, clk) + + // create entries more than maxSvidCacheSize + updateEntries := createUpdateEntries(12, makeBundles(bundleV1)) + cache.UpdateEntries(updateEntries, nil) + + require.Len(t, cache.GetStaleEntries(), 10) + + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDsFromStaleEntries(cache.GetStaleEntries()), + }) + require.Len(t, cache.GetStaleEntries(), 0) + assert.Equal(t, 10, cache.CountSVIDs()) + + // Validate that active subscriber will still get SVID even if SVID count is at maxSvidCacheSize + foo := makeRegistrationEntry("FOO", "A") + updateEntries.RegistrationEntries[foo.EntryId] = foo + + subA := cache.NewSubscriber(foo.Selectors) + defer subA.Finish() + + cache.UpdateEntries(updateEntries, nil) + require.Len(t, cache.GetStaleEntries(), 1) + assert.Equal(t, 10, cache.CountSVIDs()) + + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + assert.Equal(t, 11, cache.CountSVIDs()) + require.Len(t, cache.GetStaleEntries(), 0) +} + +func TestSyncSVIDsWithSubscribers(t *testing.T) { + clk := clock.NewMock() + cache := newTestLRUCacheWithConfig(5, clk) + + updateEntries := createUpdateEntries(5, makeBundles(bundleV1)) + cache.UpdateEntries(updateEntries, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDsFromStaleEntries(cache.GetStaleEntries()), + }) + assert.Equal(t, 5, cache.CountSVIDs()) + + // Update foo but its SVID is not yet cached + foo := makeRegistrationEntry("FOO", "A") + updateEntries.RegistrationEntries[foo.EntryId] = foo + + cache.UpdateEntries(updateEntries, nil) + + // Create a subscriber for foo + subA := cache.NewSubscriber(foo.Selectors) + defer subA.Finish() + require.Len(t, cache.GetStaleEntries(), 0) + + // After SyncSVIDsWithSubscribers foo should be marked as stale, requiring signing + cache.SyncSVIDsWithSubscribers() + require.Len(t, cache.GetStaleEntries(), 1) + assert.Equal(t, []*StaleEntry{{Entry: cache.records[foo.EntryId].entry}}, cache.GetStaleEntries()) + + assert.Equal(t, 5, cache.CountSVIDs()) +} + +func TestNotify(t *testing.T) { + cache := newTestLRUCache() + + foo := makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + + assert.False(t, cache.Notify(makeSelectors("A"))) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + assert.True(t, cache.Notify(makeSelectors("A"))) +} + +func TestSubscribeToLRUCacheChanges(t *testing.T) { + clk := clock.NewMock() + cache := newTestLRUCacheWithConfig(1, clk) + + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + }, nil) + + sub1WaitCh := make(chan struct{}, 1) + sub1ErrCh := make(chan error, 1) + go func() { + sub1, err := cache.subscribeToWorkloadUpdates(context.Background(), foo.Selectors, func() { + sub1WaitCh <- struct{}{} + }) + if err != nil { + sub1ErrCh <- err + return + } + + defer sub1.Finish() + u1 := <-sub1.Updates() + if len(u1.Identities) != 1 { + sub1ErrCh <- fmt.Errorf("expected 1 SVID, got: %d", len(u1.Identities)) + return + } + sub1ErrCh <- nil + }() + + sub2WaitCh := make(chan struct{}, 1) + sub2ErrCh := make(chan error, 1) + go func() { + sub2, err := cache.subscribeToWorkloadUpdates(context.Background(), bar.Selectors, func() { + sub2WaitCh <- struct{}{} + }) + if err != nil { + sub2ErrCh <- err + return + } + + defer sub2.Finish() + u2 := <-sub2.Updates() + if len(u2.Identities) != 1 { + sub1ErrCh <- fmt.Errorf("expected 1 SVID, got: %d", len(u2.Identities)) + return + } + sub2ErrCh <- nil + }() + + <-sub1WaitCh + <-sub2WaitCh + cache.SyncSVIDsWithSubscribers() + + assert.Len(t, cache.GetStaleEntries(), 2) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + }) + assert.Equal(t, 2, cache.CountSVIDs()) + + clk.Add(SvidSyncInterval * 2) + + sub1Err := <-sub1ErrCh + assert.NoError(t, sub1Err, "subscriber 1 error") + + sub2Err := <-sub2ErrCh + assert.NoError(t, sub2Err, "subscriber 2 error") +} + +func TestNewLRUCache(t *testing.T) { + // negative value + cache := newTestLRUCacheWithConfig(-5, clock.NewMock()) + require.Equal(t, DefaultSVIDCacheMaxSize, cache.svidCacheMaxSize) + + // zero value + cache = newTestLRUCacheWithConfig(0, clock.NewMock()) + require.Equal(t, DefaultSVIDCacheMaxSize, cache.svidCacheMaxSize) +} + +func BenchmarkLRUCacheGlobalNotification(b *testing.B) { + cache := newTestLRUCache() + + const numEntries = 1000 + const numWorkloads = 1000 + const selectorsPerEntry = 3 + const selectorsPerWorkload = 10 + + // build a set of 1000 registration entries with distinct selectors + bundlesV1 := makeBundles(bundleV1) + bundlesV2 := makeBundles(bundleV2) + updateEntries := &UpdateEntries{ + Bundles: bundlesV1, + RegistrationEntries: make(map[string]*common.RegistrationEntry, numEntries), + } + for i := 0; i < numEntries; i++ { + entryID := fmt.Sprintf("00000000-0000-0000-0000-%012d", i) + updateEntries.RegistrationEntries[entryID] = &common.RegistrationEntry{ + EntryId: entryID, + ParentId: "spiffe://domain.test/node", + SpiffeId: fmt.Sprintf("spiffe://domain.test/workload-%d", i), + Selectors: distinctSelectors(i, selectorsPerEntry), + } + } + + cache.UpdateEntries(updateEntries, nil) + for i := 0; i < numWorkloads; i++ { + selectors := distinctSelectors(i, selectorsPerWorkload) + cache.NewSubscriber(selectors) + } + + runtime.GC() + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + if i%2 == 0 { + updateEntries.Bundles = bundlesV2 + } else { + updateEntries.Bundles = bundlesV1 + } + cache.UpdateEntries(updateEntries, nil) + } +} + +func newTestLRUCache() *LRUCache { + log, _ := test.NewNullLogger() + return NewLRUCache(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, + telemetry.Blackhole{}, 0, clock.NewMock()) +} + +func newTestLRUCacheWithConfig(svidCacheMaxSize int, clk clock.Clock) *LRUCache { + log, _ := test.NewNullLogger() + return NewLRUCache(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, telemetry.Blackhole{}, + svidCacheMaxSize, clk) +} + +// numEntries should not be more than 12 digits +func createUpdateEntries(numEntries int, bundles map[spiffeid.TrustDomain]*bundleutil.Bundle) *UpdateEntries { + updateEntries := &UpdateEntries{ + Bundles: bundles, + RegistrationEntries: make(map[string]*common.RegistrationEntry, numEntries), + } + + for i := 0; i < numEntries; i++ { + entryID := fmt.Sprintf("00000000-0000-0000-0000-%012d", i) + updateEntries.RegistrationEntries[entryID] = &common.RegistrationEntry{ + EntryId: entryID, + ParentId: "spiffe://domain.test/node", + SpiffeId: fmt.Sprintf("spiffe://domain.test/workload-%d", i), + Selectors: distinctSelectors(i, 1), + } + } + return updateEntries +} + +func makeX509SVIDsFromMap(entries map[string]*common.RegistrationEntry) map[string]*X509SVID { + out := make(map[string]*X509SVID) + for _, entry := range entries { + out[entry.EntryId] = &X509SVID{} + } + return out +} + +func makeX509SVIDsFromStaleEntries(entries []*StaleEntry) map[string]*X509SVID { + out := make(map[string]*X509SVID) + for _, entry := range entries { + out[entry.Entry.EntryId] = &X509SVID{} + } + return out +} + +func subscribeToWorkloadUpdates(t *testing.T, cache *LRUCache, selectors []*common.Selector) Subscriber { + subscriber, err := cache.subscribeToWorkloadUpdates(context.Background(), selectors, nil) + assert.NoError(t, err) + return subscriber +} diff --git a/pkg/agent/manager/cache/sets.go b/pkg/agent/manager/cache/sets.go index 6c9e1701bb..c7cc0d6895 100644 --- a/pkg/agent/manager/cache/sets.go +++ b/pkg/agent/manager/cache/sets.go @@ -30,6 +30,18 @@ var ( return make(recordSet) }, } + + lruCacheRecordSetPool = sync.Pool{ + New: func() interface{} { + return make(lruCacheRecordSet) + }, + } + + lruCacheSubscriberSetPool = sync.Pool{ + New: func() interface{} { + return make(lruCacheSubscriberSet) + }, + } ) // unique set of strings, allocated from a pool @@ -149,3 +161,37 @@ func clearRecordSet(set recordSet) { delete(set, k) } } + +// unique set of LRU cache records, allocated from a pool +type lruCacheRecordSet map[*lruCacheRecord]struct{} + +func allocLRUCacheRecordSet() (lruCacheRecordSet, func()) { + set := lruCacheRecordSetPool.Get().(lruCacheRecordSet) + return set, func() { + clearLRUCacheRecordSet(set) + lruCacheRecordSetPool.Put(set) + } +} + +func clearLRUCacheRecordSet(set lruCacheRecordSet) { + for k := range set { + delete(set, k) + } +} + +// unique set of LRU cache subscribers, allocated from a pool +type lruCacheSubscriberSet map[*lruCacheSubscriber]struct{} + +func allocLRUCacheSubscriberSet() (lruCacheSubscriberSet, func()) { + set := lruCacheSubscriberSetPool.Get().(lruCacheSubscriberSet) + return set, func() { + clearLRUCacheSubscriberSet(set) + lruCacheSubscriberSetPool.Put(set) + } +} + +func clearLRUCacheSubscriberSet(set lruCacheSubscriberSet) { + for k := range set { + delete(set, k) + } +} diff --git a/pkg/agent/manager/config.go b/pkg/agent/manager/config.go index 353cdbab7d..b192a1ef92 100644 --- a/pkg/agent/manager/config.go +++ b/pkg/agent/manager/config.go @@ -58,8 +58,15 @@ func newManager(c *Config) *manager { c.Clk = clock.New() } - cache := cache.New(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, - c.Metrics, c.SVIDCacheMaxSize, c.Clk) + var x509SVIDCache Cache + if c.SVIDCacheMaxSize > 0 { + // use LRU cache implementation + x509SVIDCache = cache.NewLRUCache(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, + c.Metrics, c.SVIDCacheMaxSize, c.Clk) + } else { + x509SVIDCache = cache.New(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, + c.Metrics) + } rotCfg := &svid.RotatorConfig{ SVIDKeyManager: keymanager.ForSVID(c.Catalog.GetKeyManager()), @@ -67,7 +74,7 @@ func newManager(c *Config) *manager { Metrics: c.Metrics, SVID: c.SVID, SVIDKey: c.SVIDKey, - BundleStream: cache.SubscribeToBundleChanges(), + BundleStream: x509SVIDCache.SubscribeToBundleChanges(), ServerAddr: c.ServerAddr, TrustDomain: c.TrustDomain, Interval: c.RotationInterval, @@ -76,7 +83,7 @@ func newManager(c *Config) *manager { svidRotator, client := svid.NewRotator(rotCfg) m := &manager{ - cache: cache, + cache: x509SVIDCache, c: c, mtx: new(sync.RWMutex), svid: svidRotator, diff --git a/pkg/agent/manager/manager.go b/pkg/agent/manager/manager.go index fc427f22dd..775623acbf 100644 --- a/pkg/agent/manager/manager.go +++ b/pkg/agent/manager/manager.go @@ -25,8 +25,6 @@ import ( "github.com/spiffe/spire/proto/spire/common" ) -const svidSyncInterval = 500 * time.Millisecond - // Manager provides cache management functionalities for agents. type Manager interface { // Initialize initializes the manager. @@ -78,13 +76,52 @@ type Manager interface { GetBundle() *cache.Bundle } +// Cache stores each registration entry, signed X509-SVIDs for those entries, +// bundles, and JWT SVIDs for the agent. +type Cache interface { + SVIDCache + + // Bundle gets latest cached bundle + Bundle() *bundleutil.Bundle + + // SyncSVIDsWithSubscribers syncs SVID cache + SyncSVIDsWithSubscribers() + + // SubscribeToWorkloadUpdates creates a subscriber for given selector set. + SubscribeToWorkloadUpdates(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) + + // SubscribeToBundleChanges creates a stream for providing bundle changes + SubscribeToBundleChanges() *cache.BundleStream + + // MatchingRegistrationEntries with given selectors + MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry + + // CountSVIDs in cache stored + CountSVIDs() int + + // FetchWorkloadUpdate for giveb selectors + FetchWorkloadUpdate(selectors []*common.Selector) *cache.WorkloadUpdate + + // GetJWTSVID provides JWTSVID + GetJWTSVID(id spiffeid.ID, audience []string) (*client.JWTSVID, bool) + + // SetJWTSVID adds JWTSVID to cache + SetJWTSVID(id spiffeid.ID, audience []string, svid *client.JWTSVID) + + // Entries get all registration entries + Entries() []*common.RegistrationEntry + + // Identities get all identities in cache + Identities() []cache.Identity +} + type manager struct { c *Config // Fields protected by mtx mutex. mtx *sync.RWMutex - cache *cache.Cache + cache Cache svid svid.Rotator storage storage.Storage @@ -93,7 +130,6 @@ type manager struct { // fetch attempt synchronizeBackoff backoff.BackOff svidSyncBackoff backoff.BackOff - subscribeBackoffFn func() backoff.BackOff client client.Client @@ -111,12 +147,7 @@ func (m *manager) Initialize(ctx context.Context) error { m.storeBundle(m.cache.Bundle()) m.synchronizeBackoff = backoff.NewBackoff(m.clk, m.c.SyncInterval) - m.svidSyncBackoff = backoff.NewBackoff(m.clk, svidSyncInterval) - if m.subscribeBackoffFn == nil { - m.subscribeBackoffFn = func() backoff.BackOff { - return backoff.NewBackoff(m.clk, svidSyncInterval) - } - } + m.svidSyncBackoff = backoff.NewBackoff(m.clk, cache.SvidSyncInterval) err := m.synchronize(ctx) if nodeutil.ShouldAgentReattest(err) { @@ -151,30 +182,7 @@ func (m *manager) Run(ctx context.Context) error { } func (m *manager) SubscribeToCacheChanges(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) { - return m.subscribeToCacheChanges(ctx, selectors, nil) -} - -func (m *manager) subscribeToCacheChanges(ctx context.Context, selectors cache.Selectors, notifyCallbackFn func()) (cache.Subscriber, error) { - subscriber := m.cache.NewSubscriber(selectors) - bo := m.subscribeBackoffFn() - // block until all svids are cached and subscriber is notified - for { - svidsInCache := m.cache.Notify(selectors) - // used for testing - if notifyCallbackFn != nil { - notifyCallbackFn() - } - if svidsInCache { - return subscriber, nil - } - m.c.Log.WithField(telemetry.Selectors, selectors).Info("Waiting for SVID to get cached") - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-m.clk.After(bo.NextBackOff()): - } - } + return m.cache.SubscribeToWorkloadUpdates(ctx, selectors) } func (m *manager) SubscribeToSVIDChanges() observer.Stream { diff --git a/pkg/agent/manager/manager_test.go b/pkg/agent/manager/manager_test.go index 4e4d5c3512..49a7a0c9a7 100644 --- a/pkg/agent/manager/manager_test.go +++ b/pkg/agent/manager/manager_test.go @@ -14,7 +14,6 @@ import ( "testing" "time" - backoff "github.com/cenkalti/backoff/v3" testlog "github.com/sirupsen/logrus/hooks/test" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/go-spiffe/v2/svid/x509svid" @@ -73,16 +72,15 @@ func TestInitializationFailure(t *testing.T) { cat.SetKeyManager(km) c := &Config{ - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - Metrics: &telemetry.Blackhole{}, - TrustDomain: trustDomain, - Storage: openStorage(t, dir), - Clk: clk, - Catalog: cat, - SVIDCacheMaxSize: 1, - SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + Metrics: &telemetry.Blackhole{}, + TrustDomain: trustDomain, + Storage: openStorage(t, dir), + Clk: clk, + Catalog: cat, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m := newManager(c) require.Error(t, m.Initialize(context.Background())) @@ -101,16 +99,15 @@ func TestStoreBundleOnStartup(t *testing.T) { sto := openStorage(t, dir) c := &Config{ - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - Metrics: &telemetry.Blackhole{}, - TrustDomain: trustDomain, - Storage: sto, - Bundle: bundleutil.BundleFromRootCA(trustDomain, ca), - Clk: clk, - Catalog: cat, - SVIDCacheMaxSize: 1, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + Metrics: &telemetry.Blackhole{}, + TrustDomain: trustDomain, + Storage: sto, + Bundle: bundleutil.BundleFromRootCA(trustDomain, ca), + Clk: clk, + Catalog: cat, } m := newManager(c) @@ -150,15 +147,14 @@ func TestStoreSVIDOnStartup(t *testing.T) { sto := openStorage(t, dir) c := &Config{ - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - Metrics: &telemetry.Blackhole{}, - TrustDomain: trustDomain, - Storage: sto, - Clk: clk, - Catalog: cat, - SVIDCacheMaxSize: 1, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + Metrics: &telemetry.Blackhole{}, + TrustDomain: trustDomain, + Storage: sto, + Clk: clk, + Catalog: cat, } if _, err := sto.LoadSVID(); !errors.Is(err, storage.ErrNotCached) { @@ -403,7 +399,6 @@ func TestSVIDRotation(t *testing.T) { SyncInterval: 1 * time.Hour, Clk: clk, WorkloadKeyType: workloadkey.ECP256, - SVIDCacheMaxSize: 1, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } @@ -521,11 +516,6 @@ func TestSynchronization(t *testing.T) { m := newManager(c) - if err := m.Initialize(context.Background()); err != nil { - t.Fatal(err) - } - require.Equal(t, clk.Now(), m.GetLastSync()) - sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{ {Type: "unix", Value: "uid:1111"}, {Type: "spiffe_id", Value: joinTokenID.String()}, @@ -533,6 +523,11 @@ func TestSynchronization(t *testing.T) { require.NoError(t, err) defer sub.Finish() + if err := m.Initialize(context.Background()); err != nil { + t.Fatal(err) + } + require.Equal(t, clk.Now(), m.GetLastSync()) + // Before synchronization identitiesBefore := identitiesByEntryID(m.cache.Identities()) if len(identitiesBefore) != 3 { @@ -657,19 +652,18 @@ func TestSynchronizationClearsStaleCacheEntries(t *testing.T) { cat.SetKeyManager(km) c := &Config{ - ServerAddr: api.addr, - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - TrustDomain: trustDomain, - Storage: openStorage(t, dir), - Bundle: api.bundle, - Metrics: &telemetry.Blackhole{}, - Clk: clk, - Catalog: cat, - WorkloadKeyType: workloadkey.ECP256, - SVIDCacheMaxSize: 1, - SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + Storage: openStorage(t, dir), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + Clk: clk, + Catalog: cat, + WorkloadKeyType: workloadkey.ECP256, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m := newManager(c) @@ -731,19 +725,18 @@ func TestSynchronizationUpdatesRegistrationEntries(t *testing.T) { cat.SetKeyManager(km) c := &Config{ - ServerAddr: api.addr, - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - TrustDomain: trustDomain, - Storage: openStorage(t, dir), - Bundle: api.bundle, - Metrics: &telemetry.Blackhole{}, - Clk: clk, - Catalog: cat, - WorkloadKeyType: workloadkey.ECP256, - SVIDCacheMaxSize: 1, - SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + Storage: openStorage(t, dir), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + Clk: clk, + Catalog: cat, + WorkloadKeyType: workloadkey.ECP256, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m := newManager(c) @@ -825,21 +818,21 @@ func TestSubscribersGetUpToDateBundle(t *testing.T) { }) } -func TestSyncSVIDs(t *testing.T) { +func TestSynchronizationWithLRUCache(t *testing.T) { dir := spiretest.TempDir(t) km := fakeagentkeymanager.New(t, dir) clk := clock.NewMock(t) + ttl := 3 api := newMockAPI(t, &mockAPIConfig{ km: km, - getAuthorizedEntries: func(h *mockAPI, count int32, req *entryv1.GetAuthorizedEntriesRequest) (*entryv1.GetAuthorizedEntriesResponse, error) { + getAuthorizedEntries: func(*mockAPI, int32, *entryv1.GetAuthorizedEntriesRequest) (*entryv1.GetAuthorizedEntriesResponse, error) { return makeGetAuthorizedEntriesResponse(t, "resp1", "resp2"), nil }, - batchNewX509SVIDEntries: func(h *mockAPI, count int32) []*common.RegistrationEntry { - h.rotateCA() + batchNewX509SVIDEntries: func(*mockAPI, int32) []*common.RegistrationEntry { return makeBatchNewX509SVIDEntries("resp1", "resp2") }, - svidTTL: 200, + svidTTL: ttl, clk: clk, }) @@ -856,106 +849,218 @@ func TestSyncSVIDs(t *testing.T) { Storage: openStorage(t, dir), Bundle: api.bundle, Metrics: &telemetry.Blackhole{}, - RotationInterval: 1 * time.Hour, - SyncInterval: 1 * time.Hour, - SVIDCacheMaxSize: 1, + RotationInterval: time.Hour, + SyncInterval: time.Hour, Clk: clk, Catalog: cat, WorkloadKeyType: workloadkey.ECP256, + SVIDCacheMaxSize: 10, SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m := newManager(c) - m.subscribeBackoffFn = func() backoff.BackOff { - return backoff.NewConstantBackOff(svidSyncInterval) + + if err := m.Initialize(context.Background()); err != nil { + t.Fatal(err) } + require.Equal(t, clk.Now(), m.GetLastSync()) - err := m.Initialize(context.Background()) + sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{ + {Type: "unix", Value: "uid:1111"}, + {Type: "spiffe_id", Value: joinTokenID.String()}, + }) require.NoError(t, err) + defer sub.Finish() - // After Initialize, just 1 SVID should be cached - assert.Equal(t, 1, m.CountSVIDs()) - ctx := context.Background() + // Before synchronization + identitiesBefore := identitiesByEntryID(m.cache.Identities()) + if len(identitiesBefore) != 3 { + t.Fatalf("3 cached identities were expected; got %d", len(identitiesBefore)) + } - // Validate the update received by subscribers - // Spawn subscriber 1 in new goroutine to allow SVID sync to run in parallel - sub1WaitCh := make(chan struct{}, 1) - sub1ErrCh := make(chan error, 1) - go func() { - sub1, err := m.subscribeToCacheChanges(ctx, cache.Selectors{{Type: "unix", Value: "uid:1111"}}, func() { - sub1WaitCh <- struct{}{} - }) - if err != nil { - sub1ErrCh <- err - return - } + // This is the initial update based on the selector set + u := <-sub.Updates() + if len(u.Identities) != 3 { + t.Fatalf("expected 3 identities, got: %d", len(u.Identities)) + } - defer sub1.Finish() - u1 := <-sub1.Updates() + if len(u.Bundle.RootCAs()) != 1 { + t.Fatal("expected 1 bundle root CA") + } - if len(u1.Identities) != 2 { - sub1ErrCh <- fmt.Errorf("expected 2 SVIDs, got: %d", len(u1.Identities)) - return - } - if !u1.Bundle.EqualTo(c.Bundle) { - sub1ErrCh <- errors.New("bundles were expected to be equal") - return + if !u.Bundle.EqualTo(api.bundle) { + t.Fatal("received bundle should be equals to the server bundle") + } + + for key, eu := range identitiesByEntryID(u.Identities) { + eb, ok := identitiesBefore[key] + if !ok { + t.Fatalf("an update was received for an inexistent entry on the cache with EntryId=%v", key) } + require.Equal(t, eb, eu, "identity received does not match identity on cache") + } - sub1ErrCh <- nil - }() + require.Equal(t, clk.Now(), m.GetLastSync()) - // Spawn subscriber 2 in new goroutine to allow SVID sync to run in parallel - sub2WaitCh := make(chan struct{}, 1) - sub2ErrCh := make(chan error, 1) - go func() { - sub2, err := m.subscribeToCacheChanges(ctx, cache.Selectors{{Type: "spiffe_id", Value: "spiffe://example.org/spire/agent/join_token/abcd"}}, func() { - sub2WaitCh <- struct{}{} - }) - if err != nil { - sub2ErrCh <- err - return - } + // SVIDs expire after 3 seconds, so we shouldn't expect any updates after + // 1 second has elapsed. + clk.Add(time.Second) + require.NoError(t, m.synchronize(context.Background())) + select { + case <-sub.Updates(): + t.Fatal("update unexpected after 1 second") + default: + } + + // After advancing another second, the SVIDs should have been refreshed, + // since the half-time has been exceeded. + clk.Add(time.Second) + require.NoError(t, m.synchronize(context.Background())) + select { + case u = <-sub.Updates(): + default: + t.Fatal("update expected after 2 seconds") + } - defer sub2.Finish() - u2 := <-sub2.Updates() + // Make sure the update contains the updated entries and that the cache + // has a consistent view. + identitiesAfter := identitiesByEntryID(m.cache.Identities()) + if len(identitiesAfter) != 3 { + t.Fatalf("expected 3 identities, got: %d", len(identitiesAfter)) + } - if len(u2.Identities) != 1 { - sub2ErrCh <- fmt.Errorf("expected 1 SVID, got: %d", len(u2.Identities)) - return + for key, eb := range identitiesBefore { + ea, ok := identitiesAfter[key] + if !ok { + t.Fatalf("expected identity with EntryId=%v after synchronization", key) } - if !u2.Bundle.EqualTo(c.Bundle) { - sub2ErrCh <- errors.New("bundles were expected to be equal") - return + require.NotEqual(t, eb, ea, "there is at least one identity that was not refreshed: %v", ea) + } + + if len(u.Identities) != 3 { + t.Fatalf("expected 3 identities, got: %d", len(u.Identities)) + } + + if len(u.Bundle.RootCAs()) != 1 { + t.Fatal("expected 1 bundle root CA") + } + + if !u.Bundle.EqualTo(api.bundle) { + t.Fatal("received bundle should be equals to the server bundle") + } + + for key, eu := range identitiesByEntryID(u.Identities) { + ea, ok := identitiesAfter[key] + if !ok { + t.Fatalf("an update was received for an inexistent entry on the cache with EntryId=%v", key) } + require.Equal(t, eu, ea, "entry received does not match entry on cache") + } - sub2ErrCh <- nil - }() + require.Equal(t, clk.Now(), m.GetLastSync()) +} + +func TestSyncSVIDsWithLRUCache(t *testing.T) { + dir := spiretest.TempDir(t) + km := fakeagentkeymanager.New(t, dir) + + clk := clock.NewMock(t) + api := newMockAPI(t, &mockAPIConfig{ + km: km, + getAuthorizedEntries: func(h *mockAPI, count int32, _ *entryv1.GetAuthorizedEntriesRequest) (*entryv1.GetAuthorizedEntriesResponse, error) { + switch count { + case 1: + return makeGetAuthorizedEntriesResponse(t, "resp2"), nil + case 2: + return makeGetAuthorizedEntriesResponse(t, "resp2"), nil + default: + return nil, fmt.Errorf("unexpected getAuthorizedEntries call count: %d", count) + } + }, + batchNewX509SVIDEntries: func(h *mockAPI, count int32) []*common.RegistrationEntry { + switch count { + case 1: + return makeBatchNewX509SVIDEntries("resp2") + case 2: + return makeBatchNewX509SVIDEntries("resp2") + default: + return nil + } + }, + svidTTL: 3, + clk: clk, + }) + + baseSVID, baseSVIDKey := api.newSVID(joinTokenID, 1*time.Hour) + cat := fakeagentcatalog.New() + cat.SetKeyManager(km) + + c := &Config{ + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + Storage: openStorage(t, dir), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + Clk: clk, + Catalog: cat, + WorkloadKeyType: workloadkey.ECP256, + SVIDCacheMaxSize: 1, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + } - // Wait until subscribers have been created - <-sub1WaitCh - <-sub2WaitCh + m := newManager(c) - // Sync SVIDs to populate cache - svidSyncErr := m.syncSVIDs(ctx) - require.NoError(t, svidSyncErr, "syncSVIDs method failed") + if err := m.Initialize(context.Background()); err != nil { + t.Fatal(err) + } - // Advance clock so subscribers can check for latest SVIDs in cache - clk.Add(svidSyncInterval) + ctx, cancel := context.WithCancel(context.Background()) + subErrCh := make(chan error, 1) + go func(ctx context.Context) { + sub, err := m.SubscribeToCacheChanges(ctx, cache.Selectors{ + {Type: "unix", Value: "uid:1111"}, + }) + if err != nil { + subErrCh <- err + return + } + defer sub.Finish() + subErrCh <- nil + }(ctx) + + syncErrCh := make(chan error, 1) + // run svid sync + go func(ctx context.Context) { + if err := m.runSyncSVIDs(ctx); err != nil { + syncErrCh <- err + } + syncErrCh <- nil + }(ctx) - sub1Err := <-sub1ErrCh - assert.NoError(t, sub1Err, "subscriber 1 error") + // keep clk moving so that subscriber keeps looking for svid + go func(ctx context.Context) { + for { + clk.Add(cache.SvidSyncInterval) + if ctx.Err() != nil { + return + } + } + }(ctx) - sub2Err := <-sub2ErrCh - assert.NoError(t, sub2Err, "subscriber 2 error") + subErr := <-subErrCh + assert.NoError(t, subErr, "subscriber error") - // All 3 SVIDs should be cached - assert.Equal(t, 3, m.CountSVIDs()) + // ensure 2 SVIDs corresponding to selectors are cached. + assert.Equal(t, 2, m.cache.CountSVIDs()) - assert.NoError(t, m.synchronize(ctx)) + // cancel the ctx to stop go routines + cancel() - // Make sure svid count is SVIDCacheMaxSize and non-active SVIDs are deleted from cache - assert.Equal(t, 1, m.CountSVIDs()) + syncErr := <-syncErrCh + assert.NoError(t, syncErr, "svid sync error") } func TestSurvivesCARotation(t *testing.T) { @@ -1003,9 +1108,6 @@ func TestSurvivesCARotation(t *testing.T) { } m := newManager(c) - m.subscribeBackoffFn = func() backoff.BackOff { - return backoff.NewConstantBackOff(svidSyncInterval) - } sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{{Type: "unix", Value: "uid:1111"}}) require.NoError(t, err) @@ -1052,19 +1154,18 @@ func TestFetchJWTSVID(t *testing.T) { baseSVID, baseSVIDKey := api.newSVID(joinTokenID, 1*time.Hour) c := &Config{ - ServerAddr: api.addr, - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - TrustDomain: trustDomain, - Storage: openStorage(t, dir), - Bundle: api.bundle, - Metrics: &telemetry.Blackhole{}, - Catalog: cat, - Clk: clk, - WorkloadKeyType: workloadkey.ECP256, - SVIDCacheMaxSize: 1, - SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + Storage: openStorage(t, dir), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + Catalog: cat, + Clk: clk, + WorkloadKeyType: workloadkey.ECP256, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m := newManager(c) @@ -1176,19 +1277,18 @@ func TestStorableSVIDsSync(t *testing.T) { cat.SetKeyManager(fakeagentkeymanager.New(t, dir)) c := &Config{ - ServerAddr: api.addr, - SVID: baseSVID, - SVIDKey: baseSVIDKey, - Log: testLogger, - TrustDomain: trustDomain, - Storage: openStorage(t, dir), - Bundle: api.bundle, - Metrics: &telemetry.Blackhole{}, - Clk: clk, - Catalog: cat, - WorkloadKeyType: workloadkey.ECP256, - SVIDCacheMaxSize: 1, - SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + Storage: openStorage(t, dir), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + Clk: clk, + Catalog: cat, + WorkloadKeyType: workloadkey.ECP256, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), } m, closer := initializeAndRunNewManager(t, c) diff --git a/pkg/agent/manager/sync.go b/pkg/agent/manager/sync.go index 56a9a633b9..8660dd50cf 100644 --- a/pkg/agent/manager/sync.go +++ b/pkg/agent/manager/sync.go @@ -25,7 +25,7 @@ type csrRequest struct { CurrentSVIDExpiresAt time.Time } -type Cache interface { +type SVIDCache interface { // UpdateEntries updates entries on cache UpdateEntries(update *cache.UpdateEntries, checkSVID func(*common.RegistrationEntry, *common.RegistrationEntry, *cache.X509SVID) bool) @@ -37,10 +37,13 @@ type Cache interface { } func (m *manager) syncSVIDs(ctx context.Context) (err error) { - m.cache.SyncSVIDsWithSubscribers() - staleEntries := m.cache.GetStaleEntries() - if len(staleEntries) > 0 { - return m.updateSVIDs(ctx, staleEntries, m.cache) + // perform syncSVIDs only if using LRU cache + if m.c.SVIDCacheMaxSize > 0 { + m.cache.SyncSVIDsWithSubscribers() + staleEntries := m.cache.GetStaleEntries() + if len(staleEntries) > 0 { + return m.updateSVIDs(ctx, staleEntries, m.cache) + } } return nil } @@ -66,7 +69,7 @@ func (m *manager) synchronize(ctx context.Context) (err error) { return nil } -func (m *manager) updateCache(ctx context.Context, update *cache.UpdateEntries, log logrus.FieldLogger, cacheType string, c Cache) error { +func (m *manager) updateCache(ctx context.Context, update *cache.UpdateEntries, log logrus.FieldLogger, cacheType string, c SVIDCache) error { // update the cache and build a list of CSRs that need to be processed // in this interval. // @@ -117,7 +120,7 @@ func (m *manager) updateCache(ctx context.Context, update *cache.UpdateEntries, return nil } -func (m *manager) updateSVIDs(ctx context.Context, entries []*cache.StaleEntry, c Cache) error { +func (m *manager) updateSVIDs(ctx context.Context, entries []*cache.StaleEntry, c SVIDCache) error { var csrs []csrRequest for _, entry := range entries { // we've exceeded the CSR limit, don't make any more CSRs From 37b6175ddd5a818474ef79c5f68383859e4f575e Mon Sep 17 00:00:00 2001 From: Prasad Borole Date: Wed, 24 Aug 2022 14:13:11 -0700 Subject: [PATCH 13/19] fix lint Signed-off-by: Prasad Borole --- pkg/agent/manager/cache/cache.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/agent/manager/cache/cache.go b/pkg/agent/manager/cache/cache.go index fb5372d199..19a7e7a954 100644 --- a/pkg/agent/manager/cache/cache.go +++ b/pkg/agent/manager/cache/cache.go @@ -477,7 +477,6 @@ func (c *Cache) Entries() []*common.RegistrationEntry { func (c *Cache) SyncSVIDsWithSubscribers() { c.log.Error("SyncSVIDsWithSubscribers method is not implemented") - return } func (c *Cache) subscribeToWorkloadUpdates(selectors []*common.Selector) Subscriber { From d84b3260bedf435560e394aaaf28fbaa21f84b89 Mon Sep 17 00:00:00 2001 From: Prasad Borole Date: Thu, 25 Aug 2022 11:30:47 -0700 Subject: [PATCH 14/19] fix linting Signed-off-by: Prasad Borole --- pkg/agent/manager/cache/lru_cache_test.go | 12 ++++++------ pkg/agent/manager/cache/sets.go | 3 +-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/pkg/agent/manager/cache/lru_cache_test.go b/pkg/agent/manager/cache/lru_cache_test.go index c270dd3aee..fdbd880365 100644 --- a/pkg/agent/manager/cache/lru_cache_test.go +++ b/pkg/agent/manager/cache/lru_cache_test.go @@ -509,25 +509,25 @@ func TestLRUCacheCheckSVIDCallback(t *testing.T) { func TestLRUCacheGetStaleEntries(t *testing.T) { cache := newTestLRUCache() - foo := makeRegistrationEntryWithTTL("FOO", 60) + bar := makeRegistrationEntryWithTTL("BAR", 60) // Create entry but don't mark it stale from checkSVID method; // it will be marked stale cause it does not have SVID cached cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV2), - RegistrationEntries: makeRegistrationEntries(foo), + RegistrationEntries: makeRegistrationEntries(bar), }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { return false }) // Assert that the entry is returned as stale. The `ExpiresAt` field should be unset since there is no SVID. - expectedEntries := []*StaleEntry{{Entry: cache.records[foo.EntryId].entry}} + expectedEntries := []*StaleEntry{{Entry: cache.records[bar.EntryId].entry}} assert.Equal(t, expectedEntries, cache.GetStaleEntries()) // Update the SVID for the stale entry svids := make(map[string]*X509SVID) expiredAt := time.Now() - svids[foo.EntryId] = &X509SVID{ + svids[bar.EntryId] = &X509SVID{ Chain: []*x509.Certificate{{NotAfter: expiredAt}}, } cache.UpdateSVIDs(&UpdateSVIDs{ @@ -539,14 +539,14 @@ func TestLRUCacheGetStaleEntries(t *testing.T) { // Update entry again and mark it as stale cache.UpdateEntries(&UpdateEntries{ Bundles: makeBundles(bundleV2), - RegistrationEntries: makeRegistrationEntries(foo), + RegistrationEntries: makeRegistrationEntries(bar), }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { return true }) // Assert that the entry again returns as stale. This time the `ExpiresAt` field should be populated with the expiration of the SVID. expectedEntries = []*StaleEntry{{ - Entry: cache.records[foo.EntryId].entry, + Entry: cache.records[bar.EntryId].entry, ExpiresAt: expiredAt, }} assert.Equal(t, expectedEntries, cache.GetStaleEntries()) diff --git a/pkg/agent/manager/cache/sets.go b/pkg/agent/manager/cache/sets.go index c7cc0d6895..98baf01682 100644 --- a/pkg/agent/manager/cache/sets.go +++ b/pkg/agent/manager/cache/sets.go @@ -47,9 +47,8 @@ var ( // unique set of strings, allocated from a pool type stringSet map[string]struct{} -func allocStringSet(ss ...string) (stringSet, func()) { +func allocStringSet() (stringSet, func()) { set := stringSetPool.Get().(stringSet) - set.Merge(ss...) return set, func() { clearStringSet(set) stringSetPool.Put(set) From b9d4685562cab762b4c1d847144d01f20db9e327 Mon Sep 17 00:00:00 2001 From: Prasad Borole Date: Thu, 25 Aug 2022 11:57:16 -0700 Subject: [PATCH 15/19] fix linting Signed-off-by: Prasad Borole --- pkg/agent/manager/cache/lru_cache_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/agent/manager/cache/lru_cache_test.go b/pkg/agent/manager/cache/lru_cache_test.go index fdbd880365..9df6828e7b 100644 --- a/pkg/agent/manager/cache/lru_cache_test.go +++ b/pkg/agent/manager/cache/lru_cache_test.go @@ -509,7 +509,7 @@ func TestLRUCacheCheckSVIDCallback(t *testing.T) { func TestLRUCacheGetStaleEntries(t *testing.T) { cache := newTestLRUCache() - bar := makeRegistrationEntryWithTTL("BAR", 60) + bar := makeRegistrationEntryWithTTL("BAR", 120) // Create entry but don't mark it stale from checkSVID method; // it will be marked stale cause it does not have SVID cached From b7313b35bd093defaeae091a836fa4434f94e37f Mon Sep 17 00:00:00 2001 From: Prasad Borole Date: Thu, 25 Aug 2022 12:11:18 -0700 Subject: [PATCH 16/19] update makeRegistrationEntryWithTTL Signed-off-by: Prasad Borole --- pkg/agent/manager/cache/lru_cache_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/agent/manager/cache/lru_cache_test.go b/pkg/agent/manager/cache/lru_cache_test.go index 9df6828e7b..088a191389 100644 --- a/pkg/agent/manager/cache/lru_cache_test.go +++ b/pkg/agent/manager/cache/lru_cache_test.go @@ -509,7 +509,7 @@ func TestLRUCacheCheckSVIDCallback(t *testing.T) { func TestLRUCacheGetStaleEntries(t *testing.T) { cache := newTestLRUCache() - bar := makeRegistrationEntryWithTTL("BAR", 120) + bar := makeRegistrationEntryWithTTL("BAR", 120, "B") // Create entry but don't mark it stale from checkSVID method; // it will be marked stale cause it does not have SVID cached From b2efd63f4970c127fabd3ebb39fa7d6cdc958033 Mon Sep 17 00:00:00 2001 From: Prasad Borole Date: Thu, 25 Aug 2022 12:44:03 -0700 Subject: [PATCH 17/19] update constant Signed-off-by: Prasad Borole --- pkg/agent/manager/cache/lru_cache.go | 25 ++++++++++++----------- pkg/agent/manager/cache/lru_cache_test.go | 2 +- pkg/agent/manager/manager.go | 2 +- pkg/agent/manager/manager_test.go | 2 +- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/pkg/agent/manager/cache/lru_cache.go b/pkg/agent/manager/cache/lru_cache.go index 8516ca51d5..0db4836d91 100644 --- a/pkg/agent/manager/cache/lru_cache.go +++ b/pkg/agent/manager/cache/lru_cache.go @@ -17,7 +17,7 @@ import ( const ( DefaultSVIDCacheMaxSize = 1000 - SvidSyncInterval = 500 * time.Millisecond + SVIDSyncInterval = 500 * time.Millisecond ) // Cache caches each registration entry, bundles, and JWT SVIDs for the agent. @@ -25,9 +25,10 @@ const ( // It allows subscriptions by (workload) selector sets and notifies subscribers when: // // 1) a registration entry related to the selectors: -// * is modified -// * has a new X509-SVID signed for it -// * federates with a federated bundle that is updated +// - is modified +// - has a new X509-SVID signed for it +// - federates with a federated bundle that is updated +// // 2) the trust bundle for the agent trust domain is updated // // When notified, the subscriber is given a WorkloadUpdate containing @@ -38,13 +39,13 @@ const ( // workloads) and registration entries that have that selector. // // The LRU-like SVID cache has configurable size limit and expiry period. -// 1. Size limit of SVID cache is a soft limit. If SVID has a subscriber present then -// that SVID is never removed from cache. -// 2. Least recently used SVIDs are removed from cache only after the cache expiry period has passed. -// This is done to reduce the overall cache churn. -// 3. Last access timestamp for SVID cache entry is updated when a new subscriber is created -// 4. When a new subscriber is created and there is a cache miss -// then subscriber needs to wait for next SVID sync event to receive WorkloadUpdate with newly minted SVID +// 1. Size limit of SVID cache is a soft limit. If SVID has a subscriber present then +// that SVID is never removed from cache. +// 2. Least recently used SVIDs are removed from cache only after the cache expiry period has passed. +// This is done to reduce the overall cache churn. +// 3. Last access timestamp for SVID cache entry is updated when a new subscriber is created +// 4. When a new subscriber is created and there is a cache miss +// then subscriber needs to wait for next SVID sync event to receive WorkloadUpdate with newly minted SVID // // The advantage of above approach is that if agent has entry count less than cache size // then all SVIDs are cached at all times. If agent has entry count greater than cache size then @@ -129,7 +130,7 @@ func NewLRUCache(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundl svidCacheMaxSize: svidCacheMaxSize, clk: clk, subscribeBackoffFn: func() backoff.BackOff { - return backoff.NewBackoff(clk, SvidSyncInterval) + return backoff.NewBackoff(clk, SVIDSyncInterval) }, } } diff --git a/pkg/agent/manager/cache/lru_cache_test.go b/pkg/agent/manager/cache/lru_cache_test.go index 088a191389..8fd5ea2bce 100644 --- a/pkg/agent/manager/cache/lru_cache_test.go +++ b/pkg/agent/manager/cache/lru_cache_test.go @@ -836,7 +836,7 @@ func TestSubscribeToLRUCacheChanges(t *testing.T) { }) assert.Equal(t, 2, cache.CountSVIDs()) - clk.Add(SvidSyncInterval * 2) + clk.Add(SVIDSyncInterval * 2) sub1Err := <-sub1ErrCh assert.NoError(t, sub1Err, "subscriber 1 error") diff --git a/pkg/agent/manager/manager.go b/pkg/agent/manager/manager.go index 775623acbf..2ea76551fc 100644 --- a/pkg/agent/manager/manager.go +++ b/pkg/agent/manager/manager.go @@ -147,7 +147,7 @@ func (m *manager) Initialize(ctx context.Context) error { m.storeBundle(m.cache.Bundle()) m.synchronizeBackoff = backoff.NewBackoff(m.clk, m.c.SyncInterval) - m.svidSyncBackoff = backoff.NewBackoff(m.clk, cache.SvidSyncInterval) + m.svidSyncBackoff = backoff.NewBackoff(m.clk, cache.SVIDSyncInterval) err := m.synchronize(ctx) if nodeutil.ShouldAgentReattest(err) { diff --git a/pkg/agent/manager/manager_test.go b/pkg/agent/manager/manager_test.go index 49a7a0c9a7..71b2ba65bf 100644 --- a/pkg/agent/manager/manager_test.go +++ b/pkg/agent/manager/manager_test.go @@ -1043,7 +1043,7 @@ func TestSyncSVIDsWithLRUCache(t *testing.T) { // keep clk moving so that subscriber keeps looking for svid go func(ctx context.Context) { for { - clk.Add(cache.SvidSyncInterval) + clk.Add(cache.SVIDSyncInterval) if ctx.Err() != nil { return } From d1aacc4c9a6f80cdbb049a89881f94c46e260d34 Mon Sep 17 00:00:00 2001 From: Prasad Borole Date: Wed, 31 Aug 2022 10:37:52 -0700 Subject: [PATCH 18/19] addressed comments Signed-off-by: Prasad Borole --- pkg/agent/manager/cache/cache.go | 4 ---- pkg/agent/manager/cache/cache_test.go | 5 +++++ pkg/agent/manager/cache/lru_cache.go | 6 ------ pkg/agent/manager/cache/util.go | 13 +++++++++++++ pkg/agent/manager/config.go | 14 +++++++------- pkg/agent/manager/manager.go | 6 +++--- pkg/agent/manager/manager_test.go | 5 +---- 7 files changed, 29 insertions(+), 24 deletions(-) create mode 100644 pkg/agent/manager/cache/util.go diff --git a/pkg/agent/manager/cache/cache.go b/pkg/agent/manager/cache/cache.go index 1cac536861..7f98a7396b 100644 --- a/pkg/agent/manager/cache/cache.go +++ b/pkg/agent/manager/cache/cache.go @@ -449,10 +449,6 @@ func (c *Cache) MatchingRegistrationEntries(selectors []*common.Selector) []*com records, recordsDone := c.getRecordsForSelectors(set) defer recordsDone() - if len(records) == 0 { - return nil - } - // Return identities in ascending "entry id" order to maintain a consistent // ordering. // TODO: figure out how to determine the "default" identity diff --git a/pkg/agent/manager/cache/cache_test.go b/pkg/agent/manager/cache/cache_test.go index 0a048e5c9c..8f8372842e 100644 --- a/pkg/agent/manager/cache/cache_test.go +++ b/pkg/agent/manager/cache/cache_test.go @@ -695,6 +695,11 @@ func TestMatchingRegistrationEntries(t *testing.T) { // populate the cache with FOO and BAR without SVIDS foo := makeRegistrationEntry("FOO", "A") bar := makeRegistrationEntry("BAR", "B") + + // check empty result + assert.Equal(t, []*common.RegistrationEntry{}, + cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) + updateEntries := &UpdateEntries{ Bundles: makeBundles(bundleV1), RegistrationEntries: makeRegistrationEntries(foo, bar), diff --git a/pkg/agent/manager/cache/lru_cache.go b/pkg/agent/manager/cache/lru_cache.go index 0db4836d91..ad09719ac0 100644 --- a/pkg/agent/manager/cache/lru_cache.go +++ b/pkg/agent/manager/cache/lru_cache.go @@ -915,12 +915,6 @@ func newSelectorsMapIndex() *selectorsMapIndex { } } -func sortEntriesByID(entries []*common.RegistrationEntry) { - sort.Slice(entries, func(a, b int) bool { - return entries[a].EntryId < entries[b].EntryId - }) -} - func sortByTimestamps(records []recordAccessEvent) { sort.Slice(records, func(a, b int) bool { return records[a].timestamp < records[b].timestamp diff --git a/pkg/agent/manager/cache/util.go b/pkg/agent/manager/cache/util.go new file mode 100644 index 0000000000..ab365514fd --- /dev/null +++ b/pkg/agent/manager/cache/util.go @@ -0,0 +1,13 @@ +package cache + +import ( + "sort" + + "github.com/spiffe/spire/proto/spire/common" +) + +func sortEntriesByID(entries []*common.RegistrationEntry) { + sort.Slice(entries, func(a, b int) bool { + return entries[a].EntryId < entries[b].EntryId + }) +} diff --git a/pkg/agent/manager/config.go b/pkg/agent/manager/config.go index c05963bc32..3a3fe11eee 100644 --- a/pkg/agent/manager/config.go +++ b/pkg/agent/manager/config.go @@ -9,7 +9,7 @@ import ( "github.com/sirupsen/logrus" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/catalog" - "github.com/spiffe/spire/pkg/agent/manager/cache" + managerCache "github.com/spiffe/spire/pkg/agent/manager/cache" "github.com/spiffe/spire/pkg/agent/manager/storecache" "github.com/spiffe/spire/pkg/agent/plugin/keymanager" "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor" @@ -24,7 +24,7 @@ type Config struct { // Agent SVID and key resulting from successful attestation. SVID []*x509.Certificate SVIDKey keymanager.Key - Bundle *cache.Bundle + Bundle *managerCache.Bundle Reattestable bool Catalog catalog.Catalog TrustDomain spiffeid.TrustDomain @@ -61,13 +61,13 @@ func newManager(c *Config) *manager { c.Clk = clock.New() } - var x509SVIDCache Cache + var cache Cache if c.SVIDCacheMaxSize > 0 { // use LRU cache implementation - x509SVIDCache = cache.NewLRUCache(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, + cache = managerCache.NewLRUCache(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, c.Metrics, c.SVIDCacheMaxSize, c.Clk) } else { - x509SVIDCache = cache.New(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, + cache = managerCache.New(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, c.Metrics) } @@ -77,7 +77,7 @@ func newManager(c *Config) *manager { Metrics: c.Metrics, SVID: c.SVID, SVIDKey: c.SVIDKey, - BundleStream: x509SVIDCache.SubscribeToBundleChanges(), + BundleStream: cache.SubscribeToBundleChanges(), ServerAddr: c.ServerAddr, TrustDomain: c.TrustDomain, Interval: c.RotationInterval, @@ -88,7 +88,7 @@ func newManager(c *Config) *manager { svidRotator, client := svid.NewRotator(rotCfg) m := &manager{ - cache: x509SVIDCache, + cache: cache, c: c, mtx: new(sync.RWMutex), svid: svidRotator, diff --git a/pkg/agent/manager/manager.go b/pkg/agent/manager/manager.go index c2fbc10fe7..eada770c98 100644 --- a/pkg/agent/manager/manager.go +++ b/pkg/agent/manager/manager.go @@ -99,13 +99,13 @@ type Cache interface { // CountSVIDs in cache stored CountSVIDs() int - // FetchWorkloadUpdate for giveb selectors + // FetchWorkloadUpdate for given selectors FetchWorkloadUpdate(selectors []*common.Selector) *cache.WorkloadUpdate - // GetJWTSVID provides JWTSVID + // GetJWTSVID provides JWT-SVID GetJWTSVID(id spiffeid.ID, audience []string) (*client.JWTSVID, bool) - // SetJWTSVID adds JWTSVID to cache + // SetJWTSVID adds JWT-SVID to cache SetJWTSVID(id spiffeid.ID, audience []string, svid *client.JWTSVID) // Entries get all registration entries diff --git a/pkg/agent/manager/manager_test.go b/pkg/agent/manager/manager_test.go index 4c2e59e1d3..73eef1edce 100644 --- a/pkg/agent/manager/manager_test.go +++ b/pkg/agent/manager/manager_test.go @@ -1036,10 +1036,7 @@ func TestSyncSVIDsWithLRUCache(t *testing.T) { syncErrCh := make(chan error, 1) // run svid sync go func(ctx context.Context) { - if err := m.runSyncSVIDs(ctx); err != nil { - syncErrCh <- err - } - syncErrCh <- nil + syncErrCh <- m.runSyncSVIDs(ctx) }(ctx) // keep clk moving so that subscriber keeps looking for svid From 5b9d6b1b8084b3e72ec2a638cb45d653fd686abf Mon Sep 17 00:00:00 2001 From: Prasad Borole Date: Fri, 9 Sep 2022 14:58:51 -0700 Subject: [PATCH 19/19] putting updateSVID under lock Signed-off-by: Prasad Borole --- pkg/agent/manager/manager.go | 2 ++ pkg/agent/manager/sync.go | 53 ++++++++++++++++++------------------ 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/pkg/agent/manager/manager.go b/pkg/agent/manager/manager.go index eada770c98..c82aba287d 100644 --- a/pkg/agent/manager/manager.go +++ b/pkg/agent/manager/manager.go @@ -120,6 +120,8 @@ type manager struct { // Fields protected by mtx mutex. mtx *sync.RWMutex + // Protects multiple goroutines from requesting SVID signings at the same time + updateSVIDMu sync.RWMutex cache Cache svid svid.Rotator diff --git a/pkg/agent/manager/sync.go b/pkg/agent/manager/sync.go index 8660dd50cf..25b1dfb402 100644 --- a/pkg/agent/manager/sync.go +++ b/pkg/agent/manager/sync.go @@ -40,10 +40,7 @@ func (m *manager) syncSVIDs(ctx context.Context) (err error) { // perform syncSVIDs only if using LRU cache if m.c.SVIDCacheMaxSize > 0 { m.cache.SyncSVIDsWithSubscribers() - staleEntries := m.cache.GetStaleEntries() - if len(staleEntries) > 0 { - return m.updateSVIDs(ctx, staleEntries, m.cache) - } + return m.updateSVIDs(ctx, m.c.Log.WithField(telemetry.CacheType, "workload"), m.cache) } return nil } @@ -109,39 +106,41 @@ func (m *manager) updateCache(ctx context.Context, update *cache.UpdateEntries, log.WithField(telemetry.OutdatedSVIDs, outdated).Debug("Updating SVIDs with outdated attributes in cache") } + return m.updateSVIDs(ctx, log, c) +} + +func (m *manager) updateSVIDs(ctx context.Context, log logrus.FieldLogger, c SVIDCache) error { + m.updateSVIDMu.Lock() + defer m.updateSVIDMu.Unlock() + staleEntries := c.GetStaleEntries() if len(staleEntries) > 0 { + var csrs []csrRequest log.WithFields(logrus.Fields{ telemetry.Count: len(staleEntries), telemetry.Limit: limits.SignLimitPerIP, }).Debug("Renewing stale entries") - return m.updateSVIDs(ctx, staleEntries, c) - } - return nil -} -func (m *manager) updateSVIDs(ctx context.Context, entries []*cache.StaleEntry, c SVIDCache) error { - var csrs []csrRequest - for _, entry := range entries { - // we've exceeded the CSR limit, don't make any more CSRs - if len(csrs) >= limits.SignLimitPerIP { - break + for _, entry := range staleEntries { + // we've exceeded the CSR limit, don't make any more CSRs + if len(csrs) >= limits.SignLimitPerIP { + break + } + + csrs = append(csrs, csrRequest{ + EntryID: entry.Entry.EntryId, + SpiffeID: entry.Entry.SpiffeId, + CurrentSVIDExpiresAt: entry.ExpiresAt, + }) } - csrs = append(csrs, csrRequest{ - EntryID: entry.Entry.EntryId, - SpiffeID: entry.Entry.SpiffeId, - CurrentSVIDExpiresAt: entry.ExpiresAt, - }) - } - - update, err := m.fetchSVIDs(ctx, csrs) - if err != nil { - return err + update, err := m.fetchSVIDs(ctx, csrs) + if err != nil { + return err + } + // the values in `update` now belong to the cache. DO NOT MODIFY. + c.UpdateSVIDs(update) } - // the values in `update` now belong to the cache. DO NOT MODIFY. - c.UpdateSVIDs(update) - return nil }