Skip to content

Commit 6ed989a

Browse files
committed
Ensure no goroutines leak after Close
Closes #330
1 parent e6a7e0e commit 6ed989a

7 files changed

+70
-31
lines changed

conn.go

+24-10
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,10 @@ type Conn struct {
5353
br *bufio.Reader
5454
bw *bufio.Writer
5555

56-
readTimeout chan context.Context
57-
writeTimeout chan context.Context
56+
timeoutLoopCancel context.CancelFunc
57+
timeoutLoopDone chan struct{}
58+
readTimeout chan context.Context
59+
writeTimeout chan context.Context
5860

5961
// Read state.
6062
readMu *mu
@@ -102,8 +104,9 @@ func newConn(cfg connConfig) *Conn {
102104
br: cfg.br,
103105
bw: cfg.bw,
104106

105-
readTimeout: make(chan context.Context),
106-
writeTimeout: make(chan context.Context),
107+
timeoutLoopDone: make(chan struct{}),
108+
readTimeout: make(chan context.Context),
109+
writeTimeout: make(chan context.Context),
107110

108111
closed: make(chan struct{}),
109112
activePings: make(map[string]chan<- struct{}),
@@ -130,7 +133,9 @@ func newConn(cfg connConfig) *Conn {
130133
c.close(errors.New("connection garbage collected"))
131134
})
132135

133-
go c.timeoutLoop()
136+
var ctx context.Context
137+
ctx, c.timeoutLoopCancel = context.WithCancel(context.Background())
138+
go c.timeoutLoop(ctx)
134139

135140
return c
136141
}
@@ -152,6 +157,10 @@ func (c *Conn) close(err error) {
152157
err = c.rwc.Close()
153158
}
154159
c.setCloseErrLocked(err)
160+
161+
c.timeoutLoopCancel()
162+
<-c.timeoutLoopDone
163+
155164
close(c.closed)
156165
runtime.SetFinalizer(c, nil)
157166

@@ -160,18 +169,23 @@ func (c *Conn) close(err error) {
160169
// closeErr.
161170
c.rwc.Close()
162171

163-
go func() {
164-
c.msgWriter.close()
165-
c.msgReader.close()
166-
}()
172+
c.closeMu.Unlock()
173+
defer c.closeMu.Lock()
174+
175+
c.msgWriter.close()
176+
c.msgReader.close()
167177
}
168178

169-
func (c *Conn) timeoutLoop() {
179+
func (c *Conn) timeoutLoop(ctx context.Context) {
180+
defer close(c.timeoutLoopDone)
181+
170182
readCtx := context.Background()
171183
writeCtx := context.Background()
172184

173185
for {
174186
select {
187+
case <-ctx.Done():
188+
return
175189
case <-c.closed:
176190
return
177191

conn_test.go

+9-8
Original file line numberDiff line numberDiff line change
@@ -399,10 +399,8 @@ func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *webs
399399
c1, c2 = c2, c1
400400
}
401401
t.Cleanup(func() {
402-
// We don't actually care whether this succeeds so we just run it in a separate goroutine to avoid
403-
// blocking the test shutting down.
404-
go c2.Close(websocket.StatusInternalError, "")
405-
go c1.Close(websocket.StatusInternalError, "")
402+
c2.CloseNow()
403+
c1.CloseNow()
406404
})
407405

408406
return tt, c1, c2
@@ -596,16 +594,19 @@ func TestConcurrentClosePing(t *testing.T) {
596594
defer c2.CloseNow()
597595
c1.CloseRead(context.Background())
598596
c2.CloseRead(context.Background())
599-
go func() {
597+
errc := xsync.Go(func() error {
600598
for range time.Tick(time.Millisecond) {
601-
if err := c1.Ping(context.Background()); err != nil {
602-
return
599+
err := c1.Ping(context.Background())
600+
if err != nil {
601+
return err
603602
}
604603
}
605-
}()
604+
panic("unreachable")
605+
})
606606

607607
time.Sleep(10 * time.Millisecond)
608608
assert.Success(t, c1.Close(websocket.StatusNormalClosure, ""))
609+
<-errc
609610
}()
610611
}
611612
}

dial_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,12 @@ func Test_verifyHostOverride(t *testing.T) {
164164
}, nil
165165
}
166166

167-
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
167+
c, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
168168
HTTPClient: mockHTTPClient(rt),
169169
Host: tc.host,
170170
})
171171
assert.Success(t, err)
172+
c.CloseNow()
172173
})
173174
}
174175

main_test.go

+14-1
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,23 @@ import (
77
"testing"
88
)
99

