Skip to content

Commit

Permalink
Dial from your own listener
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoPolo committed Aug 24, 2024
1 parent 412daa4 commit 7566f9f
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 25 deletions.
68 changes: 68 additions & 0 deletions libp2p_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/rand"
"errors"
"fmt"
"io"
"net"
"net/netip"
"regexp"
Expand Down Expand Up @@ -587,3 +588,70 @@ func TestWebRTCReuseAddrWithQUIC(t *testing.T) {
require.Contains(t, h1.Addrs()[0].String(), "quic-v1")
})
}

func TestUseCorrectTransportForDialOut(t *testing.T) {
listAddrOrder := [][]string{
{"/ip4/127.0.0.1/udp/0/quic-v1", "/ip4/127.0.0.1/udp/0/quic-v1/webtransport"},
{"/ip4/127.0.0.1/udp/0/quic-v1/webtransport", "/ip4/127.0.0.1/udp/0/quic-v1"},
{"/ip4/0.0.0.0/udp/0/quic-v1", "/ip4/0.0.0.0/udp/0/quic-v1/webtransport"},
{"/ip4/0.0.0.0/udp/0/quic-v1/webtransport", "/ip4/0.0.0.0/udp/0/quic-v1"},
}
for _, order := range listAddrOrder {
h1, err := New(ListenAddrStrings(order...), Transport(quic.NewTransport), Transport(webtransport.New))
require.NoError(t, err)
t.Cleanup(func() {
h1.Close()
})

go func() {
h1.SetStreamHandler("/echo-port", func(s network.Stream) {
m := s.Conn().RemoteMultiaddr()
v, err := m.ValueForProtocol(ma.P_UDP)
if err != nil {
s.Reset()
return
}
s.Write([]byte(v))
s.Close()
})
}()

for _, addr := range h1.Addrs() {
t.Run("order "+strings.Join(order, ",")+" Dial to "+addr.String(), func(t *testing.T) {
h2, err := New(ListenAddrStrings(
"/ip4/0.0.0.0/udp/0/quic-v1",
"/ip4/0.0.0.0/udp/0/quic-v1/webtransport",
), Transport(quic.NewTransport), Transport(webtransport.New))
require.NoError(t, err)
defer h2.Close()
t.Log("H2 Addrs", h2.Addrs())
var myExpectedDialOutAddr ma.Multiaddr
addrIsWT, _ := webtransport.IsWebtransportMultiaddr(addr)
isLocal := func(a ma.Multiaddr) bool {
return strings.Contains(a.String(), "127.0.0.1")
}
addrIsLocal := isLocal(addr)
for _, a := range h2.Addrs() {
aIsWT, _ := webtransport.IsWebtransportMultiaddr(a)
if addrIsWT == aIsWT && isLocal(a) == addrIsLocal {
myExpectedDialOutAddr = a
break
}
}

err = h2.Connect(context.Background(), peer.AddrInfo{ID: h1.ID(), Addrs: []ma.Multiaddr{addr}})
require.NoError(t, err)

s, err := h2.NewStream(context.Background(), h1.ID(), "/echo-port")
require.NoError(t, err)

port, err := io.ReadAll(s)
require.NoError(t, err)

myExpectedPort, err := myExpectedDialOutAddr.ValueForProtocol(ma.P_UDP)
require.NoError(t, err)
require.Equal(t, myExpectedPort, string(port))
})
}
}
}
5 changes: 3 additions & 2 deletions p2p/transport/quic/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee
}

