Skip to content

Commit d91a212

Browse files
committed
wsjs: Ensure no goroutines leak after Close
Closes #330
1 parent 7b1a6bb commit d91a212

File tree

3 files changed

+26
-23
lines changed

3 files changed

+26
-23
lines changed

close.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,14 @@ func CloseStatus(err error) StatusCode {
9999
// Close will unblock all goroutines interacting with the connection once
100100
// complete.
101101
func (c *Conn) Close(code StatusCode, reason string) error {
102-
defer c.wgWait()
102+
defer c.wg.Wait()
103103
return c.closeHandshake(code, reason)
104104
}
105105

106106
// CloseNow closes the WebSocket connection without attempting a close handshake.
107107
// Use when you do not want the overhead of the close handshake.
108108
func (c *Conn) CloseNow() (err error) {
109-
defer c.wgWait()
109+
defer c.wg.Wait()
110110
defer errd.Wrap(&err, "failed to close WebSocket")
111111

112112
if c.isClosed() {

conn.go

+14-19
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ const (
4545
type Conn struct {
4646
noCopy
4747

48-
wg sync.WaitGroup
49-
5048
subprotocol string
5149
rwc io.ReadWriteCloser
5250
client bool
@@ -72,6 +70,7 @@ type Conn struct {
7270
writeHeaderBuf [8]byte
7371
writeHeader header
7472

73+
wg sync.WaitGroup
7574
closed chan struct{}
7675
closeMu sync.Mutex
7776
closeErr error
@@ -132,7 +131,11 @@ func newConn(cfg connConfig) *Conn {
132131
c.close(errors.New("connection garbage collected"))
133132
})
134133

135-
c.wgGo(c.timeoutLoop)
134+
c.wg.Add(1)
135+
go func() {
136+
defer c.wg.Done()
137+
c.timeoutLoop()
138+
}()
136139

137140
return c
138141
}
@@ -163,10 +166,12 @@ func (c *Conn) close(err error) {
163166
// closeErr.
164167
c.rwc.Close()
165168

166-
c.wgGo(func() {
169+
c.wg.Add(1)
170+
go func() {
171+
defer c.wg.Done()
167172
c.msgWriter.close()
168173
c.msgReader.close()
169-
})
174+
}()
170175
}
171176

172177
func (c *Conn) timeoutLoop() {
@@ -183,9 +188,11 @@ func (c *Conn) timeoutLoop() {
183188

184189
case <-readCtx.Done():
185190
c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
186-
c.wgGo(func() {
191+
c.wg.Add(1)
192+
go func() {
193+
defer c.wg.Done()
187194
c.writeError(StatusPolicyViolation, errors.New("read timed out"))
188-
})
195+
}()
189196
case <-writeCtx.Done():
190197
c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
191198
return
@@ -302,15 +309,3 @@ func (m *mu) unlock() {
302309
type noCopy struct{}
303310

304311
func (*noCopy) Lock() {}
305-
306-
func (c *Conn) wgGo(fn func()) {
307-
c.wg.Add(1)
308-
go func() {
309-
defer c.wg.Done()
310-
fn()
311-
}()
312-
}
313-
314-
func (c *Conn) wgWait() {
315-
c.wg.Wait()
316-
}

ws_js.go

+10-2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ type Conn struct {
4747
// read limit for a message in bytes.
4848
msgReadLimit xsync.Int64
4949

50+
wg sync.WaitGroup
5051
closingMu sync.Mutex
5152
isReadClosed xsync.Int64
5253
closeOnce sync.Once
@@ -223,6 +224,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
223224
// or the connection is closed.
224225
// It thus performs the full WebSocket close handshake.
225226
func (c *Conn) Close(code StatusCode, reason string) error {
227+
defer c.wg.Wait()
226228
err := c.exportedClose(code, reason)
227229
if err != nil {
228230
return fmt.Errorf("failed to close WebSocket: %w", err)
@@ -236,6 +238,7 @@ func (c *Conn) Close(code StatusCode, reason string) error {
236238
// note: No different from Close(StatusGoingAway, "") in WASM as there is no way to close
237239
// a WebSocket without the close handshake.
238240
func (c *Conn) CloseNow() error {
241+
defer c.wg.Wait()
239242
return c.Close(StatusGoingAway, "")
240243
}
241244

@@ -388,10 +391,15 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context {
388391
c.isReadClosed.Store(1)
389392

390393
ctx, cancel := context.WithCancel(ctx)
394+
c.wg.Add(1)
391395
go func() {
396+
defer c.CloseNow()
397+
defer c.wg.Done()
392398
defer cancel()
393-
c.read(ctx)
394-
c.Close(StatusPolicyViolation, "unexpected data message")
399+
_, _, err := c.read(ctx)
400+
if err != nil {
401+
c.Close(StatusPolicyViolation, "unexpected data message")
402+
}
395403
}()
396404
return ctx
397405
}

0 commit comments

Comments
 (0)