10+
func goroutineStacks() []byte {
11+
buf := make([]byte, 512)
12+
for {
13+
m := runtime.Stack(buf, true)
14+
if m < len(buf) {
15+
return buf[:m]
16+
}
17+
buf = make([]byte, len(buf)*2)
18+
}
19+
}
20+
1021
func TestMain(m *testing.M) {
1122
code := m.Run()
12-
if runtime.NumGoroutine() != 1 {
23+
if runtime.GOOS != "js" && runtime.NumGoroutine() != 1 ||
24+
runtime.GOOS == "js" && runtime.NumGoroutine() != 2 {
1325
fmt.Fprintf(os.Stderr, "goroutine leak detected, expected 1 but got %d goroutines\n", runtime.NumGoroutine())
26+
fmt.Fprintf(os.Stderr, "%s\n", goroutineStacks())
1427
os.Exit(1)
1528
}
1629
os.Exit(code)

read.go

+5
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
219219
case <-ctx.Done():
220220
return header{}, ctx.Err()
221221
default:
222+
c.readMu.unlock()
222223
c.close(err)
223224
return header{}, err
224225
}
@@ -249,6 +250,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
249250
return n, ctx.Err()
250251
default:
251252
err = fmt.Errorf("failed to read frame payload: %w", err)
253+
c.readMu.unlock()
252254
c.close(err)
253255
return n, err
254256
}
@@ -319,6 +321,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
319321
err = fmt.Errorf("received close frame: %w", ce)
320322
c.setCloseErr(err)
321323
c.writeClose(ce.Code, ce.Reason)
324+
c.readMu.unlock()
322325
c.close(err)
323326
return err
324327
}
@@ -334,6 +337,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro
334337

335338
if !c.msgReader.fin {
336339
err = errors.New("previous message not read to completion")
340+
c.readMu.unlock()
337341
c.close(fmt.Errorf("failed to get reader: %w", err))
338342
return 0, nil, err
339343
}
@@ -409,6 +413,7 @@ func (mr *msgReader) Read(p []byte) (n int, err error) {
409413
}
410414
if err != nil {
411415
err = fmt.Errorf("failed to read: %w", err)
416+
mr.c.readMu.unlock()
412417
mr.c.close(err)
413418
}
414419
return n, err

write.go

+15-10
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
109109

110110
if !c.flate() {
111111
defer c.msgWriter.mu.unlock()
112-
return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
112+
return c.writeFrame(true, ctx, true, false, c.msgWriter.opcode, p)
113113
}
114114

115115
n, err := mw.Write(p)
@@ -159,6 +159,7 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) {
159159
defer func() {
160160
if err != nil {
161161
err = fmt.Errorf("failed to write: %w", err)
162+
mw.writeMu.unlock()
162163
mw.c.close(err)
163164
}
164165
}()
@@ -179,7 +180,7 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) {
179180
}
180181

181182
func (mw *msgWriter) write(p []byte) (int, error) {
182-
n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
183+
n, err := mw.c.writeFrame(true, mw.ctx, false, mw.flate, mw.opcode, p)
183184
if err != nil {
184185
return n, fmt.Errorf("failed to write data frame: %w", err)
185186
}
@@ -191,25 +192,25 @@ func (mw *msgWriter) write(p []byte) (int, error) {
191192
func (mw *msgWriter) Close() (err error) {
192193
defer errd.Wrap(&err, "failed to close writer")
193194

194-
if mw.closed {
195-
return errors.New("writer already closed")
196-
}
197-
mw.closed = true
198-
199195
err = mw.writeMu.lock(mw.ctx)
200196
if err != nil {
201197
return err
202198
}
203199
defer mw.writeMu.unlock()
204200

201+
if mw.closed {
202+
return errors.New("writer already closed")
203+
}
204+
mw.closed = true
205+
205206
if mw.flate {
206207
err = mw.flateWriter.Flush()
207208
if err != nil {
208209
return fmt.Errorf("failed to flush flate: %w", err)
209210
}
210211
}
211212

212-
_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
213+
_, err = mw.c.writeFrame(true, mw.ctx, true, mw.flate, mw.opcode, nil)
213214
if err != nil {
214215
return fmt.Errorf("failed to write fin frame: %w", err)
215216
}
@@ -235,15 +236,15 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
235236
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
236237
defer cancel()
237238

238-
_, err := c.writeFrame(ctx, true, false, opcode, p)
239+
_, err := c.writeFrame(false, ctx, true, false, opcode, p)
239240
if err != nil {
240241
return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
241242
}
242243
return nil
243244
}
244245

245246
// frame handles all writes to the connection.
246-
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
247+
func (c *Conn) writeFrame(msgWriter bool, ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
247248
err = c.writeFrameMu.lock(ctx)
248249
if err != nil {
249250
return 0, err
@@ -283,6 +284,10 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
283284
err = ctx.Err()
284285
default:
285286
}
287+
c.writeFrameMu.unlock()
288+
if msgWriter {
289+
c.msgWriter.writeMu.unlock()
290+
}
286291
c.close(err)
287292
err = fmt.Errorf("failed to write frame: %w", err)
288293
}

ws_js.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ func (c *Conn) Close(code StatusCode, reason string) error {
231231
}
232232

233233
// CloseNow closes the WebSocket connection without attempting a close handshake.
234-
// Use When you do not want the overhead of the close handshake.
234+
// Use when you do not want the overhead of the close handshake.
235235
//
236236
// note: No different from Close(StatusGoingAway, "") in WASM as there is no way to close
237237
// a WebSocket without the close handshake.

0 commit comments

Comments
 (0)