tlsConf, keyCh := t.identity.ConfigForPeer(p)
ctx = quicreuse.WithAssociation(ctx, t)
pconn, err := t.connManager.DialQUIC(ctx, raddr, tlsConf, t.allowWindowIncrease)
if err != nil {
return nil, err
Expand Down Expand Up @@ -196,7 +197,7 @@ func (t *transport) holePunch(ctx context.Context, raddr ma.Multiaddr, p peer.ID
if err != nil {
return nil, err
}
tr, err := t.connManager.TransportForDial(network, addr)
tr, err := t.connManager.TransportWithAssociationForDial(t, network, addr)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -313,7 +314,7 @@ func (t *transport) Listen(addr ma.Multiaddr) (tpt.Listener, error) {
return nil, fmt.Errorf("can't listen on quic version %v, underlying listener doesn't support it", version)
}
} else {
ln, err := t.connManager.ListenQUIC(addr, &tlsConf, t.allowWindowIncrease)
ln, err := t.connManager.ListenQUICAndAssociate(t, addr, &tlsConf, t.allowWindowIncrease)
if err != nil {
return nil, err
}
Expand Down
38 changes: 33 additions & 5 deletions p2p/transport/quicreuse/connmgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ func (c *ConnManager) getReuse(network string) (*reuse, error) {
}

func (c *ConnManager) ListenQUIC(addr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (Listener, error) {
return c.ListenQUICAndAssociate(nil, addr, tlsConf, allowWindowIncrease)
}

// ListenQUICAndAssociate returns a QUIC listener and associates the underlying transport with the given association.
func (c *ConnManager) ListenQUICAndAssociate(association any, addr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (Listener, error) {
netw, host, err := manet.DialArgs(addr)
if err != nil {
return nil, err
Expand All @@ -117,7 +122,7 @@ func (c *ConnManager) ListenQUIC(addr ma.Multiaddr, tlsConf *tls.Config, allowWi
key := laddr.String()
entry, ok := c.quicListeners[key]
if !ok {
tr, err := c.transportForListen(netw, laddr)
tr, err := c.transportForListen(association, netw, laddr)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -176,13 +181,18 @@ func (c *ConnManager) SharedNonQUICPacketConn(network string, laddr *net.UDPAddr
return nil, errors.New("expected to be able to share with a QUIC listener, but the QUIC listener is not using a refcountedTransport. `DisableReuseport` should not be set")
}

func (c *ConnManager) transportForListen(network string, laddr *net.UDPAddr) (refCountedQuicTransport, error) {
func (c *ConnManager) transportForListen(association any, network string, laddr *net.UDPAddr) (refCountedQuicTransport, error) {
if c.enableReuseport {
reuse, err := c.getReuse(network)
if err != nil {
return nil, err
}
return reuse.TransportForListen(network, laddr)
tr, err := reuse.TransportForListen(network, laddr)
if err != nil {
return nil, err
}
tr.associate(association)
return tr, nil
}

conn, err := net.ListenUDP(network, laddr)
Expand All @@ -199,6 +209,14 @@ func (c *ConnManager) transportForListen(network string, laddr *net.UDPAddr) (re
}, nil
}

type associationKey struct{}

// WithAssociation returns a new context with the given association. Used in
// DialQUIC to prefer a transport that has the given association.
func WithAssociation(ctx context.Context, association any) context.Context {
return context.WithValue(ctx, associationKey{}, association)
}

func (c *ConnManager) DialQUIC(ctx context.Context, raddr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (quic.Connection, error) {
naddr, v, err := FromQuicMultiaddr(raddr)
if err != nil {
Expand All @@ -219,7 +237,12 @@ func (c *ConnManager) DialQUIC(ctx context.Context, raddr ma.Multiaddr, tlsConf
return nil, errors.New("unknown QUIC version")
}

tr, err := c.TransportForDial(netw, naddr)
var tr refCountedQuicTransport
if association := ctx.Value(associationKey{}); association != nil {
tr, err = c.TransportWithAssociationForDial(association, netw, naddr)
} else {
tr, err = c.TransportForDial(netw, naddr)
}
if err != nil {
return nil, err
}
Expand All @@ -232,12 +255,17 @@ func (c *ConnManager) DialQUIC(ctx context.Context, raddr ma.Multiaddr, tlsConf
}

func (c *ConnManager) TransportForDial(network string, raddr *net.UDPAddr) (refCountedQuicTransport, error) {
return c.TransportWithAssociationForDial(nil, network, raddr)
}

// TransportWithAssociationForDial returns a QUIC transport for dialing, preferring a transport with the given association.
func (c *ConnManager) TransportWithAssociationForDial(association any, network string, raddr *net.UDPAddr) (refCountedQuicTransport, error) {
if c.enableReuseport {
reuse, err := c.getReuse(network)
if err != nil {
return nil, err
}
return reuse.TransportForDial(network, raddr)
return reuse.transportWithAssociationForDial(association, network, raddr)
}

var laddr *net.UDPAddr
Expand Down
4 changes: 1 addition & 3 deletions p2p/transport/quicreuse/connmgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ func testListenOnSameProto(t *testing.T, enableReuseport bool) {

const alpn = "proto"

var tlsConf tls.Config
tlsConf.NextProtos = []string{alpn}
ln1, err := cm.ListenQUIC(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), &tls.Config{NextProtos: []string{alpn}}, nil)
require.NoError(t, err)
defer ln1.Close()
Expand Down Expand Up @@ -96,7 +94,7 @@ func TestConnectionPassedToQUICForListening(t *testing.T) {

_, err = cm.ListenQUIC(raddr, &tls.Config{NextProtos: []string{"proto"}}, nil)
require.NoError(t, err)
quicTr, err := cm.transportForListen(netw, naddr)
quicTr, err := cm.transportForListen(nil, netw, naddr)
require.NoError(t, err)
defer quicTr.Close()
if _, ok := quicTr.(*singleOwnerTransport).Transport.Conn.(quic.OOBCapablePacketConn); !ok {
Expand Down
49 changes: 42 additions & 7 deletions p2p/transport/quicreuse/reuse.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,36 @@ type refcountedTransport struct {
mutex sync.Mutex
refCount int
unusedSince time.Time

assocations map[any]struct{}
}

// associate an arbitrary value with this transport.
// This lets us "tag" the refcountedTransport when listening so we can use it
// later for dialing. Necessary for holepunching and learning about our own
// observed listening address.
func (c *refcountedTransport) associate(a any) {
if a == nil {
return
}
c.mutex.Lock()
defer c.mutex.Unlock()
if c.assocations == nil {
c.assocations = make(map[any]struct{})
}
c.assocations[a] = struct{}{}
}

// hasAssociation returns true if the transport has the given association.
// If it is a nil association, it will always return true.
func (c *refcountedTransport) hasAssociation(a any) bool {
if a == nil {
return true
}
c.mutex.Lock()
defer c.mutex.Unlock()
_, ok := c.assocations[a]
return ok
}

func (c *refcountedTransport) IncreaseCount() {
Expand Down Expand Up @@ -204,7 +234,7 @@ func (r *reuse) gc() {
}
}

func (r *reuse) TransportForDial(network string, raddr *net.UDPAddr) (*refcountedTransport, error) {
func (r *reuse) transportWithAssociationForDial(association any, network string, raddr *net.UDPAddr) (*refcountedTransport, error) {
var ip *net.IP

// Only bother looking up the source address if we actually _have_ non 0.0.0.0 listeners.
Expand All @@ -224,29 +254,34 @@ func (r *reuse) TransportForDial(network string, raddr *net.UDPAddr) (*refcounte
r.mutex.Lock()
defer r.mutex.Unlock()

tr, err := r.transportForDialLocked(network, ip)
tr, err := r.transportForDialLocked(association, network, ip)
if err != nil {
return nil, err
}
tr.IncreaseCount()
return tr, nil
}

func (r *reuse) transportForDialLocked(network string, source *net.IP) (*refcountedTransport, error) {
func (r *reuse) transportForDialLocked(association any, network string, source *net.IP) (*refcountedTransport, error) {
if source != nil {
// We already have at least one suitable transport...
if trs, ok := r.unicast[source.String()]; ok {
// ... we don't care which port we're dialing from. Just use the first.
// Prefer a transport that has the given association. We want to
// reuse the transport the association used for listening.
for _, tr := range trs {
return tr, nil
if tr.hasAssociation(association) {
return tr, nil
}
}
}
}

// Use a transport listening on 0.0.0.0 (or ::).
// Again, we don't care about the port number.
// Again, prefer a transport that has the given association.
for _, tr := range r.globalListeners {
return tr, nil
if tr.hasAssociation(association) {
return tr, nil
}
}

// Use a transport we've previously dialed from
Expand Down
14 changes: 7 additions & 7 deletions p2p/transport/quicreuse/reuse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func TestReuseCreateNewGlobalConnOnDial(t *testing.T) {

addr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
require.NoError(t, err)
conn, err := reuse.TransportForDial("udp4", addr)
conn, err := reuse.transportWithAssociationForDial(nil, "udp4", addr)
require.NoError(t, err)
require.Equal(t, 1, conn.GetCount())
laddr := conn.LocalAddr().(*net.UDPAddr)
Expand All @@ -111,7 +111,7 @@ func TestReuseConnectionWhenDialing(t *testing.T) {
// dial
raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
require.NoError(t, err)
conn, err := reuse.TransportForDial("udp4", raddr)
conn, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr)
require.NoError(t, err)
require.Equal(t, 2, conn.GetCount())
}
Expand All @@ -122,7 +122,7 @@ func TestReuseConnectionWhenListening(t *testing.T) {

raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
require.NoError(t, err)
tr, err := reuse.TransportForDial("udp4", raddr)
tr, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr)
require.NoError(t, err)
laddr := &net.UDPAddr{IP: net.IPv4zero, Port: tr.LocalAddr().(*net.UDPAddr).Port}
lconn, err := reuse.TransportForListen("udp4", laddr)
Expand All @@ -138,7 +138,7 @@ func TestReuseConnectionWhenDialBeforeListen(t *testing.T) {
// dial any address
raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
require.NoError(t, err)
rTr, err := reuse.TransportForDial("udp4", raddr)
rTr, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr)
require.NoError(t, err)

// open a listener
Expand All @@ -149,7 +149,7 @@ func TestReuseConnectionWhenDialBeforeListen(t *testing.T) {
// new dials should go via the listener connection
raddr, err = net.ResolveUDPAddr("udp4", "1.1.1.1:1235")
require.NoError(t, err)
tr, err := reuse.TransportForDial("udp4", raddr)
tr, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr)
require.NoError(t, err)
require.Equal(t, lTr, tr)
require.Equal(t, 2, tr.GetCount())
Expand Down Expand Up @@ -183,7 +183,7 @@ func TestReuseListenOnSpecificInterface(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 1, lconn.GetCount())
// dial
conn, err := reuse.TransportForDial("udp4", raddr)
conn, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr)
require.NoError(t, err)
require.Equal(t, 1, conn.GetCount())
}
Expand Down Expand Up @@ -214,7 +214,7 @@ func TestReuseGarbageCollect(t *testing.T) {

raddr, err := net.ResolveUDPAddr("udp4", "1.2.3.4:1234")
require.NoError(t, err)
dTr, err := reuse.TransportForDial("udp4", raddr)
dTr, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr)
require.NoError(t, err)
require.Equal(t, 1, dTr.GetCount())

Expand Down
3 changes: 2 additions & 1 deletion p2p/transport/webtransport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string
return verifyRawCerts(rawCerts, certHashes)
}
}
ctx = quicreuse.WithAssociation(ctx, t)
conn, err := t.connManager.DialQUIC(ctx, addr, tlsConf, t.allowWindowIncrease)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -331,7 +332,7 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) {
}
tlsConf.NextProtos = append(tlsConf.NextProtos, http3.NextProtoH3)

ln, err := t.connManager.ListenQUIC(laddr, tlsConf, t.allowWindowIncrease)
ln, err := t.connManager.ListenQUICAndAssociate(t, laddr, tlsConf, t.allowWindowIncrease)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 7566f9f

Please sign in to comment.