Skip to content

Commit

Permalink
Add safe read/write to route map (netbirdio#1760)
Browse files Browse the repository at this point in the history
  • Loading branch information
hurricanehrndz authored Apr 11, 2024
1 parent 77313df commit beaaa3c
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 11 deletions.
1 change: 1 addition & 0 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
FQDN: offlinePeer.GetFqdn(),
ConnStatus: peer.StatusDisconnected,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
}
e.statusRecorder.ReplaceOfflinePeers(replacement)
Expand Down
5 changes: 4 additions & 1 deletion client/internal/peer/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ func (conn *Conn) reCreateAgent() error {
}

conn.agent, err = ice.NewAgent(agentConfig)

if err != nil {
return err
}
Expand Down Expand Up @@ -285,6 +284,7 @@ func (conn *Conn) Open() error {
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
ConnStatusUpdate: time.Now(),
ConnStatus: conn.status,
Mux: new(sync.RWMutex),
}
err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
Expand Down Expand Up @@ -344,6 +344,7 @@ func (conn *Conn) Open() error {
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
Expand Down Expand Up @@ -468,6 +469,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()),
Direct: !isRelayCandidate(pair.Local),
RosenpassEnabled: rosenpassEnabled,
Mux: new(sync.RWMutex),
}
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true
Expand Down Expand Up @@ -558,6 +560,7 @@ func (conn *Conn) cleanup() error {
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
Expand Down
40 changes: 36 additions & 4 deletions client/internal/peer/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

// State contains the latest state of a peer
type State struct {
Mux *sync.RWMutex
IP string
PubKey string
FQDN string
Expand All @@ -30,7 +31,38 @@ type State struct {
BytesRx int64
Latency time.Duration
RosenpassEnabled bool
Routes map[string]struct{}
routes map[string]struct{}
}

// AddRoute add a single route to routes map
func (s *State) AddRoute(network string) {
s.Mux.Lock()
if s.routes == nil {
s.routes = make(map[string]struct{})
}
s.routes[network] = struct{}{}
s.Mux.Unlock()
}

// SetRoutes set state routes
func (s *State) SetRoutes(routes map[string]struct{}) {
s.Mux.Lock()
s.routes = routes
s.Mux.Unlock()
}

// DeleteRoute removes a route from the network amp
func (s *State) DeleteRoute(network string) {
s.Mux.Lock()
delete(s.routes, network)
s.Mux.Unlock()
}

// GetRoutes return routes map
func (s *State) GetRoutes() map[string]struct{} {
s.Mux.RLock()
defer s.Mux.RUnlock()
return s.routes
}

// LocalPeerState contains the latest state of the local peer
Expand Down Expand Up @@ -143,6 +175,7 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
PubKey: peerPubKey,
ConnStatus: StatusDisconnected,
FQDN: fqdn,
Mux: new(sync.RWMutex),
}
d.peerListChangedForNotification = true
return nil
Expand Down Expand Up @@ -189,8 +222,8 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.IP = receivedState.IP
}

if receivedState.Routes != nil {
peerState.Routes = receivedState.Routes
if receivedState.GetRoutes() != nil {
peerState.SetRoutes(receivedState.GetRoutes())
}

skipNotification := shouldSkipNotify(receivedState, peerState)
Expand Down Expand Up @@ -440,7 +473,6 @@ func (d *Status) IsLoginRequired() bool {
s, ok := gstatus.FromError(d.managementError)
if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return true

}
return false
}
Expand Down
5 changes: 5 additions & 0 deletions client/internal/peer/status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package peer
import (
"errors"
"testing"
"sync"

"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -42,6 +43,7 @@ func TestUpdatePeerState(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}

status.peers[key] = peerState
Expand All @@ -62,6 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}

status.peers[key] = peerState
Expand All @@ -80,6 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}

status.peers[key] = peerState
Expand All @@ -104,6 +108,7 @@ func TestRemovePeer(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}

status.peers[key] = peerState
Expand Down
7 changes: 2 additions & 5 deletions client/internal/routemanager/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
return fmt.Errorf("get peer state: %v", err)
}

delete(state.Routes, c.network.String())
state.DeleteRoute(c.network.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
Expand Down Expand Up @@ -268,10 +268,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
} else {
if state.Routes == nil {
state.Routes = map[string]struct{}{}
}
state.Routes[c.network.String()] = struct{}{}
state.AddRoute(c.network.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion client/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Routes: maps.Keys(peerState.Routes),
Routes: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
Expand Down

0 comments on commit beaaa3c

Please sign in to comment.