Skip to content

Commit

Permalink
Merge pull request #1185 from retailnext/reuse-compression-objects
Browse files Browse the repository at this point in the history
Reuse compression objects
  • Loading branch information
bai committed Dec 14, 2018
2 parents 97315fe + f352e5c commit 94536b3
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 165 deletions.
75 changes: 75 additions & 0 deletions compress.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package sarama

import (
"bytes"
"compress/gzip"
"fmt"
"sync"

"github.com/eapache/go-xerial-snappy"
"github.com/pierrec/lz4"
)

var (
lz4WriterPool = sync.Pool{
New: func() interface{} {
return lz4.NewWriter(nil)
},
}

gzipWriterPool = sync.Pool{
New: func() interface{} {
return gzip.NewWriter(nil)
},
}
)

func compress(cc CompressionCodec, level int, data []byte) ([]byte, error) {
switch cc {
case CompressionNone:
return data, nil
case CompressionGZIP:
var (
err error
buf bytes.Buffer
writer *gzip.Writer
)
if level != CompressionLevelDefault {
writer, err = gzip.NewWriterLevel(&buf, level)
if err != nil {
return nil, err
}
} else {
writer = gzipWriterPool.Get().(*gzip.Writer)
defer gzipWriterPool.Put(writer)
writer.Reset(&buf)
}
if _, err := writer.Write(data); err != nil {
return nil, err
}
if err := writer.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
case CompressionSnappy:
return snappy.Encode(data), nil
case CompressionLZ4:
writer := lz4WriterPool.Get().(*lz4.Writer)
defer lz4WriterPool.Put(writer)

var buf bytes.Buffer
writer.Reset(&buf)

if _, err := writer.Write(data); err != nil {
return nil, err
}
if err := writer.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
case CompressionZSTD:
return zstdCompressLevel(nil, data, level)
default:
return nil, PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", cc)}
}
}
63 changes: 63 additions & 0 deletions decompress.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package sarama

import (
"bytes"
"compress/gzip"
"fmt"
"io/ioutil"
"sync"

"github.com/eapache/go-xerial-snappy"
"github.com/pierrec/lz4"
)

var (
lz4ReaderPool = sync.Pool{
New: func() interface{} {
return lz4.NewReader(nil)
},
}

gzipReaderPool sync.Pool
)

func decompress(cc CompressionCodec, data []byte) ([]byte, error) {
switch cc {
case CompressionNone:
return data, nil
case CompressionGZIP:
var (
err error
reader *gzip.Reader
readerIntf = gzipReaderPool.Get()
)
if readerIntf != nil {
reader = readerIntf.(*gzip.Reader)
} else {
reader, err = gzip.NewReader(bytes.NewReader(data))
if err != nil {
return nil, err
}
}

defer gzipReaderPool.Put(reader)

if err := reader.Reset(bytes.NewReader(data)); err != nil {
return nil, err
}

return ioutil.ReadAll(reader)
case CompressionSnappy:
return snappy.Decode(data)
case CompressionLZ4:
reader := lz4ReaderPool.Get().(*lz4.Reader)
defer lz4ReaderPool.Put(reader)

reader.Reset(bytes.NewReader(data))
return ioutil.ReadAll(reader)
case CompressionZSTD:
return zstdDecompress(nil, data)
default:
return nil, PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", cc)}
}
}
98 changes: 8 additions & 90 deletions message.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
package sarama

import (
"bytes"
"compress/gzip"
"fmt"
"io/ioutil"
"time"

"github.com/eapache/go-xerial-snappy"
"github.com/pierrec/lz4"
)

