From bc4fce01803c504367a6996b2bb66aee1eb5a143 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Fri, 11 Oct 2019 15:18:35 -0400 Subject: [PATCH] Improve Close handshake behaviour - For JS we ensure we indicate which size initiated the close first from our POV - For normal Go, concurrent closes block until the first one succeeds instead of returning early --- conn.go | 33 ++++++++++++++++++++++++++++----- conn_common.go | 9 +++++++++ conn_test.go | 6 +++++- websocket_js.go | 33 +++++++++++++++------------------ 4 files changed, 57 insertions(+), 24 deletions(-) diff --git a/conn.go b/conn.go index 43a94397..861b2390 100644 --- a/conn.go +++ b/conn.go @@ -851,6 +851,13 @@ func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, e // complete. func (c *Conn) Close(code StatusCode, reason string) error { err := c.exportedClose(code, reason, true) + var ec errClosing + if errors.As(err, &ec) { + <-c.closed + // We wait until the connection closes. + // We use writeClose and not exportedClose to avoid a second failed to marshal close frame error. + err = c.writeClose(nil, ec.ce, true) + } if err != nil { return fmt.Errorf("failed to close websocket connection: %w", err) } @@ -878,15 +885,31 @@ func (c *Conn) exportedClose(code StatusCode, reason string, handshake bool) err return c.writeClose(p, fmt.Errorf("sent close: %w", ce), handshake) } +type errClosing struct { + ce error +} + +func (e errClosing) Error() string { + return "already closing connection" +} + func (c *Conn) writeClose(p []byte, ce error, handshake bool) error { - select { - case <-c.closed: - return fmt.Errorf("tried to close with %v but connection already closed: %w", ce, c.closeErr) - default: + if c.isClosed() { + return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr) } if !c.closing.CAS(0, 1) { - return fmt.Errorf("another goroutine is closing") + // Normally, we would want to wait until the connection is closed, + // at least for when a user calls into Close, so we handle that case in + // the exported Close function. + // + // But for internal library usage, we always want to return early, e.g. + // if we are performing a close handshake and the peer sends their close frame, + // we do not want to block here waiting for c.closed to close because it won't, + // at least not until we return since the gorouine that will close it is this one. + return errClosing{ + ce: ce, + } } // No matter what happens next, close error should be set. diff --git a/conn_common.go b/conn_common.go index 9f0b045a..1247df6e 100644 --- a/conn_common.go +++ b/conn_common.go @@ -234,3 +234,12 @@ func (v *atomicInt64) Increment(delta int64) int64 { func (v *atomicInt64) CAS(old, new int64) (swapped bool) { return atomic.CompareAndSwapInt64(&v.v, old, new) } + +func (c *Conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} diff --git a/conn_test.go b/conn_test.go index 1acdf595..8413c4c2 100644 --- a/conn_test.go +++ b/conn_test.go @@ -602,7 +602,11 @@ func TestConn(t *testing.T) { { name: "largeControlFrame", server: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, websocket.OpClose, []byte(strings.Repeat("x", 4096))) + err := c.WriteHeader(ctx, websocket.Header{ + Fin: true, + OpCode: websocket.OpClose, + PayloadLength: 4096, + }) if err != nil { return err } diff --git a/websocket_js.go b/websocket_js.go index f297f9d4..d27809cf 100644 --- a/websocket_js.go +++ b/websocket_js.go @@ -23,7 +23,7 @@ type Conn struct { // read limit for a message in bytes. msgReadLimit *atomicInt64 - closeMu sync.Mutex + closingMu sync.Mutex isReadClosed *atomicInt64 closeOnce sync.Once closed chan struct{} @@ -43,6 +43,9 @@ func (c *Conn) close(err error, wasClean bool) { c.closeOnce.Do(func() { runtime.SetFinalizer(c, nil) + if !wasClean { + err = fmt.Errorf("unclean connection close: %w", err) + } c.setCloseErr(err) c.closeWasClean = wasClean close(c.closed) @@ -59,14 +62,11 @@ func (c *Conn) init() { c.isReadClosed = &atomicInt64{} c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) { - var err error = CloseError{ + err := CloseError{ Code: StatusCode(e.Code), Reason: e.Reason, } - if !e.WasClean { - err = fmt.Errorf("connection close was not clean: %w", err) - } - c.close(err, e.WasClean) + c.close(fmt.Errorf("received close: %w", err), e.WasClean) c.releaseOnClose() c.releaseOnMessage() @@ -182,15 +182,6 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { } } -func (c *Conn) isClosed() bool { - select { - case <-c.closed: - return true - default: - return false - } -} - // Close closes the websocket with the given code and reason. // It will wait until the peer responds with a close frame // or the connection is closed. @@ -204,13 +195,19 @@ func (c *Conn) Close(code StatusCode, reason string) error { } func (c *Conn) exportedClose(code StatusCode, reason string) error { - c.closeMu.Lock() - defer c.closeMu.Unlock() + c.closingMu.Lock() + defer c.closingMu.Unlock() + + ce := fmt.Errorf("sent close: %w", CloseError{ + Code: code, + Reason: reason, + }) if c.isClosed() { - return fmt.Errorf("already closed: %w", c.closeErr) + return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr) } + c.setCloseErr(ce) err := c.ws.Close(int(code), reason) if err != nil { return err