diff --git a/listener.go b/listener.go index d12f1f3..3756ea3 100644 --- a/listener.go +++ b/listener.go @@ -77,22 +77,18 @@ func (l *listener) Accept() (tpt.CapableConn, error) { // return through active hole punching if any key := sess.RemoteAddr().String() + var wasHolePunch bool l.transport.holePunchingMx.Lock() holePunch, ok := l.transport.holePunching[key] + if ok && !holePunch.fulfilled { + holePunch.connCh <- conn + wasHolePunch = true + l.transport.holePunching[key].fulfilled = true + } l.transport.holePunchingMx.Unlock() - if ok { - select { - case holePunch.connCh <- conn: - // We need to delete the entry from the map here, - // in case we accept two connections from the same address. - l.transport.holePunchingMx.Lock() - delete(l.transport.holePunching, key) - l.transport.holePunchingMx.Unlock() - continue - default: - } + if wasHolePunch { + continue } - return conn, nil } } diff --git a/transport.go b/transport.go index f96b55e..4c7d9e5 100644 --- a/transport.go +++ b/transport.go @@ -11,14 +11,13 @@ import ( "sync" "time" - "github.com/libp2p/go-libp2p-core/connmgr" - n "github.com/libp2p/go-libp2p-core/network" - "github.com/minio/sha256-simd" "golang.org/x/crypto/hkdf" logging "github.com/ipfs/go-log" + "github.com/libp2p/go-libp2p-core/connmgr" ic "github.com/libp2p/go-libp2p-core/crypto" + n "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/pnet" tpt "github.com/libp2p/go-libp2p-core/transport" @@ -106,13 +105,14 @@ type transport struct { gater connmgr.ConnectionGater holePunchingMx sync.Mutex - holePunching map[string]activeHolePunch + holePunching map[string]*activeHolePunch } var _ tpt.Transport = &transport{} type activeHolePunch struct { - connCh chan tpt.CapableConn + connCh chan tpt.CapableConn + fulfilled bool } // NewTransport creates a new QUIC transport @@ -153,7 +153,7 @@ func NewTransport(key ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater) ( serverConfig: config, clientConfig: config.Clone(), gater: gater, - holePunching: make(map[string]activeHolePunch), + holePunching: make(map[string]*activeHolePunch), }, nil } @@ -235,26 +235,34 @@ func (t *transport) holePunch(ctx context.Context, network string, addr *net.UDP ctx, cancel := context.WithTimeout(ctx, HolePunchTimeout) defer cancel() - connCh := make(chan tpt.CapableConn) - key := addr.String() t.holePunchingMx.Lock() - t.holePunching[key] = activeHolePunch{connCh: connCh} + if _, ok := t.holePunching[key]; ok { + t.holePunchingMx.Unlock() + return nil, fmt.Errorf("already punching hole for %s", addr) + } + connCh := make(chan tpt.CapableConn, 1) + t.holePunching[key] = &activeHolePunch{connCh: connCh} t.holePunchingMx.Unlock() - payload := make([]byte, 64) var timer *time.Timer defer func() { if timer != nil { timer.Stop() } }() + + payload := make([]byte, 64) + var punchErr error +loop: for i := 0; ; i++ { if _, err := rand.Read(payload); err != nil { - return nil, err + punchErr = err + break } if _, err := pconn.UDPConn.WriteToUDP(payload, addr); err != nil { - return nil, err + punchErr = err + break } maxSleep := 10 * (i + 1) * (i + 1) // in ms @@ -269,15 +277,28 @@ func (t *transport) holePunch(ctx context.Context, network string, addr *net.UDP } select { case c := <-connCh: - return c, nil - case <-timer.C: - case <-ctx.Done(): t.holePunchingMx.Lock() delete(t.holePunching, key) t.holePunchingMx.Unlock() - return nil, ErrHolePunching + return c, nil + case <-timer.C: + case <-ctx.Done(): + punchErr = ErrHolePunching + break loop } } + // we only arrive here if punchErr != nil + t.holePunchingMx.Lock() + defer func() { + delete(t.holePunching, key) + t.holePunchingMx.Unlock() + }() + select { + case c := <-t.holePunching[key].connCh: + return c, nil + default: + return nil, punchErr + } } // Don't use mafmt.QUIC as we don't want to dial DNS addresses. Just /ip{4,6}/udp/quic