diff --git a/agent.go b/agent.go index e9c4a5c4..26ed93fc 100644 --- a/agent.go +++ b/agent.go @@ -120,8 +120,8 @@ type Agent struct { loggerFactory logging.LoggerFactory log logging.LeveledLogger - net *vnet.Net - tcp *tcpIPMux + net *vnet.Net + tcpMux TCPMux interfaceFilter func(string) bool @@ -306,11 +306,10 @@ func NewAgent(config *AgentConfig) (*Agent, error) { insecureSkipVerify: config.InsecureSkipVerify, } - a.tcp = newTCPIPMux(tcpIPMuxParams{ - ListenPort: config.TCPListenPort, - Logger: log, - ReadBufferSize: 8, - }) + a.tcpMux = config.TCPMux + if a.tcpMux == nil { + a.tcpMux = newInvalidTCPMux() + } if a.net == nil { a.net = vnet.NewNet(nil) @@ -887,7 +886,9 @@ func (a *Agent) Close() error { a.gatherCandidateCancel() a.err.Store(ErrClosed) - a.tcp.RemoveUfrag(a.localUfrag) + + a.tcpMux.RemoveConnByUfrag(a.localUfrag) + close(a.done) <-done diff --git a/agent_config.go b/agent_config.go index 25c38f04..321907d0 100644 --- a/agent_config.go +++ b/agent_config.go @@ -139,10 +139,10 @@ type AgentConfig struct { // to TURN servers via TLS or DTLS InsecureSkipVerify bool - // TCPListenPort will be used to start a TCP listener on all allowed interfaces for - // ICE TCP. Currently only passive candidates are supported. This functionality is - // experimental and this API will likely change in the future. - TCPListenPort int + // TCPMux will be used for multiplexing incoming TCP connections for ICE TCP. + // Currently only passive candidates are supported. This functionality is + // experimental and the API might change in the future. + TCPMux TCPMux } // initWithDefaults populates an agent and falls back to defaults if fields are unset diff --git a/errors.go b/errors.go index 641b5cf1..55043280 100644 --- a/errors.go +++ b/errors.go @@ -103,6 +103,9 @@ var ( // ErrRunCanceled indicates a run operation was canceled by its individual done ErrRunCanceled = errors.New("run was canceled by done") + // ErrTCPMuxNotInitialized indicates TCPMux is not initialized and that invalidTCPMux is used. + ErrTCPMuxNotInitialized = errors.New("TCPMux is not initialized") + // ErrTCPRemoteAddrAlreadyExists indicates we already have the connection with same remote addr. ErrTCPRemoteAddrAlreadyExists = errors.New("conn with same remote addr already exists") ) diff --git a/gather.go b/gather.go index 1ded5a2c..7d6a7ba1 100644 --- a/gather.go +++ b/gather.go @@ -161,28 +161,22 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ var tcpType TCPType switch network { case tcp: - if a.tcp == nil { - continue - } - - // below is for passive mode + // Handle ICE TCP passive mode // TODO active mode // TODO S-O mode - mux, muxErr := a.tcp.Listen(ip) - if muxErr != nil { - a.log.Warnf("could not listen %s %s\n", network, ip) - continue - } - a.log.Debugf("GetConn by ufrag: %s\n", a.localUfrag) - conn, err = mux.GetConn(a.localUfrag) + conn, err = a.tcpMux.GetConnByUfrag(a.localUfrag) if err != nil { - a.log.Warnf("error getting tcp conn by ufrag: %s %s\n", network, ip, a.localUfrag) + if err != ErrTCPMuxNotInitialized { + a.log.Warnf("error getting tcp conn by ufrag: %s %s\n", network, ip, a.localUfrag) + } continue } port = conn.LocalAddr().(*net.TCPAddr).Port tcpType = TCPTypePassive + // TODO is there a way to verify that the listen address is even + // accessible from the current interface. case udp: conn, err = listenUDPInPortRange(a.net, a.log, int(a.portmax), int(a.portmin), network, &net.UDPAddr{IP: ip, Port: 0}) if err != nil { diff --git a/gather_test.go b/gather_test.go index 9d60992e..d1d939de 100644 --- a/gather_test.go +++ b/gather_test.go @@ -14,9 +14,11 @@ import ( "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/pkg/crypto/selfsign" + "github.com/pion/logging" "github.com/pion/transport/test" "github.com/pion/turn/v2" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestListenUDP(t *testing.T) { @@ -116,11 +118,25 @@ func TestSTUNConcurrency(t *testing.T) { Port: serverPort, }) + listener, err := net.ListenTCP("tcp", &net.TCPAddr{ + IP: net.IP{127, 0, 0, 1}, + }) + require.NoError(t, err) + defer func() { + _ = listener.Close() + }() + a, err := NewAgent(&AgentConfig{ NetworkTypes: supportedNetworkTypes, Urls: urls, CandidateTypes: []CandidateType{CandidateTypeHost, CandidateTypeServerReflexive}, - TCPListenPort: 9999, + TCPMux: NewTCPMuxDefault( + TCPMuxParams{ + Listener: listener, + Logger: logging.NewDefaultLoggerFactory().NewLogger("ice"), + ReadBufferSize: 8, + }, + ), }) assert.NoError(t, err) diff --git a/tcp_ip_mux.go b/tcp_ip_mux.go deleted file mode 100644 index f3c2d597..00000000 --- a/tcp_ip_mux.go +++ /dev/null @@ -1,102 +0,0 @@ -package ice - -import ( - "net" - "strconv" - "sync" - - "github.com/pion/logging" -) - -// tcpMuxes is a map of local addr listeners to tcpMux -var tcpMuxes map[string]*tcpMux -var tcpMuxesMu sync.Mutex - -type tcpIPMux struct { - params *tcpIPMuxParams - wg sync.WaitGroup -} - -type tcpIPMuxParams struct { - ListenPort int - ReadBufferSize int - Logger logging.LeveledLogger -} - -func newTCPIPMux(params tcpIPMuxParams) *tcpIPMux { - m := &tcpIPMux{ - params: ¶ms, - } - - tcpMuxesMu.Lock() - - if tcpMuxes == nil { - tcpMuxes = map[string]*tcpMux{} - } - - tcpMuxesMu.Unlock() - - return m -} - -func (m *tcpIPMux) Remove(key string) { - tcpMuxesMu.Lock() - defer tcpMuxesMu.Unlock() - - if tcpMux, ok := tcpMuxes[key]; ok { - err := tcpMux.Close() - if err != nil { - m.params.Logger.Errorf("Error closing tcpMux for key: %s: %s", key, err) - } - delete(tcpMuxes, key) - } -} - -func (m *tcpIPMux) RemoveUfrag(ufrag string) { - tcpMuxesMu.Lock() - defer tcpMuxesMu.Unlock() - - for _, tcpMux := range tcpMuxes { - tcpMux.RemoveConn(ufrag) - } -} - -func (m *tcpIPMux) Listen(ip net.IP) (*tcpMux, error) { - tcpMuxesMu.Lock() - defer tcpMuxesMu.Unlock() - - key := net.JoinHostPort(ip.String(), strconv.Itoa(m.params.ListenPort)) - - tcpMux, ok := tcpMuxes[key] - if ok { - return tcpMux, nil - } - - listener, err := net.ListenTCP("tcp", &net.TCPAddr{ - IP: ip, - Port: m.params.ListenPort, - }) - - if err != nil { - return nil, err - } - - key = net.JoinHostPort(ip.String(), strconv.Itoa(listener.Addr().(*net.TCPAddr).Port)) - - tcpMux = newTCPMux(tcpMuxParams{ - Listener: listener, - Logger: m.params.Logger, - ReadBufferSize: m.params.ReadBufferSize, - }) - - tcpMuxes[key] = tcpMux - - m.wg.Add(1) - go func() { - defer m.wg.Done() - <-tcpMux.CloseChannel() - m.Remove(key) - }() - - return tcpMux, nil -} diff --git a/tcp_mux.go b/tcp_mux.go index 5892b0da..5bad6174 100644 --- a/tcp_mux.go +++ b/tcp_mux.go @@ -11,31 +11,69 @@ import ( "github.com/pion/stun" ) -type tcpMux struct { - params *tcpMuxParams +// TCPMux is allows grouping multiple TCP net.Conns and using them like UDP +// net.PacketConns. The main implementation of this is TCPMuxDefault, and this +// interface exists to: +// 1. prevent SEGV panics when TCPMuxDefault is not initialized by using the +// invalidTCPMux implementation, and +// 2. allow mocking in tests. +type TCPMux interface { + io.Closer + GetConnByUfrag(ufrag string) (net.PacketConn, error) + RemoveConnByUfrag(ufrag string) +} + +// invalidTCPMux is an implementation of TCPMux that always returns ErroTCPMuxNotInitialized. +type invalidTCPMux struct { +} + +func newInvalidTCPMux() *invalidTCPMux { + return &invalidTCPMux{} +} + +// Close implements TCPMux interface. +func (m *invalidTCPMux) Close() error { + return ErrTCPMuxNotInitialized +} + +// GetConnByUfrag implements TCPMux interface. +func (m *invalidTCPMux) GetConnByUfrag(ufrag string) (net.PacketConn, error) { + return nil, ErrTCPMuxNotInitialized +} + +// RemoveConnByUfrag implements TCPMux interface. +func (m *invalidTCPMux) RemoveConnByUfrag(ufrag string) {} + +// TCPMuxDefault muxes TCP net.Conns into net.PacketConns and groups them by +// Ufrag. It is a default implementation of TCPMux interface. +type TCPMuxDefault struct { + params *TCPMuxParams + closed bool // conns is a map of all tcpPacketConns indexed by ufrag conns map[string]*tcpPacketConn - mu sync.Mutex - wg sync.WaitGroup - closedChan chan struct{} - closeOnce sync.Once + mu sync.Mutex + wg sync.WaitGroup } -type tcpMuxParams struct { +// TCPMuxParams are parameters for TCPMux. +type TCPMuxParams struct { Listener net.Listener Logger logging.LeveledLogger ReadBufferSize int } -func newTCPMux(params tcpMuxParams) *tcpMux { - m := &tcpMux{ +// NewTCPMuxDefault creates a new instance of TCPMuxDefault. +func NewTCPMuxDefault(params TCPMuxParams) *TCPMuxDefault { + if params.Logger == nil { + params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") + } + + m := &TCPMuxDefault{ params: ¶ms, conns: map[string]*tcpPacketConn{}, - - closedChan: make(chan struct{}), } m.wg.Add(1) @@ -47,7 +85,7 @@ func newTCPMux(params tcpMuxParams) *tcpMux { return m } -func (m *tcpMux) start() { +func (m *TCPMuxDefault) start() { m.params.Logger.Infof("Listening TCP on %s\n", m.params.Listener.Addr()) for { conn, err := m.params.Listener.Accept() @@ -66,14 +104,20 @@ func (m *tcpMux) start() { } } -func (m *tcpMux) LocalAddr() net.Addr { +// LocalAddr returns the listening address of this TCPMuxDefault. +func (m *TCPMuxDefault) LocalAddr() net.Addr { return m.params.Listener.Addr() } -func (m *tcpMux) GetConn(ufrag string) (net.PacketConn, error) { +// GetConnByUfrag retrieves an existing or creates a new net.PacketConn. +func (m *TCPMuxDefault) GetConnByUfrag(ufrag string) (net.PacketConn, error) { m.mu.Lock() defer m.mu.Unlock() + if m.closed { + return nil, io.ErrClosedPipe + } + conn, ok := m.conns[ufrag] if ok { @@ -86,7 +130,7 @@ func (m *tcpMux) GetConn(ufrag string) (net.PacketConn, error) { return conn, nil } -func (m *tcpMux) createConn(ufrag string, localAddr net.Addr) *tcpPacketConn { +func (m *TCPMuxDefault) createConn(ufrag string, localAddr net.Addr) *tcpPacketConn { conn := newTCPPacketConn(tcpPacketParams{ ReadBuffer: m.params.ReadBufferSize, LocalAddr: localAddr, @@ -98,20 +142,20 @@ func (m *tcpMux) createConn(ufrag string, localAddr net.Addr) *tcpPacketConn { go func() { defer m.wg.Done() <-conn.CloseChannel() - m.RemoveConn(ufrag) + m.RemoveConnByUfrag(ufrag) }() return conn } -func (m *tcpMux) closeAndLogError(closer io.Closer) { +func (m *TCPMuxDefault) closeAndLogError(closer io.Closer) { err := closer.Close() if err != nil { m.params.Logger.Warnf("Error closing connection: %s", err) } } -func (m *tcpMux) handleConn(conn net.Conn) { +func (m *TCPMuxDefault) handleConn(conn net.Conn) { buf := make([]byte, receiveMTU) n, err := readStreamingPacket(conn, buf) @@ -169,28 +213,27 @@ func (m *tcpMux) handleConn(conn net.Conn) { } } -func (m *tcpMux) Close() error { +// Close closes the listener and waits for all goroutines to exit. +func (m *TCPMuxDefault) Close() error { m.mu.Lock() + m.closed = true - m.closeOnce.Do(func() { - close(m.closedChan) - }) - + for _, conn := range m.conns { + m.closeAndLogError(conn) + } m.conns = map[string]*tcpPacketConn{} - m.mu.Unlock() err := m.params.Listener.Close() + m.mu.Unlock() + m.wg.Wait() return err } -func (m *tcpMux) CloseChannel() <-chan struct{} { - return m.closedChan -} - -func (m *tcpMux) RemoveConn(ufrag string) { +// RemoveConnByUfrag closes and removes a net.PacketConn by Ufrag. +func (m *TCPMuxDefault) RemoveConnByUfrag(ufrag string) { m.mu.Lock() defer m.mu.Unlock() @@ -198,14 +241,6 @@ func (m *tcpMux) RemoveConn(ufrag string) { m.closeAndLogError(conn) delete(m.conns, ufrag) } - - if len(m.conns) == 0 { - m.closeOnce.Do(func() { - close(m.closedChan) - }) - - m.closeAndLogError(m.params.Listener) - } } const streamingPacketHeaderLen = 2 diff --git a/tcp_ip_mux_test.go b/tcp_mux_test.go similarity index 52% rename from tcp_ip_mux_test.go rename to tcp_mux_test.go index 1007aaed..6d2f0c52 100644 --- a/tcp_ip_mux_test.go +++ b/tcp_mux_test.go @@ -1,6 +1,7 @@ package ice import ( + "io" "net" "testing" @@ -11,20 +12,30 @@ import ( "github.com/stretchr/testify/require" ) -func TestTCP_Recv(t *testing.T) { +var _ TCPMux = &TCPMuxDefault{} +var _ TCPMux = &invalidTCPMux{} + +func TestTCPMux_Recv(t *testing.T) { report := test.CheckRoutines(t) defer report() loggerFactory := logging.NewDefaultLoggerFactory() - tim := newTCPIPMux(tcpIPMuxParams{ - ListenPort: 8080, + listener, err := net.ListenTCP("tcp", &net.TCPAddr{ + IP: net.IP{127, 0, 0, 1}, + Port: 0, + }) + require.NoError(t, err, "error starting listener") + defer func() { + _ = listener.Close() + }() + + tcpMux := NewTCPMuxDefault(TCPMuxParams{ + Listener: listener, Logger: loggerFactory.NewLogger("ice"), ReadBufferSize: 20, }) - tcpMux, err := tim.Listen(net.IP{127, 0, 0, 1}) - require.NoError(t, err, "error starting listener") defer func() { _ = tcpMux.Close() }() @@ -42,7 +53,7 @@ func TestTCP_Recv(t *testing.T) { n, err := writeStreamingPacket(conn, msg.Raw) require.NoError(t, err, "error writing tcp stun packet") - pktConn, err := tcpMux.GetConn("myufrag") + pktConn, err := tcpMux.GetConnByUfrag("myufrag") require.NoError(t, err, "error retrieving muxed connection for ufrag") defer func() { _ = pktConn.Close() @@ -55,3 +66,34 @@ func TestTCP_Recv(t *testing.T) { assert.Equal(t, n, n2, "received byte size mismatch") assert.Equal(t, msg.Raw, recv, "received bytes mismatch") } + +func TestTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) { + report := test.CheckRoutines(t) + defer report() + + loggerFactory := logging.NewDefaultLoggerFactory() + + listener, err := net.ListenTCP("tcp", &net.TCPAddr{ + IP: net.IP{127, 0, 0, 1}, + Port: 0, + }) + require.NoError(t, err, "error starting listener") + defer func() { + _ = listener.Close() + }() + + tcpMux := NewTCPMuxDefault(TCPMuxParams{ + Listener: listener, + Logger: loggerFactory.NewLogger("ice"), + ReadBufferSize: 20, + }) + + _, err = tcpMux.GetConnByUfrag("test") + require.NoError(t, err, "error getting conn by ufrag") + + require.NoError(t, tcpMux.Close(), "error closing tcpMux") + + conn, err := tcpMux.GetConnByUfrag("test") + assert.Nil(t, conn, "should receive nil because mux is closed") + assert.Equal(t, io.ErrClosedPipe, err, "should receive error because mux is closed") +} diff --git a/tcp_packet_conn.go b/tcp_packet_conn.go index ebdfe501..99a462da 100644 --- a/tcp_packet_conn.go +++ b/tcp_packet_conn.go @@ -62,17 +62,16 @@ func (t *tcpPacketConn) AddConn(conn net.Conn, firstPacketData []byte) error { } if _, ok := t.conns[conn.RemoteAddr().String()]; ok { - return ErrTCPRemoteAddrAlreadyExists + return fmt.Errorf("connection with same remote address already exists: %s", conn.RemoteAddr().String()) } t.conns[conn.RemoteAddr().String()] = conn - if firstPacketData != nil { - t.recvChan <- streamingPacket{firstPacketData, conn.RemoteAddr(), nil} - } - t.wg.Add(1) go func() { + if firstPacketData != nil { + t.recvChan <- streamingPacket{firstPacketData, conn.RemoteAddr(), nil} + } defer t.wg.Done() t.startReading(conn) }()