Skip to content

Commit

Permalink
mem: ReadAll for more efficient io.Reader consumption
Browse files Browse the repository at this point in the history
  • Loading branch information
ash2k committed Sep 20, 2024
1 parent 1418e5e commit d601913
Show file tree
Hide file tree
Showing 3 changed files with 348 additions and 2 deletions.
48 changes: 48 additions & 0 deletions mem/buffer_slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ import (
"io"
)

const (
readAllBufSize = 32 * 1024 // 32 KiB
)

// BufferSlice offers a means to represent data that spans one or more Buffer
// instances. A BufferSlice is meant to be immutable after creation, and methods
// like Ref create and return copies of the slice. This is why all methods have
Expand Down Expand Up @@ -224,3 +228,47 @@ func (w *writer) Write(p []byte) (n int, err error) {
func NewWriter(buffers *BufferSlice, pool BufferPool) io.Writer {
return &writer{buffers: buffers, pool: pool}
}

// ReadAll reads from r until an error or EOF and returns the data it read.
// A successful call returns err == nil, not err == EOF. Because ReadAll is
// defined to read from src until EOF, it does not treat an EOF from Read
// as an error to be reported.
// Make sure to free the returned buffers even if you get an error from this function.
func ReadAll(r io.Reader, pool BufferPool) (BufferSlice, error) {
var result BufferSlice
wt, ok := r.(io.WriterTo)
if ok {
// This is more optimal since wt knows the size of chunks it wants to write and, hence, we can allocate
// buffers of an optimal size to fit them. E.g. might be a single big chunk, and we wouldn't chop it into pieces.
w := NewWriter(&result, pool)
_, err := wt.WriteTo(w)
return result, err
}
for {
buf := pool.Get(readAllBufSize)
// We asked for 32KiB but may have been given a bigger buffer. Use all of it if that's the case.
*buf = (*buf)[:cap(*buf)]
usedCap := 0
for {
n, err := r.Read((*buf)[usedCap:])
usedCap += n
if err != nil {
if usedCap == 0 {
// Nothing in this buf, put it back
pool.Put(buf)
} else {
*buf = (*buf)[:usedCap]
result = append(result, NewBuffer(buf, pool))
}
if err == io.EOF {
err = nil
}
return result, err
}
if len(*buf) == usedCap {
result = append(result, NewBuffer(buf, pool))
break // grab a new buf from pool
}
}
}
}
299 changes: 299 additions & 0 deletions mem/buffer_slice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,23 @@ package mem_test

import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"io"
"testing"

"google.golang.org/grpc/mem"
)

const (
// 1025 is a value above 1024 that is not mem.IsBelowBufferPoolingThreshold().
// See https://github.com/grpc/grpc-go/issues/7631.
minReadSize = 1025
// Should match the constant in buffer_slice.go (another package)
readAllBufSize = 32 * 1024 // 32 KiB
)

func newBuffer(data []byte, pool mem.BufferPool) mem.Buffer {
return mem.NewBuffer(&data, pool)
}
Expand Down Expand Up @@ -156,6 +166,249 @@ func (s) TestBufferSlice_Reader(t *testing.T) {
}
}

