diff --git a/dht_net.go b/dht_net.go index 18085389f..8513db3cf 100644 --- a/dht_net.go +++ b/dht_net.go @@ -78,7 +78,10 @@ func (dht *IpfsDHT) handleNewMessage(s inet.Stream) { // measure the RTT for latency measurements. func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) { - ms := dht.messageSenderForPeer(p) + ms, err := dht.messageSenderForPeer(p) + if err != nil { + return nil, err + } start := time.Now() @@ -97,8 +100,10 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message // sendMessage sends out a message func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error { - - ms := dht.messageSenderForPeer(p) + ms, err := dht.messageSenderForPeer(p) + if err != nil { + return err + } if err := ms.SendMessage(ctx, pmes); err != nil { return err @@ -112,17 +117,36 @@ func (dht *IpfsDHT) updateFromMessage(ctx context.Context, p peer.ID, mes *pb.Me return nil } -func (dht *IpfsDHT) messageSenderForPeer(p peer.ID) *messageSender { +func (dht *IpfsDHT) messageSenderForPeer(p peer.ID) (*messageSender, error) { dht.smlk.Lock() - defer dht.smlk.Unlock() - ms, ok := dht.strmap[p] - if !ok { - ms = dht.newMessageSender(p) - dht.strmap[p] = ms + if ok { + dht.smlk.Unlock() + return ms, nil } - - return ms + ms = &messageSender{p: p, dht: dht} + dht.strmap[p] = ms + dht.smlk.Unlock() + + if err := ms.prepOrInvalidate(); err != nil { + dht.smlk.Lock() + defer dht.smlk.Unlock() + + if msCur, ok := dht.strmap[p]; ok { + // Changed. Use the new one, old one is invalid and + // not in the map so we can just throw it away. + if ms != msCur { + return msCur, nil + } + // Not changed, remove the now invalid stream from the + // map. + delete(dht.strmap, p) + } + // Invalid but not in map. Must have been removed by a disconnect. + return nil, err + } + // All ready to go. + return ms, nil } type messageSender struct { @@ -133,14 +157,35 @@ type messageSender struct { p peer.ID dht *IpfsDHT + invalid bool singleMes int } -func (dht *IpfsDHT) newMessageSender(p peer.ID) *messageSender { - return &messageSender{p: p, dht: dht} +// invalidate is called before this messageSender is removed from the strmap. +// It prevents the messageSender from being reused/reinitialized and then +// forgotten (leaving the stream open). +func (ms *messageSender) invalidate() { + ms.invalid = true + if ms.s != nil { + ms.s.Reset() + ms.s = nil + } +} + +func (ms *messageSender) prepOrInvalidate() error { + ms.lk.Lock() + defer ms.lk.Unlock() + if err := ms.prep(); err != nil { + ms.invalidate() + return err + } + return nil } func (ms *messageSender) prep() error { + if ms.invalid { + return fmt.Errorf("message sender has been invalidated") + } if ms.s != nil { return nil } diff --git a/dht_test.go b/dht_test.go index 6b5d4d0ec..cb1241b79 100644 --- a/dht_test.go +++ b/dht_test.go @@ -198,6 +198,22 @@ func TestValueGetSet(t *testing.T) { } } +func TestInvalidMessageSenderTracking(t *testing.T) { + ctx := context.Background() + dht := setupDHT(ctx, t, false) + foo := peer.ID("asdasd") + _, err := dht.messageSenderForPeer(foo) + if err == nil { + t.Fatal("that shouldnt have succeeded") + } + + dht.smlk.Lock() + defer dht.smlk.Unlock() + if len(dht.strmap) > 0 { + t.Fatal("should have no message senders in map") + } +} + func TestProvides(t *testing.T) { // t.Skip("skipping test to debug another") ctx := context.Background() diff --git a/notif.go b/notif.go index 804b019bc..75f57230b 100644 --- a/notif.go +++ b/notif.go @@ -92,21 +92,40 @@ func (nn *netNotifiee) Disconnected(n inet.Network, v inet.Conn) { default: } - dht.plk.Lock() - defer dht.plk.Unlock() + p := v.RemotePeer() - conn, ok := nn.peers[v.RemotePeer()] + func() { + dht.plk.Lock() + defer dht.plk.Unlock() + + conn, ok := nn.peers[p] + if !ok { + // Unmatched disconnects are fine. It just means that we were + // already connected when we registered the listener. + return + } + conn.refcount -= 1 + if conn.refcount == 0 { + delete(nn.peers, p) + conn.cancel() + dht.routingTable.Remove(p) + } + }() + + dht.smlk.Lock() + defer dht.smlk.Unlock() + ms, ok := dht.strmap[p] if !ok { - // Unmatched disconnects are fine. It just means that we were - // already connected when we registered the listener. return } - conn.refcount -= 1 - if conn.refcount == 0 { - delete(nn.peers, v.RemotePeer()) - conn.cancel() - dht.routingTable.Remove(v.RemotePeer()) - } + delete(dht.strmap, p) + + // Do this asynchronously as ms.lk can block for a while. + go func() { + ms.lk.Lock() + defer ms.lk.Unlock() + ms.invalidate() + }() } func (nn *netNotifiee) OpenedStream(n inet.Network, v inet.Stream) {}