diff --git a/p2p/protocol/holepunch/coordination.go b/p2p/protocol/holepunch/coordination.go index 5243f37fda..fd0c445298 100644 --- a/p2p/protocol/holepunch/coordination.go +++ b/p2p/protocol/holepunch/coordination.go @@ -56,6 +56,8 @@ type Service struct { closed bool refCount sync.WaitGroup + hasPublicAddrsChan chan struct{} // this chan is closed as soon as we have a public address + // active hole punches for deduplicating activeMx sync.Mutex active map[peer.ID]struct{} @@ -71,11 +73,12 @@ func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service, ctx, cancel := context.WithCancel(context.Background()) hs := &Service{ - ctx: ctx, - ctxCancel: cancel, - host: h, - ids: ids, - active: make(map[peer.ID]struct{}), + ctx: ctx, + ctxCancel: cancel, + host: h, + ids: ids, + active: make(map[peer.ID]struct{}), + hasPublicAddrsChan: make(chan struct{}), } for _, opt := range opts { @@ -85,11 +88,47 @@ func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service, } } - h.SetStreamHandler(Protocol, hs.handleNewStream) + hs.refCount.Add(1) + go hs.watchForPublicAddr() + h.Network().Notify((*netNotifiee)(hs)) return hs, nil } +func (hs *Service) watchForPublicAddr() { + defer hs.refCount.Done() + + log.Debug("waiting until we have at least one public address", "peer", hs.host.ID()) + + // TODO: We should have an event here that fires when identify discovers a new + // address (and when autonat confirms that address). + // As we currently don't have an event like this, just check our observed addresses + // regularly (exponential backoff starting at 250 ms, capped at 5s). + duration := 250 * time.Millisecond + const maxDuration = 5 * time.Second + t := time.NewTimer(duration) + defer t.Stop() + for { + if containsPublicAddr(hs.ids.OwnObservedAddrs()) { + log.Debug("Host now has a public address. Starting holepunch protocol.") + hs.host.SetStreamHandler(Protocol, hs.handleNewStream) + close(hs.hasPublicAddrsChan) + return + } + + select { + case <-hs.ctx.Done(): + return + case <-t.C: + duration *= 2 + if duration > maxDuration { + duration = maxDuration + } + t.Reset(duration) + } + } +} + // Close closes the Hole Punch Service. func (hs *Service) Close() error { hs.closeMx.Lock() @@ -176,7 +215,6 @@ func (hs *Service) beginDirectConnect(p peer.ID) error { // It first attempts a direct dial (if we have a public address of that peer), and then // coordinates a hole punch over the given relay connection. func (hs *Service) DirectConnect(p peer.ID) error { - log.Debugw("got inbound proxy conn", "peer", p) if err := hs.beginDirectConnect(p); err != nil { return err } @@ -221,8 +259,16 @@ func (hs *Service) directConnect(rp peer.ID) error { } } - if len(hs.ids.OwnObservedAddrs()) == 0 { + log.Debugw("got inbound proxy conn", "peer", rp) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + select { + case <-hs.ctx.Done(): + return hs.ctx.Err() + case <-ctx.Done(): + log.Debug("didn't find any public host address") return errors.New("can't initiate hole punch, as we don't have any public addresses") + case <-hs.hasPublicAddrsChan: } // hole punch @@ -341,11 +387,6 @@ func (hs *Service) handleNewStream(s network.Stream) { err = hs.holePunchConnect(pi, false) dt := time.Since(start) hs.tracer.EndHolePunch(rp, dt, err) - if err != nil { - log.Debugw("hole punching failed", "peer", rp, "time", dt, "error", err) - } else { - log.Debugw("hole punching succeeded", "peer", rp, "time", dt) - } } func (hs *Service) holePunchConnect(pi peer.AddrInfo, isClient bool) error { @@ -363,6 +404,16 @@ func (hs *Service) holePunchConnect(pi peer.AddrInfo, isClient bool) error { return nil } +func containsPublicAddr(addrs []ma.Multiaddr) bool { + for _, addr := range addrs { + if isRelayAddress(addr) || !manet.IsPublicAddr(addr) { + continue + } + return true + } + return false +} + func removeRelayAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { result := make([]ma.Multiaddr, 0, len(addrs)) for _, addr := range addrs { @@ -414,6 +465,7 @@ func (nn *netNotifiee) Connected(_ network.Network, conn network.Conn) { // that we can dial to for a hole punch. case <-hs.ids.IdentifyWait(conn): case <-hs.ctx.Done(): + return } _ = hs.DirectConnect(conn.RemotePeer()) diff --git a/p2p/protocol/holepunch/coordination_test.go b/p2p/protocol/holepunch/coordination_test.go index 4786e3416d..da0db1ae19 100644 --- a/p2p/protocol/holepunch/coordination_test.go +++ b/p2p/protocol/holepunch/coordination_test.go @@ -47,6 +47,22 @@ func (m *mockEventTracer) getEvents() []*holepunch.Event { var _ holepunch.EventTracer = &mockEventTracer{} +type mockIDService struct { + identify.IDService +} + +var _ identify.IDService = &mockIDService{} + +func newMockIDService(t *testing.T, h host.Host) identify.IDService { + ids, err := identify.NewIDService(h) + require.NoError(t, err) + return &mockIDService{IDService: ids} +} + +func (s *mockIDService) OwnObservedAddrs() []ma.Multiaddr { + return append(s.IDService.OwnObservedAddrs(), ma.StringCast("/ip4/1.1.1.1/tcp/1234")) +} + func TestNoHolePunchIfDirectConnExists(t *testing.T) { tr := &mockEventTracer{} h1, hps := mkHostWithHolePunchSvc(t, holepunch.WithTracer(tr)) @@ -95,7 +111,7 @@ func TestDirectDialWorks(t *testing.T) { func TestEndToEndSimConnect(t *testing.T) { tr := &mockEventTracer{} - h1, h2, relay, _ := makeRelayedHosts(t, holepunch.WithTracer(tr), true) + h1, h2, relay, _ := makeRelayedHosts(t, nil, holepunch.WithTracer(tr), true) defer h1.Close() defer h2.Close() defer relay.Close() @@ -158,11 +174,11 @@ func TestFailuresOnInitiator(t *testing.T) { } tr := &mockEventTracer{} - h1, h2, relay, _ := makeRelayedHosts(t, holepunch.WithTracer(tr), false) + h1, h2, relay, _ := makeRelayedHosts(t, nil, nil, false) defer h1.Close() defer h2.Close() defer relay.Close() - hps := addHolePunchService(t, h2) + hps := addHolePunchService(t, h2, holepunch.WithTracer(tr)) if tc.rhandler != nil { h1.SetStreamHandler(holepunch.Protocol, tc.rhandler) @@ -180,6 +196,14 @@ func TestFailuresOnInitiator(t *testing.T) { } } +func addrsToBytes(as []ma.Multiaddr) [][]byte { + bzs := make([][]byte, 0, len(as)) + for _, a := range as { + bzs = append(bzs, a.Bytes()) + } + return bzs +} + func TestFailuresOnResponder(t *testing.T) { tcs := map[string]struct { initiator func(s network.Stream) @@ -192,10 +216,13 @@ func TestFailuresOnResponder(t *testing.T) { }, errMsg: "expected CONNECT message", }, - "initiator does NOT send a SYNC message after a Connect message": { + "initiator does NOT send a SYNC message after a CONNECT message": { initiator: func(s network.Stream) { w := protoio.NewDelimitedWriter(s) - w.WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()}) + w.WriteMsg(&holepunch_pb.HolePunch{ + Type: holepunch_pb.HolePunch_CONNECT.Enum(), + ObsAddrs: addrsToBytes([]ma.Multiaddr{ma.StringCast("/ip4/127.0.0.1/tcp/1234")}), + }) w.WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()}) }, errMsg: "expected SYNC message", @@ -203,11 +230,22 @@ func TestFailuresOnResponder(t *testing.T) { "initiator does NOT reply within hole punch deadline": { holePunchTimeout: 10 * time.Millisecond, initiator: func(s network.Stream) { - protoio.NewDelimitedWriter(s).WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()}) + protoio.NewDelimitedWriter(s).WriteMsg(&holepunch_pb.HolePunch{ + Type: holepunch_pb.HolePunch_CONNECT.Enum(), + ObsAddrs: addrsToBytes([]ma.Multiaddr{ma.StringCast("/ip4/127.0.0.1/tcp/1234")}), + }) time.Sleep(10 * time.Second) }, errMsg: "i/o deadline reached", }, + "initiator does NOT send any addresses in CONNECT": { + holePunchTimeout: 10 * time.Millisecond, + initiator: func(s network.Stream) { + protoio.NewDelimitedWriter(s).WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()}) + time.Sleep(10 * time.Second) + }, + errMsg: "expected CONNECT message to contain at least one address", + }, } for name, tc := range tcs { @@ -219,7 +257,7 @@ func TestFailuresOnResponder(t *testing.T) { } tr := &mockEventTracer{} - h1, h2, relay, _ := makeRelayedHosts(t, holepunch.WithTracer(tr), false) + h1, h2, relay, _ := makeRelayedHosts(t, holepunch.WithTracer(tr), nil, false) defer h1.Close() defer h2.Close() defer relay.Close() @@ -293,7 +331,7 @@ func ensureDirectConn(t *testing.T, h1, h2 host.Host) { }, 5*time.Second, 50*time.Millisecond) } -func mkHostWithStaticAutoRelay(t *testing.T, ctx context.Context, relay host.Host) host.Host { +func mkHostWithStaticAutoRelay(t *testing.T, relay host.Host) host.Host { if race.WithRace() { t.Skip("modifying manet.Private4 is racy") } @@ -327,9 +365,13 @@ func mkHostWithStaticAutoRelay(t *testing.T, ctx context.Context, relay host.Hos return h } -func makeRelayedHosts(t *testing.T, h1Opt holepunch.Option, addHolePuncher bool) (h1, h2, relay host.Host, hps *holepunch.Service) { +func makeRelayedHosts(t *testing.T, h1opt, h2opt holepunch.Option, addHolePuncher bool) (h1, h2, relay host.Host, hps *holepunch.Service) { t.Helper() - h1, _ = mkHostWithHolePunchSvc(t, h1Opt) + var h1opts []holepunch.Option + if h1opt != nil { + h1opts = append(h1opts, h1opt) + } + h1, _ = mkHostWithHolePunchSvc(t, h1opts...) var err error relay, err = libp2p.New(libp2p.ListenAddrs(ma.StringCast("/ip4/127.0.0.1/tcp/0")), libp2p.DisableRelay()) require.NoError(t, err) @@ -337,9 +379,9 @@ func makeRelayedHosts(t *testing.T, h1Opt holepunch.Option, addHolePuncher bool) _, err = relayv1.NewRelay(relay) require.NoError(t, err) - h2 = mkHostWithStaticAutoRelay(t, context.Background(), relay) + h2 = mkHostWithStaticAutoRelay(t, relay) if addHolePuncher { - hps = addHolePunchService(t, h2) + hps = addHolePunchService(t, h2, h2opt) } // h1 has a relay addr @@ -359,11 +401,13 @@ func makeRelayedHosts(t *testing.T, h1Opt holepunch.Option, addHolePuncher bool) return } -func addHolePunchService(t *testing.T, h host.Host) *holepunch.Service { +func addHolePunchService(t *testing.T, h host.Host, opt holepunch.Option) *holepunch.Service { t.Helper() - ids, err := identify.NewIDService(h) - require.NoError(t, err) - hps, err := holepunch.NewService(h, ids) + var opts []holepunch.Option + if opt != nil { + opts = append(opts, opt) + } + hps, err := holepunch.NewService(h, newMockIDService(t, h), opts...) require.NoError(t, err) return hps } @@ -372,9 +416,7 @@ func mkHostWithHolePunchSvc(t *testing.T, opts ...holepunch.Option) (host.Host, t.Helper() h, err := libp2p.New(libp2p.ListenAddrs(ma.StringCast("/ip4/127.0.0.1/tcp/0"), ma.StringCast("/ip6/::1/tcp/0"))) require.NoError(t, err) - ids, err := identify.NewIDService(h) - require.NoError(t, err) - hps, err := holepunch.NewService(h, ids, opts...) + hps, err := holepunch.NewService(h, newMockIDService(t, h), opts...) require.NoError(t, err) return h, hps }