diff --git a/conn.go b/conn.go index e208d116..1a57c656 100644 --- a/conn.go +++ b/conn.go @@ -246,6 +246,15 @@ func (m *mu) forceLock() { m.ch <- struct{}{} } +func (m *mu) tryLock() bool { + select { + case m.ch <- struct{}{}: + return true + default: + return false + } +} + func (m *mu) lock(ctx context.Context) error { select { case <-m.c.closed: diff --git a/netconn.go b/netconn.go index 64aadf0b..ae04b20a 100644 --- a/netconn.go +++ b/netconn.go @@ -6,7 +6,7 @@ import ( "io" "math" "net" - "sync" + "sync/atomic" "time" ) @@ -28,9 +28,10 @@ import ( // // Close will close the *websocket.Conn with StatusNormalClosure. // -// When a deadline is hit, the connection will be closed. This is -// different from most net.Conn implementations where only the -// reading/writing goroutines are interrupted but the connection is kept alive. +// When a deadline is hit and there is an active read or write goroutine, the +// connection will be closed. This is different from most net.Conn implementations +// where only the reading/writing goroutines are interrupted but the connection +// is kept alive. // // The Addr methods will return a mock net.Addr that returns "websocket" for Network // and "websocket/unknown-addr" for String. @@ -41,17 +42,43 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { nc := &netConn{ c: c, msgType: msgType, + readMu: newMu(c), + writeMu: newMu(c), } - var cancel context.CancelFunc - nc.writeContext, cancel = context.WithCancel(ctx) - nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel) + var writeCancel context.CancelFunc + nc.writeCtx, writeCancel = context.WithCancel(ctx) + var readCancel context.CancelFunc + nc.readCtx, readCancel = context.WithCancel(ctx) + + nc.writeTimer = time.AfterFunc(math.MaxInt64, func() { + if !nc.writeMu.tryLock() { + // If the lock cannot be acquired, then there is an + // active write goroutine and so we should cancel the context. + writeCancel() + return + } + defer nc.writeMu.unlock() + + // Prevents future writes from writing until the deadline is reset. + atomic.StoreInt64(&nc.writeExpired, 1) + }) if !nc.writeTimer.Stop() { <-nc.writeTimer.C } - nc.readContext, cancel = context.WithCancel(ctx) - nc.readTimer = time.AfterFunc(math.MaxInt64, cancel) + nc.readTimer = time.AfterFunc(math.MaxInt64, func() { + if !nc.readMu.tryLock() { + // If the lock cannot be acquired, then there is an + // active read goroutine and so we should cancel the context. + readCancel() + return + } + defer nc.readMu.unlock() + + // Prevents future reads from reading until the deadline is reset. + atomic.StoreInt64(&nc.readExpired, 1) + }) if !nc.readTimer.Stop() { <-nc.readTimer.C } @@ -64,59 +91,72 @@ type netConn struct { msgType MessageType writeTimer *time.Timer - writeContext context.Context + writeMu *mu + writeExpired int64 + writeCtx context.Context readTimer *time.Timer - readContext context.Context - - readMu sync.Mutex - eofed bool - reader io.Reader + readMu *mu + readExpired int64 + readCtx context.Context + readEOFed bool + reader io.Reader } var _ net.Conn = &netConn{} -func (c *netConn) Close() error { - return c.c.Close(StatusNormalClosure, "") +func (nc *netConn) Close() error { + return nc.c.Close(StatusNormalClosure, "") } -func (c *netConn) Write(p []byte) (int, error) { - err := c.c.Write(c.writeContext, c.msgType, p) +func (nc *netConn) Write(p []byte) (int, error) { + nc.writeMu.forceLock() + defer nc.writeMu.unlock() + + if atomic.LoadInt64(&nc.writeExpired) == 1 { + return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded) + } + + err := nc.c.Write(nc.writeCtx, nc.msgType, p) if err != nil { return 0, err } return len(p), nil } -func (c *netConn) Read(p []byte) (int, error) { - c.readMu.Lock() - defer c.readMu.Unlock() +func (nc *netConn) Read(p []byte) (int, error) { + nc.readMu.forceLock() + defer nc.readMu.unlock() + + if atomic.LoadInt64(&nc.readExpired) == 1 { + return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded) + } - if c.eofed { + if nc.readEOFed { return 0, io.EOF } - if c.reader == nil { - typ, r, err := c.c.Reader(c.readContext) + if nc.reader == nil { + typ, r, err := nc.c.Reader(nc.readCtx) if err != nil { switch CloseStatus(err) { case StatusNormalClosure, StatusGoingAway: - c.eofed = true + nc.readEOFed = true return 0, io.EOF } return 0, err } - if typ != c.msgType { - err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ) - c.c.Close(StatusUnsupportedData, err.Error()) + if typ != nc.msgType { + err := fmt.Errorf("unexpected frame type read (expected %v): %v", nc.msgType, typ) + nc.c.Close(StatusUnsupportedData, err.Error()) return 0, err } - c.reader = r + nc.reader = r } - n, err := c.reader.Read(p) + n, err := nc.reader.Read(p) if err == io.EOF { - c.reader = nil + nc.reader = nil err = nil } return n, err @@ -133,34 +173,36 @@ func (a websocketAddr) String() string { return "websocket/unknown-addr" } -func (c *netConn) RemoteAddr() net.Addr { +func (nc *netConn) RemoteAddr() net.Addr { return websocketAddr{} } -func (c *netConn) LocalAddr() net.Addr { +func (nc *netConn) LocalAddr() net.Addr { return websocketAddr{} } -func (c *netConn) SetDeadline(t time.Time) error { - c.SetWriteDeadline(t) - c.SetReadDeadline(t) +func (nc *netConn) SetDeadline(t time.Time) error { + nc.SetWriteDeadline(t) + nc.SetReadDeadline(t) return nil } -func (c *netConn) SetWriteDeadline(t time.Time) error { +func (nc *netConn) SetWriteDeadline(t time.Time) error { + atomic.StoreInt64(&nc.writeExpired, 0) if t.IsZero() { - c.writeTimer.Stop() + nc.writeTimer.Stop() } else { - c.writeTimer.Reset(t.Sub(time.Now())) + nc.writeTimer.Reset(t.Sub(time.Now())) } return nil } -func (c *netConn) SetReadDeadline(t time.Time) error { +func (nc *netConn) SetReadDeadline(t time.Time) error { + atomic.StoreInt64(&nc.readExpired, 0) if t.IsZero() { - c.readTimer.Stop() + nc.readTimer.Stop() } else { - c.readTimer.Reset(t.Sub(time.Now())) + nc.readTimer.Reset(t.Sub(time.Now())) } return nil } diff --git a/ws_js.go b/ws_js.go index 31e3c2f6..27ba17a7 100644 --- a/ws_js.go +++ b/ws_js.go @@ -511,3 +511,58 @@ const ( // MessageBinary is for binary messages like protobufs. MessageBinary ) + +type mu struct { + c *Conn + ch chan struct{} +} + +func newMu(c *Conn) *mu { + return &mu{ + c: c, + ch: make(chan struct{}, 1), + } +} + +func (m *mu) forceLock() { + m.ch <- struct{}{} +} + +func (m *mu) tryLock() bool { + select { + case m.ch <- struct{}{}: + return true + default: + return false + } +} + +func (m *mu) lock(ctx context.Context) error { + select { + case <-m.c.closed: + return m.c.closeErr + case <-ctx.Done(): + err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) + m.c.close(err, false) + return err + case m.ch <- struct{}{}: + // To make sure the connection is certainly alive. + // As it's possible the send on m.ch was selected + // over the receive on closed. + select { + case <-m.c.closed: + // Make sure to release. + m.unlock() + return m.c.closeErr + default: + } + return nil + } +} + +func (m *mu) unlock() { + select { + case <-m.ch: + default: + } +}