Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Nov 19, 2024
1 parent ede18a5 commit b523bdd
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 36 deletions.
7 changes: 4 additions & 3 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (e *GoAwayError) Temporary() bool {

func (e *GoAwayError) Is(target error) bool {
// to maintain compatibility with errors returned by previous versions
if e.Remote && target == ErrRemoteGoAwayNormal {
if e.Remote && target == ErrRemoteGoAway {
return true

Check warning on line 49 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L49

Added line #L49 was not covered by tests
} else if !e.Remote && target == ErrSessionShutdown {
return true

Check warning on line 51 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L51

Added line #L51 was not covered by tests
Expand Down Expand Up @@ -114,8 +114,9 @@ var (
// ErrUnexpectedFlag is set when we get an unexpected flag
ErrUnexpectedFlag = &Error{msg: "unexpected flag"}

// ErrRemoteGoAwayNormal is used when we get a go away from the other side
ErrRemoteGoAwayNormal = &GoAwayError{Remote: true, ErrorCode: goAwayNormal}
// ErrRemoteGoAway is used when we get a go away from the other side with error code
// goAwayNormal(0).
ErrRemoteGoAway = &GoAwayError{Remote: true, ErrorCode: goAwayNormal}

// ErrStreamReset is sent if a stream is reset. This can happen
// if the backlog is exceeded, or if there was a remote GoAway.
Expand Down
18 changes: 9 additions & 9 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,6 @@ func (s *Session) OpenStream(ctx context.Context) (*Stream, error) {
if s.IsClosed() {
return nil, s.shutdownErr
}
if atomic.LoadInt32(&s.remoteGoAwayNormal) == 1 {
return nil, ErrRemoteGoAwayNormal
}

// Block if we have too many inflight SYNs
select {
Expand Down Expand Up @@ -535,8 +532,14 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
// send is a long running goroutine that sends data
func (s *Session) send() {
if err := s.sendLoop(); err != nil {
// Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code
// received in a GoAway frame received just before the TCP RST that closed the sendLoop
// If we are shutting down because remote closed the connection, prefer the recvLoop error
// over the sendLoop error. The receive loop might have error code received in a GoAway frame,
// which was received just before the TCP RST that closed the sendLoop.
//
// If we are closing because of an write error, we use the error from the sendLoop and not the recvLoop.
// We hold the shutdownLock, close the connection, and wait for the receive loop to finish and
// use the sendLoop error. Holding the shutdownLock ensures that the recvLoop doesn't trigger connection close
// but the sendLoop does.
s.shutdownLock.Lock()
if s.shutdownErr == nil {
s.conn.Close()
Expand Down Expand Up @@ -815,10 +818,7 @@ func (s *Session) handleGoAway(hdr header) error {
code := hdr.Length()
switch code {
case goAwayNormal:
atomic.SwapInt32(&s.remoteGoAwayNormal, 1)
// Don't close connection on normal go away. Let the existing streams
// complete gracefully.
return nil
return ErrRemoteGoAway
case goAwayProtoErr:
s.logger.Printf("[ERR] yamux: received protocol error go away")
case goAwayInternalErr:
Expand Down
38 changes: 17 additions & 21 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -648,15 +648,16 @@ func TestGoAway(t *testing.T) {

for i := 0; i < 100; i++ {
s, err := client.Open(context.Background())
switch err {
case nil:
if err == nil {
s.Close()
case ErrRemoteGoAwayNormal:
time.Sleep(50 * time.Millisecond)
continue
}
if err != ErrRemoteGoAway {
t.Fatalf("expected %s, got %s", ErrRemoteGoAway, err)
} else {
return
default:
t.Fatalf("err: %v", err)
}
time.Sleep(50 * time.Millisecond)
}
t.Fatalf("expected GoAway error")
}
Expand Down Expand Up @@ -1578,7 +1579,7 @@ func TestStreamResetWithError(t *testing.T) {
defer server.Close()

wc := new(sync.WaitGroup)
wc.Add(2)
wc.Add(1)
go func() {
defer wc.Done()
stream, err := server.AcceptStream()
Expand All @@ -1589,7 +1590,7 @@ func TestStreamResetWithError(t *testing.T) {
se := &StreamError{}
_, err = io.ReadAll(stream)
if !errors.As(err, &se) {
t.Errorf("exptected StreamError, got type:%T, err: %s", err, err)
t.Errorf("expected StreamError, got type:%T, err: %s", err, err)
return
}
expected := &StreamError{Remote: true, ErrorCode: 42}
Expand All @@ -1601,24 +1602,19 @@ func TestStreamResetWithError(t *testing.T) {
t.Error(err)
}

go func() {
defer wc.Done()

se := &StreamError{}
_, err := io.ReadAll(stream)
if !errors.As(err, &se) {
t.Errorf("exptected StreamError, got type:%T, err: %s", err, err)
return
}
expected := &StreamError{Remote: false, ErrorCode: 42}
assert.Equal(t, se, expected)
}()

time.Sleep(1 * time.Second)
err = stream.ResetWithError(42)
if err != nil {
t.Fatal(err)
}
se := &StreamError{}
_, err = io.ReadAll(stream)
if !errors.As(err, &se) {
t.Errorf("expected StreamError, got type:%T, err: %s", err, err)
return
}
expected := &StreamError{Remote: false, ErrorCode: 42}
assert.Equal(t, se, expected)
wc.Wait()
}

Expand Down
6 changes: 3 additions & 3 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ func (s *Stream) cleanup() {

// processFlags is used to update the state of the stream
// based on set flags, if any. Lock must be held
func (s *Stream) processFlags(flags uint16, hdr header) {
func (s *Stream) processFlags(hdr header, flags uint16) {
// Close the stream without holding the state lock
var closeStream bool
defer func() {
Expand Down Expand Up @@ -459,15 +459,15 @@ func (s *Stream) notifyWaiting() {

// incrSendWindow updates the size of our send window
func (s *Stream) incrSendWindow(hdr header, flags uint16) {
s.processFlags(flags, hdr)
s.processFlags(hdr, flags)
// Increase window, unblock a sender
atomic.AddUint32(&s.sendWindow, hdr.Length())
asyncNotify(s.sendNotifyCh)
}

// readData is used to handle a data frame
func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
s.processFlags(flags, hdr)
s.processFlags(hdr, flags)

// Check that our recv window is not exceeded
length := hdr.Length()
Expand Down

0 comments on commit b523bdd

Please sign in to comment.