diff --git a/p2p/transport/webrtc/udpmux/mux.go b/p2p/transport/webrtc/udpmux/mux.go index 4dd0bf78c2..ca54c18593 100644 --- a/p2p/transport/webrtc/udpmux/mux.go +++ b/p2p/transport/webrtc/udpmux/mux.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "strings" "sync" logging "github.com/ipfs/go-log/v2" @@ -42,9 +43,15 @@ type UDPMux struct { queue chan Candidate - mx sync.Mutex + mx sync.Mutex + // ufragMap allows us to multiplex incoming STUN packets based on ufrag ufragMap map[ufragConnKey]*muxedConnection - addrMap map[string]*muxedConnection + // addrMap allows us to correctly direct incoming packets after the connection + // is established and ufrag isn't available on all packets + addrMap map[string]*muxedConnection + // ufragAddrMap allows cleaning up all addresses from the addrMap once the connection is closed + // During the ICE connectivity checks, the same ufrag might be used on multiple addresses. + ufragAddrMap map[ufragConnKey][]net.Addr // the context controls the lifecycle of the mux wg sync.WaitGroup @@ -57,12 +64,13 @@ var _ ice.UDPMux = &UDPMux{} func NewUDPMux(socket net.PacketConn) *UDPMux { ctx, cancel := context.WithCancel(context.Background()) mux := &UDPMux{ - ctx: ctx, - cancel: cancel, - socket: socket, - ufragMap: make(map[ufragConnKey]*muxedConnection), - addrMap: make(map[string]*muxedConnection), - queue: make(chan Candidate, 32), + ctx: ctx, + cancel: cancel, + socket: socket, + ufragMap: make(map[ufragConnKey]*muxedConnection), + addrMap: make(map[string]*muxedConnection), + ufragAddrMap: make(map[ufragConnKey][]net.Addr), + queue: make(chan Candidate, 32), } return mux @@ -130,7 +138,11 @@ func (mux *UDPMux) readLoop() { n, addr, err := mux.socket.ReadFrom(buf) if err != nil { - log.Errorf("error reading from socket: %v", err) + if strings.Contains(err.Error(), "use of closed network connection") { + log.Debugf("readLoop exiting: socket %s closed", mux.socket.LocalAddr()) + } else { + log.Errorf("error reading from socket %s: %v", mux.socket.LocalAddr(), err) + } pool.Put(buf) return } @@ -157,7 +169,7 @@ func (mux *UDPMux) processPacket(buf []byte, addr net.Addr) (processed bool) { conn, ok := mux.addrMap[addr.String()] mux.mx.Unlock() if ok { - if err := conn.Push(buf); err != nil { + if err := conn.Push(buf, addr); err != nil { log.Debugf("could not push packet: %v", err) return false } @@ -196,7 +208,7 @@ func (mux *UDPMux) processPacket(buf []byte, addr net.Addr) (processed bool) { } } - if err := conn.Push(buf); err != nil { + if err := conn.Push(buf, addr); err != nil { log.Debugf("could not push packet: %v", err) return false } @@ -250,9 +262,12 @@ func (mux *UDPMux) RemoveConnByUfrag(ufrag string) { for _, isIPv6 := range [...]bool{true, false} { key := ufragConnKey{ufrag: ufrag, isIPv6: isIPv6} - if conn, ok := mux.ufragMap[key]; ok { + if _, ok := mux.ufragMap[key]; ok { delete(mux.ufragMap, key) - delete(mux.addrMap, conn.RemoteAddr().String()) + for _, addr := range mux.ufragAddrMap[key] { + delete(mux.addrMap, addr.String()) + } + delete(mux.ufragAddrMap, key) } } } @@ -264,12 +279,14 @@ func (mux *UDPMux) getOrCreateConn(ufrag string, isIPv6 bool, _ *UDPMux, addr ne defer mux.mx.Unlock() if conn, ok := mux.ufragMap[key]; ok { + mux.addrMap[addr.String()] = conn + mux.ufragAddrMap[key] = append(mux.ufragAddrMap[key], addr) return false, conn } - conn := newMuxedConnection(mux, func() { mux.RemoveConnByUfrag(ufrag) }, addr) + conn := newMuxedConnection(mux, func() { mux.RemoveConnByUfrag(ufrag) }) mux.ufragMap[key] = conn mux.addrMap[addr.String()] = conn - + mux.ufragAddrMap[key] = append(mux.ufragAddrMap[key], addr) return true, conn } diff --git a/p2p/transport/webrtc/udpmux/mux_test.go b/p2p/transport/webrtc/udpmux/mux_test.go index 4121b6fdf5..cb24cf9f03 100644 --- a/p2p/transport/webrtc/udpmux/mux_test.go +++ b/p2p/transport/webrtc/udpmux/mux_test.go @@ -1,89 +1,227 @@ package udpmux import ( + "context" + "fmt" "net" + "sync" "testing" "time" + "github.com/pion/stun" "github.com/stretchr/testify/require" ) -var _ net.PacketConn = dummyPacketConn{} - -type dummyPacketConn struct{} - -// Close implements net.PacketConn -func (dummyPacketConn) Close() error { - return nil -} - -// LocalAddr implements net.PacketConn -func (dummyPacketConn) LocalAddr() net.Addr { - return nil -} - -// ReadFrom implements net.PacketConn -func (dummyPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - return 0, &net.UDPAddr{}, nil -} - -// SetDeadline implements net.PacketConn -func (dummyPacketConn) SetDeadline(t time.Time) error { - return nil +func getSTUNBindingRequest(ufrag string) *stun.Message { + msg := stun.New() + msg.SetType(stun.BindingRequest) + uattr := stun.RawAttribute{ + Type: stun.AttrUsername, + Value: []byte(fmt.Sprintf("%s:%s", ufrag, ufrag)), // This is the format we expect in our connections + } + uattr.AddTo(msg) + msg.Encode() + return msg } -// SetReadDeadline implements net.PacketConn -func (dummyPacketConn) SetReadDeadline(t time.Time) error { - return nil +func setupMapping(t *testing.T, ufrag string, from net.PacketConn, m *UDPMux) { + t.Helper() + msg := getSTUNBindingRequest(ufrag) + _, err := from.WriteTo(msg.Raw, m.GetListenAddresses()[0]) + require.NoError(t, err) } -// SetWriteDeadline implements net.PacketConn -func (dummyPacketConn) SetWriteDeadline(t time.Time) error { - return nil +func newPacketConn(t *testing.T) net.PacketConn { + t.Helper() + udpPort0 := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0} + c, err := net.ListenUDP("udp", udpPort0) + require.NoError(t, err) + t.Cleanup(func() { c.Close() }) + return c } -// WriteTo implements net.PacketConn -func (dummyPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - return 0, nil +func TestAccept(t *testing.T) { + c := newPacketConn(t) + defer c.Close() + m := NewUDPMux(c) + m.Start() + defer m.Close() + + ufrags := []string{"a", "b", "c", "d"} + conns := make([]net.PacketConn, len(ufrags)) + for i, ufrag := range ufrags { + conns[i] = newPacketConn(t) + setupMapping(t, ufrag, conns[i], m) + } + for i, ufrag := range ufrags { + c, err := m.Accept(context.Background()) + require.NoError(t, err) + require.Equal(t, c.Ufrag, ufrag) + require.Equal(t, c.Addr, conns[i].LocalAddr()) + } + + for i, ufrag := range ufrags { + // should not be accepted + setupMapping(t, ufrag, conns[i], m) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + _, err := m.Accept(ctx) + require.Error(t, err) + + // should not be accepted + cc := newPacketConn(t) + setupMapping(t, ufrag, cc, m) + ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + _, err = m.Accept(ctx) + require.Error(t, err) + } } -func hasConn(m *UDPMux, ufrag string, isIPv6 bool) bool { - m.mx.Lock() - _, ok := m.ufragMap[ufragConnKey{ufrag: ufrag, isIPv6: isIPv6}] - m.mx.Unlock() - return ok +func TestGetConn(t *testing.T) { + c := newPacketConn(t) + m := NewUDPMux(c) + m.Start() + defer m.Close() + + ufrags := []string{"a", "b", "c", "d"} + conns := make([]net.PacketConn, len(ufrags)) + for i, ufrag := range ufrags { + conns[i] = newPacketConn(t) + setupMapping(t, ufrag, conns[i], m) + } + for i, ufrag := range ufrags { + c, err := m.Accept(context.Background()) + require.NoError(t, err) + require.Equal(t, c.Ufrag, ufrag) + require.Equal(t, c.Addr, conns[i].LocalAddr()) + } + + for i, ufrag := range ufrags { + c, err := m.GetConn(ufrag, conns[i].LocalAddr()) + require.NoError(t, err) + msg := make([]byte, 100) + _, _, err = c.ReadFrom(msg) + require.NoError(t, err) + } + + for i, ufrag := range ufrags { + cc := newPacketConn(t) + // setupMapping of cc to ufrags[0] and remove the stun binding request from the queue + setupMapping(t, ufrag, cc, m) + mc, err := m.GetConn(ufrag, cc.LocalAddr()) + require.NoError(t, err) + msg := make([]byte, 100) + _, _, err = mc.ReadFrom(msg) + require.NoError(t, err) + + // Write from new connection should provide the new address on ReadFrom + _, err = cc.WriteTo([]byte("test1"), c.LocalAddr()) + require.NoError(t, err) + n, addr, err := mc.ReadFrom(msg) + require.NoError(t, err) + require.Equal(t, addr, cc.LocalAddr()) + require.Equal(t, string(msg[:n]), "test1") + + // Write from original connection should provide the original address + _, err = conns[i].WriteTo([]byte("test2"), c.LocalAddr()) + require.NoError(t, err) + n, addr, err = mc.ReadFrom(msg) + require.NoError(t, err) + require.Equal(t, addr, conns[i].LocalAddr()) + require.Equal(t, string(msg[:n]), "test2") + } } -var ( - addrV4 = net.UDPAddr{IP: net.IPv4zero, Port: 1234} - addrV6 = net.UDPAddr{IP: net.IPv6zero, Port: 1234} -) - -func TestUDPMux_GetConn(t *testing.T) { - m := NewUDPMux(dummyPacketConn{}) - require.False(t, hasConn(m, "test", false)) - conn, err := m.GetConn("test", &addrV4) +func TestRemoveConnByUfrag(t *testing.T) { + c := newPacketConn(t) + m := NewUDPMux(c) + m.Start() + defer m.Close() + + // Map each ufrag to two addresses + ufrag := "a" + count := 10 + conns := make([]net.PacketConn, count) + for i := 0; i < 10; i++ { + conns[i] = newPacketConn(t) + setupMapping(t, ufrag, conns[i], m) + } + mc, err := m.GetConn(ufrag, conns[0].LocalAddr()) require.NoError(t, err) - require.NotNil(t, conn) - - require.False(t, hasConn(m, "test", true)) - connv6, err := m.GetConn("test", &addrV6) - require.NoError(t, err) - require.NotNil(t, connv6) - - require.NotEqual(t, conn, connv6) -} - -func TestUDPMux_RemoveConnectionOnClose(t *testing.T) { - mux := NewUDPMux(dummyPacketConn{}) - conn, err := mux.GetConn("test", &addrV4) + for i := 0; i < 10; i++ { + mc1, err := m.GetConn(ufrag, conns[i].LocalAddr()) + require.NoError(t, err) + require.Equal(t, mc1, mc) + } + + // Now remove the ufrag + m.RemoveConnByUfrag(ufrag) + + // All connections should now be associated with b + ufrag = "b" + for i := 0; i < 10; i++ { + setupMapping(t, ufrag, conns[i], m) + } + mc, err = m.GetConn(ufrag, conns[0].LocalAddr()) require.NoError(t, err) - require.NotNil(t, conn) - - require.True(t, hasConn(mux, "test", false)) - - err = conn.Close() + for i := 0; i < 10; i++ { + mc1, err := m.GetConn(ufrag, conns[i].LocalAddr()) + require.NoError(t, err) + require.Equal(t, mc1, mc) + } + + // Should be different even if the address is the same + mc1, err := m.GetConn("a", conns[0].LocalAddr()) require.NoError(t, err) + require.NotEqual(t, mc1, mc) +} - require.False(t, hasConn(mux, "test", false)) +func TestMuxedConnection(t *testing.T) { + c := newPacketConn(t) + m := NewUDPMux(c) + m.Start() + defer m.Close() + + msgCount := 3 + connCount := 3 + + ufrags := []string{"a", "b", "c"} + var mu sync.Mutex + addrUfragMap := make(map[string]string) + for _, ufrag := range ufrags { + go func(ufrag string) { + for i := 0; i < connCount; i++ { + cc := newPacketConn(t) + mu.Lock() + addrUfragMap[cc.LocalAddr().String()] = ufrag + mu.Unlock() + setupMapping(t, ufrag, cc, m) + for j := 0; j < msgCount; j++ { + cc.WriteTo([]byte(ufrag), c.LocalAddr()) + } + } + }(ufrag) + } + + for _, ufrag := range ufrags { + mc, err := m.GetConn(ufrag, c.LocalAddr()) // the address is irrelevant + require.NoError(t, err) + for i := 0; i < connCount; i++ { + msg := make([]byte, 100) + // Read the binding request + _, addr1, err := mc.ReadFrom(msg) + require.NoError(t, err) + require.Equal(t, addrUfragMap[addr1.String()], ufrag) + // Read individual msgs + for i := 0; i < msgCount; i++ { + n, addr2, err := mc.ReadFrom(msg) + require.NoError(t, err) + require.Equal(t, addr2, addr1) + require.Equal(t, ufrag, string(msg[:n])) + } + delete(addrUfragMap, addr1.String()) + } + } + require.Equal(t, len(addrUfragMap), 0) } diff --git a/p2p/transport/webrtc/udpmux/muxed_connection.go b/p2p/transport/webrtc/udpmux/muxed_connection.go index 5d86912aa1..2af5d33253 100644 --- a/p2p/transport/webrtc/udpmux/muxed_connection.go +++ b/p2p/transport/webrtc/udpmux/muxed_connection.go @@ -9,6 +9,11 @@ import ( pool "github.com/libp2p/go-buffer-pool" ) +type packet struct { + buf []byte + addr net.Addr +} + var _ net.PacketConn = &muxedConnection{} const queueLen = 128 @@ -21,48 +26,46 @@ type muxedConnection struct { ctx context.Context cancel context.CancelFunc onClose func() - queue chan []byte - remote net.Addr + queue chan packet mux *UDPMux } var _ net.PacketConn = &muxedConnection{} -func newMuxedConnection(mux *UDPMux, onClose func(), remote net.Addr) *muxedConnection { +func newMuxedConnection(mux *UDPMux, onClose func()) *muxedConnection { ctx, cancel := context.WithCancel(mux.ctx) return &muxedConnection{ ctx: ctx, cancel: cancel, - queue: make(chan []byte, queueLen), + queue: make(chan packet, queueLen), onClose: onClose, - remote: remote, mux: mux, } } -func (c *muxedConnection) Push(buf []byte) error { +func (c *muxedConnection) Push(buf []byte, addr net.Addr) error { select { case <-c.ctx.Done(): return errors.New("closed") default: } select { - case c.queue <- buf: + case c.queue <- packet{buf: buf, addr: addr}: return nil default: return errors.New("queue full") } } -func (c *muxedConnection) ReadFrom(p []byte) (int, net.Addr, error) { +func (c *muxedConnection) ReadFrom(buf []byte) (int, net.Addr, error) { select { - case buf := <-c.queue: - n := copy(p, buf) // This might discard parts of the packet, if p is too short - if n < len(buf) { - log.Debugf("short read, had %d, read %d", len(buf), n) + case p := <-c.queue: + n := copy(buf, p.buf) // This might discard parts of the packet, if p is too short + if n < len(p.buf) { + log.Debugf("short read, had %d, read %d", len(p.buf), n) } - pool.Put(buf) - return n, c.remote, nil + pool.Put(p.buf) + return n, p.addr, nil case <-c.ctx.Done(): return 0, nil, c.ctx.Err() } @@ -83,15 +86,15 @@ func (c *muxedConnection) Close() error { // drain the packet queue for { select { - case <-c.queue: + case p := <-c.queue: + pool.Put(p.buf) default: return nil } } } -func (c *muxedConnection) LocalAddr() net.Addr { return c.mux.socket.LocalAddr() } -func (c *muxedConnection) RemoteAddr() net.Addr { return c.remote } +func (c *muxedConnection) LocalAddr() net.Addr { return c.mux.socket.LocalAddr() } func (*muxedConnection) SetDeadline(t time.Time) error { // no deadline is desired here