Skip to content

Commit

Permalink
refactor: merge readers and writers
Browse files Browse the repository at this point in the history
  • Loading branch information
wolfogre committed Aug 1, 2024
1 parent 03de907 commit 2d98663
Showing 1 changed file with 62 additions and 51 deletions.
113 changes: 62 additions & 51 deletions modules/zstd/zstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,73 +10,49 @@ import (
"github.com/klauspost/compress/zstd"
)

type Writer zstd.Encoder
type Writer struct {
enc *zstd.Encoder

var _ io.WriteCloser = (*Writer)(nil)

func NewWriter(w io.Writer, opts ...WriterOption) (*Writer, error) {
zstdW, err := zstd.NewWriter(w, opts...)
if err != nil {
return nil, err
}
return (*Writer)(zstdW), nil
}

func (w *Writer) Write(p []byte) (int, error) {
return (*zstd.Encoder)(w).Write(p)
}

func (w *Writer) Close() error {
return (*zstd.Encoder)(w).Close()
skw seekable.Writer
buf []byte
n int
}

type Reader zstd.Decoder

var _ io.ReadCloser = (*Reader)(nil)
var _ io.WriteCloser = (*Writer)(nil)

func NewReader(r io.Reader, opts ...ReaderOption) (*Reader, error) {
zstdR, err := zstd.NewReader(r, opts...)
func NewWriter(w io.Writer, opts ...WriterOption) (*Writer, error) {
enc, err := zstd.NewWriter(w, opts...)
if err != nil {
return nil, err
}
return (*Reader)(zstdR), nil
}

func (r *Reader) Read(p []byte) (int, error) {
return (*zstd.Decoder)(r).Read(p)
}

func (r *Reader) Close() error {
(*zstd.Decoder)(r).Close() // no error returned
return nil
}

type SeekableWriter struct {
buf []byte
n int
w seekable.Writer
return &Writer{
enc: enc,
}, nil
}

var _ io.WriteCloser = (*SeekableWriter)(nil)

func NewSeekableWriter(w io.Writer, blockSize int, opts ...WriterOption) (*SeekableWriter, error) {
zstdW, err := zstd.NewWriter(nil, opts...)
func NewSeekableWriter(w io.Writer, blockSize int, opts ...WriterOption) (*Writer, error) {
enc, err := zstd.NewWriter(nil, opts...)
if err != nil {
return nil, err
}

seekableW, err := seekable.NewWriter(w, zstdW)
skw, err := seekable.NewWriter(w, enc)
if err != nil {
return nil, err
}

return &SeekableWriter{
return &Writer{
enc: enc,
skw: skw,
buf: make([]byte, blockSize),
w: seekableW,
}, nil
}

func (w *SeekableWriter) Write(p []byte) (int, error) {
func (w *Writer) Write(p []byte) (int, error) {
if w.skw != nil {
return w.enc.Write(p)
}

written := 0
for len(p) > 0 {
n := copy(w.buf[w.n:], p)
Expand All @@ -85,7 +61,7 @@ func (w *SeekableWriter) Write(p []byte) (int, error) {
p = p[n:]

if w.n == len(w.buf) {
if _, err := w.w.Write(w.buf); err != nil {
if _, err := w.skw.Write(w.buf); err != nil {
return written, err
}
w.n = 0
Expand All @@ -94,13 +70,48 @@ func (w *SeekableWriter) Write(p []byte) (int, error) {
return written, nil
}

func (w *SeekableWriter) Close() error {
if w.n > 0 {
if _, err := w.w.Write(w.buf[:w.n]); err != nil {
func (w *Writer) Close() error {
if w.skw != nil {
if w.n > 0 {
if _, err := w.skw.Write(w.buf[:w.n]); err != nil {
return err
}
}
if err := w.skw.Close(); err != nil {
return err
}
}
return w.w.Close()
return w.enc.Close()
}

type Reader struct {
dec *zstd.Decoder
skr seekable.Reader
}

var _ io.ReadCloser = (*Reader)(nil)

func NewReader(r io.Reader, opts ...ReaderOption) (*Reader, error) {
dec, err := zstd.NewReader(r, opts...)
if err != nil {
return nil, err
}
return &Reader{
dec: dec,
}, nil
}

func (r *Reader) Read(p []byte) (int, error) {
return r.dec.Read(p)
}

func (r *Reader) Close() error {
r.dec.Close() // no error returned
return nil
}

func (r *Reader) SeekReader() (seekable.Reader, error) {
return r.skr
}

type SeekableReader seekable.Reader
Expand Down

0 comments on commit 2d98663

Please sign in to comment.