From 5b38c56faa7551889c5c1a27125320523a418cee Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 11:18:28 +0100 Subject: [PATCH 01/15] Fix routes state race condition --- .../internal/routemanager/systemops/state.go | 23 ++++++++----------- .../systemops/systemops_generic.go | 6 +---- 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go index 42590892297..8e158711e50 100644 --- a/client/internal/routemanager/systemops/state.go +++ b/client/internal/routemanager/systemops/state.go @@ -2,31 +2,28 @@ package systemops import ( "net/netip" - "sync" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" ) -type ShutdownState struct { - Counter *ExclusionCounter `json:"counter,omitempty"` - mu sync.RWMutex -} +type ShutdownState ExclusionCounter func (s *ShutdownState) Name() string { return "route_state" } func (s *ShutdownState) Cleanup() error { - s.mu.RLock() - defer s.mu.RUnlock() - - if s.Counter == nil { - return nil - } - sysops := NewSysOps(nil, nil) sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) - sysops.refCounter.LoadData(s.Counter) + sysops.refCounter.LoadData((*ExclusionCounter)(s)) return sysops.refCounter.Flush() } + +func (s *ShutdownState) MarshalJSON() ([]byte, error) { + return (*ExclusionCounter)(s).MarshalJSON() +} + +func (s *ShutdownState) UnmarshalJSON(data []byte) error { + return (*ExclusionCounter)(s).UnmarshalJSON(data) +} diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 4ff34aa5162..d1e1bf0fd8a 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -76,11 +76,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana } func (r *SysOps) updateState(stateManager *statemanager.Manager) { - state := getState(stateManager) - - state.Counter = r.refCounter - - if err := stateManager.UpdateState(state); err != nil { + if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil { log.Errorf("failed to update state: %v", err) } } From 3c95f6fc20a1048d8de8d86d6dc0d679eed9ef9b Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 12:59:39 +0100 Subject: [PATCH 02/15] Ensure lock is in place during marshaling --- client/internal/statemanager/manager.go | 8 ++++++-- util/file.go | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 580ccdfc78a..8b085b882d2 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -179,14 +179,18 @@ func (m *Manager) PersistState(ctx context.Context) error { return nil } + bs, err := json.MarshalIndent(m.states, "", " ") + if err != nil { + return fmt.Errorf("marshal states: %w", err) + } + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() done := make(chan error, 1) - start := time.Now() go func() { - done <- util.WriteJsonWithRestrictedPermission(ctx, m.filePath, m.states) + done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs) }() select { diff --git a/util/file.go b/util/file.go index 4641cc1b825..7be5742b3e4 100644 --- a/util/file.go +++ b/util/file.go @@ -14,6 +14,19 @@ import ( log "github.com/sirupsen/logrus" ) +func WriteBytesWithRestrictedPermission(ctx context.Context, file string, bs []byte) error { + configDir, configFileName, err := prepareConfigFileDir(file) + if err != nil { + return fmt.Errorf("prepare config file dir: %w", err) + } + + if err = EnforcePermission(file); err != nil { + return fmt.Errorf("enfore permission: %w", err) + } + + return writeBytes(ctx, file, err, configDir, configFileName, bs) +} + // WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error { configDir, configFileName, err := prepareConfigFileDir(file) @@ -91,6 +104,10 @@ func writeJson(ctx context.Context, file string, obj interface{}, configDir stri return err } + return writeBytes(ctx, file, err, configDir, configFileName, bs) +} + +func writeBytes(ctx context.Context, file string, err error, configDir string, configFileName string, bs []byte) error { if ctx.Err() != nil { return ctx.Err() } From 00a4edc812f58613c4d2a2cbd129553865439de2 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 13:43:34 +0100 Subject: [PATCH 03/15] Add file deadline --- util/file.go | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/util/file.go b/util/file.go index 7be5742b3e4..f547cd76c41 100644 --- a/util/file.go +++ b/util/file.go @@ -114,14 +114,26 @@ func writeBytes(ctx context.Context, file string, err error, configDir string, c tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) if err != nil { - return err + return fmt.Errorf("create temp: %w", err) } tempFileName := tempFile.Name() - // closing file ops as windows doesn't allow to move it - err = tempFile.Close() + + if deadline, ok := ctx.Deadline(); ok { + if err := tempFile.SetDeadline(deadline); err != nil { + //if err := tempFile.SetDeadline(deadline); err != nil && !errors.Is(err, os.ErrNoDeadline) { + log.Warnf("failed to set write deadline: %v", err) + } + } + + _, err = tempFile.Write(bs) if err != nil { - return err + _ = tempFile.Close() + return fmt.Errorf("write: %w", err) + } + + if err = tempFile.Close(); err != nil { + return fmt.Errorf("close %s: %w", tempFileName, err) } defer func() { @@ -131,19 +143,13 @@ func writeBytes(ctx context.Context, file string, err error, configDir string, c } }() - err = os.WriteFile(tempFileName, bs, 0600) - if err != nil { - return err - } - // Check context again if ctx.Err() != nil { return ctx.Err() } - err = os.Rename(tempFileName, file) - if err != nil { - return err + if err = os.Rename(tempFileName, file); err != nil { + return fmt.Errorf("move %s to %s: %w", tempFileName, file, err) } return nil From cd0dbae1ecdf3f29ad8bece4e51c54ee6fc84a6c Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 14:00:13 +0100 Subject: [PATCH 04/15] Add test --- client/internal/statemanager/manager_test.go | 82 ++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 client/internal/statemanager/manager_test.go diff --git a/client/internal/statemanager/manager_test.go b/client/internal/statemanager/manager_test.go new file mode 100644 index 00000000000..f3ca8187fa6 --- /dev/null +++ b/client/internal/statemanager/manager_test.go @@ -0,0 +1,82 @@ +package statemanager + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// MockState implements the State interface for testing +type MockState struct { +} + +func (m MockState) Name() string { + return "mock_state" +} + +func (m MockState) Cleanup() error { + return nil +} + +func TestManager_PersistState_SlowWrite(t *testing.T) { + tmpDir := t.TempDir() + + tests := []struct { + name string + contextTimeout time.Duration + expectError bool + errorType error + }{ + { + name: "write completes before deadline", + contextTimeout: 1 * time.Second, + expectError: false, + }, + { + name: "write exceeds deadline", + contextTimeout: 0, + expectError: true, + errorType: context.DeadlineExceeded, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stateFile := filepath.Join(tmpDir, tt.name+"-state.json") + + file, err := os.Create(stateFile) + require.NoError(t, err) + defer file.Close() + + m := New(stateFile) + + // Register and update mock state + mockState := &MockState{} + m.RegisterState(mockState) + err = m.UpdateState(mockState) + require.NoError(t, err) + + // Create context with timeout + ctx, cancel := context.WithTimeout(context.Background(), tt.contextTimeout) + defer cancel() + + // Attempt to persist state + err = m.PersistState(ctx) + + if tt.expectError { + assert.Error(t, err) + assert.Equal(t, context.DeadlineExceeded, err) + assert.Len(t, m.dirty, 1) + } else { + assert.NoError(t, err) + assert.FileExists(t, stateFile) + assert.Empty(t, m.dirty) + } + }) + } +} From e07caa8f0514bcd81d6a9dff0feedda286a80224 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 14:18:47 +0100 Subject: [PATCH 05/15] Remove unused function --- .../routemanager/systemops/systemops_generic.go | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index d1e1bf0fd8a..f8b3ebbb8c1 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -62,7 +62,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana return nexthop, err }, func(prefix netip.Prefix, nexthop Nexthop) error { - // remove from state even if we have trouble removing it from the route table + // update state even if we have trouble removing it from the route table // it could be already gone r.updateState(stateManager) @@ -75,6 +75,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana return r.setupHooks(initAddresses) } +// updateState updates state on every change so it will be persisted regularly func (r *SysOps) updateState(stateManager *statemanager.Manager) { if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil { log.Errorf("failed to update state: %v", err) @@ -528,14 +529,3 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P // Return true if the longest matching prefix is from vpnRoutes return isVpn, longestPrefix } - -func getState(stateManager *statemanager.Manager) *ShutdownState { - var shutdownState *ShutdownState - if state := stateManager.GetState(shutdownState); state != nil { - shutdownState = state.(*ShutdownState) - } else { - shutdownState = &ShutdownState{} - } - - return shutdownState -} From 9a56fc0137aa132cac83cf88236379f592a7012b Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 15:30:26 +0100 Subject: [PATCH 06/15] Catch marshal panics --- client/internal/statemanager/manager.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 8b085b882d2..263806ab0ad 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -179,7 +179,7 @@ func (m *Manager) PersistState(ctx context.Context) error { return nil } - bs, err := json.MarshalIndent(m.states, "", " ") + bs, err := marshalWithPanicRecovery(m.states) if err != nil { return fmt.Errorf("marshal states: %w", err) } @@ -290,3 +290,19 @@ func (m *Manager) PerformCleanup() error { return nberrors.FormatErrorOrNil(merr) } + +func marshalWithPanicRecovery(v any) ([]byte, error) { + var bs []byte + var err error + + func() { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic during marshal: %v", r) + } + }() + bs, err = json.MarshalIndent(v, "", " ") + }() + + return bs, err +} From 81f0810918a03b70a06b6ccc88dcd793469f5b5d Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 20:14:24 +0100 Subject: [PATCH 07/15] Don't prettify json --- client/internal/statemanager/manager.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 263806ab0ad..7c9d8742720 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -301,7 +301,7 @@ func marshalWithPanicRecovery(v any) ([]byte, error) { err = fmt.Errorf("panic during marshal: %v", r) } }() - bs, err = json.MarshalIndent(v, "", " ") + bs, err = json.Marshal(v) }() return bs, err From 41c9c395b45409fcb0837aa131c2976bbacddbf6 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 20:14:31 +0100 Subject: [PATCH 08/15] Add error context --- util/file.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/util/file.go b/util/file.go index f547cd76c41..e75c988de06 100644 --- a/util/file.go +++ b/util/file.go @@ -95,13 +95,13 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error { func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error { // Check context before expensive operations if ctx.Err() != nil { - return ctx.Err() + return fmt.Errorf("write json start: %w", ctx.Err()) } // make it pretty bs, err := json.MarshalIndent(obj, "", " ") if err != nil { - return err + return fmt.Errorf("marshal: %w", err) } return writeBytes(ctx, file, err, configDir, configFileName, bs) @@ -109,7 +109,7 @@ func writeJson(ctx context.Context, file string, obj interface{}, configDir stri func writeBytes(ctx context.Context, file string, err error, configDir string, configFileName string, bs []byte) error { if ctx.Err() != nil { - return ctx.Err() + return fmt.Errorf("write bytes start: %w", ctx.Err()) } tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) @@ -145,7 +145,7 @@ func writeBytes(ctx context.Context, file string, err error, configDir string, c // Check context again if ctx.Err() != nil { - return ctx.Err() + return fmt.Errorf("after temp file: %w", ctx.Err()) } if err = os.Rename(tempFileName, file); err != nil { From 3c581d8fdca1763c954a78f5f3a3eeed0fe0019f Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 20:28:21 +0100 Subject: [PATCH 09/15] Don't use cancelled contexts --- client/internal/dns/server.go | 20 +++++++------------- client/internal/engine.go | 3 +-- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 6c4dccae74a..f0277319cd5 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -7,7 +7,6 @@ import ( "runtime" "strings" "sync" - "time" "github.com/miekg/dns" "github.com/mitchellh/hashstructure/v2" @@ -323,13 +322,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { log.Error(err) } - // persist dns state right away - ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) - defer cancel() - - // don't block go func() { - if err := s.stateManager.PersistState(ctx); err != nil { + // persist dns state right away + if err := s.stateManager.PersistState(s.ctx); err != nil { log.Errorf("Failed to persist dns state: %v", err) } }() @@ -537,12 +532,11 @@ func (s *DefaultServer) upstreamCallbacks( l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) } - // persist dns state right away - ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) - defer cancel() - if err := s.stateManager.PersistState(ctx); err != nil { - l.Errorf("Failed to persist dns state: %v", err) - } + go func() { + if err := s.stateManager.PersistState(s.ctx); err != nil { + l.Errorf("Failed to persist dns state: %v", err) + } + }() if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { s.addHostRootZone() diff --git a/client/internal/engine.go b/client/internal/engine.go index 190d795cdbe..cce69b6d79a 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -38,7 +38,6 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" - nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" @@ -297,7 +296,7 @@ func (e *Engine) Stop() error { if err := e.stateManager.Stop(ctx); err != nil { return fmt.Errorf("failed to stop state manager: %w", err) } - if err := e.stateManager.PersistState(ctx); err != nil { + if err := e.stateManager.PersistState(context.Background()); err != nil { log.Errorf("failed to persist state: %v", err) } From e1af05654b80f8e2247558217183dff2f024f775 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 20:30:24 +0100 Subject: [PATCH 10/15] Ignore deadline not supported error --- util/file.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/util/file.go b/util/file.go index e75c988de06..fcb1b5184a4 100644 --- a/util/file.go +++ b/util/file.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "os" @@ -120,9 +121,8 @@ func writeBytes(ctx context.Context, file string, err error, configDir string, c tempFileName := tempFile.Name() if deadline, ok := ctx.Deadline(); ok { - if err := tempFile.SetDeadline(deadline); err != nil { - //if err := tempFile.SetDeadline(deadline); err != nil && !errors.Is(err, os.ErrNoDeadline) { - log.Warnf("failed to set write deadline: %v", err) + if err := tempFile.SetDeadline(deadline); err != nil && !errors.Is(err, os.ErrNoDeadline) { + log.Warnf("failed to set deadline: %v", err) } } From eceab3669791a3b2fb6e67c8827b4b5c9786e5f0 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 14 Nov 2024 21:24:34 +0100 Subject: [PATCH 11/15] Avoid deadlock on stop --- client/internal/statemanager/manager.go | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 7c9d8742720..4feef38e09c 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -71,18 +71,20 @@ func (m *Manager) Stop(ctx context.Context) error { return nil } + var cancel context.CancelFunc m.mu.Lock() - defer m.mu.Unlock() + cancel = m.cancel + m.mu.Unlock() - if m.cancel != nil { - m.cancel() + if cancel == nil { + return nil + } + cancel() - select { - case <-ctx.Done(): - return ctx.Err() - case <-m.done: - return nil - } + select { + case <-ctx.Done(): + return ctx.Err() + case <-m.done: } return nil From 75a3f80a2e522ab77ae52c4efa685d9036c36db8 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 15 Nov 2024 11:08:52 +0100 Subject: [PATCH 12/15] Fix deadlock from mutex ordering --- .../routemanager/refcounter/refcounter.go | 58 +++++++++---------- 1 file changed, 28 insertions(+), 30 deletions(-) diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index 0e230ef4039..f2f0a169df0 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -47,10 +47,9 @@ type RemoveFunc[Key, O any] func(key Key, out O) error type Counter[Key comparable, I, O any] struct { // refCountMap keeps track of the reference Ref for keys refCountMap map[Key]Ref[O] - refCountMu sync.Mutex + mu sync.Mutex // idMap keeps track of the keys associated with an ID for removal idMap map[string][]Key - idMu sync.Mutex add AddFunc[Key, I, O] remove RemoveFunc[Key, O] } @@ -75,10 +74,8 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key func (rm *Counter[Key, I, O]) LoadData( existingCounter *Counter[Key, I, O], ) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() rm.refCountMap = existingCounter.refCountMap rm.idMap = existingCounter.idMap @@ -87,8 +84,8 @@ func (rm *Counter[Key, I, O]) LoadData( // Get retrieves the current reference count and associated data for a key. // If the key doesn't exist, it returns a zero value Ref and false. func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() ref, ok := rm.refCountMap[key] return ref, ok @@ -97,9 +94,13 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { // Increment increments the reference count for the given key. // If this is the first reference to the key, the AddFunc is called. func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() + return rm.increment(key, in) +} + +func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) { ref := rm.refCountMap[key] logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out) @@ -126,10 +127,10 @@ func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) { // IncrementWithID increments the reference count for the given key and groups it under the given ID. // If this is the first reference to the key, the AddFunc is called. func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) { - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() - ref, err := rm.Increment(key, in) + ref, err := rm.increment(key, in) if err != nil { return ref, fmt.Errorf("with ID: %w", err) } @@ -141,9 +142,12 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], // Decrement decrements the reference count for the given key. // If the reference count reaches 0, the RemoveFunc is called. func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() + return rm.decrement(key) +} +func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) { ref, ok := rm.refCountMap[key] if !ok { logCallerF("No reference found for key %v", key) @@ -168,12 +172,12 @@ func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) { // DecrementWithID decrements the reference count for all keys associated with the given ID. // If the reference count reaches 0, the RemoveFunc is called. func (rm *Counter[Key, I, O]) DecrementWithID(id string) error { - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() var merr *multierror.Error for _, key := range rm.idMap[id] { - if _, err := rm.Decrement(key); err != nil { + if _, err := rm.decrement(key); err != nil { merr = multierror.Append(merr, err) } } @@ -184,10 +188,8 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error { // Flush removes all references and calls RemoveFunc for each key. func (rm *Counter[Key, I, O]) Flush() error { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() var merr *multierror.Error for key := range rm.refCountMap { @@ -206,10 +208,8 @@ func (rm *Counter[Key, I, O]) Flush() error { // Clear removes all references without calling RemoveFunc. func (rm *Counter[Key, I, O]) Clear() { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() clear(rm.refCountMap) clear(rm.idMap) @@ -217,10 +217,8 @@ func (rm *Counter[Key, I, O]) Clear() { // MarshalJSON implements the json.Marshaler interface for Counter. func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() return json.Marshal(struct { RefCountMap map[Key]Ref[O] `json:"refCountMap"` From c8ec1bc06470d6da362da7f23d0e9c4ecd274384 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 15 Nov 2024 15:48:24 +0100 Subject: [PATCH 13/15] Remove broken test --- client/internal/statemanager/manager_test.go | 82 -------------------- 1 file changed, 82 deletions(-) delete mode 100644 client/internal/statemanager/manager_test.go diff --git a/client/internal/statemanager/manager_test.go b/client/internal/statemanager/manager_test.go deleted file mode 100644 index f3ca8187fa6..00000000000 --- a/client/internal/statemanager/manager_test.go +++ /dev/null @@ -1,82 +0,0 @@ -package statemanager - -import ( - "context" - "os" - "path/filepath" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// MockState implements the State interface for testing -type MockState struct { -} - -func (m MockState) Name() string { - return "mock_state" -} - -func (m MockState) Cleanup() error { - return nil -} - -func TestManager_PersistState_SlowWrite(t *testing.T) { - tmpDir := t.TempDir() - - tests := []struct { - name string - contextTimeout time.Duration - expectError bool - errorType error - }{ - { - name: "write completes before deadline", - contextTimeout: 1 * time.Second, - expectError: false, - }, - { - name: "write exceeds deadline", - contextTimeout: 0, - expectError: true, - errorType: context.DeadlineExceeded, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - stateFile := filepath.Join(tmpDir, tt.name+"-state.json") - - file, err := os.Create(stateFile) - require.NoError(t, err) - defer file.Close() - - m := New(stateFile) - - // Register and update mock state - mockState := &MockState{} - m.RegisterState(mockState) - err = m.UpdateState(mockState) - require.NoError(t, err) - - // Create context with timeout - ctx, cancel := context.WithTimeout(context.Background(), tt.contextTimeout) - defer cancel() - - // Attempt to persist state - err = m.PersistState(ctx) - - if tt.expectError { - assert.Error(t, err) - assert.Equal(t, context.DeadlineExceeded, err) - assert.Len(t, m.dirty, 1) - } else { - assert.NoError(t, err) - assert.FileExists(t, stateFile) - assert.Empty(t, m.dirty) - } - }) - } -} From 773dbb80c02564b2435acfd2df6c7ce07015d85e Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 15 Nov 2024 15:55:48 +0100 Subject: [PATCH 14/15] Lock the whole Stop method --- client/internal/statemanager/manager.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 4feef38e09c..da6dd022fc2 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -71,15 +71,13 @@ func (m *Manager) Stop(ctx context.Context) error { return nil } - var cancel context.CancelFunc m.mu.Lock() - cancel = m.cancel - m.mu.Unlock() + defer m.mu.Unlock() - if cancel == nil { + if m.cancel == nil { return nil } - cancel() + m.cancel() select { case <-ctx.Done(): From 18cfe2f4c7c60d9cea28e9f8f9461c1b9c1f619b Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 15 Nov 2024 16:02:02 +0100 Subject: [PATCH 15/15] Fix spelling --- util/file.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/util/file.go b/util/file.go index fcb1b5184a4..f7de7ede26b 100644 --- a/util/file.go +++ b/util/file.go @@ -22,7 +22,7 @@ func WriteBytesWithRestrictedPermission(ctx context.Context, file string, bs []b } if err = EnforcePermission(file); err != nil { - return fmt.Errorf("enfore permission: %w", err) + return fmt.Errorf("enforce permission: %w", err) } return writeBytes(ctx, file, err, configDir, configFileName, bs)