Skip to content

Commit

Permalink
Merge pull request #243 from libp2p/fix/direct-connect
Browse files Browse the repository at this point in the history
New Dialer
  • Loading branch information
vyzo authored Apr 1, 2021
2 parents 70b63da + ad620bf commit ce7c0bc
Show file tree
Hide file tree
Showing 8 changed files with 741 additions and 400 deletions.
142 changes: 66 additions & 76 deletions p2p/net/swarm/dial_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,132 +5,122 @@ import (
"errors"
"sync"

"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
)

// TODO: change this text when we fix the bug
var errDialCanceled = errors.New("dial was aborted internally, likely due to https://git.io/Je2wW")

// DialFunc is the type of function expected by DialSync.
type DialFunc func(context.Context, peer.ID) (*Conn, error)
// DialWorerFunc is used by DialSync to spawn a new dial worker
type dialWorkerFunc func(context.Context, peer.ID, <-chan dialRequest) error

// NewDialSync constructs a new DialSync
func NewDialSync(dfn DialFunc) *DialSync {
// newDialSync constructs a new DialSync
func newDialSync(worker dialWorkerFunc) *DialSync {
return &DialSync{
dials: make(map[peer.ID]*activeDial),
dialFunc: dfn,
dials: make(map[peer.ID]*activeDial),
dialWorker: worker,
}
}

// DialSync is a dial synchronization helper that ensures that at most one dial
// to any given peer is active at any given time.
type DialSync struct {
dials map[peer.ID]*activeDial
dialsLk sync.Mutex
dialFunc DialFunc
dials map[peer.ID]*activeDial
dialsLk sync.Mutex
dialWorker dialWorkerFunc
}

type activeDial struct {
id peer.ID
refCnt int
refCntLk sync.Mutex
cancel func()
id peer.ID
refCnt int

err error
conn *Conn
waitch chan struct{}
ctx context.Context
cancel func()

reqch chan dialRequest

ds *DialSync
}

func (ad *activeDial) wait(ctx context.Context) (*Conn, error) {
defer ad.decref()
select {
case <-ad.waitch:
return ad.conn, ad.err
case <-ctx.Done():
return nil, ctx.Err()
func (ad *activeDial) decref() {
ad.ds.dialsLk.Lock()
ad.refCnt--
if ad.refCnt == 0 {
ad.cancel()
close(ad.reqch)
delete(ad.ds.dials, ad.id)
}
ad.ds.dialsLk.Unlock()
}

func (ad *activeDial) incref() {
ad.refCntLk.Lock()
defer ad.refCntLk.Unlock()
ad.refCnt++
}
func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) {
dialCtx := ad.ctx

func (ad *activeDial) decref() {
ad.refCntLk.Lock()
ad.refCnt--
maybeZero := (ad.refCnt <= 0)
ad.refCntLk.Unlock()

// make sure to always take locks in correct order.
if maybeZero {
ad.ds.dialsLk.Lock()
ad.refCntLk.Lock()
// check again after lock swap drop to make sure nobody else called incref
// in between locks
if ad.refCnt <= 0 {
ad.cancel()
delete(ad.ds.dials, ad.id)
}
ad.refCntLk.Unlock()
ad.ds.dialsLk.Unlock()
if forceDirect, reason := network.GetForceDirectDial(ctx); forceDirect {
dialCtx = network.WithForceDirectDial(dialCtx, reason)
}
if simConnect, reason := network.GetSimultaneousConnect(ctx); simConnect {
dialCtx = network.WithSimultaneousConnect(dialCtx, reason)
}
}

func (ad *activeDial) start(ctx context.Context) {
ad.conn, ad.err = ad.ds.dialFunc(ctx, ad.id)

// This isn't the user's context so we should fix the error.
switch ad.err {
case context.Canceled:
// The dial was canceled with `CancelDial`.
ad.err = errDialCanceled
case context.DeadlineExceeded:
// We hit an internal timeout, not a context timeout.
ad.err = ErrDialTimeout
resch := make(chan dialResponse, 1)
select {
case ad.reqch <- dialRequest{ctx: dialCtx, resch: resch}:
case <-ctx.Done():
return nil, ctx.Err()
}

select {
case res := <-resch:
return res.conn, res.err
case <-ctx.Done():
return nil, ctx.Err()
}
close(ad.waitch)
ad.cancel()
}

func (ds *DialSync) getActiveDial(ctx context.Context, p peer.ID) *activeDial {
func (ds *DialSync) getActiveDial(p peer.ID) (*activeDial, error) {
ds.dialsLk.Lock()
defer ds.dialsLk.Unlock()

actd, ok := ds.dials[p]
if !ok {
adctx, cancel := context.WithCancel(ctx)
// This code intentionally uses the background context. Otherwise, if the first call
// to Dial is canceled, subsequent dial calls will also be canceled.
// XXX: this also breaks direct connection logic. We will need to pipe the
// information through some other way.
adctx, cancel := context.WithCancel(context.Background())
actd = &activeDial{
id: p,
ctx: adctx,
cancel: cancel,
waitch: make(chan struct{}),
reqch: make(chan dialRequest),
ds: ds,
}
ds.dials[p] = actd

go actd.start(adctx)
err := ds.dialWorker(adctx, p, actd.reqch)
if err != nil {
cancel()
return nil, err
}

ds.dials[p] = actd
}

// increase ref count before dropping dialsLk
actd.incref()
actd.refCnt++

return actd
return actd, nil
}

// DialLock initiates a dial to the given peer if there are none in progress
// then waits for the dial to that peer to complete.
func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) {
return ds.getActiveDial(ctx, p).wait(ctx)
}

// CancelDial cancels all in-progress dials to the given peer.
func (ds *DialSync) CancelDial(p peer.ID) {
ds.dialsLk.Lock()
defer ds.dialsLk.Unlock()
if ad, ok := ds.dials[p]; ok {
ad.cancel()
ad, err := ds.getActiveDial(p)
if err != nil {
return nil, err
}

defer ad.decref()
return ad.dial(ctx, p)
}
112 changes: 88 additions & 24 deletions p2p/net/swarm/dial_sync_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package swarm_test
package swarm

import (
"context"
Expand All @@ -7,24 +7,37 @@ import (
"testing"
"time"

. "github.com/libp2p/go-libp2p-swarm"

"github.com/libp2p/go-libp2p-core/peer"
)

func getMockDialFunc() (DialFunc, func(), context.Context, <-chan struct{}) {
func getMockDialFunc() (dialWorkerFunc, func(), context.Context, <-chan struct{}) {
dfcalls := make(chan struct{}, 512) // buffer it large enough that we won't care
dialctx, cancel := context.WithCancel(context.Background())
ch := make(chan struct{})
f := func(ctx context.Context, p peer.ID) (*Conn, error) {
f := func(ctx context.Context, p peer.ID, reqch <-chan dialRequest) error {
dfcalls <- struct{}{}
defer cancel()
select {
case <-ch:
return new(Conn), nil
case <-ctx.Done():
return nil, ctx.Err()
}
go func() {
defer cancel()
for {
select {
case req, ok := <-reqch:
if !ok {
return
}

select {
case <-ch:
req.resch <- dialResponse{conn: new(Conn)}
case <-ctx.Done():
req.resch <- dialResponse{err: ctx.Err()}
return
}
case <-ctx.Done():
return
}
}
}()
return nil
}

o := new(sync.Once)
Expand All @@ -35,7 +48,7 @@ func getMockDialFunc() (DialFunc, func(), context.Context, <-chan struct{}) {
func TestBasicDialSync(t *testing.T) {
df, done, _, callsch := getMockDialFunc()

dsync := NewDialSync(df)
dsync := newDialSync(df)

p := peer.ID("testpeer")

Expand Down Expand Up @@ -73,7 +86,7 @@ func TestBasicDialSync(t *testing.T) {
func TestDialSyncCancel(t *testing.T) {
df, done, _, dcall := getMockDialFunc()

dsync := NewDialSync(df)
dsync := newDialSync(df)

p := peer.ID("testpeer")

Expand Down Expand Up @@ -124,7 +137,7 @@ func TestDialSyncCancel(t *testing.T) {
func TestDialSyncAllCancel(t *testing.T) {
df, done, dctx, _ := getMockDialFunc()

dsync := NewDialSync(df)
dsync := newDialSync(df)

p := peer.ID("testpeer")

Expand Down Expand Up @@ -174,15 +187,31 @@ func TestDialSyncAllCancel(t *testing.T) {

func TestFailFirst(t *testing.T) {
var count int
f := func(ctx context.Context, p peer.ID) (*Conn, error) {
if count > 0 {
return new(Conn), nil
}
count++
return nil, fmt.Errorf("gophers ate the modem")
f := func(ctx context.Context, p peer.ID, reqch <-chan dialRequest) error {
go func() {
for {
select {
case req, ok := <-reqch:
if !ok {
return
}

if count > 0 {
req.resch <- dialResponse{conn: new(Conn)}
} else {
req.resch <- dialResponse{err: fmt.Errorf("gophers ate the modem")}
}
count++

case <-ctx.Done():
return
}
}
}()
return nil
}

ds := NewDialSync(f)
ds := newDialSync(f)

p := peer.ID("testing")

Expand All @@ -205,8 +234,22 @@ func TestFailFirst(t *testing.T) {
}

func TestStressActiveDial(t *testing.T) {
ds := NewDialSync(func(ctx context.Context, p peer.ID) (*Conn, error) {
return nil, nil
ds := newDialSync(func(ctx context.Context, p peer.ID, reqch <-chan dialRequest) error {
go func() {
for {
select {
case req, ok := <-reqch:
if !ok {
return
}

req.resch <- dialResponse{}
case <-ctx.Done():
return
}
}
}()
return nil
})

wg := sync.WaitGroup{}
Expand All @@ -227,3 +270,24 @@ func TestStressActiveDial(t *testing.T) {

wg.Wait()
}

func TestDialSelf(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

self := peer.ID("ABC")
s := NewSwarm(ctx, self, nil, nil)
defer s.Close()

// this should fail
_, err := s.dsync.DialLock(ctx, self)
if err != ErrDialToSelf {
t.Fatal("expected error from self dial")
}

// do it twice to make sure we get a new active dial object that fails again
_, err = s.dsync.DialLock(ctx, self)
if err != ErrDialToSelf {
t.Fatal("expected error from self dial")
}
}
Loading

0 comments on commit ce7c0bc

Please sign in to comment.