From 06b0c46a5d67ceb200d65c957ea12c61355f80f8 Mon Sep 17 00:00:00 2001 From: braginini Date: Thu, 10 Jun 2021 17:08:40 +0200 Subject: [PATCH] chore: [Signal] synchronize peer registry --- signal/peer/peer.go | 45 ++++++++++++++++++++++++++-------------- signal/peer/peer_test.go | 30 ++++++++++++++++----------- signal/signal.go | 6 +++--- 3 files changed, 51 insertions(+), 30 deletions(-) diff --git a/signal/peer/peer.go b/signal/peer/peer.go index 68150beb37e..e0ee8d19d54 100644 --- a/signal/peer/peer.go +++ b/signal/peer/peer.go @@ -3,6 +3,7 @@ package peer import ( log "github.com/sirupsen/logrus" "github.com/wiretrustee/wiretrustee/signal/proto" + "sync" ) // Peer representation of a connected Peer @@ -25,32 +26,46 @@ func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer) *Peer { // Registry registry that holds all currently connected Peers type Registry struct { // Peer.key -> Peer - Peers map[string]*Peer + Peers sync.Map } // NewRegistry creates a new connected Peer registry func NewRegistry() *Registry { - return &Registry{ - Peers: make(map[string]*Peer), + return &Registry{} +} + +// Get gets a peer from the registry +func (registry *Registry) Get(peerId string) (*Peer, bool) { + if load, ok := registry.Peers.Load(peerId); ok { + return load.(*Peer), ok } + return nil, false + } -// Register registers peer in the registry -func (reg *Registry) Register(peer *Peer) { - if _, exists := reg.Peers[peer.Id]; exists { - log.Warnf("peer [%s] has been already registered", peer.Id) - } else { - log.Printf("registering new peer [%s]", peer.Id) +func (registry *Registry) IsPeerRegistered(peerId string) bool { + if _, ok := registry.Peers.Load(peerId); ok { + return ok } - //replace Peer even if exists - //todo should we really replace? - reg.Peers[peer.Id] = peer + return false +} + +// Register registers peer in the registry +func (registry *Registry) Register(peer *Peer) { + // can be that peer already exists but it is fine (e.g. reconnect) + // todo investigate what happens to the old peer (especially Peer.Stream) when we override it + registry.Peers.Store(peer.Id, peer) + log.Printf("registered peer [%s]", peer.Id) + } // Deregister deregister Peer from the Registry (usually once it disconnects) -func (reg *Registry) Deregister(peer *Peer) { - if _, ok := reg.Peers[peer.Id]; ok { - delete(reg.Peers, peer.Id) +func (registry *Registry) Deregister(peer *Peer) { + _, loaded := registry.Peers.LoadAndDelete(peer.Id) + if loaded { log.Printf("deregistered peer [%s]", peer.Id) + } else { + log.Warnf("attempted to remove non-existent peer [%s]", peer.Id) } + } diff --git a/signal/peer/peer_test.go b/signal/peer/peer_test.go index 650cdf11acc..bf301bae743 100644 --- a/signal/peer/peer_test.go +++ b/signal/peer/peer_test.go @@ -4,6 +4,20 @@ import ( "testing" ) +func TestRegistry_GetNonExistentPeer(t *testing.T) { + r := NewRegistry() + + peer, ok := r.Get("non_existent_peer") + + if peer != nil { + t.Errorf("expected non_existent_peer not found in the registry") + } + + if ok { + t.Errorf("expected non_existent_peer not found in the registry") + } +} + func TestRegistry_Register(t *testing.T) { r := NewRegistry() peer1 := NewPeer("test_peer_1", nil) @@ -11,15 +25,11 @@ func TestRegistry_Register(t *testing.T) { r.Register(peer1) r.Register(peer2) - if len(r.Peers) != 2 { - t.Errorf("expected 2 registered peers") - } - - if _, ok := r.Peers["test_peer_1"]; !ok { + if _, ok := r.Get("test_peer_1"); !ok { t.Errorf("expected test_peer_1 not found in the registry") } - if _, ok := r.Peers["test_peer_2"]; !ok { + if _, ok := r.Get("test_peer_2"); !ok { t.Errorf("expected test_peer_2 not found in the registry") } } @@ -33,15 +43,11 @@ func TestRegistry_Deregister(t *testing.T) { r.Deregister(peer1) - if len(r.Peers) != 1 { - t.Errorf("expected 1 registered peers after deregistring") - } - - if _, ok := r.Peers["test_peer_1"]; ok { + if _, ok := r.Get("test_peer_1"); ok { t.Errorf("expected test_peer_1 to absent in the registry after deregistering") } - if _, ok := r.Peers["test_peer_2"]; !ok { + if _, ok := r.Get("test_peer_2"); !ok { t.Errorf("expected test_peer_2 not found in the registry") } diff --git a/signal/signal.go b/signal/signal.go index efd381f443d..a75c2050ef8 100644 --- a/signal/signal.go +++ b/signal/signal.go @@ -27,11 +27,11 @@ func NewServer() *SignalExchangeServer { // Send forwards a message to the signal peer func (s *SignalExchangeServer) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { - if _, found := s.registry.Peers[msg.Key]; !found { + if !s.registry.IsPeerRegistered(msg.Key) { return nil, fmt.Errorf("unknown peer %s", msg.Key) } - if dstPeer, found := s.registry.Peers[msg.RemoteKey]; found { + if dstPeer, found := s.registry.Get(msg.RemoteKey); found { //forward the message to the target peer err := dstPeer.Stream.Send(msg) if err != nil { @@ -63,7 +63,7 @@ func (s *SignalExchangeServer) ConnectStream(stream proto.SignalExchange_Connect } log.Debugf("received a new message from peer [%s] to peer [%s]", p.Id, msg.RemoteKey) // lookup the target peer where the message is going to - if dstPeer, found := s.registry.Peers[msg.RemoteKey]; found { + if dstPeer, found := s.registry.Get(msg.RemoteKey); found { //forward the message to the target peer err := dstPeer.Stream.Send(msg) if err != nil {