Skip to content

Commit

Permalink
only start hole punching service after the host has a public address
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Nov 16, 2021
1 parent 79db68c commit ac2d335
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 32 deletions.
78 changes: 65 additions & 13 deletions p2p/protocol/holepunch/coordination.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ type Service struct {
closed bool
refCount sync.WaitGroup

hasPublicAddrsChan chan struct{} // this chan is closed as soon as we have a public address

// active hole punches for deduplicating
activeMx sync.Mutex
active map[peer.ID]struct{}
Expand All @@ -71,11 +73,12 @@ func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service,

ctx, cancel := context.WithCancel(context.Background())
hs := &Service{
ctx: ctx,
ctxCancel: cancel,
host: h,
ids: ids,
active: make(map[peer.ID]struct{}),
ctx: ctx,
ctxCancel: cancel,
host: h,
ids: ids,
active: make(map[peer.ID]struct{}),
hasPublicAddrsChan: make(chan struct{}),
}

for _, opt := range opts {
Expand All @@ -85,11 +88,47 @@ func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service,
}
}

h.SetStreamHandler(Protocol, hs.handleNewStream)
hs.refCount.Add(1)
go hs.watchForPublicAddr()

h.Network().Notify((*netNotifiee)(hs))
return hs, nil
}

func (hs *Service) watchForPublicAddr() {
defer hs.refCount.Done()

log.Debug("waiting until we have at least one public address", "peer", hs.host.ID())

// TODO: We should have an event here that fires when identify discovers a new
// address (and when autonat confirms that address).
// As we currently don't have an event like this, just check our observed addresses
// regularly (exponential backoff starting at 250 ms, capped at 5s).
duration := 250 * time.Millisecond
const maxDuration = 5 * time.Second
t := time.NewTimer(duration)
defer t.Stop()
for {
if containsPublicAddr(hs.ids.OwnObservedAddrs()) {
log.Debug("Host now has a public address. Starting holepunch protocol.")
hs.host.SetStreamHandler(Protocol, hs.handleNewStream)
close(hs.hasPublicAddrsChan)
return
}

select {
case <-hs.ctx.Done():
return
case <-t.C:
duration *= 2
if duration > maxDuration {
duration = maxDuration
}
t.Reset(duration)
}
}
}

