Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[client] Fix state manager race conditions #2890

Merged
merged 15 commits into from
Nov 15, 2024
20 changes: 7 additions & 13 deletions client/internal/dns/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"runtime"
"strings"
"sync"
"time"

"github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2"
Expand Down Expand Up @@ -323,13 +322,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
log.Error(err)
}

// persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()

// don't block
go func() {
if err := s.stateManager.PersistState(ctx); err != nil {
// persist dns state right away
if err := s.stateManager.PersistState(s.ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}
}()
Expand Down Expand Up @@ -537,12 +532,11 @@ func (s *DefaultServer) upstreamCallbacks(
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
}

// persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()
if err := s.stateManager.PersistState(ctx); err != nil {
l.Errorf("Failed to persist dns state: %v", err)
}
go func() {
if err := s.stateManager.PersistState(s.ctx); err != nil {
l.Errorf("Failed to persist dns state: %v", err)
}
}()

if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
s.addHostRootZone()
Expand Down
3 changes: 1 addition & 2 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"


nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
Expand Down Expand Up @@ -297,7 +296,7 @@ func (e *Engine) Stop() error {
if err := e.stateManager.Stop(ctx); err != nil {
return fmt.Errorf("failed to stop state manager: %w", err)
}
if err := e.stateManager.PersistState(ctx); err != nil {
if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}

Expand Down
58 changes: 28 additions & 30 deletions client/internal/routemanager/refcounter/refcounter.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,9 @@ type RemoveFunc[Key, O any] func(key Key, out O) error
type Counter[Key comparable, I, O any] struct {
// refCountMap keeps track of the reference Ref for keys
refCountMap map[Key]Ref[O]
refCountMu sync.Mutex
mu sync.Mutex
// idMap keeps track of the keys associated with an ID for removal
idMap map[string][]Key
idMu sync.Mutex
add AddFunc[Key, I, O]
remove RemoveFunc[Key, O]
}
Expand All @@ -75,10 +74,8 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
func (rm *Counter[Key, I, O]) LoadData(
existingCounter *Counter[Key, I, O],
) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()

rm.refCountMap = existingCounter.refCountMap
rm.idMap = existingCounter.idMap
Expand All @@ -87,8 +84,8 @@ func (rm *Counter[Key, I, O]) LoadData(
// Get retrieves the current reference count and associated data for a key.
// If the key doesn't exist, it returns a zero value Ref and false.
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()

ref, ok := rm.refCountMap[key]
return ref, ok
Expand All @@ -97,9 +94,13 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
// Increment increments the reference count for the given key.
// If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()

return rm.increment(key, in)
}

func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) {
ref := rm.refCountMap[key]
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)

Expand All @@ -126,10 +127,10 @@ func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
// IncrementWithID increments the reference count for the given key and groups it under the given ID.
// If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()

ref, err := rm.Increment(key, in)
ref, err := rm.increment(key, in)
if err != nil {
return ref, fmt.Errorf("with ID: %w", err)
}
Expand All @@ -141,9 +142,12 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O],
// Decrement decrements the reference count for the given key.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()
return rm.decrement(key)
}

