From e7338b046f1c3794b57d2dc9d00d5e76a9c6abdc Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 26 Aug 2024 15:58:27 +0530 Subject: [PATCH 01/11] introduce GoAwayError type --- const.go | 65 ---------------------------------- errors.go | 102 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 65 deletions(-) create mode 100644 errors.go diff --git a/const.go b/const.go index e4b2bc2..c1a2deb 100644 --- a/const.go +++ b/const.go @@ -5,71 +5,6 @@ import ( "fmt" ) -type Error struct { - msg string - timeout, temporary bool -} - -func (ye *Error) Error() string { - return ye.msg -} - -func (ye *Error) Timeout() bool { - return ye.timeout -} - -func (ye *Error) Temporary() bool { - return ye.temporary -} - -var ( - // ErrInvalidVersion means we received a frame with an - // invalid version - ErrInvalidVersion = &Error{msg: "invalid protocol version"} - - // ErrInvalidMsgType means we received a frame with an - // invalid message type - ErrInvalidMsgType = &Error{msg: "invalid msg type"} - - // ErrSessionShutdown is used if there is a shutdown during - // an operation - ErrSessionShutdown = &Error{msg: "session shutdown"} - - // ErrStreamsExhausted is returned if we have no more - // stream ids to issue - ErrStreamsExhausted = &Error{msg: "streams exhausted"} - - // ErrDuplicateStream is used if a duplicate stream is - // opened inbound - ErrDuplicateStream = &Error{msg: "duplicate stream initiated"} - - // ErrReceiveWindowExceeded indicates the window was exceeded - ErrRecvWindowExceeded = &Error{msg: "recv window exceeded"} - - // ErrTimeout is used when we reach an IO deadline - ErrTimeout = &Error{msg: "i/o deadline reached", timeout: true, temporary: true} - - // ErrStreamClosed is returned when using a closed stream - ErrStreamClosed = &Error{msg: "stream closed"} - - // ErrUnexpectedFlag is set when we get an unexpected flag - ErrUnexpectedFlag = &Error{msg: "unexpected flag"} - - // ErrRemoteGoAway is used when we get a go away from the other side - ErrRemoteGoAway = &Error{msg: "remote end is not accepting connections"} - - // ErrStreamReset is sent if a stream is reset. This can happen - // if the backlog is exceeded, or if there was a remote GoAway. - ErrStreamReset = &Error{msg: "stream reset"} - - // ErrConnectionWriteTimeout indicates that we hit the "safety valve" - // timeout writing to the underlying stream connection. - ErrConnectionWriteTimeout = &Error{msg: "connection write timeout", timeout: true} - - // ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close - ErrKeepAliveTimeout = &Error{msg: "keepalive timeout", timeout: true} -) - const ( // protoVersion is the only version we support protoVersion uint8 = 0 diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..5ffc742 --- /dev/null +++ b/errors.go @@ -0,0 +1,102 @@ +package yamux + +import "fmt" + +type Error struct { + msg string + timeout, temporary bool +} + +func (ye *Error) Error() string { + return ye.msg +} + +func (ye *Error) Timeout() bool { + return ye.timeout +} + +func (ye *Error) Temporary() bool { + return ye.temporary +} + +type GoAwayError struct { + Remote bool + ErrorCode uint32 +} + +func (e *GoAwayError) Error() string { + if e.Remote { + return fmt.Sprintf("remote sent go away, code: %d", e.ErrorCode) + } + return fmt.Sprintf("sent go away, code: %d", e.ErrorCode) +} + +func (e *GoAwayError) Timeout() bool { + return false +} + +func (e *GoAwayError) Temporary() bool { + return false +} + +func (e *GoAwayError) Is(target error) bool { + // to maintain compatibility with errors returned by previous versions + if e.Remote && target == ErrRemoteGoAway { + return true + } else if !e.Remote && target == ErrSessionShutdown { + return true + } + + if err, ok := target.(*GoAwayError); ok { + return *e == *err + } + return false +} + +var ( + // ErrInvalidVersion means we received a frame with an + // invalid version + ErrInvalidVersion = &Error{msg: "invalid protocol version"} + + // ErrInvalidMsgType means we received a frame with an + // invalid message type + ErrInvalidMsgType = &Error{msg: "invalid msg type"} + + // ErrSessionShutdown is used if there is a shutdown during + // an operation + ErrSessionShutdown = &Error{msg: "session shutdown"} + + // ErrStreamsExhausted is returned if we have no more + // stream ids to issue + ErrStreamsExhausted = &Error{msg: "streams exhausted"} + + // ErrDuplicateStream is used if a duplicate stream is + // opened inbound + ErrDuplicateStream = &Error{msg: "duplicate stream initiated"} + + // ErrReceiveWindowExceeded indicates the window was exceeded + ErrRecvWindowExceeded = &Error{msg: "recv window exceeded"} + + // ErrTimeout is used when we reach an IO deadline + ErrTimeout = &Error{msg: "i/o deadline reached", timeout: true, temporary: true} + + // ErrStreamClosed is returned when using a closed stream + ErrStreamClosed = &Error{msg: "stream closed"} + + // ErrUnexpectedFlag is set when we get an unexpected flag + ErrUnexpectedFlag = &Error{msg: "unexpected flag"} + + // ErrRemoteGoAway is used when we get a go away from the other side + ErrRemoteGoAway = &Error{msg: "remote end is not accepting connections"} + + // ErrStreamReset is sent if a stream is reset. This can happen + // if the backlog is exceeded, or if there was a remote GoAway. + ErrStreamReset = &Error{msg: "stream reset"} + + // ErrConnectionWriteTimeout indicates that we hit the "safety valve" + // timeout writing to the underlying stream connection. + ErrConnectionWriteTimeout = &Error{msg: "connection write timeout", timeout: true} + + // ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close + ErrKeepAliveTimeout = &Error{msg: "keepalive timeout", timeout: true} +) From d8cf4e74eacf8ac6195c34e977a075638d3b8460 Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 26 Aug 2024 16:12:41 +0530 Subject: [PATCH 02/11] send GoAway on Close --- const.go | 2 ++ session.go | 25 +++++++++++++++++++++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/const.go b/const.go index c1a2deb..e737d85 100644 --- a/const.go +++ b/const.go @@ -3,6 +3,7 @@ package yamux import ( "encoding/binary" "fmt" + "time" ) const ( @@ -52,6 +53,7 @@ const ( // It's not an implementation choice, the value defined in the specification. initialStreamWindow = 256 * 1024 maxStreamWindow = 16 * 1024 * 1024 + goAwayWaitTime = 5 * time.Second ) const ( diff --git a/session.go b/session.go index c4cd1bd..62ea2c3 100644 --- a/session.go +++ b/session.go @@ -284,8 +284,14 @@ func (s *Session) AcceptStream() (*Stream, error) { } // Close is used to close the session and all streams. -// Attempts to send a GoAway before closing the connection. +// Attempts to send a GoAway before closing the connection. The GoAway may not actually be sent depending on the +// semantics of the underlying net.Conn. For TCP connections, it may be dropped depending on LINGER value or +// if there's unread data in the kernel receive buffer. func (s *Session) Close() error { + return s.close(true, goAwayNormal) +} + +func (s *Session) close(sendGoAway bool, errCode uint32) error { s.shutdownLock.Lock() defer s.shutdownLock.Unlock() @@ -297,10 +303,21 @@ func (s *Session) Close() error { s.shutdownErr = ErrSessionShutdown } close(s.shutdownCh) - s.conn.Close() s.stopKeepalive() - <-s.recvDoneCh + + // wait for write loop to exit + _ = s.conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)) // if SetWriteDeadline errored, any blocked writes will be unblocked <-s.sendDoneCh + if sendGoAway { + ga := s.goAway(errCode) + if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil { + _, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here + } + } + + s.conn.SetWriteDeadline(time.Time{}) + s.conn.Close() + <-s.recvDoneCh s.streamLock.Lock() defer s.streamLock.Unlock() @@ -320,7 +337,7 @@ func (s *Session) exitErr(err error) { s.shutdownErr = err } s.shutdownLock.Unlock() - s.Close() + s.close(false, 0) } // GoAway can be used to prevent accepting further From 4b262c087158b3b6ca5c5f0c5b51be77d72ec577 Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 26 Aug 2024 21:02:33 +0530 Subject: [PATCH 03/11] add CloseWithError --- errors.go | 4 +-- session.go | 69 ++++++++++++++++++++++++++++++------------------- session_test.go | 43 +++++++++++++++++++++++++++--- 3 files changed, 83 insertions(+), 33 deletions(-) diff --git a/errors.go b/errors.go index 5ffc742..7bedec5 100644 --- a/errors.go +++ b/errors.go @@ -64,7 +64,7 @@ var ( // ErrSessionShutdown is used if there is a shutdown during // an operation - ErrSessionShutdown = &Error{msg: "session shutdown"} + ErrSessionShutdown = &GoAwayError{ErrorCode: goAwayNormal, Remote: false} // ErrStreamsExhausted is returned if we have no more // stream ids to issue @@ -87,7 +87,7 @@ var ( ErrUnexpectedFlag = &Error{msg: "unexpected flag"} // ErrRemoteGoAway is used when we get a go away from the other side - ErrRemoteGoAway = &Error{msg: "remote end is not accepting connections"} + ErrRemoteGoAway = &GoAwayError{Remote: true, ErrorCode: goAwayNormal} // ErrStreamReset is sent if a stream is reset. This can happen // if the backlog is exceeded, or if there was a remote GoAway. diff --git a/session.go b/session.go index 62ea2c3..e31bef4 100644 --- a/session.go +++ b/session.go @@ -102,6 +102,8 @@ type Session struct { // recvDoneCh is closed when recv() exits to avoid a race // between stream registration and stream shutdown recvDoneCh chan struct{} + // recvErr is the error the receive loop ended with + recvErr error // sendDoneCh is closed when send() exits to avoid a race // between returning from a Stream.Write and exiting from the send loop @@ -288,10 +290,18 @@ func (s *Session) AcceptStream() (*Stream, error) { // semantics of the underlying net.Conn. For TCP connections, it may be dropped depending on LINGER value or // if there's unread data in the kernel receive buffer. func (s *Session) Close() error { - return s.close(true, goAwayNormal) + return s.close(ErrSessionShutdown, true, goAwayNormal) } -func (s *Session) close(sendGoAway bool, errCode uint32) error { +// CloseWithError is used to close the session and all streams after sending a GoAway message with errCode. +// The GoAway may not actually be sent depending on the semantics of the underlying net.Conn. +// For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel +// receive buffer. +func (s *Session) CloseWithError(errCode uint32) error { + return s.close(&GoAwayError{Remote: false, ErrorCode: errCode}, true, errCode) +} + +func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) error { s.shutdownLock.Lock() defer s.shutdownLock.Unlock() @@ -300,23 +310,25 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error { } s.shutdown = true if s.shutdownErr == nil { - s.shutdownErr = ErrSessionShutdown + s.shutdownErr = shutdownErr } close(s.shutdownCh) s.stopKeepalive() - // wait for write loop to exit - _ = s.conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)) // if SetWriteDeadline errored, any blocked writes will be unblocked - <-s.sendDoneCh if sendGoAway { + // wait for write loop to exit + // We need to write the current frame completely before sending a goaway. + // This will wait for at most s.config.ConnectionWriteTimeout + <-s.sendDoneCh ga := s.goAway(errCode) if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil { _, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here } + s.conn.SetWriteDeadline(time.Time{}) } - s.conn.SetWriteDeadline(time.Time{}) s.conn.Close() + <-s.sendDoneCh <-s.recvDoneCh s.streamLock.Lock() @@ -329,17 +341,6 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error { return nil } -// exitErr is used to handle an error that is causing the -// session to terminate. -func (s *Session) exitErr(err error) { - s.shutdownLock.Lock() - if s.shutdownErr == nil { - s.shutdownErr = err - } - s.shutdownLock.Unlock() - s.close(false, 0) -} - // GoAway can be used to prevent accepting further // connections. It does not close the underlying conn. func (s *Session) GoAway() error { @@ -468,7 +469,7 @@ func (s *Session) startKeepalive() { if err != nil { s.logger.Printf("[ERR] yamux: keepalive failed: %v", err) - s.exitErr(ErrKeepAliveTimeout) + s.close(ErrKeepAliveTimeout, false, 0) } }) } @@ -533,7 +534,18 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err // send is a long running goroutine that sends data func (s *Session) send() { if err := s.sendLoop(); err != nil { - s.exitErr(err) + // Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code + // received in a GoAway frame received just before the TCP RST that closed the sendLoop + // + // Take the shutdownLock to avoid closing the connection concurrently with a Close call. + s.shutdownLock.Lock() + s.conn.Close() + <-s.recvDoneCh + if _, ok := s.recvErr.(*GoAwayError); ok { + err = s.recvErr + } + s.shutdownLock.Unlock() + s.close(err, false, 0) } } @@ -661,7 +673,7 @@ func (s *Session) sendLoop() (err error) { // recv is a long running goroutine that accepts new data func (s *Session) recv() { if err := s.recvLoop(); err != nil { - s.exitErr(err) + s.close(err, false, 0) } } @@ -683,7 +695,10 @@ func (s *Session) recvLoop() (err error) { err = fmt.Errorf("panic in yamux receive loop: %s", rerr) } }() - defer close(s.recvDoneCh) + defer func() { + s.recvErr = err + close(s.recvDoneCh) + }() var hdr header for { // fmt.Printf("ReadFull from %#v\n", s.reader) @@ -799,17 +814,17 @@ func (s *Session) handleGoAway(hdr header) error { switch code { case goAwayNormal: atomic.SwapInt32(&s.remoteGoAway, 1) + // Don't close connection on normal go away. Let the existing streams + // complete gracefully. + return nil case goAwayProtoErr: s.logger.Printf("[ERR] yamux: received protocol error go away") - return fmt.Errorf("yamux protocol error") case goAwayInternalErr: s.logger.Printf("[ERR] yamux: received internal error go away") - return fmt.Errorf("remote yamux internal error") default: - s.logger.Printf("[ERR] yamux: received unexpected go away") - return fmt.Errorf("unexpected go away received") + s.logger.Printf("[ERR] yamux: received go away with error code: %d", code) } - return nil + return &GoAwayError{Remote: true, ErrorCode: code} } // incomingStream is used to create a new incoming stream diff --git a/session_test.go b/session_test.go index 974b6d5..df3e3c9 100644 --- a/session_test.go +++ b/session_test.go @@ -3,6 +3,7 @@ package yamux import ( "bytes" "context" + "errors" "fmt" "io" "math/rand" @@ -39,6 +40,8 @@ type pipeConn struct { writeDeadline pipeDeadline writeBlocker chan struct{} closeCh chan struct{} + closeOnce sync.Once + closeErr error } func (p *pipeConn) SetDeadline(t time.Time) error { @@ -65,10 +68,12 @@ func (p *pipeConn) Write(b []byte) (int, error) { } func (p *pipeConn) Close() error { - p.writeDeadline.set(time.Time{}) - err := p.Conn.Close() - close(p.closeCh) - return err + p.closeOnce.Do(func() { + p.writeDeadline.set(time.Time{}) + p.closeErr = p.Conn.Close() + close(p.closeCh) + }) + return p.closeErr } func (p *pipeConn) BlockWrites() { @@ -650,6 +655,35 @@ func TestGoAway(t *testing.T) { default: t.Fatalf("err: %v", err) } + time.Sleep(50 * time.Millisecond) + } + t.Fatalf("expected GoAway error") +} + +func TestCloseWithError(t *testing.T) { + // This test is noisy. + conf := testConf() + conf.LogOutput = io.Discard + + client, server := testClientServerConfig(conf) + defer client.Close() + defer server.Close() + + if err := server.CloseWithError(42); err != nil { + t.Fatalf("err: %v", err) + } + + for i := 0; i < 100; i++ { + s, err := client.Open(context.Background()) + if err == nil { + s.Close() + time.Sleep(50 * time.Millisecond) + continue + } + if !errors.Is(err, &GoAwayError{ErrorCode: 42, Remote: true}) { + t.Fatalf("err: %v", err) + } + return } t.Fatalf("expected GoAway error") } @@ -1048,6 +1082,7 @@ func TestKeepAlive_Timeout(t *testing.T) { // Prevent the client from responding clientConn := client.conn.(*pipeConn) clientConn.BlockWrites() + defer clientConn.UnblockWrites() select { case err := <-errCh: From ea5605b186febe5c5e75c263f17365092ae13240 Mon Sep 17 00:00:00 2001 From: sukun Date: Tue, 27 Aug 2024 12:01:50 +0530 Subject: [PATCH 04/11] move errors back to const --- const.go | 99 ++++++++++++++++++++++++++++++++++++++++++++++++++++ errors.go | 102 ------------------------------------------------------ 2 files changed, 99 insertions(+), 102 deletions(-) delete mode 100644 errors.go diff --git a/const.go b/const.go index e737d85..ba33f7c 100644 --- a/const.go +++ b/const.go @@ -6,6 +6,105 @@ import ( "time" ) +type Error struct { + msg string + timeout, temporary bool +} + +func (ye *Error) Error() string { + return ye.msg +} + +func (ye *Error) Timeout() bool { + return ye.timeout +} + +func (ye *Error) Temporary() bool { + return ye.temporary +} + +type GoAwayError struct { + Remote bool + ErrorCode uint32 +} + +func (e *GoAwayError) Error() string { + if e.Remote { + return fmt.Sprintf("remote sent go away, code: %d", e.ErrorCode) + } + return fmt.Sprintf("sent go away, code: %d", e.ErrorCode) +} + +func (e *GoAwayError) Timeout() bool { + return false +} + +func (e *GoAwayError) Temporary() bool { + return false +} + +func (e *GoAwayError) Is(target error) bool { + // to maintain compatibility with errors returned by previous versions + if e.Remote && target == ErrRemoteGoAway { + return true + } else if !e.Remote && target == ErrSessionShutdown { + return true + } + + if err, ok := target.(*GoAwayError); ok { + return *e == *err + } + return false +} + +var ( + // ErrInvalidVersion means we received a frame with an + // invalid version + ErrInvalidVersion = &Error{msg: "invalid protocol version"} + + // ErrInvalidMsgType means we received a frame with an + // invalid message type + ErrInvalidMsgType = &Error{msg: "invalid msg type"} + + // ErrSessionShutdown is used if there is a shutdown during + // an operation + ErrSessionShutdown = &GoAwayError{ErrorCode: goAwayNormal, Remote: false} + + // ErrStreamsExhausted is returned if we have no more + // stream ids to issue + ErrStreamsExhausted = &Error{msg: "streams exhausted"} + + // ErrDuplicateStream is used if a duplicate stream is + // opened inbound + ErrDuplicateStream = &Error{msg: "duplicate stream initiated"} + + // ErrReceiveWindowExceeded indicates the window was exceeded + ErrRecvWindowExceeded = &Error{msg: "recv window exceeded"} + + // ErrTimeout is used when we reach an IO deadline + ErrTimeout = &Error{msg: "i/o deadline reached", timeout: true, temporary: true} + + // ErrStreamClosed is returned when using a closed stream + ErrStreamClosed = &Error{msg: "stream closed"} + + // ErrUnexpectedFlag is set when we get an unexpected flag + ErrUnexpectedFlag = &Error{msg: "unexpected flag"} + + // ErrRemoteGoAway is used when we get a go away from the other side + ErrRemoteGoAway = &GoAwayError{Remote: true, ErrorCode: goAwayNormal} + + // ErrStreamReset is sent if a stream is reset. This can happen + // if the backlog is exceeded, or if there was a remote GoAway. + ErrStreamReset = &Error{msg: "stream reset"} + + // ErrConnectionWriteTimeout indicates that we hit the "safety valve" + // timeout writing to the underlying stream connection. + ErrConnectionWriteTimeout = &Error{msg: "connection write timeout", timeout: true} + + // ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close + ErrKeepAliveTimeout = &Error{msg: "keepalive timeout", timeout: true} +) + const ( // protoVersion is the only version we support protoVersion uint8 = 0 diff --git a/errors.go b/errors.go deleted file mode 100644 index 7bedec5..0000000 --- a/errors.go +++ /dev/null @@ -1,102 +0,0 @@ -package yamux - -import "fmt" - -type Error struct { - msg string - timeout, temporary bool -} - -func (ye *Error) Error() string { - return ye.msg -} - -func (ye *Error) Timeout() bool { - return ye.timeout -} - -func (ye *Error) Temporary() bool { - return ye.temporary -} - -type GoAwayError struct { - Remote bool - ErrorCode uint32 -} - -func (e *GoAwayError) Error() string { - if e.Remote { - return fmt.Sprintf("remote sent go away, code: %d", e.ErrorCode) - } - return fmt.Sprintf("sent go away, code: %d", e.ErrorCode) -} - -func (e *GoAwayError) Timeout() bool { - return false -} - -func (e *GoAwayError) Temporary() bool { - return false -} - -func (e *GoAwayError) Is(target error) bool { - // to maintain compatibility with errors returned by previous versions - if e.Remote && target == ErrRemoteGoAway { - return true - } else if !e.Remote && target == ErrSessionShutdown { - return true - } - - if err, ok := target.(*GoAwayError); ok { - return *e == *err - } - return false -} - -var ( - // ErrInvalidVersion means we received a frame with an - // invalid version - ErrInvalidVersion = &Error{msg: "invalid protocol version"} - - // ErrInvalidMsgType means we received a frame with an - // invalid message type - ErrInvalidMsgType = &Error{msg: "invalid msg type"} - - // ErrSessionShutdown is used if there is a shutdown during - // an operation - ErrSessionShutdown = &GoAwayError{ErrorCode: goAwayNormal, Remote: false} - - // ErrStreamsExhausted is returned if we have no more - // stream ids to issue - ErrStreamsExhausted = &Error{msg: "streams exhausted"} - - // ErrDuplicateStream is used if a duplicate stream is - // opened inbound - ErrDuplicateStream = &Error{msg: "duplicate stream initiated"} - - // ErrReceiveWindowExceeded indicates the window was exceeded - ErrRecvWindowExceeded = &Error{msg: "recv window exceeded"} - - // ErrTimeout is used when we reach an IO deadline - ErrTimeout = &Error{msg: "i/o deadline reached", timeout: true, temporary: true} - - // ErrStreamClosed is returned when using a closed stream - ErrStreamClosed = &Error{msg: "stream closed"} - - // ErrUnexpectedFlag is set when we get an unexpected flag - ErrUnexpectedFlag = &Error{msg: "unexpected flag"} - - // ErrRemoteGoAway is used when we get a go away from the other side - ErrRemoteGoAway = &GoAwayError{Remote: true, ErrorCode: goAwayNormal} - - // ErrStreamReset is sent if a stream is reset. This can happen - // if the backlog is exceeded, or if there was a remote GoAway. - ErrStreamReset = &Error{msg: "stream reset"} - - // ErrConnectionWriteTimeout indicates that we hit the "safety valve" - // timeout writing to the underlying stream connection. - ErrConnectionWriteTimeout = &Error{msg: "connection write timeout", timeout: true} - - // ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close - ErrKeepAliveTimeout = &Error{msg: "keepalive timeout", timeout: true} -) From 8adb9a831e0f9a715d87c982e2e34220d25d51ab Mon Sep 17 00:00:00 2001 From: sukun Date: Tue, 27 Aug 2024 12:56:01 +0530 Subject: [PATCH 05/11] fix race in write timeout --- const.go | 2 +- session.go | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/const.go b/const.go index ba33f7c..7062231 100644 --- a/const.go +++ b/const.go @@ -24,8 +24,8 @@ func (ye *Error) Temporary() bool { } type GoAwayError struct { - Remote bool ErrorCode uint32 + Remote bool } func (e *GoAwayError) Error() string { diff --git a/session.go b/session.go index e31bef4..204b168 100644 --- a/session.go +++ b/session.go @@ -536,13 +536,14 @@ func (s *Session) send() { if err := s.sendLoop(); err != nil { // Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code // received in a GoAway frame received just before the TCP RST that closed the sendLoop - // - // Take the shutdownLock to avoid closing the connection concurrently with a Close call. s.shutdownLock.Lock() - s.conn.Close() - <-s.recvDoneCh - if _, ok := s.recvErr.(*GoAwayError); ok { - err = s.recvErr + if s.shutdownErr == nil { + s.conn.Close() + <-s.recvDoneCh + if _, ok := s.recvErr.(*GoAwayError); ok { + err = s.recvErr + } + s.shutdownErr = err } s.shutdownLock.Unlock() s.close(err, false, 0) From f56b1c3a9a9f693a27680d6a3729c5909e18731b Mon Sep 17 00:00:00 2001 From: sukun Date: Wed, 28 Aug 2024 03:31:48 +0530 Subject: [PATCH 06/11] add support for sending error codes on stream reset --- const.go | 21 ++++++++++++++++++++ session_test.go | 53 ++++++++++++++++++++++++++++++++++++++++++++++++- stream.go | 28 ++++++++++++++++++-------- 3 files changed, 93 insertions(+), 9 deletions(-) diff --git a/const.go b/const.go index 7062231..716d085 100644 --- a/const.go +++ b/const.go @@ -57,6 +57,27 @@ func (e *GoAwayError) Is(target error) bool { return false } +// A StreamError is used for errors returned from Read and Write calls after the stream is Reset +type StreamError struct { + ErrorCode uint32 + Remote bool +} + +func (s *StreamError) Error() string { + if s.Remote { + return fmt.Sprintf("stream reset by remote, error code: %d", s.ErrorCode) + } + return fmt.Sprintf("stream reset, error code: %d", s.ErrorCode) +} + +func (s *StreamError) Is(target error) bool { + if target == ErrStreamReset { + return true + } + e, ok := target.(*StreamError) + return ok && *e == *s +} + var ( // ErrInvalidVersion means we received a frame with an // invalid version diff --git a/session_test.go b/session_test.go index df3e3c9..2c06abb 100644 --- a/session_test.go +++ b/session_test.go @@ -16,6 +16,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -1571,6 +1572,56 @@ func TestStreamResetRead(t *testing.T) { wc.Wait() } +func TestStreamResetWithError(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + wc := new(sync.WaitGroup) + wc.Add(2) + go func() { + defer wc.Done() + stream, err := server.AcceptStream() + if err != nil { + t.Error(err) + } + + se := &StreamError{} + _, err = io.ReadAll(stream) + if !errors.As(err, &se) { + t.Errorf("exptected StreamError, got type:%T, err: %s", err, err) + return + } + expected := &StreamError{Remote: true, ErrorCode: 42} + assert.Equal(t, se, expected) + }() + + stream, err := client.OpenStream(context.Background()) + if err != nil { + t.Error(err) + } + + go func() { + defer wc.Done() + + se := &StreamError{} + _, err := io.ReadAll(stream) + if !errors.As(err, &se) { + t.Errorf("exptected StreamError, got type:%T, err: %s", err, err) + return + } + expected := &StreamError{Remote: false, ErrorCode: 42} + assert.Equal(t, se, expected) + }() + + time.Sleep(1 * time.Second) + err = stream.ResetWithError(42) + if err != nil { + t.Fatal(err) + } + wc.Wait() +} + func TestLotsOfWritesWithStreamDeadline(t *testing.T) { config := testConf() config.EnableKeepAlive = false @@ -1809,7 +1860,7 @@ func TestMaxIncomingStreams(t *testing.T) { require.NoError(t, err) str.SetDeadline(time.Now().Add(time.Second)) _, err = str.Read([]byte{0}) - require.EqualError(t, err, "stream reset") + require.ErrorIs(t, err, ErrStreamReset) // Now close one of the streams. // This should then allow the client to open a new stream. diff --git a/stream.go b/stream.go index e1e5602..f6f32ec 100644 --- a/stream.go +++ b/stream.go @@ -42,6 +42,7 @@ type Stream struct { state streamState writeState, readState halfStreamState stateLock sync.Mutex + resetErr *StreamError recvBuf segmentedBuffer @@ -89,6 +90,7 @@ func (s *Stream) Read(b []byte) (n int, err error) { START: s.stateLock.Lock() state := s.readState + resetErr := s.resetErr s.stateLock.Unlock() switch state { @@ -101,7 +103,7 @@ START: } // Closed, but we have data pending -> read. case halfReset: - return 0, ErrStreamReset + return 0, resetErr default: panic("unknown state") } @@ -147,6 +149,7 @@ func (s *Stream) write(b []byte) (n int, err error) { START: s.stateLock.Lock() state := s.writeState + resetErr := s.resetErr s.stateLock.Unlock() switch state { @@ -155,7 +158,7 @@ START: case halfClosed: return 0, ErrStreamClosed case halfReset: - return 0, ErrStreamReset + return 0, resetErr default: panic("unknown state") } @@ -250,13 +253,17 @@ func (s *Stream) sendClose() error { } // sendReset is used to send a RST -func (s *Stream) sendReset() error { - hdr := encode(typeWindowUpdate, flagRST, s.id, 0) +func (s *Stream) sendReset(errCode uint32) error { + hdr := encode(typeWindowUpdate, flagRST, s.id, errCode) return s.session.sendMsg(hdr, nil, nil) } // Reset resets the stream (forcibly closes the stream) func (s *Stream) Reset() error { + return s.ResetWithError(0) +} + +func (s *Stream) ResetWithError(errCode uint32) error { sendReset := false s.stateLock.Lock() switch s.state { @@ -281,10 +288,11 @@ func (s *Stream) Reset() error { s.readState = halfReset } s.state = streamFinished + s.resetErr = &StreamError{Remote: false, ErrorCode: errCode} s.notifyWaiting() s.stateLock.Unlock() if sendReset { - _ = s.sendReset() + _ = s.sendReset(errCode) } s.cleanup() return nil @@ -382,7 +390,7 @@ func (s *Stream) cleanup() { // processFlags is used to update the state of the stream // based on set flags, if any. Lock must be held -func (s *Stream) processFlags(flags uint16) { +func (s *Stream) processFlags(flags uint16, hdr header) { // Close the stream without holding the state lock var closeStream bool defer func() { @@ -425,6 +433,10 @@ func (s *Stream) processFlags(flags uint16) { s.writeState = halfReset } s.state = streamFinished + // Length in a window update frame with RST flag encodes an error code. + if hdr.MsgType() == typeWindowUpdate && s.resetErr == nil { + s.resetErr = &StreamError{Remote: true, ErrorCode: hdr.Length()} + } s.stateLock.Unlock() closeStream = true s.notifyWaiting() @@ -439,7 +451,7 @@ func (s *Stream) notifyWaiting() { // incrSendWindow updates the size of our send window func (s *Stream) incrSendWindow(hdr header, flags uint16) { - s.processFlags(flags) + s.processFlags(flags, hdr) // Increase window, unblock a sender atomic.AddUint32(&s.sendWindow, hdr.Length()) asyncNotify(s.sendNotifyCh) @@ -447,7 +459,7 @@ func (s *Stream) incrSendWindow(hdr header, flags uint16) { // readData is used to handle a data frame func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { - s.processFlags(flags) + s.processFlags(flags, hdr) // Check that our recv window is not exceeded length := hdr.Length() From 9190b780f8929b9a2ae6dcb0524ef0c64d9374f6 Mon Sep 17 00:00:00 2001 From: sukun Date: Thu, 29 Aug 2024 01:00:53 +0530 Subject: [PATCH 07/11] fix err on conn close --- const.go | 2 +- session.go | 2 +- stream.go | 25 ++++++++++++++++--------- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/const.go b/const.go index 716d085..e1e9dc5 100644 --- a/const.go +++ b/const.go @@ -173,7 +173,7 @@ const ( // It's not an implementation choice, the value defined in the specification. initialStreamWindow = 256 * 1024 maxStreamWindow = 16 * 1024 * 1024 - goAwayWaitTime = 5 * time.Second + goAwayWaitTime = 50 * time.Millisecond ) const ( diff --git a/session.go b/session.go index 204b168..c9af6e0 100644 --- a/session.go +++ b/session.go @@ -334,7 +334,7 @@ func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) erro s.streamLock.Lock() defer s.streamLock.Unlock() for id, stream := range s.streams { - stream.forceClose() + stream.forceClose(fmt.Errorf("%w: connection closed: %w", ErrStreamReset, s.shutdownErr)) delete(s.streams, id) stream.memorySpan.Done() } diff --git a/stream.go b/stream.go index f6f32ec..e79562d 100644 --- a/stream.go +++ b/stream.go @@ -41,8 +41,8 @@ type Stream struct { state streamState writeState, readState halfStreamState + writeErr, readErr error stateLock sync.Mutex - resetErr *StreamError recvBuf segmentedBuffer @@ -90,7 +90,7 @@ func (s *Stream) Read(b []byte) (n int, err error) { START: s.stateLock.Lock() state := s.readState - resetErr := s.resetErr + resetErr := s.readErr s.stateLock.Unlock() switch state { @@ -149,7 +149,7 @@ func (s *Stream) write(b []byte) (n int, err error) { START: s.stateLock.Lock() state := s.writeState - resetErr := s.resetErr + resetErr := s.writeErr s.stateLock.Unlock() switch state { @@ -283,12 +283,13 @@ func (s *Stream) ResetWithError(errCode uint32) error { // If we've already sent/received an EOF, no need to reset that side. if s.writeState == halfOpen { s.writeState = halfReset + s.writeErr = &StreamError{Remote: false, ErrorCode: errCode} } if s.readState == halfOpen { s.readState = halfReset + s.readErr = &StreamError{Remote: false, ErrorCode: errCode} } s.state = streamFinished - s.resetErr = &StreamError{Remote: false, ErrorCode: errCode} s.notifyWaiting() s.stateLock.Unlock() if sendReset { @@ -344,6 +345,7 @@ func (s *Stream) CloseRead() error { panic("invalid state") } s.readState = halfReset + s.readErr = ErrStreamReset cleanup = s.writeState != halfOpen if cleanup { s.state = streamFinished @@ -365,13 +367,15 @@ func (s *Stream) Close() error { } // forceClose is used for when the session is exiting -func (s *Stream) forceClose() { +func (s *Stream) forceClose(err error) { s.stateLock.Lock() if s.readState == halfOpen { s.readState = halfReset + s.readErr = err } if s.writeState == halfOpen { s.writeState = halfReset + s.writeErr = err } s.state = streamFinished s.notifyWaiting() @@ -426,17 +430,20 @@ func (s *Stream) processFlags(flags uint16, hdr header) { } if flags&flagRST == flagRST { s.stateLock.Lock() + var resetErr error = ErrStreamReset + // Length in a window update frame with RST flag encodes an error code. + if hdr.MsgType() == typeWindowUpdate { + resetErr = &StreamError{Remote: true, ErrorCode: hdr.Length()} + } if s.readState == halfOpen { s.readState = halfReset + s.readErr = resetErr } if s.writeState == halfOpen { s.writeState = halfReset + s.writeErr = resetErr } s.state = streamFinished - // Length in a window update frame with RST flag encodes an error code. - if hdr.MsgType() == typeWindowUpdate && s.resetErr == nil { - s.resetErr = &StreamError{Remote: true, ErrorCode: hdr.Length()} - } s.stateLock.Unlock() closeStream = true s.notifyWaiting() From 5727def301f6e42f1f07ab751535e2b59b1cf460 Mon Sep 17 00:00:00 2001 From: sukun Date: Wed, 4 Sep 2024 20:03:27 +0530 Subject: [PATCH 08/11] only send goaway on close --- const.go | 2 +- session.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/const.go b/const.go index 7062231..5a5dee3 100644 --- a/const.go +++ b/const.go @@ -152,7 +152,7 @@ const ( // It's not an implementation choice, the value defined in the specification. initialStreamWindow = 256 * 1024 maxStreamWindow = 16 * 1024 * 1024 - goAwayWaitTime = 5 * time.Second + goAwayWaitTime = 100 * time.Millisecond ) const ( diff --git a/session.go b/session.go index 204b168..8f1cdc1 100644 --- a/session.go +++ b/session.go @@ -290,7 +290,7 @@ func (s *Session) AcceptStream() (*Stream, error) { // semantics of the underlying net.Conn. For TCP connections, it may be dropped depending on LINGER value or // if there's unread data in the kernel receive buffer. func (s *Session) Close() error { - return s.close(ErrSessionShutdown, true, goAwayNormal) + return s.close(ErrSessionShutdown, false, goAwayNormal) } // CloseWithError is used to close the session and all streams after sending a GoAway message with errCode. From ede18a5a56bdc2510742b9c5471090313c0f2265 Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 18 Nov 2024 20:18:16 +0530 Subject: [PATCH 09/11] don't block on connection close --- const.go | 9 ++++++--- session.go | 23 ++++++++++++----------- session_test.go | 2 +- stream.go | 5 +++-- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/const.go b/const.go index 93fb83f..3ecba41 100644 --- a/const.go +++ b/const.go @@ -45,10 +45,13 @@ func (e *GoAwayError) Temporary() bool { func (e *GoAwayError) Is(target error) bool { // to maintain compatibility with errors returned by previous versions - if e.Remote && target == ErrRemoteGoAway { + if e.Remote && target == ErrRemoteGoAwayNormal { return true } else if !e.Remote && target == ErrSessionShutdown { return true + } else if target == ErrStreamReset { + // A GoAway on a connection also resets all the streams. + return true } if err, ok := target.(*GoAwayError); ok { @@ -111,8 +114,8 @@ var ( // ErrUnexpectedFlag is set when we get an unexpected flag ErrUnexpectedFlag = &Error{msg: "unexpected flag"} - // ErrRemoteGoAway is used when we get a go away from the other side - ErrRemoteGoAway = &GoAwayError{Remote: true, ErrorCode: goAwayNormal} + // ErrRemoteGoAwayNormal is used when we get a go away from the other side + ErrRemoteGoAwayNormal = &GoAwayError{Remote: true, ErrorCode: goAwayNormal} // ErrStreamReset is sent if a stream is reset. This can happen // if the backlog is exceeded, or if there was a remote GoAway. diff --git a/session.go b/session.go index bbecf19..06d20fa 100644 --- a/session.go +++ b/session.go @@ -46,9 +46,9 @@ var nullMemoryManager = &nullMemoryManagerImpl{} type Session struct { rtt int64 // to be accessed atomically, in nanoseconds - // remoteGoAway indicates the remote side does + // remoteGoAwayNormal indicates the remote side does // not want futher connections. Must be first for alignment. - remoteGoAway int32 + remoteGoAwayNormal int32 // localGoAway indicates that we should stop // accepting futher connections. Must be first for alignment. @@ -205,8 +205,8 @@ func (s *Session) OpenStream(ctx context.Context) (*Stream, error) { if s.IsClosed() { return nil, s.shutdownErr } - if atomic.LoadInt32(&s.remoteGoAway) == 1 { - return nil, ErrRemoteGoAway + if atomic.LoadInt32(&s.remoteGoAwayNormal) == 1 { + return nil, ErrRemoteGoAwayNormal } // Block if we have too many inflight SYNs @@ -285,15 +285,15 @@ func (s *Session) AcceptStream() (*Stream, error) { } } -// Close is used to close the session and all streams. -// Attempts to send a GoAway before closing the connection. The GoAway may not actually be sent depending on the -// semantics of the underlying net.Conn. For TCP connections, it may be dropped depending on LINGER value or -// if there's unread data in the kernel receive buffer. +// Close is used to close the session and all streams. It doesn't send a GoAway before +// closing the connection. func (s *Session) Close() error { return s.close(ErrSessionShutdown, false, goAwayNormal) } // CloseWithError is used to close the session and all streams after sending a GoAway message with errCode. +// Blocks for ConnectionWriteTimeout to write the GoAway message. +// // The GoAway may not actually be sent depending on the semantics of the underlying net.Conn. // For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel // receive buffer. @@ -315,7 +315,8 @@ func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) erro close(s.shutdownCh) s.stopKeepalive() - if sendGoAway { + // Only send GoAway if we have an error code. + if sendGoAway && errCode != goAwayNormal { // wait for write loop to exit // We need to write the current frame completely before sending a goaway. // This will wait for at most s.config.ConnectionWriteTimeout @@ -334,7 +335,7 @@ func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) erro s.streamLock.Lock() defer s.streamLock.Unlock() for id, stream := range s.streams { - stream.forceClose(fmt.Errorf("%w: connection closed: %w", ErrStreamReset, s.shutdownErr)) + stream.forceClose(s.shutdownErr) delete(s.streams, id) stream.memorySpan.Done() } @@ -814,7 +815,7 @@ func (s *Session) handleGoAway(hdr header) error { code := hdr.Length() switch code { case goAwayNormal: - atomic.SwapInt32(&s.remoteGoAway, 1) + atomic.SwapInt32(&s.remoteGoAwayNormal, 1) // Don't close connection on normal go away. Let the existing streams // complete gracefully. return nil diff --git a/session_test.go b/session_test.go index 2c06abb..dc6c3f0 100644 --- a/session_test.go +++ b/session_test.go @@ -651,7 +651,7 @@ func TestGoAway(t *testing.T) { switch err { case nil: s.Close() - case ErrRemoteGoAway: + case ErrRemoteGoAwayNormal: return default: t.Fatalf("err: %v", err) diff --git a/stream.go b/stream.go index e79562d..0835165 100644 --- a/stream.go +++ b/stream.go @@ -310,7 +310,7 @@ func (s *Stream) CloseWrite() error { return nil case halfReset: s.stateLock.Unlock() - return ErrStreamReset + return s.writeErr default: panic("invalid state") } @@ -331,7 +331,8 @@ func (s *Stream) CloseWrite() error { return err } -// CloseRead is used to close the stream for writing. +// CloseRead is used to close the stream for reading. +// Note: Remote is not notified. func (s *Stream) CloseRead() error { cleanup := false s.stateLock.Lock() From 3eaea398c49ca9c4f24d419f5f349805c3098dad Mon Sep 17 00:00:00 2001 From: sukun Date: Tue, 19 Nov 2024 18:58:02 +0530 Subject: [PATCH 10/11] review comments --- const.go | 7 ++++--- session.go | 22 +++++++++------------- session_test.go | 38 +++++++++++++++++--------------------- stream.go | 6 +++--- 4 files changed, 33 insertions(+), 40 deletions(-) diff --git a/const.go b/const.go index 3ecba41..08199af 100644 --- a/const.go +++ b/const.go @@ -45,7 +45,7 @@ func (e *GoAwayError) Temporary() bool { func (e *GoAwayError) Is(target error) bool { // to maintain compatibility with errors returned by previous versions - if e.Remote && target == ErrRemoteGoAwayNormal { + if e.Remote && target == ErrRemoteGoAway { return true } else if !e.Remote && target == ErrSessionShutdown { return true @@ -114,8 +114,9 @@ var ( // ErrUnexpectedFlag is set when we get an unexpected flag ErrUnexpectedFlag = &Error{msg: "unexpected flag"} - // ErrRemoteGoAwayNormal is used when we get a go away from the other side - ErrRemoteGoAwayNormal = &GoAwayError{Remote: true, ErrorCode: goAwayNormal} + // ErrRemoteGoAway is used when we get a go away from the other side with error code + // goAwayNormal(0). + ErrRemoteGoAway = &GoAwayError{Remote: true, ErrorCode: goAwayNormal} // ErrStreamReset is sent if a stream is reset. This can happen // if the backlog is exceeded, or if there was a remote GoAway. diff --git a/session.go b/session.go index 06d20fa..6fb6731 100644 --- a/session.go +++ b/session.go @@ -46,10 +46,6 @@ var nullMemoryManager = &nullMemoryManagerImpl{} type Session struct { rtt int64 // to be accessed atomically, in nanoseconds - // remoteGoAwayNormal indicates the remote side does - // not want futher connections. Must be first for alignment. - remoteGoAwayNormal int32 - // localGoAway indicates that we should stop // accepting futher connections. Must be first for alignment. localGoAway int32 @@ -205,9 +201,6 @@ func (s *Session) OpenStream(ctx context.Context) (*Stream, error) { if s.IsClosed() { return nil, s.shutdownErr } - if atomic.LoadInt32(&s.remoteGoAwayNormal) == 1 { - return nil, ErrRemoteGoAwayNormal - } // Block if we have too many inflight SYNs select { @@ -535,8 +528,14 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err // send is a long running goroutine that sends data func (s *Session) send() { if err := s.sendLoop(); err != nil { - // Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code - // received in a GoAway frame received just before the TCP RST that closed the sendLoop + // If we are shutting down because remote closed the connection, prefer the recvLoop error + // over the sendLoop error. The receive loop might have error code received in a GoAway frame, + // which was received just before the TCP RST that closed the sendLoop. + // + // If we are closing because of an write error, we use the error from the sendLoop and not the recvLoop. + // We hold the shutdownLock, close the connection, and wait for the receive loop to finish and + // use the sendLoop error. Holding the shutdownLock ensures that the recvLoop doesn't trigger connection close + // but the sendLoop does. s.shutdownLock.Lock() if s.shutdownErr == nil { s.conn.Close() @@ -815,10 +814,7 @@ func (s *Session) handleGoAway(hdr header) error { code := hdr.Length() switch code { case goAwayNormal: - atomic.SwapInt32(&s.remoteGoAwayNormal, 1) - // Don't close connection on normal go away. Let the existing streams - // complete gracefully. - return nil + return ErrRemoteGoAway case goAwayProtoErr: s.logger.Printf("[ERR] yamux: received protocol error go away") case goAwayInternalErr: diff --git a/session_test.go b/session_test.go index dc6c3f0..6d3bce0 100644 --- a/session_test.go +++ b/session_test.go @@ -648,15 +648,16 @@ func TestGoAway(t *testing.T) { for i := 0; i < 100; i++ { s, err := client.Open(context.Background()) - switch err { - case nil: + if err == nil { s.Close() - case ErrRemoteGoAwayNormal: + time.Sleep(50 * time.Millisecond) + continue + } + if err != ErrRemoteGoAway { + t.Fatalf("expected %s, got %s", ErrRemoteGoAway, err) + } else { return - default: - t.Fatalf("err: %v", err) } - time.Sleep(50 * time.Millisecond) } t.Fatalf("expected GoAway error") } @@ -1578,7 +1579,7 @@ func TestStreamResetWithError(t *testing.T) { defer server.Close() wc := new(sync.WaitGroup) - wc.Add(2) + wc.Add(1) go func() { defer wc.Done() stream, err := server.AcceptStream() @@ -1589,7 +1590,7 @@ func TestStreamResetWithError(t *testing.T) { se := &StreamError{} _, err = io.ReadAll(stream) if !errors.As(err, &se) { - t.Errorf("exptected StreamError, got type:%T, err: %s", err, err) + t.Errorf("expected StreamError, got type:%T, err: %s", err, err) return } expected := &StreamError{Remote: true, ErrorCode: 42} @@ -1601,24 +1602,19 @@ func TestStreamResetWithError(t *testing.T) { t.Error(err) } - go func() { - defer wc.Done() - - se := &StreamError{} - _, err := io.ReadAll(stream) - if !errors.As(err, &se) { - t.Errorf("exptected StreamError, got type:%T, err: %s", err, err) - return - } - expected := &StreamError{Remote: false, ErrorCode: 42} - assert.Equal(t, se, expected) - }() - time.Sleep(1 * time.Second) err = stream.ResetWithError(42) if err != nil { t.Fatal(err) } + se := &StreamError{} + _, err = io.ReadAll(stream) + if !errors.As(err, &se) { + t.Errorf("expected StreamError, got type:%T, err: %s", err, err) + return + } + expected := &StreamError{Remote: false, ErrorCode: 42} + assert.Equal(t, se, expected) wc.Wait() } diff --git a/stream.go b/stream.go index 0835165..15a8b56 100644 --- a/stream.go +++ b/stream.go @@ -395,7 +395,7 @@ func (s *Stream) cleanup() { // processFlags is used to update the state of the stream // based on set flags, if any. Lock must be held -func (s *Stream) processFlags(flags uint16, hdr header) { +func (s *Stream) processFlags(hdr header, flags uint16) { // Close the stream without holding the state lock var closeStream bool defer func() { @@ -459,7 +459,7 @@ func (s *Stream) notifyWaiting() { // incrSendWindow updates the size of our send window func (s *Stream) incrSendWindow(hdr header, flags uint16) { - s.processFlags(flags, hdr) + s.processFlags(hdr, flags) // Increase window, unblock a sender atomic.AddUint32(&s.sendWindow, hdr.Length()) asyncNotify(s.sendNotifyCh) @@ -467,7 +467,7 @@ func (s *Stream) incrSendWindow(hdr header, flags uint16) { // readData is used to handle a data frame func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { - s.processFlags(flags, hdr) + s.processFlags(hdr, flags) // Check that our recv window is not exceeded length := hdr.Length() From 39abe7ed206a6ddee6689eb2179f07d8ba8fc358 Mon Sep 17 00:00:00 2001 From: sukun Date: Wed, 20 Nov 2024 15:33:19 +0530 Subject: [PATCH 11/11] use ErrStreamReset for resetting streams --- session.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/session.go b/session.go index 6fb6731..e229730 100644 --- a/session.go +++ b/session.go @@ -325,10 +325,14 @@ func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) erro <-s.sendDoneCh <-s.recvDoneCh + resetErr := shutdownErr + if _, ok := resetErr.(*GoAwayError); !ok { + resetErr = fmt.Errorf("%w: connection closed: %w", ErrStreamReset, shutdownErr) + } s.streamLock.Lock() defer s.streamLock.Unlock() for id, stream := range s.streams { - stream.forceClose(s.shutdownErr) + stream.forceClose(resetErr) delete(s.streams, id) stream.memorySpan.Done() }