Skip to content

Commit

Permalink
[client] Fix state manager race conditions (#2890)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixmal authored Nov 15, 2024
1 parent a1c5287 commit 121dfda
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 100 deletions.
20 changes: 7 additions & 13 deletions client/internal/dns/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"runtime"
"strings"
"sync"
"time"

"github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2"
Expand Down Expand Up @@ -323,13 +322,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
log.Error(err)
}

// persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()

// don't block
go func() {
if err := s.stateManager.PersistState(ctx); err != nil {
// persist dns state right away
if err := s.stateManager.PersistState(s.ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}
}()
Expand Down Expand Up @@ -537,12 +532,11 @@ func (s *DefaultServer) upstreamCallbacks(
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
}

// persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()
if err := s.stateManager.PersistState(ctx); err != nil {
l.Errorf("Failed to persist dns state: %v", err)
}
go func() {
if err := s.stateManager.PersistState(s.ctx); err != nil {
l.Errorf("Failed to persist dns state: %v", err)
}
}()

if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
s.addHostRootZone()
Expand Down
2 changes: 1 addition & 1 deletion client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ func (e *Engine) Stop() error {
if err := e.stateManager.Stop(ctx); err != nil {
return fmt.Errorf("failed to stop state manager: %w", err)
}
if err := e.stateManager.PersistState(ctx); err != nil {
if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}

Expand Down
58 changes: 28 additions & 30 deletions client/internal/routemanager/refcounter/refcounter.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,9 @@ type RemoveFunc[Key, O any] func(key Key, out O) error
type Counter[Key comparable, I, O any] struct {
// refCountMap keeps track of the reference Ref for keys
refCountMap map[Key]Ref[O]
refCountMu sync.Mutex
mu sync.Mutex
// idMap keeps track of the keys associated with an ID for removal
idMap map[string][]Key
idMu sync.Mutex
add AddFunc[Key, I, O]
remove RemoveFunc[Key, O]
}
Expand All @@ -75,10 +74,8 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
func (rm *Counter[Key, I, O]) LoadData(
existingCounter *Counter[Key, I, O],
) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()

rm.refCountMap = existingCounter.refCountMap
rm.idMap = existingCounter.idMap
Expand All @@ -87,8 +84,8 @@ func (rm *Counter[Key, I, O]) LoadData(
// Get retrieves the current reference count and associated data for a key.
// If the key doesn't exist, it returns a zero value Ref and false.
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()

ref, ok := rm.refCountMap[key]
return ref, ok
Expand All @@ -97,9 +94,13 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
// Increment increments the reference count for the given key.
// If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()

return rm.increment(key, in)
}

func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) {
ref := rm.refCountMap[key]
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)

Expand All @@ -126,10 +127,10 @@ func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
// IncrementWithID increments the reference count for the given key and groups it under the given ID.
// If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()

ref, err := rm.Increment(key, in)
ref, err := rm.increment(key, in)
if err != nil {
return ref, fmt.Errorf("with ID: %w", err)
}
Expand All @@ -141,9 +142,12 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O],
// Decrement decrements the reference count for the given key.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()
return rm.decrement(key)
}

func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) {
ref, ok := rm.refCountMap[key]
if !ok {
logCallerF("No reference found for key %v", key)
Expand All @@ -168,12 +172,12 @@ func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
// DecrementWithID decrements the reference count for all keys associated with the given ID.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()

var merr *multierror.Error
for _, key := range rm.idMap[id] {
if _, err := rm.Decrement(key); err != nil {
if _, err := rm.decrement(key); err != nil {
merr = multierror.Append(merr, err)
}
}
Expand All @@ -184,10 +188,8 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {

// Flush removes all references and calls RemoveFunc for each key.
func (rm *Counter[Key, I, O]) Flush() error {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()

var merr *multierror.Error
for key := range rm.refCountMap {
Expand All @@ -206,21 +208,17 @@ func (rm *Counter[Key, I, O]) Flush() error {

// Clear removes all references without calling RemoveFunc.
func (rm *Counter[Key, I, O]) Clear() {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()

clear(rm.refCountMap)
clear(rm.idMap)
}

// MarshalJSON implements the json.Marshaler interface for Counter.
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()

return json.Marshal(struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
Expand Down
23 changes: 10 additions & 13 deletions client/internal/routemanager/systemops/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,28 @@ package systemops

import (
"net/netip"
"sync"

"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)

type ShutdownState struct {
Counter *ExclusionCounter `json:"counter,omitempty"`
mu sync.RWMutex
}
type ShutdownState ExclusionCounter

func (s *ShutdownState) Name() string {
return "route_state"
}

func (s *ShutdownState) Cleanup() error {
s.mu.RLock()
defer s.mu.RUnlock()

if s.Counter == nil {
return nil
}

sysops := NewSysOps(nil, nil)
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
sysops.refCounter.LoadData(s.Counter)
sysops.refCounter.LoadData((*ExclusionCounter)(s))

return sysops.refCounter.Flush()
}

func (s *ShutdownState) MarshalJSON() ([]byte, error) {
return (*ExclusionCounter)(s).MarshalJSON()
}

func (s *ShutdownState) UnmarshalJSON(data []byte) error {
return (*ExclusionCounter)(s).UnmarshalJSON(data)
}
20 changes: 3 additions & 17 deletions client/internal/routemanager/systemops/systemops_generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
return nexthop, err
},
func(prefix netip.Prefix, nexthop Nexthop) error {
// remove from state even if we have trouble removing it from the route table
// update state even if we have trouble removing it from the route table
// it could be already gone
r.updateState(stateManager)

Expand All @@ -75,12 +75,9 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
return r.setupHooks(initAddresses)
}

// updateState updates state on every change so it will be persisted regularly
func (r *SysOps) updateState(stateManager *statemanager.Manager) {
state := getState(stateManager)

state.Counter = r.refCounter

if err := stateManager.UpdateState(state); err != nil {
if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
Expand Down Expand Up @@ -532,14 +529,3 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P
// Return true if the longest matching prefix is from vpnRoutes
return isVpn, longestPrefix
}

func getState(stateManager *statemanager.Manager) *ShutdownState {
var shutdownState *ShutdownState
if state := stateManager.GetState(shutdownState); state != nil {
shutdownState = state.(*ShutdownState)
} else {
shutdownState = &ShutdownState{}
}

return shutdownState
}
40 changes: 30 additions & 10 deletions client/internal/statemanager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,15 @@ func (m *Manager) Stop(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()

if m.cancel != nil {
m.cancel()
if m.cancel == nil {
return nil
}
m.cancel()

select {
case <-ctx.Done():
return ctx.Err()
case <-m.done:
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case <-m.done:
}

return nil
Expand Down Expand Up @@ -179,14 +179,18 @@ func (m *Manager) PersistState(ctx context.Context) error {
return nil
}

bs, err := marshalWithPanicRecovery(m.states)
if err != nil {
return fmt.Errorf("marshal states: %w", err)
}

ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()

done := make(chan error, 1)

start := time.Now()
go func() {
done <- util.WriteJsonWithRestrictedPermission(ctx, m.filePath, m.states)
done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs)
}()

select {
Expand Down Expand Up @@ -286,3 +290,19 @@ func (m *Manager) PerformCleanup() error {

return nberrors.FormatErrorOrNil(merr)
}

func marshalWithPanicRecovery(v any) ([]byte, error) {
var bs []byte
var err error

func() {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic during marshal: %v", r)
}
}()
bs, err = json.Marshal(v)
}()

return bs, err
}
Loading

0 comments on commit 121dfda

Please sign in to comment.