Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for sending error codes on session close #121

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
67 changes: 64 additions & 3 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import (
"encoding/binary"
"fmt"
"time"
)

type Error struct {
Expand All @@ -22,6 +23,64 @@
return ye.temporary
}

type GoAwayError struct {
ErrorCode uint32
Remote bool
}

func (e *GoAwayError) Error() string {
if e.Remote {
return fmt.Sprintf("remote sent go away, code: %d", e.ErrorCode)
}

Check warning on line 34 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L33-L34

Added lines #L33 - L34 were not covered by tests
return fmt.Sprintf("sent go away, code: %d", e.ErrorCode)
}

func (e *GoAwayError) Timeout() bool {
return false

Check warning on line 39 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L38-L39

Added lines #L38 - L39 were not covered by tests
}

func (e *GoAwayError) Temporary() bool {
return false

Check warning on line 43 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L42-L43

Added lines #L42 - L43 were not covered by tests
}

func (e *GoAwayError) Is(target error) bool {
// to maintain compatibility with errors returned by previous versions
if e.Remote && target == ErrRemoteGoAway {
return true

Check warning on line 49 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L49

Added line #L49 was not covered by tests
} else if !e.Remote && target == ErrSessionShutdown {
return true

Check warning on line 51 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L51

Added line #L51 was not covered by tests
} else if target == ErrStreamReset {
// A GoAway on a connection also resets all the streams.
return true
}

Check warning on line 55 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L53-L55

Added lines #L53 - L55 were not covered by tests

if err, ok := target.(*GoAwayError); ok {
return *e == *err
}
return false

Check warning on line 60 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L60

Added line #L60 was not covered by tests
}

// A StreamError is used for errors returned from Read and Write calls after the stream is Reset
type StreamError struct {
ErrorCode uint32
Remote bool
}

func (s *StreamError) Error() string {
if s.Remote {
return fmt.Sprintf("stream reset by remote, error code: %d", s.ErrorCode)
}
return fmt.Sprintf("stream reset, error code: %d", s.ErrorCode)

Check warning on line 73 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L69-L73

Added lines #L69 - L73 were not covered by tests
}

func (s *StreamError) Is(target error) bool {
if target == ErrStreamReset {
return true
}
e, ok := target.(*StreamError)
return ok && *e == *s

Check warning on line 81 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L80-L81

Added lines #L80 - L81 were not covered by tests
}

var (
// ErrInvalidVersion means we received a frame with an
// invalid version
Expand All @@ -33,7 +92,7 @@

// ErrSessionShutdown is used if there is a shutdown during
// an operation
ErrSessionShutdown = &Error{msg: "session shutdown"}
ErrSessionShutdown = &GoAwayError{ErrorCode: goAwayNormal, Remote: false}

// ErrStreamsExhausted is returned if we have no more
// stream ids to issue
Expand All @@ -55,8 +114,9 @@
// ErrUnexpectedFlag is set when we get an unexpected flag
ErrUnexpectedFlag = &Error{msg: "unexpected flag"}

// ErrRemoteGoAway is used when we get a go away from the other side
ErrRemoteGoAway = &Error{msg: "remote end is not accepting connections"}
// ErrRemoteGoAway is used when we get a go away from the other side with error code
// goAwayNormal(0).
ErrRemoteGoAway = &GoAwayError{Remote: true, ErrorCode: goAwayNormal}

// ErrStreamReset is sent if a stream is reset. This can happen
// if the backlog is exceeded, or if there was a remote GoAway.
Expand Down Expand Up @@ -117,6 +177,7 @@
// It's not an implementation choice, the value defined in the specification.
initialStreamWindow = 256 * 1024
maxStreamWindow = 16 * 1024 * 1024
goAwayWaitTime = 100 * time.Millisecond
)

