6
6
"io"
7
7
"math"
8
8
"net"
9
- "sync"
9
+ "sync/atomic "
10
10
"time"
11
11
)
12
12
@@ -28,9 +28,10 @@ import (
28
28
//
29
29
// Close will close the *websocket.Conn with StatusNormalClosure.
30
30
//
31
- // When a deadline is hit, the connection will be closed. This is
32
- // different from most net.Conn implementations where only the
33
- // reading/writing goroutines are interrupted but the connection is kept alive.
31
+ // When a deadline is hit and there is an active read or write goroutine, the
32
+ // connection will be closed. This is different from most net.Conn implementations
33
+ // where only the reading/writing goroutines are interrupted but the connection
34
+ // is kept alive.
34
35
//
35
36
// The Addr methods will return a mock net.Addr that returns "websocket" for Network
36
37
// and "websocket/unknown-addr" for String.
@@ -41,17 +42,43 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
41
42
nc := & netConn {
42
43
c : c ,
43
44
msgType : msgType ,
45
+ readMu : newMu (c ),
46
+ writeMu : newMu (c ),
44
47
}
45
48
46
- var cancel context.CancelFunc
47
- nc .writeContext , cancel = context .WithCancel (ctx )
48
- nc .writeTimer = time .AfterFunc (math .MaxInt64 , cancel )
49
+ var writeCancel context.CancelFunc
50
+ nc .writeCtx , writeCancel = context .WithCancel (ctx )
51
+ var readCancel context.CancelFunc
52
+ nc .readCtx , readCancel = context .WithCancel (ctx )
53
+
54
+ nc .writeTimer = time .AfterFunc (math .MaxInt64 , func () {
55
+ if ! nc .writeMu .tryLock () {
56
+ // If the lock cannot be acquired, then there is an
57
+ // active write goroutine and so we should cancel the context.
58
+ writeCancel ()
59
+ return
60
+ }
61
+ defer nc .writeMu .unlock ()
62
+
63
+ // Prevents future writes from writing until the deadline is reset.
64
+ atomic .StoreInt64 (& nc .writeExpired , 1 )
65
+ })
49
66
if ! nc .writeTimer .Stop () {
50
67
<- nc .writeTimer .C
51
68
}
52
69
53
- nc .readContext , cancel = context .WithCancel (ctx )
54
- nc .readTimer = time .AfterFunc (math .MaxInt64 , cancel )
70
+ nc .readTimer = time .AfterFunc (math .MaxInt64 , func () {
71
+ if ! nc .readMu .tryLock () {
72
+ // If the lock cannot be acquired, then there is an
73
+ // active read goroutine and so we should cancel the context.
74
+ readCancel ()
75
+ return
76
+ }
77
+ defer nc .readMu .unlock ()
78
+
79
+ // Prevents future reads from reading until the deadline is reset.
80
+ atomic .StoreInt64 (& nc .readExpired , 1 )
81
+ })
55
82
if ! nc .readTimer .Stop () {
56
83
<- nc .readTimer .C
57
84
}
@@ -64,59 +91,72 @@ type netConn struct {
64
91
msgType MessageType
65
92
66
93
writeTimer * time.Timer
67
- writeContext context.Context
94
+ writeMu * mu
95
+ writeExpired int64
96
+ writeCtx context.Context
68
97
69
98
readTimer * time.Timer
70
- readContext context. Context
71
-
72
- readMu sync. Mutex
73
- eofed bool
74
- reader io.Reader
99
+ readMu * mu
100
+ readExpired int64
101
+ readCtx context. Context
102
+ readEOFed bool
103
+ reader io.Reader
75
104
}
76
105
77
106
var _ net.Conn = & netConn {}
78
107
79
- func (c * netConn ) Close () error {
80
- return c .c .Close (StatusNormalClosure , "" )
108
+ func (nc * netConn ) Close () error {
109
+ return nc .c .Close (StatusNormalClosure , "" )
81
110
}
82
111
83
- func (c * netConn ) Write (p []byte ) (int , error ) {
84
- err := c .c .Write (c .writeContext , c .msgType , p )
112
+ func (nc * netConn ) Write (p []byte ) (int , error ) {
113
+ nc .writeMu .forceLock ()
114
+ defer nc .writeMu .unlock ()
115
+
116
+ if atomic .LoadInt64 (& nc .writeExpired ) == 1 {
117
+ return 0 , fmt .Errorf ("failed to write: %w" , context .DeadlineExceeded )
118
+ }
119
+
120
+ err := nc .c .Write (nc .writeCtx , nc .msgType , p )
85
121
if err != nil {
86
122
return 0 , err
87
123
}
88
124
return len (p ), nil
89
125
}
90
126
91
- func (c * netConn ) Read (p []byte ) (int , error ) {
92
- c .readMu .Lock ()
93
- defer c .readMu .Unlock ()
127
+ func (nc * netConn ) Read (p []byte ) (int , error ) {
128
+ nc .readMu .forceLock ()
129
+ defer nc .readMu .unlock ()
130
+
131
+ if atomic .LoadInt64 (& nc .readExpired ) == 1 {
132
+ return 0 , fmt .Errorf ("failed to read: %w" , context .DeadlineExceeded )
133
+ }
94
134
95
- if c . eofed {
135
+ if nc . readEOFed {
96
136
return 0 , io .EOF
97
137
}
98
138
99
- if c .reader == nil {
100
- typ , r , err := c .c .Reader (c . readContext )
139
+ if nc .reader == nil {
140
+ typ , r , err := nc .c .Reader (nc . readCtx )
101
141
if err != nil {
102
142
switch CloseStatus (err ) {
103
143
case StatusNormalClosure , StatusGoingAway :
104
- c . eofed = true
144
+ nc . readEOFed = true
105
145
return 0 , io .EOF
106
146
}
107
147
return 0 , err
108
148
}
109
- if typ != c .msgType {
110
- err := fmt .Errorf ("unexpected frame type read (expected %v): %v" , c .msgType , typ )
111
- c .c .Close (StatusUnsupportedData , err .Error ())
149
+ if typ != nc .msgType {
150
+ err := fmt .Errorf ("unexpected frame type read (expected %v): %v" , nc .msgType , typ )
151
+ nc .c .Close (StatusUnsupportedData , err .Error ())
112
152
return 0 , err
113
153
}
114
- c .reader = r
154
+ nc .reader = r
115
155
}
116
156
117
- n , err := c .reader .Read (p )
157
+ n , err := nc .reader .Read (p )
118
158
if err == io .EOF {
119
- c .reader = nil
159
+ nc .reader = nil
120
160
err = nil
121
161
}
122
162
return n , err
@@ -133,34 +173,36 @@ func (a websocketAddr) String() string {
133
173
return "websocket/unknown-addr"
134
174
}
135
175
136
- func (c * netConn ) RemoteAddr () net.Addr {
176
+ func (nc * netConn ) RemoteAddr () net.Addr {
137
177
return websocketAddr {}
138
178
}
139
179
140
- func (c * netConn ) LocalAddr () net.Addr {
180
+ func (nc * netConn ) LocalAddr () net.Addr {
141
181
return websocketAddr {}
142
182
}
143
183
144
- func (c * netConn ) SetDeadline (t time.Time ) error {
145
- c .SetWriteDeadline (t )
146
- c .SetReadDeadline (t )
184
+ func (nc * netConn ) SetDeadline (t time.Time ) error {
185
+ nc .SetWriteDeadline (t )
186
+ nc .SetReadDeadline (t )
147
187
return nil
148
188
}
149
189
150
- func (c * netConn ) SetWriteDeadline (t time.Time ) error {
190
+ func (nc * netConn ) SetWriteDeadline (t time.Time ) error {
191
+ atomic .StoreInt64 (& nc .writeExpired , 0 )
151
192
if t .IsZero () {
152
- c .writeTimer .Stop ()
193
+ nc .writeTimer .Stop ()
153
194
} else {
154
- c .writeTimer .Reset (t .Sub (time .Now ()))
195
+ nc .writeTimer .Reset (t .Sub (time .Now ()))
155
196
}
156
197
return nil
157
198
}
158
199
159
- func (c * netConn ) SetReadDeadline (t time.Time ) error {
200
+ func (nc * netConn ) SetReadDeadline (t time.Time ) error {
201
+ atomic .StoreInt64 (& nc .readExpired , 0 )
160
202
if t .IsZero () {
161
- c .readTimer .Stop ()
203
+ nc .readTimer .Stop ()
162
204
} else {
163
- c .readTimer .Reset (t .Sub (time .Now ()))
205
+ nc .readTimer .Reset (t .Sub (time .Now ()))
164
206
}
165
207
return nil
166
208
}
0 commit comments