diff --git a/readloop.go b/readloop.go index de4c1a5..9e18cfb 100644 --- a/readloop.go +++ b/readloop.go @@ -34,6 +34,9 @@ func (s *UDPSession) defaultReadLoop() { var src string for { if n, addr, err := s.conn.ReadFrom(buf); err == nil { + if s.isClosed() { + return + } // make sure the packet is from the same source if src == "" { // set source address src = addr.String() diff --git a/readloop_linux.go b/readloop_linux.go index bcc30c9..e1817d3 100644 --- a/readloop_linux.go +++ b/readloop_linux.go @@ -52,6 +52,9 @@ func (s *UDPSession) readLoop() { for { if count, err := s.xconn.ReadBatch(msgs, 0); err == nil { + if s.isClosed() { + return + } for i := 0; i < count; i++ { msg := &msgs[i] // make sure the packet is from the same source diff --git a/sess.go b/sess.go index 05122ca..91c4d0a 100644 --- a/sess.go +++ b/sess.go @@ -402,6 +402,15 @@ RESET_TIMER: } } +func (s *UDPSession) isClosed() bool { + select { + case <-s.die: + return true + default: + return false + } +} + // Close closes the connection. func (s *UDPSession) Close() error { var once bool diff --git a/sess_test.go b/sess_test.go index c6cebd3..f01aaf0 100644 --- a/sess_test.go +++ b/sess_test.go @@ -729,3 +729,69 @@ func TestControl(t *testing.T) { t.Fatal(err) } } + +func TestSessionReadAfterClosed(t *testing.T) { + us, _ := net.ListenPacket("udp", "127.0.0.1:0") + uc, _ := net.ListenPacket("udp", "127.0.0.1:0") + defer us.Close() + defer uc.Close() + + knockDoor := func(c net.Conn, myid string) (string, error) { + c.SetDeadline(time.Now().Add(time.Second * 3)) + _, err := c.Write([]byte(myid)) + c.SetDeadline(time.Time{}) + if err != nil { + return "", err + } + c.SetDeadline(time.Now().Add(time.Second * 3)) + var buf [1024]byte + n, err := c.Read(buf[:]) + c.SetDeadline(time.Time{}) + return string(buf[:n]), err + } + + check := func(c1, c2 *UDPSession) { + done := make(chan struct{}, 1) + go func() { + rid, err := knockDoor(c2, "4321") + done <- struct{}{} + if err != nil { + panic(err) + } + if rid != "1234" { + panic("mismatch id") + } + }() + rid, err := knockDoor(c1, "1234") + if err != nil { + panic(err) + } + if rid != "4321" { + panic("mismatch id") + } + <-done + } + + c1, err := NewConn3(0, uc.LocalAddr(), nil, 0, 0, us) + if err != nil { + panic(err) + } + c2, err := NewConn3(0, us.LocalAddr(), nil, 0, 0, uc) + if err != nil { + panic(err) + } + check(c1, c2) + c1.Close() + c2.Close() + //log.Println("conv id 0 is closed") + + c1, err = NewConn3(4321, uc.LocalAddr(), nil, 0, 0, us) + if err != nil { + panic(err) + } + c2, err = NewConn3(4321, us.LocalAddr(), nil, 0, 0, uc) + if err != nil { + panic(err) + } + check(c1, c2) +}