Skip to content

Commit

Permalink
Merge pull request #3 from ortuman/reduce-decompression-buffer-alloca…
Browse files Browse the repository at this point in the history
…tions

Signed-off-by: Miguel Ángel Ortuño <ortuman@gmail.com>
  • Loading branch information
ortuman committed Oct 3, 2024
1 parent bf66f21 commit 6010ccd
Show file tree
Hide file tree
Showing 7 changed files with 366 additions and 43 deletions.
57 changes: 49 additions & 8 deletions pkg/kgo/compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"github.com/klauspost/compress/s2"
"github.com/klauspost/compress/zstd"
"github.com/pierrec/lz4/v4"

"github.com/twmb/franz-go/pkg/kgo/internal/pool"
)

var byteBuffers = sync.Pool{New: func() any { return bytes.NewBuffer(make([]byte, 8<<10)) }}
Expand Down Expand Up @@ -266,15 +268,23 @@ type zstdDecoder struct {
inner *zstd.Decoder
}

func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) {
func (d *decompressor) decompress(src []byte, codec byte, pool *pool.BucketedPool[byte]) ([]byte, error) {
// Early return in case there is no compression
compCodec := codecType(codec)
if compCodec == codecNone {
return src, nil
}
out := byteBuffers.Get().(*bytes.Buffer)
out.Reset()
defer byteBuffers.Put(out)

out, buf, err := d.getDecodedBuffer(src, compCodec, pool)
if err != nil {
return nil, err
}
defer func() {
if compCodec == codecSnappy {
return
}
pool.Put(buf)
}()

switch compCodec {
case codecGzip:
Expand All @@ -286,7 +296,7 @@ func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) {
if _, err := io.Copy(out, ungz); err != nil {
return nil, err
}
return append([]byte(nil), out.Bytes()...), nil
return d.copyDecodedBuffer(out.Bytes(), compCodec, pool), nil
case codecSnappy:
if len(src) > 16 && bytes.HasPrefix(src, xerialPfx) {
return xerialDecode(src)
Expand All @@ -295,28 +305,59 @@ func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) {
if err != nil {
return nil, err
}
return append([]byte(nil), decoded...), nil
return d.copyDecodedBuffer(decoded, compCodec, pool), nil
case codecLZ4:
unlz4 := d.unlz4Pool.Get().(*lz4.Reader)
defer d.unlz4Pool.Put(unlz4)
unlz4.Reset(bytes.NewReader(src))
if _, err := io.Copy(out, unlz4); err != nil {
return nil, err
}
return append([]byte(nil), out.Bytes()...), nil
return d.copyDecodedBuffer(out.Bytes(), compCodec, pool), nil
case codecZstd:
unzstd := d.unzstdPool.Get().(*zstdDecoder)
defer d.unzstdPool.Put(unzstd)
decoded, err := unzstd.inner.DecodeAll(src, out.Bytes())
if err != nil {
return nil, err
}
return append([]byte(nil), decoded...), nil
return d.copyDecodedBuffer(decoded, compCodec, pool), nil
default:
return nil, errors.New("unknown compression codec")
}
}

func (d *decompressor) getDecodedBuffer(src []byte, compCodec codecType, pool *pool.BucketedPool[byte]) (*bytes.Buffer, []byte, error) {
var (
decodedBufSize int
err error
)
switch compCodec {
case codecSnappy:
decodedBufSize, err = s2.DecodedLen(src)
if err != nil {
return nil, nil, err
}

default:
// Make a guess at the output size.
decodedBufSize = len(src) * 2
}
buf := pool.Get(decodedBufSize)[:0]

return bytes.NewBuffer(buf), buf, nil
}

func (d *decompressor) copyDecodedBuffer(decoded []byte, compCodec codecType, pool *pool.BucketedPool[byte]) []byte {
if compCodec == codecSnappy {
// We already know the actual size of the decoded buffer before decompression,
// so there's no need to copy the buffer.
return decoded
}
out := pool.Get(len(decoded))
return append(out[:0], decoded...)
}

var xerialPfx = []byte{130, 83, 78, 65, 80, 80, 89, 0}

var errMalformedXerial = errors.New("malformed xerial framing")
Expand Down
8 changes: 6 additions & 2 deletions pkg/kgo/compression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"testing"

"github.com/pierrec/lz4/v4"

"github.com/twmb/franz-go/pkg/kgo/internal/pool"
)

// Regression test for #778.
Expand Down Expand Up @@ -78,6 +80,8 @@ func TestCompressDecompress(t *testing.T) {
randStr(1 << 8),
}

buffPool := pool.NewBucketedPool(1, 1<<16, 2, func(int) []byte { return make([]byte, 1<<16) })

var wg sync.WaitGroup
for _, produceVersion := range []int16{
0, 7,
Expand Down Expand Up @@ -110,7 +114,7 @@ func TestCompressDecompress(t *testing.T) {
w.Reset()

got, used := c.compress(w, in, produceVersion)
got, err := d.decompress(got, byte(used))
got, err := d.decompress(got, byte(used), buffPool)
if err != nil {
t.Errorf("unexpected decompress err: %v", err)
return
Expand Down Expand Up @@ -156,7 +160,7 @@ func BenchmarkDecompress(b *testing.B) {
b.Run(fmt.Sprint(codec), func(b *testing.B) {
for i := 0; i < b.N; i++ {
d := newDecompressor()
d.decompress(w.Bytes(), byte(codec))
d.decompress(w.Bytes(), byte(codec), pool.NewBucketedPool(1, 1<<16, 2, func(int) []byte { return make([]byte, 1<<16) }))
}
})
byteBuffers.Put(w)
Expand Down
10 changes: 9 additions & 1 deletion pkg/kgo/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"github.com/twmb/franz-go/pkg/kmsg"
"github.com/twmb/franz-go/pkg/kversion"
"github.com/twmb/franz-go/pkg/sasl"

"github.com/twmb/franz-go/pkg/kgo/internal/pool"
)

// Opt is an option to configure a client.
Expand Down Expand Up @@ -151,7 +153,8 @@ type cfg struct {
partitions map[string]map[int32]Offset // partitions to directly consume from
regex bool

recordsPool recordsPool
recordsPool *recordsPool
decompressBufferPool *pool.BucketedPool[byte]

////////////////////////////
// CONSUMER GROUP SECTION //
Expand Down Expand Up @@ -391,6 +394,11 @@ func (cfg *cfg) validate() error {
}
cfg.hooks = processedHooks

// Assume a 2x compression ratio.
maxDecompressedBatchSize := int(cfg.maxBytes.load()) * 2
cfg.decompressBufferPool = pool.NewBucketedPool[byte](4096, maxDecompressedBatchSize, 2, func(sz int) []byte {
return make([]byte, sz)
})
return nil
}

Expand Down
94 changes: 94 additions & 0 deletions pkg/kgo/internal/pool/bucketed_pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright 2017 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package pool

import (
"sync"
)

// BucketedPool is a bucketed pool for variably sized slices.
type BucketedPool[T any] struct {
buckets []sync.Pool
sizes []int
// make is the function used to create an empty slice when none exist yet.
make func(int) []T
}

// NewBucketedPool returns a new BucketedPool with size buckets for minSize to maxSize
// increasing by the given factor.
func NewBucketedPool[T any](minSize, maxSize int, factor float64, makeFunc func(int) []T) *BucketedPool[T] {
if minSize < 1 {
panic("invalid minimum pool size")
}
if maxSize < 1 {
panic("invalid maximum pool size")
}
if factor < 1 {
panic("invalid factor")
}

var sizes []int

for s := minSize; s <= maxSize; s = int(float64(s) * factor) {
sizes = append(sizes, s)
}

p := &BucketedPool[T]{
buckets: make([]sync.Pool, len(sizes)),
sizes: sizes,
make: makeFunc,
}
return p
}

// Get returns a new slice with capacity greater than or equal to size.
func (p *BucketedPool[T]) Get(size int) []T {
for i, bktSize := range p.sizes {
if size > bktSize {
continue
}
buff := p.buckets[i].Get()
if buff == nil {
buff = p.make(bktSize)
}
return buff.([]T)
}
return p.make(size)
}

// Put adds a slice to the right bucket in the pool.
// If the slice does not belong to any bucket in the pool, it is ignored.
func (p *BucketedPool[T]) Put(s []T) {
sCap := cap(s)
if sCap < p.sizes[0] {
return
}

for i, size := range p.sizes {
if sCap > size {
continue
}

if sCap == size {
// Buffer is exactly the minimum size for this bucket. Add it to this bucket.
p.buckets[i].Put(s)
} else {
// Buffer belongs in previous bucket.
p.buckets[i-1].Put(s)
}
return
}
}


68 changes: 68 additions & 0 deletions pkg/kgo/internal/pool/bucketed_pool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// SPDX-License-Identifier: Apache-2.0
// Provenance-includes-location: https://github.com/prometheus/prometheus/blob/main/util/pool/pool_test.go
// Provenance-includes-copyright: The Prometheus Authors

package pool

import (
"testing"
)

func makeFunc(size int) []int {
return make([]int, 0, size)
}

func TestBucketedPool_HappyPath(t *testing.T) {
testPool := NewBucketedPool(1, 8, 2, makeFunc)
cases := []struct {
size int
expectedCap int
}{
{
size: -1,
expectedCap: 1,
},
{
size: 3,
expectedCap: 4,
},
{
size: 10,
expectedCap: 10,
},
}
for _, c := range cases {
ret := testPool.Get(c.size)
if cap(ret) < c.expectedCap {
t.Fatalf("expected cap >= %d, got %d", c.expectedCap , cap(ret))
}
testPool.Put(ret)
}
}

func TestBucketedPool_SliceNotAlignedToBuckets(t *testing.T) {
pool := NewBucketedPool(1, 1000, 10, makeFunc)
pool.Put(make([]int, 0, 2))
s := pool.Get(3)
if cap(s) < 3 {
t.Fatalf("expected cap >= 3, got %d", cap(s))
}
}

func TestBucketedPool_PutEmptySlice(t *testing.T) {
pool := NewBucketedPool(1, 1000, 10, makeFunc)
pool.Put([]int{})
s := pool.Get(1)
if cap(s) < 1 {
t.Fatalf("expected cap >= 1, got %d", cap(s))
}
}

func TestBucketedPool_PutSliceSmallerThanMinimum(t *testing.T) {
pool := NewBucketedPool(3, 1000, 10, makeFunc)
pool.Put([]int{1, 2})
s := pool.Get(3)
if cap(s) < 3 {
t.Fatalf("expected cap >= 3, got %d", cap(s))
}
}
25 changes: 23 additions & 2 deletions pkg/kgo/record_and_fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"reflect"
"time"
"unsafe"

"github.com/twmb/franz-go/pkg/kmsg"
)

// RecordHeader contains extra information that can be sent with Records.
Expand Down Expand Up @@ -153,15 +155,34 @@ type Record struct {
// recordsPool is the pool that this record was fetched from, if any.
//
// When reused, record is returned to this pool.
recordsPool recordsPool
recordsPool *recordsPool

// rcBatchBuffer is used to keep track of the raw buffer that this record was
// derived from when consuming, after decompression.
//
// This is used to allow reusing these buffers when record pooling has been enabled
// via EnableRecordsPool option.
rcBatchBuffer *rcBuffer[byte]

// rcRawRecordsBuffer is used to keep track of the raw record buffer that this record was
// derived from when consuming.
//
// This is used to allow reusing these buffers when record pooling has been enabled
// via EnableRecordsPool option.
rcRawRecordsBuffer *rcBuffer[kmsg.Record]
}

// Reuse releases the record back to the pool.
//
//
// Once this method has been called, any reference to the passed record should be considered invalid by the caller,
// as it may be reused as a result of future calls to the PollFetches/PollRecords method.
func (r *Record) Reuse() {
r.recordsPool.put(r)
if r.recordsPool != nil {
r.rcRawRecordsBuffer.release()
r.rcBatchBuffer.release()
r.recordsPool.put(r)
}
}

func (r *Record) userSize() int64 {
Expand Down
Loading

0 comments on commit 6010ccd

Please sign in to comment.