// CompressionCodec represents the various compression codecs recognized by Kafka in messages.
Expand Down Expand Up @@ -77,53 +71,12 @@ func (m *Message) encode(pe packetEncoder) error {
payload = m.compressedCache
m.compressedCache = nil
} else if m.Value != nil {
switch m.Codec {
case CompressionNone:
payload = m.Value
case CompressionGZIP:
var buf bytes.Buffer
var writer *gzip.Writer
if m.CompressionLevel != CompressionLevelDefault {
writer, err = gzip.NewWriterLevel(&buf, m.CompressionLevel)
if err != nil {
return err
}
} else {
writer = gzip.NewWriter(&buf)
}
if _, err = writer.Write(m.Value); err != nil {
return err
}
if err = writer.Close(); err != nil {
return err
}
m.compressedCache = buf.Bytes()
payload = m.compressedCache
case CompressionSnappy:
tmp := snappy.Encode(m.Value)
m.compressedCache = tmp
payload = m.compressedCache
case CompressionLZ4:
var buf bytes.Buffer
writer := lz4.NewWriter(&buf)
if _, err = writer.Write(m.Value); err != nil {
return err
}
if err = writer.Close(); err != nil {
return err
}
m.compressedCache = buf.Bytes()
payload = m.compressedCache
case CompressionZSTD:
c, err := zstdCompressLevel(nil, m.Value, m.CompressionLevel)
if err != nil {
return err
}
m.compressedCache = c
payload = m.compressedCache
default:
return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", m.Codec)}

payload, err = compress(m.Codec, m.CompressionLevel, m.Value)
if err != nil {
return err
}
m.compressedCache = payload
// Keep in mind the compressed payload size for metric gathering
m.compressedSize = len(payload)
}
Expand Down Expand Up @@ -179,53 +132,18 @@ func (m *Message) decode(pd packetDecoder) (err error) {
switch m.Codec {
case CompressionNone:
// nothing to do
case CompressionGZIP:
default:
if m.Value == nil {
break
}
reader, err := gzip.NewReader(bytes.NewReader(m.Value))

m.Value, err = decompress(m.Codec, m.Value)
if err != nil {
return err
}
if m.Value, err = ioutil.ReadAll(reader); err != nil {
return err
}
if err := m.decodeSet(); err != nil {
return err
}
case CompressionSnappy:
if m.Value == nil {
break
}
if m.Value, err = snappy.Decode(m.Value); err != nil {
return err
}
if err := m.decodeSet(); err != nil {
return err
}
case CompressionLZ4:
if m.Value == nil {
break
}
reader := lz4.NewReader(bytes.NewReader(m.Value))
if m.Value, err = ioutil.ReadAll(reader); err != nil {
return err
}
if err := m.decodeSet(); err != nil {
return err
}
case CompressionZSTD:
if m.Value == nil {
break
}
if m.Value, err = zstdDecompress(nil, m.Value); err != nil {
return err
}
if err := m.decodeSet(); err != nil {
return err
}
default:
return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", m.Codec)}
}

return pd.pop()
Expand Down
80 changes: 5 additions & 75 deletions record_batch.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
package sarama

import (
"bytes"
"compress/gzip"
"fmt"
"io/ioutil"
"time"

"github.com/eapache/go-xerial-snappy"
"github.com/pierrec/lz4"
)

const recordBatchOverhead = 49
Expand Down Expand Up @@ -174,31 +168,9 @@ func (b *RecordBatch) decode(pd packetDecoder) (err error) {
return err
}

switch b.Codec {
case CompressionNone:
case CompressionGZIP:
reader, err := gzip.NewReader(bytes.NewReader(recBuffer))
if err != nil {
return err
}
if recBuffer, err = ioutil.ReadAll(reader); err != nil {
return err
}
case CompressionSnappy:
if recBuffer, err = snappy.Decode(recBuffer); err != nil {
return err
}
case CompressionLZ4:
reader := lz4.NewReader(bytes.NewReader(recBuffer))
if recBuffer, err = ioutil.ReadAll(reader); err != nil {
return err
}
case CompressionZSTD:
if recBuffer, err = zstdDecompress(nil, recBuffer); err != nil {
return err
}
default:
return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", b.Codec)}
recBuffer, err = decompress(b.Codec, recBuffer)
if err != nil {
return err
}

b.recordsLen = len(recBuffer)
Expand All @@ -219,50 +191,8 @@ func (b *RecordBatch) encodeRecords(pe packetEncoder) error {
}
b.recordsLen = len(raw)

switch b.Codec {
case CompressionNone:
b.compressedRecords = raw
case CompressionGZIP:
var buf bytes.Buffer
var writer *gzip.Writer
if b.CompressionLevel != CompressionLevelDefault {
writer, err = gzip.NewWriterLevel(&buf, b.CompressionLevel)
if err != nil {
return err
}
} else {
writer = gzip.NewWriter(&buf)
}
if _, err := writer.Write(raw); err != nil {
return err
}
if err := writer.Close(); err != nil {
return err
}
b.compressedRecords = buf.Bytes()
case CompressionSnappy:
b.compressedRecords = snappy.Encode(raw)
case CompressionLZ4:
var buf bytes.Buffer
writer := lz4.NewWriter(&buf)
if _, err := writer.Write(raw); err != nil {
return err
}
if err := writer.Close(); err != nil {
return err
}
b.compressedRecords = buf.Bytes()
case CompressionZSTD:
c, err := zstdCompressLevel(nil, raw, b.CompressionLevel)
if err != nil {
return err
}
b.compressedRecords = c
default:
return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)}
}

return nil
b.compressedRecords, err = compress(b.Codec, b.CompressionLevel, raw)
return err
}

func (b *RecordBatch) computeAttributes() int16 {
Expand Down

0 comments on commit 94536b3

Please sign in to comment.