func (s) TestBufferSlice_ReadAll_Reads(t *testing.T) {
testcases := []struct {
name string
reads []readStep
expectedErr string
expectedBufs int
}{
{
name: "EOF",
reads: []readStep{
{
err: io.EOF,
},
},
},
{
name: "data,EOF",
reads: []readStep{
{
n: minReadSize,
},
{
err: io.EOF,
},
},
expectedBufs: 1,
},
{
name: "data+EOF",
reads: []readStep{
{
n: minReadSize,
err: io.EOF,
},
},
expectedBufs: 1,
},
{
name: "0,data+EOF",
reads: []readStep{
{},
{
n: minReadSize,
err: io.EOF,
},
},
expectedBufs: 1,
},
{
name: "0,data,EOF",
reads: []readStep{
{},
{
n: minReadSize,
},
{
err: io.EOF,
},
},
expectedBufs: 1,
},
{
name: "data,data+EOF",
reads: []readStep{
{
n: minReadSize,
},
{
n: minReadSize,
err: io.EOF,
},
},
expectedBufs: 1,
},
{
name: "error",
reads: []readStep{
{
err: errors.New("boom"),
},
},
expectedErr: "boom",
},
{
name: "data+error",
reads: []readStep{
{
n: minReadSize,
err: errors.New("boom"),
},
},
expectedErr: "boom",
expectedBufs: 1,
},
{
name: "data,data+error",
reads: []readStep{
{
n: minReadSize,
},
{
n: minReadSize,
err: errors.New("boom"),
},
},
expectedErr: "boom",
expectedBufs: 1,
},
{
name: "data,data+EOF - whole buf",
reads: []readStep{
{
n: minReadSize,
},
{
n: readAllBufSize - minReadSize,
err: io.EOF,
},
},
expectedBufs: 1,
},
{
name: "data,data,EOF - whole buf",
reads: []readStep{
{
n: minReadSize,
},
{
n: readAllBufSize - minReadSize,
},
{
err: io.EOF,
},
},
expectedBufs: 1,
},
{
name: "data,data,EOF - 2 bufs",
reads: []readStep{
{
n: readAllBufSize,
},
{
n: minReadSize,
},
{
n: readAllBufSize - minReadSize,
},
{
n: minReadSize,
},
{
err: io.EOF,
},
},
expectedBufs: 3,
},
}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
pool := &testPool{
allocated: make(map[*[]byte]struct{}),
}
r := &stepReader{
reads: tc.reads,
}
data, err := mem.ReadAll(r, pool)
if tc.expectedErr != "" {
if err == nil || err.Error() != tc.expectedErr {
t.Fatalf("ReadAll() expected error %q, got %q", tc.expectedErr, err)
}
} else {
if err != nil {
t.Fatal(err)
}
}
actualData := data.Materialize()
if !bytes.Equal(r.read, actualData) {
t.Fatalf("ReadAll() expected data %q, got %q", r.read, actualData)
}
if len(data) != tc.expectedBufs {
t.Fatalf("ReadAll() expected %d bufs, got %d", tc.expectedBufs, len(data))
}
for i := 0; i < len(data)-1; i++ { // all but last should be full buffers
if data[i].Len() != readAllBufSize {
t.Fatalf("ReadAll() expected data length %d, got %d", readAllBufSize, data[i].Len())
}
}
data.Free()
if len(pool.allocated) > 0 {
t.Fatalf("expected no allocated buffers, got %d", len(pool.allocated))
}
})
}
}

func (s) TestBufferSlice_ReadAll_WriteTo(t *testing.T) {
testcases := []struct {
name string
size int
}{
{
name: "small",
size: minReadSize,
},
{
name: "exact size",
size: readAllBufSize,
},
{
name: "big",
size: readAllBufSize * 3,
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
pool := &testPool{
allocated: make(map[*[]byte]struct{}),
}
buf := make([]byte, tc.size)
_, err := rand.Read(buf)
if err != nil {
t.Fatal(err)
}
r := bytes.NewBuffer(buf)
data, err := mem.ReadAll(r, pool)
if err != nil {
t.Fatal(err)
}

actualData := data.Materialize()
if !bytes.Equal(buf, actualData) {
t.Fatalf("ReadAll() expected data %q, got %q", buf, actualData)
}
data.Free()
if len(pool.allocated) > 0 {
t.Fatalf("expected no allocated buffers, got %d", len(pool.allocated))
}
})
}
}

func ExampleNewWriter() {
var bs mem.BufferSlice
pool := mem.DefaultBufferPool()
Expand All @@ -176,3 +429,49 @@ func ExampleNewWriter() {
// Wrote 4 bytes, err: <nil>
// abcdabcdabcd
}

var (
_ io.Reader = (*stepReader)(nil)
_ mem.BufferPool = (*testPool)(nil)
)

type readStep struct {
n int
err error
}

type stepReader struct {
reads []readStep
read []byte
}

func (s *stepReader) Read(buf []byte) (int, error) {
if len(s.reads) == 0 {
panic("unexpected Read() call")
}
read := s.reads[0]
s.reads = s.reads[1:]
_, err := rand.Read(buf[:read.n])
if err != nil {
panic(err)
}
s.read = append(s.read, buf[:read.n]...)
return read.n, read.err
}

type testPool struct {
allocated map[*[]byte]struct{}
}

func (t *testPool) Get(length int) *[]byte {
buf := make([]byte, length)
t.allocated[&buf] = struct{}{}
return &buf
}

func (t *testPool) Put(buf *[]byte) {
if _, ok := t.allocated[buf]; !ok {
panic("unexpected put")
}
delete(t.allocated, buf)
}
3 changes: 1 addition & 2 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -899,8 +899,7 @@ func decompress(compressor encoding.Compressor, d mem.BufferSlice, maxReceiveMes
// }
//}

var out mem.BufferSlice
_, err = io.Copy(mem.NewWriter(&out, pool), io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
out, err := mem.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1), pool)
if err != nil {
out.Free()
return nil, 0, err
Expand Down

0 comments on commit d601913

Please sign in to comment.