From daba2c5cff01bb3c883f4814f03a29f1805dc686 Mon Sep 17 00:00:00 2001 From: Michiel De Backker Date: Sat, 20 Jun 2020 21:47:16 +0200 Subject: [PATCH] Break out internal/net/udp Breaking out the UDP wrapper to make it re- usable in pion/sctp. Relates to pion/sctp#74 --- go.mod | 1 + go.sum | 2 + internal/net/udp/conn.go | 296 ----------------------- internal/net/udp/conn_test.go | 435 ---------------------------------- listener.go | 2 +- 5 files changed, 4 insertions(+), 732 deletions(-) delete mode 100644 internal/net/udp/conn.go delete mode 100644 internal/net/udp/conn_test.go diff --git a/go.mod b/go.mod index a4e6153ac..dade6d0e7 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/pion/dtls/v2 require ( github.com/pion/logging v0.2.2 github.com/pion/transport v0.10.0 + github.com/pion/udp v0.1.0 golang.org/x/crypto v0.0.0-20200602180216-279210d13fed golang.org/x/net v0.0.0-20200602114024-627f9648deb9 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 diff --git a/go.sum b/go.sum index 76d8a7172..7a3025615 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/transport v0.10.0 h1:9M12BSneJm6ggGhJyWpDveFOstJsTiQjkLf4M44rm80= github.com/pion/transport v0.10.0/go.mod h1:BnHnUipd0rZQyTVB2SBGojFHT9CBt5C5TcsJSQGkvSE= +github.com/pion/udp v0.1.0 h1:uGxQsNyrqG3GLINv36Ff60covYmfrLoxzwnCsIYspXI= +github.com/pion/udp v0.1.0/go.mod h1:BPELIjbwE9PRbd/zxI/KYBnbo7B6+oA6YuEaNE8lths= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/internal/net/udp/conn.go b/internal/net/udp/conn.go deleted file mode 100644 index 9f680e83c..000000000 --- a/internal/net/udp/conn.go +++ /dev/null @@ -1,296 +0,0 @@ -// Package udp provides a connection-oriented listener over a UDP PacketConn -package udp - -import ( - "context" - "errors" - "net" - "sync" - "sync/atomic" - "time" - - "github.com/pion/transport/deadline" - "github.com/pion/transport/packetio" -) - -const receiveMTU = 8192 -const defaultListenBacklog = 128 // same as Linux default - -var errClosedListener = errors.New("udp: listener closed") -var errListenQueueExceeded = errors.New("udp: listen queue exceeded") - -// listener augments a connection-oriented Listener over a UDP PacketConn -type listener struct { - pConn *net.UDPConn - - accepting atomic.Value // bool - acceptCh chan *Conn - doneCh chan struct{} - doneOnce sync.Once - acceptFilter func([]byte) bool - - connLock sync.Mutex - conns map[string]*Conn - connWG sync.WaitGroup - - readWG sync.WaitGroup - errClose atomic.Value // error -} - -// Accept waits for and returns the next connection to the listener. -func (l *listener) Accept() (net.Conn, error) { - select { - case c := <-l.acceptCh: - l.connWG.Add(1) - return c, nil - - case <-l.doneCh: - return nil, errClosedListener - } -} - -// Close closes the listener. -// Any blocked Accept operations will be unblocked and return errors. -func (l *listener) Close() error { - var err error - l.doneOnce.Do(func() { - l.accepting.Store(false) - close(l.doneCh) - - l.connLock.Lock() - // Close unaccepted connections - L_CLOSE: - for { - select { - case c := <-l.acceptCh: - close(c.doneCh) - delete(l.conns, c.rAddr.String()) - - default: - break L_CLOSE - } - } - nConns := len(l.conns) - l.connLock.Unlock() - - l.connWG.Done() - - if nConns == 0 { - // Wait if this is the final connection - l.readWG.Wait() - if errClose, ok := l.errClose.Load().(error); ok { - err = errClose - } - } else { - err = nil - } - }) - - return err -} - -// Addr returns the listener's network address. -func (l *listener) Addr() net.Addr { - return l.pConn.LocalAddr() -} - -// ListenConfig stores options for listening to an address. -type ListenConfig struct { - // Backlog defines the maximum length of the queue of pending - // connections. It is equivalent of the backlog argument of - // POSIX listen function. - // If a connection request arrives when the queue is full, - // the request will be silently discarded, unlike TCP. - // Set zero to use default value 128 which is same as Linux default. - Backlog int - - // AcceptFilter determines whether the new conn should be made for - // the incoming packet. If not set, any packet creates new conn. - AcceptFilter func([]byte) bool -} - -// Listen creates a new listener based on the ListenConfig. -func (lc *ListenConfig) Listen(network string, laddr *net.UDPAddr) (net.Listener, error) { - if lc.Backlog == 0 { - lc.Backlog = defaultListenBacklog - } - - conn, err := net.ListenUDP(network, laddr) - if err != nil { - return nil, err - } - - l := &listener{ - pConn: conn, - acceptCh: make(chan *Conn, lc.Backlog), - conns: make(map[string]*Conn), - doneCh: make(chan struct{}), - acceptFilter: lc.AcceptFilter, - } - l.accepting.Store(true) - l.connWG.Add(1) - l.readWG.Add(2) // wait readLoop and Close execution routine - - go l.readLoop() - go func() { - l.connWG.Wait() - if err := l.pConn.Close(); err != nil { - l.errClose.Store(err) - } - l.readWG.Done() - }() - - return l, nil -} - -// Listen creates a new listener using default ListenConfig. -func Listen(network string, laddr *net.UDPAddr) (net.Listener, error) { - return (&ListenConfig{}).Listen(network, laddr) -} - -var readBufferPool = &sync.Pool{ - New: func() interface{} { - buf := make([]byte, receiveMTU) - return &buf - }, -} - -// readLoop has to tasks: -// 1. Dispatching incoming packets to the correct Conn. -// It can therefore not be ended until all Conns are closed. -// 2. Creating a new Conn when receiving from a new remote. -func (l *listener) readLoop() { - defer l.readWG.Done() - - for { - buf := *(readBufferPool.Get().(*[]byte)) - n, raddr, err := l.pConn.ReadFrom(buf) - if err != nil { - return - } - conn, ok, err := l.getConn(raddr, buf[:n]) - if err != nil { - continue - } - if ok { - _, _ = conn.buffer.Write(buf[:n]) - } - } -} - -func (l *listener) getConn(raddr net.Addr, buf []byte) (*Conn, bool, error) { - l.connLock.Lock() - defer l.connLock.Unlock() - conn, ok := l.conns[raddr.String()] - if !ok { - if !l.accepting.Load().(bool) { - return nil, false, errClosedListener - } - if l.acceptFilter != nil { - if !l.acceptFilter(buf) { - return nil, false, nil - } - } - conn = l.newConn(raddr) - select { - case l.acceptCh <- conn: - l.conns[raddr.String()] = conn - default: - return nil, false, errListenQueueExceeded - } - } - return conn, true, nil -} - -// Conn augments a connection-oriented connection over a UDP PacketConn -type Conn struct { - listener *listener - - rAddr net.Addr - - buffer *packetio.Buffer - - doneCh chan struct{} - doneOnce sync.Once - - writeDeadline *deadline.Deadline -} - -func (l *listener) newConn(rAddr net.Addr) *Conn { - return &Conn{ - listener: l, - rAddr: rAddr, - buffer: packetio.NewBuffer(), - doneCh: make(chan struct{}), - writeDeadline: deadline.New(), - } -} - -// Read -func (c *Conn) Read(p []byte) (int, error) { - return c.buffer.Read(p) -} - -// Write writes len(p) bytes from p to the DTLS connection -func (c *Conn) Write(p []byte) (n int, err error) { - select { - case <-c.writeDeadline.Done(): - return 0, context.DeadlineExceeded - default: - } - return c.listener.pConn.WriteTo(p, c.rAddr) -} - -// Close closes the conn and releases any Read calls -func (c *Conn) Close() error { - var err error - c.doneOnce.Do(func() { - c.listener.connWG.Done() - close(c.doneCh) - c.listener.connLock.Lock() - delete(c.listener.conns, c.rAddr.String()) - nConns := len(c.listener.conns) - c.listener.connLock.Unlock() - - if nConns == 0 && !c.listener.accepting.Load().(bool) { - // Wait if this is the final connection - c.listener.readWG.Wait() - if errClose, ok := c.listener.errClose.Load().(error); ok { - err = errClose - } - } else { - err = nil - } - }) - - return err -} - -// LocalAddr implements net.Conn.LocalAddr -func (c *Conn) LocalAddr() net.Addr { - return c.listener.pConn.LocalAddr() -} - -// RemoteAddr implements net.Conn.RemoteAddr -func (c *Conn) RemoteAddr() net.Addr { - return c.rAddr -} - -// SetDeadline implements net.Conn.SetDeadline -func (c *Conn) SetDeadline(t time.Time) error { - c.writeDeadline.Set(t) - return c.SetReadDeadline(t) -} - -// SetReadDeadline implements net.Conn.SetDeadline -func (c *Conn) SetReadDeadline(t time.Time) error { - return c.buffer.SetReadDeadline(t) -} - -// SetWriteDeadline implements net.Conn.SetDeadline -func (c *Conn) SetWriteDeadline(t time.Time) error { - c.writeDeadline.Set(t) - // Write deadline of underlying connection should not be changed - // since the connection can be shared. - return nil -} diff --git a/internal/net/udp/conn_test.go b/internal/net/udp/conn_test.go deleted file mode 100644 index 718161e1c..000000000 --- a/internal/net/udp/conn_test.go +++ /dev/null @@ -1,435 +0,0 @@ -// +build !js - -package udp - -import ( - "bytes" - "fmt" - "net" - "sync" - "testing" - "time" - - "github.com/pion/transport/test" -) - -// Note: doesn't work since closing isn't propagated to the other side -//func TestNetTest(t *testing.T) { -// lim := test.TimeOut(time.Minute*1 + time.Second*10) -// defer lim.Stop() -// -// nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) { -// listener, c1, c2, err = pipe() -// if err != nil { -// return nil, nil, nil, err -// } -// stop = func() { -// c1.Close() -// c2.Close() -// listener.Close(1 * time.Second) -// } -// return -// }) -//} - -func TestStressDuplex(t *testing.T) { - // Limit runtime in case of deadlocks - lim := test.TimeOut(time.Second * 20) - defer lim.Stop() - - // Check for leaking routines - report := test.CheckRoutines(t) - defer report() - - // Run the test - stressDuplex(t) -} - -func stressDuplex(t *testing.T) { - listener, ca, cb, err := pipe() - if err != nil { - t.Fatal(err) - } - - defer func() { - err = ca.Close() - if err != nil { - t.Fatal(err) - } - err = cb.Close() - if err != nil { - t.Fatal(err) - } - err = listener.Close() - if err != nil { - t.Fatal(err) - } - }() - - opt := test.Options{ - MsgSize: 2048, - MsgCount: 1, // Can't rely on UDP message order in CI - } - - err = test.StressDuplex(ca, cb, opt) - if err != nil { - t.Fatal(err) - } -} - -func TestListenerCloseTimeout(t *testing.T) { - // Limit runtime in case of deadlocks - lim := test.TimeOut(time.Second * 20) - defer lim.Stop() - - // Check for leaking routines - report := test.CheckRoutines(t) - defer report() - - listener, ca, _, err := pipe() - if err != nil { - t.Fatal(err) - } - - err = listener.Close() - if err != nil { - t.Fatal(err) - } - - // Close client after server closes to cleanup - err = ca.Close() - if err != nil { - t.Fatal(err) - } -} - -func TestListenerCloseUnaccepted(t *testing.T) { - // Limit runtime in case of deadlocks - lim := test.TimeOut(time.Second * 20) - defer lim.Stop() - - // Check for leaking routines - report := test.CheckRoutines(t) - defer report() - - const backlog = 2 - - network, addr := getConfig() - listener, err := (&ListenConfig{ - Backlog: backlog, - }).Listen(network, addr) - if err != nil { - t.Fatal(err) - } - - for i := 0; i < backlog; i++ { - conn, derr := net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr)) - if derr != nil { - t.Error(derr) - continue - } - if _, werr := conn.Write([]byte{byte(i)}); werr != nil { - t.Error(werr) - } - if cerr := conn.Close(); cerr != nil { - t.Error(cerr) - } - } - - time.Sleep(100 * time.Millisecond) // Wait all packets being processed by readLoop - - // Unaccepted connections must be closed by listener.Close() - err = listener.Close() - if err != nil { - t.Fatal(err) - } -} - -func TestListenerAcceptFilter(t *testing.T) { - // Limit runtime in case of deadlocks - lim := test.TimeOut(time.Second * 20) - defer lim.Stop() - - // Check for leaking routines - report := test.CheckRoutines(t) - defer report() - - testCases := map[string]struct { - packet []byte - accept bool - }{ - "CreateConn": { - packet: []byte{0xAA}, - accept: true, - }, - "Discarded": { - packet: []byte{0x00}, - accept: false, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - network, addr := getConfig() - listener, err := (&ListenConfig{ - AcceptFilter: func(pkt []byte) bool { - return pkt[0] == 0xAA - }, - }).Listen(network, addr) - if err != nil { - t.Fatal(err) - } - - var wgAcceptLoop sync.WaitGroup - wgAcceptLoop.Add(1) - defer func() { - cerr := listener.Close() - if cerr != nil { - t.Fatal(cerr) - } - wgAcceptLoop.Wait() - }() - - conn, derr := net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr)) - if derr != nil { - t.Fatal(derr) - } - if _, werr := conn.Write(testCase.packet); werr != nil { - t.Fatal(werr) - } - defer func() { - if cerr := conn.Close(); cerr != nil { - t.Error(cerr) - } - }() - - chAccepted := make(chan struct{}) - go func() { - defer wgAcceptLoop.Done() - - conn, aerr := listener.Accept() - if aerr != nil { - if aerr != errClosedListener { - t.Error(aerr) - } - return - } - close(chAccepted) - if cerr := conn.Close(); cerr != nil { - t.Error(cerr) - } - }() - - var accepted bool - select { - case <-chAccepted: - accepted = true - case <-time.After(10 * time.Millisecond): - } - - if accepted != testCase.accept { - if testCase.accept { - t.Error("Packet should create new conn") - } else { - t.Error("Packet should not create new conn") - } - } - }) - } -} - -func TestListenerConcurrent(t *testing.T) { - // Limit runtime in case of deadlocks - lim := test.TimeOut(time.Second * 20) - defer lim.Stop() - - // Check for leaking routines - report := test.CheckRoutines(t) - defer report() - - const backlog = 2 - - network, addr := getConfig() - listener, err := (&ListenConfig{ - Backlog: backlog, - }).Listen(network, addr) - if err != nil { - t.Fatal(err) - } - - for i := 0; i < backlog+1; i++ { - conn, derr := net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr)) - if derr != nil { - t.Error(derr) - continue - } - if _, werr := conn.Write([]byte{byte(i)}); werr != nil { - t.Error(werr) - } - if cerr := conn.Close(); cerr != nil { - t.Error(cerr) - } - } - - time.Sleep(100 * time.Millisecond) // Wait all packets being processed by readLoop - - for i := 0; i < backlog; i++ { - conn, aerr := listener.Accept() - if aerr != nil { - t.Error(aerr) - continue - } - b := make([]byte, 1) - n, rerr := conn.Read(b) - if rerr != nil { - t.Error(rerr) - } else if !bytes.Equal([]byte{byte(i)}, b[:n]) { - t.Errorf("Packet from connection %d is wrong, expected: [%d], got: %v", i, i, b[:n]) - } - if err = conn.Close(); err != nil { - t.Error(err) - } - } - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - if conn, aerr := listener.Accept(); aerr != errClosedListener { - t.Errorf("Connection exceeding backlog limit must be discarded: %v", aerr) - if aerr == nil { - _ = conn.Close() - } - } - }() - - time.Sleep(100 * time.Millisecond) // Last Accept should be discarded - err = listener.Close() - if err != nil { - t.Fatal(err) - } - - wg.Wait() -} - -func pipe() (net.Listener, net.Conn, *net.UDPConn, error) { - // Start listening - network, addr := getConfig() - listener, err := Listen(network, addr) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to listen: %v", err) - } - - // Open a connection - var dConn *net.UDPConn - dConn, err = net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr)) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to dial: %v", err) - } - - // Write to the connection to initiate it - handshake := "hello" - _, err = dConn.Write([]byte(handshake)) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to write to dialed Conn: %v", err) - } - - // Accept the connection - var lConn net.Conn - lConn, err = listener.Accept() - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to accept Conn: %v", err) - } - - buf := make([]byte, len(handshake)) - n := 0 - n, err = lConn.Read(buf) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to read handshake: %v", err) - } - - result := string(buf[:n]) - if handshake != result { - return nil, nil, nil, fmt.Errorf("handshake failed: %s != %s", handshake, result) - } - - return listener, lConn, dConn, nil -} - -func getConfig() (string, *net.UDPAddr) { - return "udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0} -} - -func TestConnClose(t *testing.T) { - lim := test.TimeOut(time.Second * 5) - defer lim.Stop() - - t.Run("Close", func(t *testing.T) { - // Check for leaking routines - report := test.CheckRoutines(t) - defer report() - - l, ca, cb, errPipe := pipe() - if errPipe != nil { - t.Fatal(errPipe) - } - if err := ca.Close(); err != nil { - t.Errorf("Failed to close A side: %v", err) - } - if err := cb.Close(); err != nil { - t.Errorf("Failed to close B side: %v", err) - } - if err := l.Close(); err != nil { - t.Errorf("Failed to close listener: %v", err) - } - }) - t.Run("CloseError1", func(t *testing.T) { - // Check for leaking routines - report := test.CheckRoutines(t) - defer report() - - l, ca, cb, errPipe := pipe() - if errPipe != nil { - t.Fatal(errPipe) - } - // Close l.pConn to inject error. - if err := l.(*listener).pConn.Close(); err != nil { - t.Error(err) - } - - if err := cb.Close(); err != nil { - t.Errorf("Failed to close A side: %v", err) - } - if err := ca.Close(); err != nil { - t.Errorf("Failed to close B side: %v", err) - } - if err := l.Close(); err == nil { - t.Errorf("Error is not propagated to Listener.Close") - } - }) - t.Run("CloseError2", func(t *testing.T) { - // Check for leaking routines - report := test.CheckRoutines(t) - defer report() - - l, ca, cb, errPipe := pipe() - if errPipe != nil { - t.Fatal(errPipe) - } - // Close l.pConn to inject error. - if err := l.(*listener).pConn.Close(); err != nil { - t.Error(err) - } - - if err := cb.Close(); err != nil { - t.Errorf("Failed to close A side: %v", err) - } - if err := l.Close(); err != nil { - t.Errorf("Failed to close listener: %v", err) - } - if err := ca.Close(); err == nil { - t.Errorf("Error is not propagated to Conn.Close") - } - }) -} diff --git a/listener.go b/listener.go index 4c55de633..dd27b2c83 100644 --- a/listener.go +++ b/listener.go @@ -3,7 +3,7 @@ package dtls import ( "net" - "github.com/pion/dtls/v2/internal/net/udp" + "github.com/pion/udp" ) // Listen creates a DTLS listener