// Close closes the Hole Punch Service.
func (hs *Service) Close() error {
hs.closeMx.Lock()
Expand Down Expand Up @@ -176,7 +215,6 @@ func (hs *Service) beginDirectConnect(p peer.ID) error {
// It first attempts a direct dial (if we have a public address of that peer), and then
// coordinates a hole punch over the given relay connection.
func (hs *Service) DirectConnect(p peer.ID) error {
log.Debugw("got inbound proxy conn", "peer", p)
if err := hs.beginDirectConnect(p); err != nil {
return err
}
Expand Down Expand Up @@ -221,8 +259,16 @@ func (hs *Service) directConnect(rp peer.ID) error {
}
}

if len(hs.ids.OwnObservedAddrs()) == 0 {
log.Debugw("got inbound proxy conn", "peer", rp)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
select {
case <-hs.ctx.Done():
return hs.ctx.Err()
case <-ctx.Done():
log.Debug("didn't find any public host address")
return errors.New("can't initiate hole punch, as we don't have any public addresses")
case <-hs.hasPublicAddrsChan:
}

// hole punch
Expand Down Expand Up @@ -341,11 +387,6 @@ func (hs *Service) handleNewStream(s network.Stream) {
err = hs.holePunchConnect(pi, false)
dt := time.Since(start)
hs.tracer.EndHolePunch(rp, dt, err)
if err != nil {
log.Debugw("hole punching failed", "peer", rp, "time", dt, "error", err)
} else {
log.Debugw("hole punching succeeded", "peer", rp, "time", dt)
}
}

func (hs *Service) holePunchConnect(pi peer.AddrInfo, isClient bool) error {
Expand All @@ -363,6 +404,16 @@ func (hs *Service) holePunchConnect(pi peer.AddrInfo, isClient bool) error {
return nil
}

func containsPublicAddr(addrs []ma.Multiaddr) bool {
for _, addr := range addrs {
if isRelayAddress(addr) || !manet.IsPublicAddr(addr) {
continue
}
return true
}
return false
}

func removeRelayAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
result := make([]ma.Multiaddr, 0, len(addrs))
for _, addr := range addrs {
Expand Down Expand Up @@ -414,6 +465,7 @@ func (nn *netNotifiee) Connected(_ network.Network, conn network.Conn) {
// that we can dial to for a hole punch.
case <-hs.ids.IdentifyWait(conn):
case <-hs.ctx.Done():
return
}

_ = hs.DirectConnect(conn.RemotePeer())
Expand Down
80 changes: 61 additions & 19 deletions p2p/protocol/holepunch/coordination_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,22 @@ func (m *mockEventTracer) getEvents() []*holepunch.Event {

var _ holepunch.EventTracer = &mockEventTracer{}

type mockIDService struct {
identify.IDService
}

var _ identify.IDService = &mockIDService{}

func newMockIDService(t *testing.T, h host.Host) identify.IDService {
ids, err := identify.NewIDService(h)
require.NoError(t, err)
return &mockIDService{IDService: ids}
}

func (s *mockIDService) OwnObservedAddrs() []ma.Multiaddr {
return append(s.IDService.OwnObservedAddrs(), ma.StringCast("/ip4/1.1.1.1/tcp/1234"))
}

func TestNoHolePunchIfDirectConnExists(t *testing.T) {
tr := &mockEventTracer{}
h1, hps := mkHostWithHolePunchSvc(t, holepunch.WithTracer(tr))
Expand Down Expand Up @@ -95,7 +111,7 @@ func TestDirectDialWorks(t *testing.T) {

func TestEndToEndSimConnect(t *testing.T) {
tr := &mockEventTracer{}
h1, h2, relay, _ := makeRelayedHosts(t, holepunch.WithTracer(tr), true)
h1, h2, relay, _ := makeRelayedHosts(t, nil, holepunch.WithTracer(tr), true)
defer h1.Close()
defer h2.Close()
defer relay.Close()
Expand Down Expand Up @@ -158,11 +174,11 @@ func TestFailuresOnInitiator(t *testing.T) {
}

tr := &mockEventTracer{}
h1, h2, relay, _ := makeRelayedHosts(t, holepunch.WithTracer(tr), false)
h1, h2, relay, _ := makeRelayedHosts(t, nil, nil, false)
defer h1.Close()
defer h2.Close()
defer relay.Close()
hps := addHolePunchService(t, h2)
hps := addHolePunchService(t, h2, holepunch.WithTracer(tr))

if tc.rhandler != nil {
h1.SetStreamHandler(holepunch.Protocol, tc.rhandler)
Expand All @@ -180,6 +196,14 @@ func TestFailuresOnInitiator(t *testing.T) {
}
}

func addrsToBytes(as []ma.Multiaddr) [][]byte {
bzs := make([][]byte, 0, len(as))
for _, a := range as {
bzs = append(bzs, a.Bytes())
}
return bzs
}

func TestFailuresOnResponder(t *testing.T) {
tcs := map[string]struct {
initiator func(s network.Stream)
Expand All @@ -192,22 +216,36 @@ func TestFailuresOnResponder(t *testing.T) {
},
errMsg: "expected CONNECT message",
},
"initiator does NOT send a SYNC message after a Connect message": {
"initiator does NOT send a SYNC message after a CONNECT message": {
initiator: func(s network.Stream) {
w := protoio.NewDelimitedWriter(s)
w.WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()})
w.WriteMsg(&holepunch_pb.HolePunch{
Type: holepunch_pb.HolePunch_CONNECT.Enum(),
ObsAddrs: addrsToBytes([]ma.Multiaddr{ma.StringCast("/ip4/127.0.0.1/tcp/1234")}),
})
w.WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()})
},
errMsg: "expected SYNC message",
},
"initiator does NOT reply within hole punch deadline": {
holePunchTimeout: 10 * time.Millisecond,
initiator: func(s network.Stream) {
protoio.NewDelimitedWriter(s).WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()})
protoio.NewDelimitedWriter(s).WriteMsg(&holepunch_pb.HolePunch{
Type: holepunch_pb.HolePunch_CONNECT.Enum(),
ObsAddrs: addrsToBytes([]ma.Multiaddr{ma.StringCast("/ip4/127.0.0.1/tcp/1234")}),
})
time.Sleep(10 * time.Second)
},
errMsg: "i/o deadline reached",
},
"initiator does NOT send any addresses in CONNECT": {
holePunchTimeout: 10 * time.Millisecond,
initiator: func(s network.Stream) {
protoio.NewDelimitedWriter(s).WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()})
time.Sleep(10 * time.Second)
},
errMsg: "expected CONNECT message to contain at least one address",
},
}

for name, tc := range tcs {
Expand All @@ -219,7 +257,7 @@ func TestFailuresOnResponder(t *testing.T) {
}

tr := &mockEventTracer{}
h1, h2, relay, _ := makeRelayedHosts(t, holepunch.WithTracer(tr), false)
h1, h2, relay, _ := makeRelayedHosts(t, holepunch.WithTracer(tr), nil, false)
defer h1.Close()
defer h2.Close()
defer relay.Close()
Expand Down Expand Up @@ -293,7 +331,7 @@ func ensureDirectConn(t *testing.T, h1, h2 host.Host) {
}, 5*time.Second, 50*time.Millisecond)
}

func mkHostWithStaticAutoRelay(t *testing.T, ctx context.Context, relay host.Host) host.Host {
func mkHostWithStaticAutoRelay(t *testing.T, relay host.Host) host.Host {
if race.WithRace() {
t.Skip("modifying manet.Private4 is racy")
}
Expand Down Expand Up @@ -327,19 +365,23 @@ func mkHostWithStaticAutoRelay(t *testing.T, ctx context.Context, relay host.Hos
return h
}

func makeRelayedHosts(t *testing.T, h1Opt holepunch.Option, addHolePuncher bool) (h1, h2, relay host.Host, hps *holepunch.Service) {
func makeRelayedHosts(t *testing.T, h1opt, h2opt holepunch.Option, addHolePuncher bool) (h1, h2, relay host.Host, hps *holepunch.Service) {
t.Helper()
h1, _ = mkHostWithHolePunchSvc(t, h1Opt)
var h1opts []holepunch.Option
if h1opt != nil {
h1opts = append(h1opts, h1opt)
}
h1, _ = mkHostWithHolePunchSvc(t, h1opts...)
var err error
relay, err = libp2p.New(libp2p.ListenAddrs(ma.StringCast("/ip4/127.0.0.1/tcp/0")), libp2p.DisableRelay())
require.NoError(t, err)

_, err = relayv1.NewRelay(relay)
require.NoError(t, err)

h2 = mkHostWithStaticAutoRelay(t, context.Background(), relay)
h2 = mkHostWithStaticAutoRelay(t, relay)
if addHolePuncher {
hps = addHolePunchService(t, h2)
hps = addHolePunchService(t, h2, h2opt)
}

// h1 has a relay addr
Expand All @@ -359,11 +401,13 @@ func makeRelayedHosts(t *testing.T, h1Opt holepunch.Option, addHolePuncher bool)
return
}

func addHolePunchService(t *testing.T, h host.Host) *holepunch.Service {
func addHolePunchService(t *testing.T, h host.Host, opt holepunch.Option) *holepunch.Service {
t.Helper()
ids, err := identify.NewIDService(h)
require.NoError(t, err)
hps, err := holepunch.NewService(h, ids)
var opts []holepunch.Option
if opt != nil {
opts = append(opts, opt)
}
hps, err := holepunch.NewService(h, newMockIDService(t, h), opts...)
require.NoError(t, err)
return hps
}
Expand All @@ -372,9 +416,7 @@ func mkHostWithHolePunchSvc(t *testing.T, opts ...holepunch.Option) (host.Host,
t.Helper()
h, err := libp2p.New(libp2p.ListenAddrs(ma.StringCast("/ip4/127.0.0.1/tcp/0"), ma.StringCast("/ip6/::1/tcp/0")))
require.NoError(t, err)
ids, err := identify.NewIDService(h)
require.NoError(t, err)
hps, err := holepunch.NewService(h, ids, opts...)
hps, err := holepunch.NewService(h, newMockIDService(t, h), opts...)
require.NoError(t, err)
return h, hps
}

0 comments on commit ac2d335

Please sign in to comment.