Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: obey the context when sending messages to peers #462

Merged
merged 1 commit into from
Mar 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions ctx_mutex.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package dht

import (
"context"
)

type ctxMutex chan struct{}

func newCtxMutex() ctxMutex {
return make(ctxMutex, 1)
}

func (m ctxMutex) Lock(ctx context.Context) error {
select {
case m <- struct{}{}:
return nil
case <-ctx.Done():
return ctx.Err()
}
}

func (m ctxMutex) Unlock() {
select {
case <-m:
default:
panic("not locked")
}
}
19 changes: 14 additions & 5 deletions dht_net.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messa
dht.smlk.Unlock()
return ms, nil
}
ms = &messageSender{p: p, dht: dht}
ms = &messageSender{p: p, dht: dht, lk: newCtxMutex()}
dht.strmap[p] = ms
dht.smlk.Unlock()

Expand Down Expand Up @@ -262,7 +262,7 @@ func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messa
type messageSender struct {
s network.Stream
r msgio.ReadCloser
lk sync.Mutex
lk ctxMutex
p peer.ID
dht *IpfsDHT

Expand All @@ -282,8 +282,11 @@ func (ms *messageSender) invalidate() {
}

func (ms *messageSender) prepOrInvalidate(ctx context.Context) error {
ms.lk.Lock()
if err := ms.lk.Lock(ctx); err != nil {
return err
}
defer ms.lk.Unlock()

if err := ms.prep(ctx); err != nil {
ms.invalidate()
return err
Expand Down Expand Up @@ -316,8 +319,11 @@ func (ms *messageSender) prep(ctx context.Context) error {
const streamReuseTries = 3

func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) error {
ms.lk.Lock()
if err := ms.lk.Lock(ctx); err != nil {
return err
}
defer ms.lk.Unlock()

retry := false
for {
if err := ms.prep(ctx); err != nil {
Expand Down Expand Up @@ -351,8 +357,11 @@ func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) erro
}

func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb.Message, error) {
ms.lk.Lock()
if err := ms.lk.Lock(ctx); err != nil {
return nil, err
}
defer ms.lk.Unlock()

retry := false
for {
if err := ms.prep(ctx); err != nil {
Expand Down
43 changes: 43 additions & 0 deletions ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,49 @@ import (
mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
)

func TestHang(t *testing.T) {
ctx := context.Background()
mn, err := mocknet.FullMeshConnected(ctx, 2)
if err != nil {
t.Fatal(err)
}
hosts := mn.Hosts()

os := []opts.Option{opts.DisableAutoRefresh()}
d, err := New(ctx, hosts[0], os...)
if err != nil {
t.Fatal(err)
}
// Hang on every request.
hosts[1].SetStreamHandler(d.protocols[0], func(s network.Stream) {
defer s.Reset()
<-ctx.Done()
})
d.Update(ctx, hosts[1].ID())

ctx1, cancel1 := context.WithTimeout(ctx, 1*time.Second)
defer cancel1()

peers, err := d.GetClosestPeers(ctx1, testCaseCids[0].KeyString())
if err != nil {
t.Fatal(err)
}

time.Sleep(100 * time.Millisecond)
ctx2, cancel2 := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel2()
_ = d.Provide(ctx2, testCaseCids[0], true)
if ctx2.Err() != context.DeadlineExceeded {
t.Errorf("expected to fail with deadline exceeded, got: %s", ctx2.Err())
}
select {
case <-peers:
t.Error("GetClosestPeers should not have returned yet")
default:
}

}

func TestGetFailures(t *testing.T) {
if testing.Short() {
t.SkipNow()
Expand Down
4 changes: 3 additions & 1 deletion notif.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package dht

import (
"context"

"github.com/libp2p/go-libp2p-core/helpers"
"github.com/libp2p/go-libp2p-core/network"

Expand Down Expand Up @@ -130,7 +132,7 @@ func (nn *netNotifiee) Disconnected(n network.Network, v network.Conn) {

// Do this asynchronously as ms.lk can block for a while.
go func() {
ms.lk.Lock()
ms.lk.Lock(context.Background())
defer ms.lk.Unlock()
ms.invalidate()
}()
Expand Down