Skip to content

Commit

Permalink
Make SetDeadline on NetConn not always close Conn
Browse files Browse the repository at this point in the history
NetConn has to close the connection to interrupt in progress reads
and writes. However, it can error out on reads and writes that occur
after the deadline instead of closing the connection.

Closes #228
  • Loading branch information
nhooyr committed May 18, 2020
1 parent 1695216 commit b16701e
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 43 deletions.
9 changes: 9 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,15 @@ func (m *mu) forceLock() {
m.ch <- struct{}{}
}

func (m *mu) tryLock() bool {
select {
case m.ch <- struct{}{}:
return true
default:
return false
}
}

func (m *mu) lock(ctx context.Context) error {
select {
case <-m.c.closed:
Expand Down
128 changes: 85 additions & 43 deletions netconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"io"
"math"
"net"
"sync"
"sync/atomic"
"time"
)

Expand All @@ -28,9 +28,10 @@ import (
//
// Close will close the *websocket.Conn with StatusNormalClosure.
//
// When a deadline is hit, the connection will be closed. This is
// different from most net.Conn implementations where only the
// reading/writing goroutines are interrupted but the connection is kept alive.
// When a deadline is hit and there is an active read or write goroutine, the
// connection will be closed. This is different from most net.Conn implementations
// where only the reading/writing goroutines are interrupted but the connection
// is kept alive.
//
// The Addr methods will return a mock net.Addr that returns "websocket" for Network
// and "websocket/unknown-addr" for String.
Expand All @@ -41,17 +42,43 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
nc := &netConn{
c: c,
msgType: msgType,
readMu: newMu(c),
writeMu: newMu(c),
}

var cancel context.CancelFunc
nc.writeContext, cancel = context.WithCancel(ctx)
nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel)
var writeCancel context.CancelFunc
nc.writeCtx, writeCancel = context.WithCancel(ctx)
var readCancel context.CancelFunc
nc.readCtx, readCancel = context.WithCancel(ctx)

nc.writeTimer = time.AfterFunc(math.MaxInt64, func() {
if !nc.writeMu.tryLock() {
// If the lock cannot be acquired, then there is an
// active write goroutine and so we should cancel the context.
writeCancel()
return
}
defer nc.writeMu.unlock()

// Prevents future writes from writing until the deadline is reset.
atomic.StoreInt64(&nc.writeExpired, 1)
})
if !nc.writeTimer.Stop() {
<-nc.writeTimer.C
}

nc.readContext, cancel = context.WithCancel(ctx)
nc.readTimer = time.AfterFunc(math.MaxInt64, cancel)
nc.readTimer = time.AfterFunc(math.MaxInt64, func() {
if !nc.readMu.tryLock() {
// If the lock cannot be acquired, then there is an
// active read goroutine and so we should cancel the context.
readCancel()
return
}
defer nc.readMu.unlock()

// Prevents future reads from reading until the deadline is reset.
atomic.StoreInt64(&nc.readExpired, 1)
})
if !nc.readTimer.Stop() {
<-nc.readTimer.C
}
Expand All @@ -64,59 +91,72 @@ type netConn struct {
msgType MessageType

writeTimer *time.Timer
writeContext context.Context
writeMu *mu
writeExpired int64
writeCtx context.Context

readTimer *time.Timer
readContext context.Context

readMu sync.Mutex
eofed bool
reader io.Reader
readMu *mu
readExpired int64
readCtx context.Context
readEOFed bool
reader io.Reader
}

var _ net.Conn = &netConn{}

func (c *netConn) Close() error {
return c.c.Close(StatusNormalClosure, "")
func (nc *netConn) Close() error {
return nc.c.Close(StatusNormalClosure, "")
}

func (c *netConn) Write(p []byte) (int, error) {
err := c.c.Write(c.writeContext, c.msgType, p)
func (nc *netConn) Write(p []byte) (int, error) {
nc.writeMu.forceLock()
defer nc.writeMu.unlock()

if atomic.LoadInt64(&nc.writeExpired) == 1 {
return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded)
}

err := nc.c.Write(nc.writeCtx, nc.msgType, p)
if err != nil {
return 0, err
}
return len(p), nil
}

func (c *netConn) Read(p []byte) (int, error) {
c.readMu.Lock()
defer c.readMu.Unlock()
func (nc *netConn) Read(p []byte) (int, error) {
nc.readMu.forceLock()
defer nc.readMu.unlock()

if atomic.LoadInt64(&nc.readExpired) == 1 {
return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded)
}

if c.eofed {
if nc.readEOFed {
return 0, io.EOF
}

if c.reader == nil {
typ, r, err := c.c.Reader(c.readContext)
if nc.reader == nil {
typ, r, err := nc.c.Reader(nc.readCtx)
if err != nil {
switch CloseStatus(err) {
case StatusNormalClosure, StatusGoingAway:
c.eofed = true
nc.readEOFed = true
return 0, io.EOF
}
return 0, err
}
if typ != c.msgType {
err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ)
c.c.Close(StatusUnsupportedData, err.Error())
if typ != nc.msgType {
err := fmt.Errorf("unexpected frame type read (expected %v): %v", nc.msgType, typ)
nc.c.Close(StatusUnsupportedData, err.Error())
return 0, err
}
c.reader = r
nc.reader = r
}

n, err := c.reader.Read(p)
n, err := nc.reader.Read(p)
if err == io.EOF {
c.reader = nil
nc.reader = nil
err = nil
}
return n, err
Expand All @@ -133,34 +173,36 @@ func (a websocketAddr) String() string {
return "websocket/unknown-addr"
}

func (c *netConn) RemoteAddr() net.Addr {
func (nc *netConn) RemoteAddr() net.Addr {
return websocketAddr{}
}

func (c *netConn) LocalAddr() net.Addr {
func (nc *netConn) LocalAddr() net.Addr {
return websocketAddr{}
}

func (c *netConn) SetDeadline(t time.Time) error {
c.SetWriteDeadline(t)
c.SetReadDeadline(t)
func (nc *netConn) SetDeadline(t time.Time) error {
nc.SetWriteDeadline(t)
nc.SetReadDeadline(t)
return nil
}

func (c *netConn) SetWriteDeadline(t time.Time) error {
func (nc *netConn) SetWriteDeadline(t time.Time) error {
atomic.StoreInt64(&nc.writeExpired, 0)
if t.IsZero() {
c.writeTimer.Stop()
nc.writeTimer.Stop()
} else {
c.writeTimer.Reset(t.Sub(time.Now()))
nc.writeTimer.Reset(t.Sub(time.Now()))
}
return nil
}

func (c *netConn) SetReadDeadline(t time.Time) error {
func (nc *netConn) SetReadDeadline(t time.Time) error {
atomic.StoreInt64(&nc.readExpired, 0)
if t.IsZero() {
c.readTimer.Stop()
nc.readTimer.Stop()
} else {
c.readTimer.Reset(t.Sub(time.Now()))
nc.readTimer.Reset(t.Sub(time.Now()))
}
return nil
}
55 changes: 55 additions & 0 deletions ws_js.go
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,58 @@ const (
// MessageBinary is for binary messages like protobufs.
MessageBinary
)

type mu struct {
c *Conn
ch chan struct{}
}

func newMu(c *Conn) *mu {
return &mu{
c: c,
ch: make(chan struct{}, 1),
}
}

func (m *mu) forceLock() {
m.ch <- struct{}{}
}

func (m *mu) tryLock() bool {
select {
case m.ch <- struct{}{}:
return true
default:
return false
}
}

func (m *mu) lock(ctx context.Context) error {
select {
case <-m.c.closed:
return m.c.closeErr
case <-ctx.Done():
err := fmt.Errorf("failed to acquire lock: %w", ctx.Err())
m.c.close(err, false)
return err
case m.ch <- struct{}{}:
// To make sure the connection is certainly alive.
// As it's possible the send on m.ch was selected
// over the receive on closed.
select {
case <-m.c.closed:
// Make sure to release.
m.unlock()
return m.c.closeErr
default:
}
return nil
}
}

func (m *mu) unlock() {
select {
case <-m.ch:
default:
}
}

0 comments on commit b16701e

Please sign in to comment.