6	"io" 
6	"io" 
7	"math" 
7	"math" 
8	"net" 
8	"net" 
9- 	"sync" 
9+ 	"sync/atomic " 
10	"time" 
10	"time" 
11)
11)
12
12
@@ -28,9 +28,10 @@ import (
28// 
28// 
29// Close will close the *websocket.Conn with StatusNormalClosure. 
29// Close will close the *websocket.Conn with StatusNormalClosure. 
30// 
30// 
31- // When a deadline is hit, the connection will be closed. This is 
31+ // When a deadline is hit and there is an active read or write goroutine, the 
32- // different from most net.Conn implementations where only the 
32+ // connection will be closed. This is different from most net.Conn implementations 
33- // reading/writing goroutines are interrupted but the connection is kept alive. 
33+ // where only the reading/writing goroutines are interrupted but the connection 
34+ // is kept alive. 
34// 
35// 
35// The Addr methods will return a mock net.Addr that returns "websocket" for Network 
36// The Addr methods will return a mock net.Addr that returns "websocket" for Network 
36// and "websocket/unknown-addr" for String. 
37// and "websocket/unknown-addr" for String. 
@@ -41,17 +42,43 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
41	nc  :=  & netConn {
42	nc  :=  & netConn {
42		c :       c ,
43		c :       c ,
43		msgType : msgType ,
44		msgType : msgType ,
45+ 		readMu :  newMu (c ),
46+ 		writeMu : newMu (c ),
44	}
47	}
45
48
46- 	var  cancel  context.CancelFunc 
49+ 	var  writeCancel  context.CancelFunc 
47- 	nc .writeContext , cancel  =  context .WithCancel (ctx )
50+ 	nc .writeCtx , writeCancel  =  context .WithCancel (ctx )
48- 	nc .writeTimer  =  time .AfterFunc (math .MaxInt64 , cancel )
51+ 	var  readCancel  context.CancelFunc 
52+ 	nc .readCtx , readCancel  =  context .WithCancel (ctx )
53+ 
54+ 	nc .writeTimer  =  time .AfterFunc (math .MaxInt64 , func () {
55+ 		if  ! nc .writeMu .tryLock () {
56+ 			// If the lock cannot be acquired, then there is an 
57+ 			// active write goroutine and so we should cancel the context. 
58+ 			writeCancel ()
59+ 			return 
60+ 		}
61+ 		defer  nc .writeMu .unlock ()
62+ 
63+ 		// Prevents future writes from writing until the deadline is reset. 
64+ 		atomic .StoreInt64 (& nc .writeExpired , 1 )
65+ 	})
49	if  ! nc .writeTimer .Stop () {
66	if  ! nc .writeTimer .Stop () {
50		<- nc .writeTimer .C 
67		<- nc .writeTimer .C 
51	}
68	}
52
69
53- 	nc .readContext , cancel  =  context .WithCancel (ctx )
70+ 	nc .readTimer  =  time .AfterFunc (math .MaxInt64 , func () {
54- 	nc .readTimer  =  time .AfterFunc (math .MaxInt64 , cancel )
71+ 		if  ! nc .readMu .tryLock () {
72+ 			// If the lock cannot be acquired, then there is an 
73+ 			// active read goroutine and so we should cancel the context. 
74+ 			readCancel ()
75+ 			return 
76+ 		}
77+ 		defer  nc .readMu .unlock ()
78+ 
79+ 		// Prevents future reads from reading until the deadline is reset. 
80+ 		atomic .StoreInt64 (& nc .readExpired , 1 )
81+ 	})
55	if  ! nc .readTimer .Stop () {
82	if  ! nc .readTimer .Stop () {
56		<- nc .readTimer .C 
83		<- nc .readTimer .C 
57	}
84	}
@@ -64,59 +91,72 @@ type netConn struct {
64	msgType  MessageType 
91	msgType  MessageType 
65
92
66	writeTimer    * time.Timer 
93	writeTimer    * time.Timer 
67- 	writeContext  context.Context 
94+ 	writeMu       * mu 
95+ 	writeExpired  int64 
96+ 	writeCtx      context.Context 
68
97
69	readTimer    * time.Timer 
98	readTimer    * time.Timer 
70- 	readContext  context. Context 
99+ 	readMu        * mu 
71- 
100+ 	 readExpired   int64 
72- 	readMu  sync. Mutex 
101+ 	readCtx      context. Context 
73- 	eofed   bool 
102+ 	readEOFed     bool 
74- 	reader  io.Reader 
103+ 	reader        io.Reader 
75}
104}
76
105
77var  _  net.Conn  =  & netConn {}
106var  _  net.Conn  =  & netConn {}
78
107
79- func  (c  * netConn ) Close () error  {
108+ func  (nc  * netConn ) Close () error  {
80- 	return  c .c .Close (StatusNormalClosure , "" )
109+ 	return  nc .c .Close (StatusNormalClosure , "" )
81}
110}
82
111
83- func  (c  * netConn ) Write (p  []byte ) (int , error ) {
112+ func  (nc  * netConn ) Write (p  []byte ) (int , error ) {
84- 	err  :=  c .c .Write (c .writeContext , c .msgType , p )
113+ 	nc .writeMu .forceLock ()
114+ 	defer  nc .writeMu .unlock ()
115+ 
116+ 	if  atomic .LoadInt64 (& nc .writeExpired ) ==  1  {
117+ 		return  0 , fmt .Errorf ("failed to write: %w" , context .DeadlineExceeded )
118+ 	}
119+ 
120+ 	err  :=  nc .c .Write (nc .writeCtx , nc .msgType , p )
85	if  err  !=  nil  {
121	if  err  !=  nil  {
86		return  0 , err 
122		return  0 , err 
87	}
123	}
88	return  len (p ), nil 
124	return  len (p ), nil 
89}
125}
90
126
91- func  (c  * netConn ) Read (p  []byte ) (int , error ) {
127+ func  (nc  * netConn ) Read (p  []byte ) (int , error ) {
92- 	c .readMu .Lock ()
128+ 	nc .readMu .forceLock ()
93- 	defer  c .readMu .Unlock ()
129+ 	defer  nc .readMu .unlock ()
130+ 
131+ 	if  atomic .LoadInt64 (& nc .readExpired ) ==  1  {
132+ 		return  0 , fmt .Errorf ("failed to read: %w" , context .DeadlineExceeded )
133+ 	}
94
134
95- 	if  c . eofed  {
135+ 	if  nc . readEOFed  {
96		return  0 , io .EOF 
136		return  0 , io .EOF 
97	}
137	}
98
138
99- 	if  c .reader  ==  nil  {
139+ 	if  nc .reader  ==  nil  {
100- 		typ , r , err  :=  c .c .Reader (c . readContext )
140+ 		typ , r , err  :=  nc .c .Reader (nc . readCtx )
101		if  err  !=  nil  {
141		if  err  !=  nil  {
102			switch  CloseStatus (err ) {
142			switch  CloseStatus (err ) {
103			case  StatusNormalClosure , StatusGoingAway :
143			case  StatusNormalClosure , StatusGoingAway :
104- 				c . eofed  =  true 
144+ 				nc . readEOFed  =  true 
105				return  0 , io .EOF 
145				return  0 , io .EOF 
106			}
146			}
107			return  0 , err 
147			return  0 , err 
108		}
148		}
109- 		if  typ  !=  c .msgType  {
149+ 		if  typ  !=  nc .msgType  {
110- 			err  :=  fmt .Errorf ("unexpected frame type read (expected %v): %v" , c .msgType , typ )
150+ 			err  :=  fmt .Errorf ("unexpected frame type read (expected %v): %v" , nc .msgType , typ )
111- 			c .c .Close (StatusUnsupportedData , err .Error ())
151+ 			nc .c .Close (StatusUnsupportedData , err .Error ())
112			return  0 , err 
152			return  0 , err 
113		}
153		}
114- 		c .reader  =  r 
154+ 		nc .reader  =  r 
115	}
155	}
116
156
117- 	n , err  :=  c .reader .Read (p )
157+ 	n , err  :=  nc .reader .Read (p )
118	if  err  ==  io .EOF  {
158	if  err  ==  io .EOF  {
119- 		c .reader  =  nil 
159+ 		nc .reader  =  nil 
120		err  =  nil 
160		err  =  nil 
121	}
161	}
122	return  n , err 
162	return  n , err 
@@ -133,34 +173,36 @@ func (a websocketAddr) String() string {
133	return  "websocket/unknown-addr" 
173	return  "websocket/unknown-addr" 
134}
174}
135
175
136- func  (c  * netConn ) RemoteAddr () net.Addr  {
176+ func  (nc  * netConn ) RemoteAddr () net.Addr  {
137	return  websocketAddr {}
177	return  websocketAddr {}
138}
178}
139
179
140- func  (c  * netConn ) LocalAddr () net.Addr  {
180+ func  (nc  * netConn ) LocalAddr () net.Addr  {
141	return  websocketAddr {}
181	return  websocketAddr {}
142}
182}
143
183
144- func  (c  * netConn ) SetDeadline (t  time.Time ) error  {
184+ func  (nc  * netConn ) SetDeadline (t  time.Time ) error  {
145- 	c .SetWriteDeadline (t )
185+ 	nc .SetWriteDeadline (t )
146- 	c .SetReadDeadline (t )
186+ 	nc .SetReadDeadline (t )
147	return  nil 
187	return  nil 
148}
188}
149
189
150- func  (c  * netConn ) SetWriteDeadline (t  time.Time ) error  {
190+ func  (nc  * netConn ) SetWriteDeadline (t  time.Time ) error  {
191+ 	atomic .StoreInt64 (& nc .writeExpired , 0 )
151	if  t .IsZero () {
192	if  t .IsZero () {
152- 		c .writeTimer .Stop ()
193+ 		nc .writeTimer .Stop ()
153	} else  {
194	} else  {
154- 		c .writeTimer .Reset (t .Sub (time .Now ()))
195+ 		nc .writeTimer .Reset (t .Sub (time .Now ()))
155	}
196	}
156	return  nil 
197	return  nil 
157}
198}
158
199
159- func  (c  * netConn ) SetReadDeadline (t  time.Time ) error  {
200+ func  (nc  * netConn ) SetReadDeadline (t  time.Time ) error  {
201+ 	atomic .StoreInt64 (& nc .readExpired , 0 )
160	if  t .IsZero () {
202	if  t .IsZero () {
161- 		c .readTimer .Stop ()
203+ 		nc .readTimer .Stop ()
162	} else  {
204	} else  {
163- 		c .readTimer .Reset (t .Sub (time .Now ()))
205+ 		nc .readTimer .Reset (t .Sub (time .Now ()))
164	}
206	}
165	return  nil 
207	return  nil 
166}
208}
0 commit comments