const (
Expand Down
102 changes: 68 additions & 34 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ var nullMemoryManager = &nullMemoryManagerImpl{}
type Session struct {
rtt int64 // to be accessed atomically, in nanoseconds

// remoteGoAway indicates the remote side does
// not want futher connections. Must be first for alignment.
remoteGoAway int32

// localGoAway indicates that we should stop
// accepting futher connections. Must be first for alignment.
localGoAway int32
Expand Down Expand Up @@ -102,6 +98,8 @@ type Session struct {
// recvDoneCh is closed when recv() exits to avoid a race
// between stream registration and stream shutdown
recvDoneCh chan struct{}
// recvErr is the error the receive loop ended with
recvErr error

// sendDoneCh is closed when send() exits to avoid a race
// between returning from a Stream.Write and exiting from the send loop
Expand Down Expand Up @@ -203,9 +201,6 @@ func (s *Session) OpenStream(ctx context.Context) (*Stream, error) {
if s.IsClosed() {
return nil, s.shutdownErr
}
if atomic.LoadInt32(&s.remoteGoAway) == 1 {
return nil, ErrRemoteGoAway
}

// Block if we have too many inflight SYNs
select {
Expand Down Expand Up @@ -283,9 +278,23 @@ func (s *Session) AcceptStream() (*Stream, error) {
}
}

// Close is used to close the session and all streams.
// Attempts to send a GoAway before closing the connection.
// Close is used to close the session and all streams. It doesn't send a GoAway before
// closing the connection.
func (s *Session) Close() error {
return s.close(ErrSessionShutdown, false, goAwayNormal)
}

// CloseWithError is used to close the session and all streams after sending a GoAway message with errCode.
// Blocks for ConnectionWriteTimeout to write the GoAway message.
//
// The GoAway may not actually be sent depending on the semantics of the underlying net.Conn.
// For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel
// receive buffer.
func (s *Session) CloseWithError(errCode uint32) error {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we updated the connection manager to be able to deal with potential blocking here? Also, we should probably document it.

return s.close(&GoAwayError{Remote: false, ErrorCode: errCode}, true, errCode)
}

func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) error {
s.shutdownLock.Lock()
defer s.shutdownLock.Unlock()

Expand All @@ -294,35 +303,42 @@ func (s *Session) Close() error {
}
s.shutdown = true
if s.shutdownErr == nil {
s.shutdownErr = ErrSessionShutdown
s.shutdownErr = shutdownErr
}
close(s.shutdownCh)
s.conn.Close()
s.stopKeepalive()
<-s.recvDoneCh

// Only send GoAway if we have an error code.
if sendGoAway && errCode != goAwayNormal {
// wait for write loop to exit
// We need to write the current frame completely before sending a goaway.
// This will wait for at most s.config.ConnectionWriteTimeout
<-s.sendDoneCh
ga := s.goAway(errCode)
if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil {
_, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here
}
s.conn.SetWriteDeadline(time.Time{})
}

s.conn.Close()
<-s.sendDoneCh
<-s.recvDoneCh

resetErr := shutdownErr
if _, ok := resetErr.(*GoAwayError); !ok {
resetErr = fmt.Errorf("%w: connection closed: %w", ErrStreamReset, shutdownErr)
}
s.streamLock.Lock()
defer s.streamLock.Unlock()
for id, stream := range s.streams {
stream.forceClose()
stream.forceClose(resetErr)
delete(s.streams, id)
stream.memorySpan.Done()
}
return nil
}

// exitErr is used to handle an error that is causing the
// session to terminate.
func (s *Session) exitErr(err error) {
s.shutdownLock.Lock()
if s.shutdownErr == nil {
s.shutdownErr = err
}
s.shutdownLock.Unlock()
s.Close()
}

// GoAway can be used to prevent accepting further
// connections. It does not close the underlying conn.
func (s *Session) GoAway() error {
Expand Down Expand Up @@ -451,7 +467,7 @@ func (s *Session) startKeepalive() {

if err != nil {
s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
s.exitErr(ErrKeepAliveTimeout)
s.close(ErrKeepAliveTimeout, false, 0)
}
})
}
Expand Down Expand Up @@ -516,7 +532,25 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
// send is a long running goroutine that sends data
func (s *Session) send() {
if err := s.sendLoop(); err != nil {
s.exitErr(err)
// If we are shutting down because remote closed the connection, prefer the recvLoop error
// over the sendLoop error. The receive loop might have error code received in a GoAway frame,
// which was received just before the TCP RST that closed the sendLoop.
//
// If we are closing because of an write error, we use the error from the sendLoop and not the recvLoop.
// We hold the shutdownLock, close the connection, and wait for the receive loop to finish and
// use the sendLoop error. Holding the shutdownLock ensures that the recvLoop doesn't trigger connection close
// but the sendLoop does.
s.shutdownLock.Lock()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment as to why you are holding the shutdownLock around this section.

Copy link
Member Author

@sukunrt sukunrt Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comment, can you review once more?

if s.shutdownErr == nil {
s.conn.Close()
<-s.recvDoneCh
if _, ok := s.recvErr.(*GoAwayError); ok {
err = s.recvErr
}
s.shutdownErr = err
}
s.shutdownLock.Unlock()
s.close(err, false, 0)
}
}

Expand Down Expand Up @@ -644,7 +678,7 @@ func (s *Session) sendLoop() (err error) {
// recv is a long running goroutine that accepts new data
func (s *Session) recv() {
if err := s.recvLoop(); err != nil {
s.exitErr(err)
s.close(err, false, 0)
}
}

Expand All @@ -666,7 +700,10 @@ func (s *Session) recvLoop() (err error) {
err = fmt.Errorf("panic in yamux receive loop: %s", rerr)
}
}()
defer close(s.recvDoneCh)
defer func() {
s.recvErr = err
close(s.recvDoneCh)
}()
var hdr header
for {
// fmt.Printf("ReadFull from %#v\n", s.reader)
Expand Down Expand Up @@ -781,18 +818,15 @@ func (s *Session) handleGoAway(hdr header) error {
code := hdr.Length()
switch code {
case goAwayNormal:
atomic.SwapInt32(&s.remoteGoAway, 1)
return ErrRemoteGoAway
case goAwayProtoErr:
s.logger.Printf("[ERR] yamux: received protocol error go away")
return fmt.Errorf("yamux protocol error")
case goAwayInternalErr:
s.logger.Printf("[ERR] yamux: received internal error go away")
return fmt.Errorf("remote yamux internal error")
default:
s.logger.Printf("[ERR] yamux: received unexpected go away")
return fmt.Errorf("unexpected go away received")
s.logger.Printf("[ERR] yamux: received go away with error code: %d", code)
}
return nil
return &GoAwayError{Remote: true, ErrorCode: code}
}

// incomingStream is used to create a new incoming stream
Expand Down
Loading