func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) {
ref, ok := rm.refCountMap[key]
if !ok {
logCallerF("No reference found for key %v", key)
Expand All @@ -168,12 +172,12 @@ func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
// DecrementWithID decrements the reference count for all keys associated with the given ID.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()

var merr *multierror.Error
for _, key := range rm.idMap[id] {
if _, err := rm.Decrement(key); err != nil {
if _, err := rm.decrement(key); err != nil {
merr = multierror.Append(merr, err)
}
}
Expand All @@ -184,10 +188,8 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {

// Flush removes all references and calls RemoveFunc for each key.
func (rm *Counter[Key, I, O]) Flush() error {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()

var merr *multierror.Error
for key := range rm.refCountMap {
Expand All @@ -206,21 +208,17 @@ func (rm *Counter[Key, I, O]) Flush() error {

// Clear removes all references without calling RemoveFunc.
func (rm *Counter[Key, I, O]) Clear() {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()

clear(rm.refCountMap)
clear(rm.idMap)
}

// MarshalJSON implements the json.Marshaler interface for Counter.
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()

return json.Marshal(struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
Expand Down
23 changes: 10 additions & 13 deletions client/internal/routemanager/systemops/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,28 @@ package systemops

import (
"net/netip"
"sync"

"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)

type ShutdownState struct {
Counter *ExclusionCounter `json:"counter,omitempty"`
mu sync.RWMutex
}
type ShutdownState ExclusionCounter

func (s *ShutdownState) Name() string {
return "route_state"
}

func (s *ShutdownState) Cleanup() error {
s.mu.RLock()
defer s.mu.RUnlock()

if s.Counter == nil {
return nil
}

sysops := NewSysOps(nil, nil)
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
sysops.refCounter.LoadData(s.Counter)
sysops.refCounter.LoadData((*ExclusionCounter)(s))

return sysops.refCounter.Flush()
}

func (s *ShutdownState) MarshalJSON() ([]byte, error) {
return (*ExclusionCounter)(s).MarshalJSON()
}

func (s *ShutdownState) UnmarshalJSON(data []byte) error {
return (*ExclusionCounter)(s).UnmarshalJSON(data)
}
20 changes: 3 additions & 17 deletions client/internal/routemanager/systemops/systemops_generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
return nexthop, err
},
func(prefix netip.Prefix, nexthop Nexthop) error {
// remove from state even if we have trouble removing it from the route table
// update state even if we have trouble removing it from the route table
// it could be already gone
r.updateState(stateManager)

Expand All @@ -75,12 +75,9 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
return r.setupHooks(initAddresses)
}

// updateState updates state on every change so it will be persisted regularly
func (r *SysOps) updateState(stateManager *statemanager.Manager) {
state := getState(stateManager)

state.Counter = r.refCounter

if err := stateManager.UpdateState(state); err != nil {
if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
Expand Down Expand Up @@ -532,14 +529,3 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P
// Return true if the longest matching prefix is from vpnRoutes
return isVpn, longestPrefix
}

func getState(stateManager *statemanager.Manager) *ShutdownState {
var shutdownState *ShutdownState
if state := stateManager.GetState(shutdownState); state != nil {
shutdownState = state.(*ShutdownState)
} else {
shutdownState = &ShutdownState{}
}

return shutdownState
}
40 changes: 30 additions & 10 deletions client/internal/statemanager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,15 @@ func (m *Manager) Stop(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()

if m.cancel != nil {
m.cancel()
if m.cancel == nil {
return nil
}
m.cancel()

select {
case <-ctx.Done():
return ctx.Err()
case <-m.done:
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case <-m.done:
}

return nil
Expand Down Expand Up @@ -179,14 +179,18 @@ func (m *Manager) PersistState(ctx context.Context) error {
return nil
}

bs, err := marshalWithPanicRecovery(m.states)
if err != nil {
return fmt.Errorf("marshal states: %w", err)
}

pappz marked this conversation as resolved.
Show resolved Hide resolved
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()

done := make(chan error, 1)

start := time.Now()
go func() {
done <- util.WriteJsonWithRestrictedPermission(ctx, m.filePath, m.states)
done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will be if this function running more then 10 sec? The ticker will start a PersistState call and will be a conflict in the file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what the ctx check and deadline is for in this fn

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the check and move is not an atomic operation. If the code runs parallel this lines with two different ctx the outcome is unpredictable.

	// Check context again
	if ctx.Err() != nil {
		return ctx.Err()
	}

	if err = os.Rename(tempFileName, file); err != nil {
		return fmt.Errorf("move %s to %s: %w", tempFileName, file, err)
	}
	```

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets do it in separated PR

}()

select {
Expand Down Expand Up @@ -286,3 +290,19 @@ func (m *Manager) PerformCleanup() error {

return nberrors.FormatErrorOrNil(merr)
}

func marshalWithPanicRecovery(v any) ([]byte, error) {
var bs []byte
var err error

func() {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic during marshal: %v", r)
}
}()
bs, err = json.Marshal(v)
}()

return bs, err
}
Loading
Loading