diff --git a/agent.go b/agent.go index e9c4a5c4..69a0e4e3 100644 --- a/agent.go +++ b/agent.go @@ -120,8 +120,8 @@ type Agent struct { loggerFactory logging.LoggerFactory log logging.LeveledLogger - net *vnet.Net - tcp *tcpIPMux + net *vnet.Net + tcpMux *TCPMux interfaceFilter func(string) bool @@ -306,11 +306,7 @@ func NewAgent(config *AgentConfig) (*Agent, error) { insecureSkipVerify: config.InsecureSkipVerify, } - a.tcp = newTCPIPMux(tcpIPMuxParams{ - ListenPort: config.TCPListenPort, - Logger: log, - ReadBufferSize: 8, - }) + a.tcpMux = config.TCPMux if a.net == nil { a.net = vnet.NewNet(nil) @@ -887,7 +883,11 @@ func (a *Agent) Close() error { a.gatherCandidateCancel() a.err.Store(ErrClosed) - a.tcp.RemoveUfrag(a.localUfrag) + + if a.tcpMux != nil { + a.tcpMux.RemoveConnByUfrag(a.localUfrag) + } + close(a.done) <-done diff --git a/agent_config.go b/agent_config.go index 25c38f04..6500c59a 100644 --- a/agent_config.go +++ b/agent_config.go @@ -139,10 +139,10 @@ type AgentConfig struct { // to TURN servers via TLS or DTLS InsecureSkipVerify bool - // TCPListenPort will be used to start a TCP listener on all allowed interfaces for - // ICE TCP. Currently only passive candidates are supported. This functionality is - // experimental and this API will likely change in the future. - TCPListenPort int + // TCPMux will be used for multiplexing incoming TCP connections for ICE TCP. + // Currently only passive candidates are supported. This functionality is + // experimental and the API might change in the future. + TCPMux *TCPMux } // initWithDefaults populates an agent and falls back to defaults if fields are unset diff --git a/gather.go b/gather.go index 1ded5a2c..1a2c0a11 100644 --- a/gather.go +++ b/gather.go @@ -161,28 +161,23 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ var tcpType TCPType switch network { case tcp: - if a.tcp == nil { + if a.tcpMux == nil { continue } - // below is for passive mode // TODO active mode // TODO S-O mode - mux, muxErr := a.tcp.Listen(ip) - if muxErr != nil { - a.log.Warnf("could not listen %s %s\n", network, ip) - continue - } - a.log.Debugf("GetConn by ufrag: %s\n", a.localUfrag) - conn, err = mux.GetConn(a.localUfrag) + conn, err = a.tcpMux.GetConnByUfrag(a.localUfrag) if err != nil { a.log.Warnf("error getting tcp conn by ufrag: %s %s\n", network, ip, a.localUfrag) continue } port = conn.LocalAddr().(*net.TCPAddr).Port tcpType = TCPTypePassive + // TODO is there a way to verify that the listen address is even + // accessible from the current interface. case udp: conn, err = listenUDPInPortRange(a.net, a.log, int(a.portmax), int(a.portmin), network, &net.UDPAddr{IP: ip, Port: 0}) if err != nil { diff --git a/gather_test.go b/gather_test.go index 9d60992e..99a035ff 100644 --- a/gather_test.go +++ b/gather_test.go @@ -14,9 +14,11 @@ import ( "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/pkg/crypto/selfsign" + "github.com/pion/logging" "github.com/pion/transport/test" "github.com/pion/turn/v2" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestListenUDP(t *testing.T) { @@ -116,11 +118,25 @@ func TestSTUNConcurrency(t *testing.T) { Port: serverPort, }) + listener, err := net.ListenTCP("tcp", &net.TCPAddr{ + IP: net.IP{127, 0, 0, 1}, + }) + require.NoError(t, err) + defer func() { + _ = listener.Close() + }() + a, err := NewAgent(&AgentConfig{ NetworkTypes: supportedNetworkTypes, Urls: urls, CandidateTypes: []CandidateType{CandidateTypeHost, CandidateTypeServerReflexive}, - TCPListenPort: 9999, + TCPMux: NewTCPMux( + TCPMuxParams{ + Listener: listener, + Logger: logging.NewDefaultLoggerFactory().NewLogger("ice"), + ReadBufferSize: 8, + }, + ), }) assert.NoError(t, err) diff --git a/go.sum b/go.sum index 6a41bd3f..1c9efd8e 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,7 @@ github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/pion/dtls/v2 v2.0.2 h1:FHCHTiM182Y8e15aFTiORroiATUI16ryHiQh8AIOJ1E= github.com/pion/dtls/v2 v2.0.2/go.mod h1:27PEO3MDdaCfo21heT59/vsdmZc0zMt9wQPcSlLu/1I= +github.com/pion/ice v0.7.17 h1:0dD2RASsDY/28idIKcVZkZgGctORIW5wGJJx93lyD3s= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/mdns v0.0.4 h1:O4vvVqr4DGX63vzmO6Fw9vpy3lfztVWHGCQfyw0ZLSY= diff --git a/tcp_ip_mux.go b/tcp_ip_mux.go deleted file mode 100644 index f3c2d597..00000000 --- a/tcp_ip_mux.go +++ /dev/null @@ -1,102 +0,0 @@ -package ice - -import ( - "net" - "strconv" - "sync" - - "github.com/pion/logging" -) - -// tcpMuxes is a map of local addr listeners to tcpMux -var tcpMuxes map[string]*tcpMux -var tcpMuxesMu sync.Mutex - -type tcpIPMux struct { - params *tcpIPMuxParams - wg sync.WaitGroup -} - -type tcpIPMuxParams struct { - ListenPort int - ReadBufferSize int - Logger logging.LeveledLogger -} - -func newTCPIPMux(params tcpIPMuxParams) *tcpIPMux { - m := &tcpIPMux{ - params: ¶ms, - } - - tcpMuxesMu.Lock() - - if tcpMuxes == nil { - tcpMuxes = map[string]*tcpMux{} - } - - tcpMuxesMu.Unlock() - - return m -} - -func (m *tcpIPMux) Remove(key string) { - tcpMuxesMu.Lock() - defer tcpMuxesMu.Unlock() - - if tcpMux, ok := tcpMuxes[key]; ok { - err := tcpMux.Close() - if err != nil { - m.params.Logger.Errorf("Error closing tcpMux for key: %s: %s", key, err) - } - delete(tcpMuxes, key) - } -} - -func (m *tcpIPMux) RemoveUfrag(ufrag string) { - tcpMuxesMu.Lock() - defer tcpMuxesMu.Unlock() - - for _, tcpMux := range tcpMuxes { - tcpMux.RemoveConn(ufrag) - } -} - -func (m *tcpIPMux) Listen(ip net.IP) (*tcpMux, error) { - tcpMuxesMu.Lock() - defer tcpMuxesMu.Unlock() - - key := net.JoinHostPort(ip.String(), strconv.Itoa(m.params.ListenPort)) - - tcpMux, ok := tcpMuxes[key] - if ok { - return tcpMux, nil - } - - listener, err := net.ListenTCP("tcp", &net.TCPAddr{ - IP: ip, - Port: m.params.ListenPort, - }) - - if err != nil { - return nil, err - } - - key = net.JoinHostPort(ip.String(), strconv.Itoa(listener.Addr().(*net.TCPAddr).Port)) - - tcpMux = newTCPMux(tcpMuxParams{ - Listener: listener, - Logger: m.params.Logger, - ReadBufferSize: m.params.ReadBufferSize, - }) - - tcpMuxes[key] = tcpMux - - m.wg.Add(1) - go func() { - defer m.wg.Done() - <-tcpMux.CloseChannel() - m.Remove(key) - }() - - return tcpMux, nil -} diff --git a/tcp_mux.go b/tcp_mux.go index 5892b0da..dcbf0e0e 100644 --- a/tcp_mux.go +++ b/tcp_mux.go @@ -11,31 +11,31 @@ import ( "github.com/pion/stun" ) -type tcpMux struct { - params *tcpMuxParams +type TCPMux struct { + params *TCPMuxParams // conns is a map of all tcpPacketConns indexed by ufrag conns map[string]*tcpPacketConn - mu sync.Mutex - wg sync.WaitGroup - closedChan chan struct{} - closeOnce sync.Once + mu sync.Mutex + wg sync.WaitGroup } -type tcpMuxParams struct { +type TCPMuxParams struct { Listener net.Listener Logger logging.LeveledLogger ReadBufferSize int } -func newTCPMux(params tcpMuxParams) *tcpMux { - m := &tcpMux{ +func NewTCPMux(params TCPMuxParams) *TCPMux { + if params.Logger == nil { + params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") + } + + m := &TCPMux{ params: ¶ms, conns: map[string]*tcpPacketConn{}, - - closedChan: make(chan struct{}), } m.wg.Add(1) @@ -47,7 +47,7 @@ func newTCPMux(params tcpMuxParams) *tcpMux { return m } -func (m *tcpMux) start() { +func (m *TCPMux) start() { m.params.Logger.Infof("Listening TCP on %s\n", m.params.Listener.Addr()) for { conn, err := m.params.Listener.Accept() @@ -66,11 +66,11 @@ func (m *tcpMux) start() { } } -func (m *tcpMux) LocalAddr() net.Addr { +func (m *TCPMux) LocalAddr() net.Addr { return m.params.Listener.Addr() } -func (m *tcpMux) GetConn(ufrag string) (net.PacketConn, error) { +func (m *TCPMux) GetConnByUfrag(ufrag string) (net.PacketConn, error) { m.mu.Lock() defer m.mu.Unlock() @@ -86,7 +86,7 @@ func (m *tcpMux) GetConn(ufrag string) (net.PacketConn, error) { return conn, nil } -func (m *tcpMux) createConn(ufrag string, localAddr net.Addr) *tcpPacketConn { +func (m *TCPMux) createConn(ufrag string, localAddr net.Addr) *tcpPacketConn { conn := newTCPPacketConn(tcpPacketParams{ ReadBuffer: m.params.ReadBufferSize, LocalAddr: localAddr, @@ -98,20 +98,20 @@ func (m *tcpMux) createConn(ufrag string, localAddr net.Addr) *tcpPacketConn { go func() { defer m.wg.Done() <-conn.CloseChannel() - m.RemoveConn(ufrag) + m.RemoveConnByUfrag(ufrag) }() return conn } -func (m *tcpMux) closeAndLogError(closer io.Closer) { +func (m *TCPMux) closeAndLogError(closer io.Closer) { err := closer.Close() if err != nil { m.params.Logger.Warnf("Error closing connection: %s", err) } } -func (m *tcpMux) handleConn(conn net.Conn) { +func (m *TCPMux) handleConn(conn net.Conn) { buf := make([]byte, receiveMTU) n, err := readStreamingPacket(conn, buf) @@ -169,13 +169,9 @@ func (m *tcpMux) handleConn(conn net.Conn) { } } -func (m *tcpMux) Close() error { +func (m *TCPMux) Close() error { m.mu.Lock() - m.closeOnce.Do(func() { - close(m.closedChan) - }) - m.conns = map[string]*tcpPacketConn{} m.mu.Unlock() @@ -186,11 +182,7 @@ func (m *tcpMux) Close() error { return err } -func (m *tcpMux) CloseChannel() <-chan struct{} { - return m.closedChan -} - -func (m *tcpMux) RemoveConn(ufrag string) { +func (m *TCPMux) RemoveConnByUfrag(ufrag string) { m.mu.Lock() defer m.mu.Unlock() @@ -200,10 +192,6 @@ func (m *tcpMux) RemoveConn(ufrag string) { } if len(m.conns) == 0 { - m.closeOnce.Do(func() { - close(m.closedChan) - }) - m.closeAndLogError(m.params.Listener) } } diff --git a/tcp_ip_mux_test.go b/tcp_mux_test.go similarity index 82% rename from tcp_ip_mux_test.go rename to tcp_mux_test.go index 1007aaed..0d29fb8a 100644 --- a/tcp_ip_mux_test.go +++ b/tcp_mux_test.go @@ -11,20 +11,27 @@ import ( "github.com/stretchr/testify/require" ) -func TestTCP_Recv(t *testing.T) { +func TestTCPMux_Recv(t *testing.T) { report := test.CheckRoutines(t) defer report() loggerFactory := logging.NewDefaultLoggerFactory() - tim := newTCPIPMux(tcpIPMuxParams{ - ListenPort: 8080, + listener, err := net.ListenTCP("tcp", &net.TCPAddr{ + IP: net.IP{127, 0, 0, 1}, + Port: 0, + }) + require.NoError(t, err, "error starting listener") + defer func() { + _ = listener.Close() + }() + + tcpMux := NewTCPMux(TCPMuxParams{ + Listener: listener, Logger: loggerFactory.NewLogger("ice"), ReadBufferSize: 20, }) - tcpMux, err := tim.Listen(net.IP{127, 0, 0, 1}) - require.NoError(t, err, "error starting listener") defer func() { _ = tcpMux.Close() }() @@ -42,7 +49,7 @@ func TestTCP_Recv(t *testing.T) { n, err := writeStreamingPacket(conn, msg.Raw) require.NoError(t, err, "error writing tcp stun packet") - pktConn, err := tcpMux.GetConn("myufrag") + pktConn, err := tcpMux.GetConnByUfrag("myufrag") require.NoError(t, err, "error retrieving muxed connection for ufrag") defer func() { _ = pktConn.Close() diff --git a/tcp_packet_conn.go b/tcp_packet_conn.go index ebdfe501..99a462da 100644 --- a/tcp_packet_conn.go +++ b/tcp_packet_conn.go @@ -62,17 +62,16 @@ func (t *tcpPacketConn) AddConn(conn net.Conn, firstPacketData []byte) error { } if _, ok := t.conns[conn.RemoteAddr().String()]; ok { - return ErrTCPRemoteAddrAlreadyExists + return fmt.Errorf("connection with same remote address already exists: %s", conn.RemoteAddr().String()) } t.conns[conn.RemoteAddr().String()] = conn - if firstPacketData != nil { - t.recvChan <- streamingPacket{firstPacketData, conn.RemoteAddr(), nil} - } - t.wg.Add(1) go func() { + if firstPacketData != nil { + t.recvChan <- streamingPacket{firstPacketData, conn.RemoteAddr(), nil} + } defer t.wg.Done() t.startReading(conn) }()