Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Register/Deregister race on Signal #431

Merged
merged 8 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,6 @@ func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtyp
},
})
if err != nil {
log.Errorf("failed signaling candidate to the remote peer %s %s", remoteKey.String(), err)
// todo ??
return err
}

Expand Down Expand Up @@ -704,6 +702,7 @@ func (e Engine) peerExists(peerKey string) bool {
}

func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, error) {
log.Debugf("creating peer connection %s", pubKey)
var stunTurn []*ice.URL
stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...)
Expand Down
1 change: 1 addition & 0 deletions signal/client/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error {
// we need this reset because after a successful connection and a consequent error, backoff lib doesn't
// reset times and next try will start with a long delay
backOff.Reset()
log.Warnf("disconnected from the Signal service but will retry silently. Reason: %v", err)
return err
}

Expand Down
50 changes: 36 additions & 14 deletions signal/peer/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,42 @@ import (
"github.com/netbirdio/netbird/signal/proto"
log "github.com/sirupsen/logrus"
"sync"
"time"
)

// Peer representation of a connected Peer
type Peer struct {
// a unique id of the Peer (e.g. sha256 fingerprint of the Wireguard public key)
Id string

StreamID int64

//a gRpc connection stream to the Peer
Stream proto.SignalExchange_ConnectStreamServer
}

// NewPeer creates a new instance of a connected Peer
func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer) *Peer {
return &Peer{
Id: id,
Stream: stream,
Id: id,
Stream: stream,
StreamID: time.Now().UnixNano(),
}
}

// Registry registry that holds all currently connected Peers
// Registry that holds all currently connected Peers
type Registry struct {
// Peer.key -> Peer
Peers sync.Map
// regMutex ensures that registration and de-registrations are safe
regMutex sync.Mutex
}

// NewRegistry creates a new connected Peer registry
func NewRegistry() *Registry {
return &Registry{}
return &Registry{
regMutex: sync.Mutex{},
}
}

// Get gets a peer from the registry
Expand All @@ -52,20 +60,34 @@ func (registry *Registry) IsPeerRegistered(peerId string) bool {

// Register registers peer in the registry
func (registry *Registry) Register(peer *Peer) {
// can be that peer already exists but it is fine (e.g. reconnect)
// todo investigate what happens to the old peer (especially Peer.Stream) when we override it
registry.Peers.Store(peer.Id, peer)
log.Debugf("peer registered [%s]", peer.Id)
registry.regMutex.Lock()
defer registry.regMutex.Unlock()

// can be that peer already exists, but it is fine (e.g. reconnect)
p, loaded := registry.Peers.LoadOrStore(peer.Id, peer)
if loaded {
pp := p.(*Peer)
log.Warnf("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.",
peer.Id, peer.StreamID, pp.StreamID)
registry.Peers.Store(peer.Id, peer)
}
log.Debugf("peer registered [%s]", peer.Id)
}

// Deregister deregister Peer from the Registry (usually once it disconnects)
// Deregister Peer from the Registry (usually once it disconnects)
func (registry *Registry) Deregister(peer *Peer) {
_, loaded := registry.Peers.LoadAndDelete(peer.Id)
registry.regMutex.Lock()
defer registry.regMutex.Unlock()

p, loaded := registry.Peers.LoadAndDelete(peer.Id)
if loaded {
log.Debugf("peer deregistered [%s]", peer.Id)
} else {
log.Warnf("attempted to remove non-existent peer [%s]", peer.Id)
pp := p.(*Peer)
if peer.StreamID < pp.StreamID {
registry.Peers.Store(peer.Id, p)
log.Warnf("attempted to remove newer registered stream of a peer [%s] [newer streamID %d, previous StreamID %d]. Ignoring.",
peer.Id, pp.StreamID, peer.StreamID)
return
}
}

log.Debugf("peer deregistered [%s]", peer.Id)
}
25 changes: 25 additions & 0 deletions signal/peer/peer_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,34 @@
package peer

import (
"github.com/stretchr/testify/assert"
"testing"
"time"
)

func TestRegistry_ShouldNotDeregisterWhenHasNewerStreamRegistered(t *testing.T) {
r := NewRegistry()

peerID := "peer"

olderPeer := NewPeer(peerID, nil)
r.Register(olderPeer)
time.Sleep(time.Nanosecond)

newerPeer := NewPeer(peerID, nil)
r.Register(newerPeer)
registered, _ := r.Get(olderPeer.Id)

assert.NotNil(t, registered, "peer can't be nil")
assert.Equal(t, newerPeer, registered)

r.Deregister(olderPeer)
registered, _ = r.Get(olderPeer.Id)

assert.NotNil(t, registered, "peer can't be nil")
assert.Equal(t, newerPeer, registered)
}

func TestRegistry_GetNonExistentPeer(t *testing.T) {
r := NewRegistry()

Expand Down
4 changes: 2 additions & 2 deletions signal/server/signal.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer)
}

defer func() {
log.Infof("peer disconnected [%s] ", p.Id)
log.Infof("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID)
s.registry.Deregister(p)
}()

Expand All @@ -66,7 +66,7 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer)
return err
}

log.Infof("peer connected [%s]", p.Id)
log.Infof("peer connected [%s] [streamID %d] ", p.Id, p.StreamID)

for {
//read incoming messages
Expand Down