Skip to content

Commit

Permalink
transport/bufWriter: fast-fail on error returned from flushKeepBuffer…
Browse files Browse the repository at this point in the history
…() (#7394)
  • Loading branch information
veshij authored Aug 7, 2024
1 parent 1490d60 commit ffaa81e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 9 deletions.
22 changes: 13 additions & 9 deletions internal/transport/http_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,28 +317,32 @@ func newBufWriter(conn net.Conn, batchSize int, pool *sync.Pool) *bufWriter {
return w
}

func (w *bufWriter) Write(b []byte) (n int, err error) {
func (w *bufWriter) Write(b []byte) (int, error) {
if w.err != nil {
return 0, w.err
}
if w.batchSize == 0 { // Buffer has been disabled.
n, err = w.conn.Write(b)
n, err := w.conn.Write(b)
return n, toIOError(err)
}
if w.buf == nil {
b := w.pool.Get().(*[]byte)
w.buf = *b
}
written := 0
for len(b) > 0 {
nn := copy(w.buf[w.offset:], b)
b = b[nn:]
w.offset += nn
n += nn
if w.offset >= w.batchSize {
err = w.flushKeepBuffer()
copied := copy(w.buf[w.offset:], b)
b = b[copied:]
written += copied
w.offset += copied
if w.offset < w.batchSize {
continue
}
if err := w.flushKeepBuffer(); err != nil {
return written, err
}
}
return n, err
return written, nil
}

func (w *bufWriter) Flush() error {
Expand Down
36 changes: 36 additions & 0 deletions internal/transport/http_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
package transport

import (
"errors"
"fmt"
"io"
"net"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -215,6 +218,39 @@ func (s) TestParseDialTarget(t *testing.T) {
}
}

type badNetworkConn struct {
net.Conn
}

func (c *badNetworkConn) Write([]byte) (int, error) {
return 0, io.EOF
}

// This test ensures Write() on a broken network connection does not lead to
// an infinite loop. See https://github.com/grpc/grpc-go/issues/7389 for more details.
func (s) TestWriteBadConnection(t *testing.T) {
data := []byte("test_data")
// Configure the bufWriter with a batchsize that results in data being flushed
// to the underlying conn, midway through Write().
writeBufferSize := (len(data) - 1) / 2
writer := newBufWriter(&badNetworkConn{}, writeBufferSize, getWriteBufferPool(writeBufferSize))

errCh := make(chan error, 1)
go func() {
_, err := writer.Write(data)
errCh <- err
}()

select {
case <-time.After(time.Second):
t.Fatalf("Write() did not return in time")
case err := <-errCh:
if !errors.Is(err, io.EOF) {
t.Fatalf("Write() = %v, want error presence = %v", err, io.EOF)
}
}
}

func BenchmarkDecodeGrpcMessage(b *testing.B) {
input := "Hello, %E4%B8%96%E7%95%8C"
want := "Hello, 世界"
Expand Down

0 comments on commit ffaa81e

Please sign in to comment.