Skip to content

Commit

Permalink
End stream flag bugfix (#3803)
Browse files Browse the repository at this point in the history
  • Loading branch information
GarrettGutierrez1 authored Aug 21, 2020
1 parent e14f1c2 commit b9bc8e7
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 7 deletions.
4 changes: 4 additions & 0 deletions internal/transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,10 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
if !ok {
return
}
if s.getState() == streamReadDone {
t.closeStream(s, true, http2.ErrCodeStreamClosed, false)
return
}
if size > 0 {
if err := s.fc.onData(size); err != nil {
t.closeStream(s, true, http2.ErrCodeFlowControl, false)
Expand Down
110 changes: 105 additions & 5 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4535,7 +4535,7 @@ func testClientRequestBodyErrorUnexpectedEOF(t *testing.T, e env) {
te.startServer(ts)
defer te.tearDown()
te.withServerTester(func(st *serverTester) {
st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall")
st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall", false)
// Say we have 5 bytes coming, but set END_STREAM flag:
st.writeData(1, true, []byte{0, 0, 0, 0, 5})
st.wantAnyFrame() // wait for server to crash (it used to crash)
Expand All @@ -4559,7 +4559,7 @@ func testClientRequestBodyErrorCloseAfterLength(t *testing.T, e env) {
te.startServer(ts)
defer te.tearDown()
te.withServerTester(func(st *serverTester) {
st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall")
st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall", false)
// say we're sending 5 bytes, but then close the connection instead.
st.writeData(1, false, []byte{0, 0, 0, 0, 5})
st.cc.Close()
Expand All @@ -4582,7 +4582,7 @@ func testClientRequestBodyErrorCancel(t *testing.T, e env) {
te.startServer(ts)
defer te.tearDown()
te.withServerTester(func(st *serverTester) {
st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall")
st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall", false)
// Say we have 5 bytes coming, but cancel it instead.
st.writeRSTStream(1, http2.ErrCodeCancel)
st.writeData(1, false, []byte{0, 0, 0, 0, 5})
Expand All @@ -4595,7 +4595,7 @@ func testClientRequestBodyErrorCancel(t *testing.T, e env) {
}

// And now send an uncanceled (but still invalid), just to get a response.
st.writeHeadersGRPC(3, "/grpc.testing.TestService/UnaryCall")
st.writeHeadersGRPC(3, "/grpc.testing.TestService/UnaryCall", false)
st.writeData(3, true, []byte{0, 0, 0, 0, 0})
<-gotCall
st.wantAnyFrame()
Expand All @@ -4619,7 +4619,7 @@ func testClientRequestBodyErrorCancelStreamingInput(t *testing.T, e env) {
te.startServer(ts)
defer te.tearDown()
te.withServerTester(func(st *serverTester) {
st.writeHeadersGRPC(1, "/grpc.testing.TestService/StreamingInputCall")
st.writeHeadersGRPC(1, "/grpc.testing.TestService/StreamingInputCall", false)
// Say we have 5 bytes coming, but cancel it instead.
st.writeData(1, false, []byte{0, 0, 0, 0, 5})
st.writeRSTStream(1, http2.ErrCodeCancel)
Expand All @@ -4636,6 +4636,106 @@ func testClientRequestBodyErrorCancelStreamingInput(t *testing.T, e env) {
})
}

func (s) TestClientInitialHeaderEndStream(t *testing.T) {
for _, e := range listTestEnv() {
if e.httpHandler {
continue
}
testClientInitialHeaderEndStream(t, e)
}
}

func testClientInitialHeaderEndStream(t *testing.T, e env) {
// To ensure RST_STREAM is sent for illegal data write and not normal stream
// close.
frameCheckingDone := make(chan struct{})
// To ensure goroutine for test does not end before RPC handler performs error
// checking.
handlerDone := make(chan struct{})
te := newTest(t, e)
ts := &funcServer{streamingInputCall: func(stream testpb.TestService_StreamingInputCallServer) error {
defer close(handlerDone)
// Block on serverTester receiving RST_STREAM. This ensures server has closed
// stream before stream.Recv().
<-frameCheckingDone
data, err := stream.Recv()
if err == nil {
t.Errorf("unexpected data received in func server method: '%v'", data)
} else if status.Code(err) != codes.Canceled {
t.Errorf("expected canceled error, instead received '%v'", err)
}
return nil
}}
te.startServer(ts)
defer te.tearDown()
te.withServerTester(func(st *serverTester) {
// Send a headers with END_STREAM flag, but then write data.
st.writeHeadersGRPC(1, "/grpc.testing.TestService/StreamingInputCall", true)
st.writeData(1, false, []byte{0, 0, 0, 0, 0})
st.wantAnyFrame()
st.wantAnyFrame()
st.wantRSTStream(http2.ErrCodeStreamClosed)
close(frameCheckingDone)
<-handlerDone
})
}

func (s) TestClientSendDataAfterCloseSend(t *testing.T) {
for _, e := range listTestEnv() {
if e.httpHandler {
continue
}
testClientSendDataAfterCloseSend(t, e)
}
}

func testClientSendDataAfterCloseSend(t *testing.T, e env) {
// To ensure RST_STREAM is sent for illegal data write prior to execution of RPC
// handler.
frameCheckingDone := make(chan struct{})
// To ensure goroutine for test does not end before RPC handler performs error
// checking.
handlerDone := make(chan struct{})
te := newTest(t, e)
ts := &funcServer{streamingInputCall: func(stream testpb.TestService_StreamingInputCallServer) error {
defer close(handlerDone)
// Block on serverTester receiving RST_STREAM. This ensures server has closed
// stream before stream.Recv().
<-frameCheckingDone
for {
_, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
if status.Code(err) != codes.Canceled {
t.Errorf("expected canceled error, instead received '%v'", err)
}
break
}
}
if err := stream.SendMsg(nil); err == nil {
t.Error("expected error sending message on stream after stream closed due to illegal data")
} else if status.Code(err) != codes.Internal {
t.Errorf("expected internal error, instead received '%v'", err)
}
return nil
}}
te.startServer(ts)
defer te.tearDown()
te.withServerTester(func(st *serverTester) {
st.writeHeadersGRPC(1, "/grpc.testing.TestService/StreamingInputCall", false)
// Send data with END_STREAM flag, but then write more data.
st.writeData(1, true, []byte{0, 0, 0, 0, 0})
st.writeData(1, false, []byte{0, 0, 0, 0, 0})
st.wantAnyFrame()
st.wantAnyFrame()
st.wantRSTStream(http2.ErrCodeStreamClosed)
close(frameCheckingDone)
<-handlerDone
})
}

func (s) TestClientResourceExhaustedCancelFullDuplex(t *testing.T) {
for _, e := range listTestEnv() {
if e.httpHandler {
Expand Down
19 changes: 17 additions & 2 deletions test/servertester.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,21 @@ func (st *serverTester) writeSettingsAck() {
}
}

func (st *serverTester) wantRSTStream(errCode http2.ErrCode) *http2.RSTStreamFrame {
f, err := st.readFrame()
if err != nil {
st.t.Fatalf("Error while expecting an RST frame: %v", err)
}
sf, ok := f.(*http2.RSTStreamFrame)
if !ok {
st.t.Fatalf("got a %T; want *http2.RSTStreamFrame", f)
}
if sf.ErrCode != errCode {
st.t.Fatalf("expected RST error code '%v', got '%v'", errCode.String(), sf.ErrCode.String())
}
return sf
}

func (st *serverTester) wantSettings() *http2.SettingsFrame {
f, err := st.readFrame()
if err != nil {
Expand Down Expand Up @@ -227,7 +242,7 @@ func (st *serverTester) encodeHeader(headers ...string) []byte {
return st.headerBuf.Bytes()
}

func (st *serverTester) writeHeadersGRPC(streamID uint32, path string) {
func (st *serverTester) writeHeadersGRPC(streamID uint32, path string, endStream bool) {
st.writeHeaders(http2.HeadersFrameParam{
StreamID: streamID,
BlockFragment: st.encodeHeader(
Expand All @@ -236,7 +251,7 @@ func (st *serverTester) writeHeadersGRPC(streamID uint32, path string) {
"content-type", "application/grpc",
"te", "trailers",
),
EndStream: false,
EndStream: endStream,
EndHeaders: true,
})
}
Expand Down

0 comments on commit b9bc8e7

Please sign in to comment.