diff --git a/factory/http2_test.go b/factory/http2_test.go index e2f742a..8ebe279 100644 --- a/factory/http2_test.go +++ b/factory/http2_test.go @@ -145,7 +145,9 @@ func getStream(data []byte) io.Reader { go func() { time.Sleep(100 * time.Millisecond) - writer.Write(data) + if len(data) != 0 { + writer.Write(data) + } writer.Close() }() @@ -156,8 +158,10 @@ func getBadStream(data []byte) io.Reader { reader, writer := io.Pipe() go func() { - writer.Write(data) time.Sleep(100 * time.Millisecond) + if len(data) != 0 { + writer.Write(data) + } writer.CloseWithError(errors.New("test error")) }() diff --git a/server.go b/server.go index 1c09530..9affc0c 100644 --- a/server.go +++ b/server.go @@ -1829,26 +1829,29 @@ func (sc *serverConn) newWriterAndRequestNoBody(st *stream) (*responseWriter, er return rw, nil } -func writeResponseBody(rw *responseWriter, reqCtx *app.RequestContext) error { +func writeResponseBody(rw *responseWriter, reqCtx *app.RequestContext) (err error) { if reqCtx.Response.IsBodyStream() { + var n int vbuf := utils.CopyBufPool.Get() buf := vbuf.([]byte) for { - n, err := reqCtx.Response.BodyStream().Read(buf) - if err == io.EOF { + n, err = reqCtx.Response.BodyStream().Read(buf) + if n == 0 { + if err == nil { + return errors.New("response bodyStream().Read(buf) returns 0, nil") + } + if err == io.EOF { + err = nil + } break } - if err != nil { - return err - } - _, err = rw.Write(buf[:n]) - if err != nil { - return err + if _, err = rw.Write(buf[:n]); err != nil { + break } } utils.CopyBufPool.Put(vbuf) - return nil + return err } else { // reqCtx.Response.Body can be no error // will split at FrameWriteRequest's Consume function diff --git a/server_test.go b/server_test.go index 028ca53..9f34d1a 100644 --- a/server_test.go +++ b/server_test.go @@ -202,6 +202,7 @@ func newHertzServerTester(t testing.TB, handler app.HandlerFunc, opts ...interfa } st.hpackEnc = hpack.NewEncoder(&st.headerBuf) st.hpackDec = hpack.NewDecoder(initialHeaderTableSize, st.onHeaderField) + st.addLogFilter("Transport readFrame error") testHookGetServerConn = func(v *serverConn) { st.scMu.Lock() @@ -228,6 +229,7 @@ func newHertzServerTester(t testing.TB, handler app.HandlerFunc, opts ...interfa st.fr.logWrites = true } } + return st }