From 0b029388bd0708272a023a831c9e66204d154310 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Sat, 22 Feb 2020 16:10:12 -0500 Subject: [PATCH] fix: obey the context when sending messages to peers Related to #453 but not a fix. This will cause us to actually return early when we start blocking on sending to some peers, but it won't really _unblock_ those peers. For that, we need to write with a context. --- ctx_mutex.go | 28 ++++++++++++++++++++++++++++ dht_net.go | 19 ++++++++++++++----- ext_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ notif.go | 4 +++- 4 files changed, 88 insertions(+), 6 deletions(-) create mode 100644 ctx_mutex.go diff --git a/ctx_mutex.go b/ctx_mutex.go new file mode 100644 index 000000000..c28d89875 --- /dev/null +++ b/ctx_mutex.go @@ -0,0 +1,28 @@ +package dht + +import ( + "context" +) + +type ctxMutex chan struct{} + +func newCtxMutex() ctxMutex { + return make(ctxMutex, 1) +} + +func (m ctxMutex) Lock(ctx context.Context) error { + select { + case m <- struct{}{}: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (m ctxMutex) Unlock() { + select { + case <-m: + default: + panic("not locked") + } +} diff --git a/dht_net.go b/dht_net.go index 31775ae8f..6ae0db934 100644 --- a/dht_net.go +++ b/dht_net.go @@ -234,7 +234,7 @@ func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messa dht.smlk.Unlock() return ms, nil } - ms = &messageSender{p: p, dht: dht} + ms = &messageSender{p: p, dht: dht, lk: newCtxMutex()} dht.strmap[p] = ms dht.smlk.Unlock() @@ -262,7 +262,7 @@ func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messa type messageSender struct { s network.Stream r msgio.ReadCloser - lk sync.Mutex + lk ctxMutex p peer.ID dht *IpfsDHT @@ -282,8 +282,11 @@ func (ms *messageSender) invalidate() { } func (ms *messageSender) prepOrInvalidate(ctx context.Context) error { - ms.lk.Lock() + if err := ms.lk.Lock(ctx); err != nil { + return err + } defer ms.lk.Unlock() + if err := ms.prep(ctx); err != nil { ms.invalidate() return err @@ -316,8 +319,11 @@ func (ms *messageSender) prep(ctx context.Context) error { const streamReuseTries = 3 func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) error { - ms.lk.Lock() + if err := ms.lk.Lock(ctx); err != nil { + return err + } defer ms.lk.Unlock() + retry := false for { if err := ms.prep(ctx); err != nil { @@ -351,8 +357,11 @@ func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) erro } func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb.Message, error) { - ms.lk.Lock() + if err := ms.lk.Lock(ctx); err != nil { + return nil, err + } defer ms.lk.Unlock() + retry := false for { if err := ms.prep(ctx); err != nil { diff --git a/ext_test.go b/ext_test.go index 91d54d9af..b01d0b5dc 100644 --- a/ext_test.go +++ b/ext_test.go @@ -18,6 +18,49 @@ import ( mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" ) +func TestHang(t *testing.T) { + ctx := context.Background() + mn, err := mocknet.FullMeshConnected(ctx, 2) + if err != nil { + t.Fatal(err) + } + hosts := mn.Hosts() + + os := []opts.Option{opts.DisableAutoRefresh()} + d, err := New(ctx, hosts[0], os...) + if err != nil { + t.Fatal(err) + } + // Hang on every request. + hosts[1].SetStreamHandler(d.protocols[0], func(s network.Stream) { + defer s.Reset() + <-ctx.Done() + }) + d.Update(ctx, hosts[1].ID()) + + ctx1, cancel1 := context.WithTimeout(ctx, 1*time.Second) + defer cancel1() + + peers, err := d.GetClosestPeers(ctx1, testCaseCids[0].KeyString()) + if err != nil { + t.Fatal(err) + } + + time.Sleep(100 * time.Millisecond) + ctx2, cancel2 := context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel2() + _ = d.Provide(ctx2, testCaseCids[0], true) + if ctx2.Err() != context.DeadlineExceeded { + t.Errorf("expected to fail with deadline exceeded, got: %s", ctx2.Err()) + } + select { + case <-peers: + t.Error("GetClosestPeers should not have returned yet") + default: + } + +} + func TestGetFailures(t *testing.T) { if testing.Short() { t.SkipNow() diff --git a/notif.go b/notif.go index a7913a5f5..04000e31e 100644 --- a/notif.go +++ b/notif.go @@ -1,6 +1,8 @@ package dht import ( + "context" + "github.com/libp2p/go-libp2p-core/helpers" "github.com/libp2p/go-libp2p-core/network" @@ -130,7 +132,7 @@ func (nn *netNotifiee) Disconnected(n network.Network, v network.Conn) { // Do this asynchronously as ms.lk can block for a while. go func() { - ms.lk.Lock() + ms.lk.Lock(context.Background()) defer ms.lk.Unlock() ms.invalidate() }()