Skip to content

Commit 4a143a0

Browse files
committed
Make SetDeadline on NetConn not always close Conn
NetConn has to close the connection to interrupt in progress reads and writes. However, it can error out on reads and writes that occur after the deadline instead of closing the connection. Closes #228
1 parent 1695216 commit 4a143a0

File tree

3 files changed

+126
-43
lines changed

3 files changed

+126
-43
lines changed

conn.go

+9
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,15 @@ func (m *mu) forceLock() {
246246
m.ch <- struct{}{}
247247
}
248248

249+
func (m *mu) tryLock() bool {
250+
select {
251+
case m.ch <- struct{}{}:
252+
return true
253+
default:
254+
return false
255+
}
256+
}
257+
249258
func (m *mu) lock(ctx context.Context) error {
250259
select {
251260
case <-m.c.closed:

netconn.go

+85-43
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import (
66
"io"
77
"math"
88
"net"
9-
"sync"
9+
"sync/atomic"
1010
"time"
1111
)
1212

@@ -28,9 +28,10 @@ import (
2828
//
2929
// Close will close the *websocket.Conn with StatusNormalClosure.
3030
//
31-
// When a deadline is hit, the connection will be closed. This is
32-
// different from most net.Conn implementations where only the
33-
// reading/writing goroutines are interrupted but the connection is kept alive.
31+
// When a deadline is hit and there is an active read or write goroutine, the
32+
// connection will be closed. This is different from most net.Conn implementations
33+
// where only the reading/writing goroutines are interrupted but the connection
34+
// is kept alive.
3435
//
3536
// The Addr methods will return a mock net.Addr that returns "websocket" for Network
3637
// and "websocket/unknown-addr" for String.
@@ -41,17 +42,43 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
4142
nc := &netConn{
4243
c: c,
4344
msgType: msgType,
45+
readMu: newMu(c),
46+
writeMu: newMu(c),
4447
}
4548

46-
var cancel context.CancelFunc
47-
nc.writeContext, cancel = context.WithCancel(ctx)
48-
nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel)
49+
var writeCancel context.CancelFunc
50+
nc.writeCtx, writeCancel = context.WithCancel(ctx)
51+
var readCancel context.CancelFunc
52+
nc.readCtx, readCancel = context.WithCancel(ctx)
53+
54+
nc.writeTimer = time.AfterFunc(math.MaxInt64, func() {
55+
if !nc.writeMu.tryLock() {
56+
// If the lock cannot be acquired, then there is an
57+
// active write goroutine and so we should cancel the context.
58+
writeCancel()
59+
return
60+
}
61+
defer nc.writeMu.unlock()
62+
63+
// Prevents future writes from writing until the deadline is reset.
64+
atomic.StoreInt64(&nc.writeExpired, 1)
65+
})
4966
if !nc.writeTimer.Stop() {
5067
<-nc.writeTimer.C
5168
}
5269

53-
nc.readContext, cancel = context.WithCancel(ctx)
54-
nc.readTimer = time.AfterFunc(math.MaxInt64, cancel)
70+
nc.readTimer = time.AfterFunc(math.MaxInt64, func() {
71+
if !nc.readMu.tryLock() {
72+
// If the lock cannot be acquired, then there is an
73+
// active read goroutine and so we should cancel the context.
74+
readCancel()
75+
return
76+
}
77+
defer nc.readMu.unlock()
78+
79+
// Prevents future reads from reading until the deadline is reset.
80+
atomic.StoreInt64(&nc.readExpired, 1)
81+
})
5582
if !nc.readTimer.Stop() {
5683
<-nc.readTimer.C
5784
}
@@ -64,59 +91,72 @@ type netConn struct {
6491
msgType MessageType
6592

6693
writeTimer *time.Timer
67-
writeContext context.Context
94+
writeMu *mu
95+
writeExpired int64
96+
writeCtx context.Context
6897

6998
readTimer *time.Timer
70-
readContext context.Context
71-
72-
readMu sync.Mutex
73-
eofed bool
74-
reader io.Reader
99+
readMu *mu
100+
readExpired int64
101+
readCtx context.Context
102+
readEOFed bool
103+
reader io.Reader
75104
}
76105

77106
var _ net.Conn = &netConn{}
78107

79-
func (c *netConn) Close() error {
80-
return c.c.Close(StatusNormalClosure, "")
108+
func (nc *netConn) Close() error {
109+
return nc.c.Close(StatusNormalClosure, "")
81110
}
82111

83-
func (c *netConn) Write(p []byte) (int, error) {
84-
err := c.c.Write(c.writeContext, c.msgType, p)
112+
func (nc *netConn) Write(p []byte) (int, error) {
113+
nc.writeMu.forceLock()
114+
defer nc.writeMu.unlock()
115+
116+
if atomic.LoadInt64(&nc.writeExpired) == 1 {
117+
return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded)
118+
}
119+
120+
err := nc.c.Write(nc.writeCtx, nc.msgType, p)
85121
if err != nil {
86122
return 0, err
87123
}
88124
return len(p), nil
89125
}
90126

91-
func (c *netConn) Read(p []byte) (int, error) {
92-
c.readMu.Lock()
93-
defer c.readMu.Unlock()
127+
func (nc *netConn) Read(p []byte) (int, error) {
128+
nc.readMu.forceLock()
129+
defer nc.readMu.unlock()
130+
131+
if atomic.LoadInt64(&nc.readExpired) == 1 {
132+
return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded)
133+
}
94134

95-
if c.eofed {
135+
if nc.readEOFed {
96136
return 0, io.EOF
97137
}
98138

99-
if c.reader == nil {
100-
typ, r, err := c.c.Reader(c.readContext)
139+
if nc.reader == nil {
140+
typ, r, err := nc.c.Reader(nc.readCtx)
101141
if err != nil {
102142
switch CloseStatus(err) {
103143
case StatusNormalClosure, StatusGoingAway:
104-
c.eofed = true
144+
nc.readEOFed = true
105145
return 0, io.EOF
106146
}
107147
return 0, err
108148
}
109-
if typ != c.msgType {
110-
err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ)
111-
c.c.Close(StatusUnsupportedData, err.Error())
149+
if typ != nc.msgType {
150+
err := fmt.Errorf("unexpected frame type read (expected %v): %v", nc.msgType, typ)
151+
nc.c.Close(StatusUnsupportedData, err.Error())
112152
return 0, err
113153
}
114-
c.reader = r
154+
nc.reader = r
115155
}
116156

117-
n, err := c.reader.Read(p)
157+
n, err := nc.reader.Read(p)
118158
if err == io.EOF {
119-
c.reader = nil
159+
nc.reader = nil
120160
err = nil
121161
}
122162
return n, err
@@ -133,34 +173,36 @@ func (a websocketAddr) String() string {
133173
return "websocket/unknown-addr"
134174
}
135175

136-
func (c *netConn) RemoteAddr() net.Addr {
176+
func (nc *netConn) RemoteAddr() net.Addr {
137177
return websocketAddr{}
138178
}
139179

140-
func (c *netConn) LocalAddr() net.Addr {
180+
func (nc *netConn) LocalAddr() net.Addr {
141181
return websocketAddr{}
142182
}
143183

144-
func (c *netConn) SetDeadline(t time.Time) error {
145-
c.SetWriteDeadline(t)
146-
c.SetReadDeadline(t)
184+
func (nc *netConn) SetDeadline(t time.Time) error {
185+
nc.SetWriteDeadline(t)
186+
nc.SetReadDeadline(t)
147187
return nil
148188
}
149189

150-
func (c *netConn) SetWriteDeadline(t time.Time) error {
190+
func (nc *netConn) SetWriteDeadline(t time.Time) error {
191+
atomic.StoreInt64(&nc.writeExpired, 0)
151192
if t.IsZero() {
152-
c.writeTimer.Stop()
193+
nc.writeTimer.Stop()
153194
} else {
154-
c.writeTimer.Reset(t.Sub(time.Now()))
195+
nc.writeTimer.Reset(t.Sub(time.Now()))
155196
}
156197
return nil
157198
}
158199

159-
func (c *netConn) SetReadDeadline(t time.Time) error {
200+
func (nc *netConn) SetReadDeadline(t time.Time) error {
201+
atomic.StoreInt64(&nc.readExpired, 0)
160202
if t.IsZero() {
161-
c.readTimer.Stop()
203+
nc.readTimer.Stop()
162204
} else {
163-
c.readTimer.Reset(t.Sub(time.Now()))
205+
nc.readTimer.Reset(t.Sub(time.Now()))
164206
}
165207
return nil
166208
}

ws_js.go

+32
Original file line numberDiff line numberDiff line change
@@ -511,3 +511,35 @@ const (
511511
// MessageBinary is for binary messages like protobufs.
512512
MessageBinary
513513
)
514+
515+
type mu struct {
516+
c *Conn
517+
ch chan struct{}
518+
}
519+
520+
func newMu(c *Conn) *mu {
521+
return &mu{
522+
c: c,
523+
ch: make(chan struct{}, 1),
524+
}
525+
}
526+
527+
func (m *mu) forceLock() {
528+
m.ch <- struct{}{}
529+
}
530+
531+
func (m *mu) tryLock() bool {
532+
select {
533+
case m.ch <- struct{}{}:
534+
return true
535+
default:
536+
return false
537+
}
538+
}
539+
540+
func (m *mu) unlock() {
541+
select {
542+
case <-m.ch:
543+
default:
544+
}
545+
}

0 commit comments

Comments
 (0)