Skip to content

Commit

Permalink
Merge pull request #1 from kl7sn/master
Browse files Browse the repository at this point in the history
fix: gzip read seeker causes oss put fail
  • Loading branch information
sevennt authored Sep 27, 2023
2 parents 2b15d50 + 93af576 commit 58a5a68
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 44 deletions.
13 changes: 6 additions & 7 deletions aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ func (a *S3) GetAndDecompress(ctx context.Context, key string) (string, error) {
return "", errors.New("GetAndDecompress only supports snappy for now, got " + *compressor)
}

rawBytes, err := ioutil.ReadAll(body)
rawBytes, err := io.ReadAll(body)
if err != nil {
return "", err
}
Expand All @@ -224,7 +224,7 @@ func (a *S3) GetAndDecompress(ctx context.Context, key string) (string, error) {
if err != nil {
if errors.Is(err, snappy.ErrCorrupt) {
reader := snappy.NewReader(bytes.NewReader(rawBytes))
data, err := ioutil.ReadAll(reader)
data, err := io.ReadAll(reader)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -258,12 +258,10 @@ func (a *S3) Put(ctx context.Context, key string, reader io.ReadSeeker, meta map
if err != nil {
return err
}

putOptions := DefaultPutOptions()
for _, opt := range options {
opt(putOptions)
}

input := &s3.PutObjectInput{
Body: reader,
Bucket: aws.String(bucketName),
Expand All @@ -283,19 +281,20 @@ func (a *S3) Put(ctx context.Context, key string, reader io.ReadSeeker, meta map
if putOptions.expires != nil {
input.Expires = putOptions.expires
}

if a.compressor != nil {
l, err := GetReaderLength(input.Body)
wrapReader, l, err := WrapReader(input.Body)
if err != nil {
return err
}
if l > a.cfg.CompressLimit {
input.Body, err = a.compressor.Compress(input.Body)
input.Body, _, err = a.compressor.Compress(wrapReader)
if err != nil {
return err
}
encoding := a.compressor.ContentEncoding()
input.ContentEncoding = &encoding
} else {
input.Body = wrapReader
}
}

Expand Down
57 changes: 30 additions & 27 deletions aws_test.go

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptrace"
"strings"
Expand Down Expand Up @@ -111,6 +112,7 @@ func newStorage(name string, cfg *BucketConfig, logger *elog.Component) (Client,
}
if cfg.Debug {
config.LogLevel = aws.LogLevel(aws.LogDebugWithHTTPBody | aws.LogDebugWithSigning)
slog.Default().Enabled(context.Background(), slog.LevelDebug)
}

config.HTTPClient = &http.Client{
Expand Down
35 changes: 28 additions & 7 deletions compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,34 @@ func Register(comp Compressor) {
}

type Compressor interface {
Compress(reader io.ReadSeeker) (gzipReader io.ReadSeeker, err error)
// Compress(reader io.ReadSeeker) (gzipReader io.ReadSeeker, err error)
Compress(reader io.ReadSeeker) (gzipReader io.ReadSeeker, len int64, err error)
ContentEncoding() string
}

type GzipCompressor struct {
}

func (g *GzipCompressor) Compress(reader io.ReadSeeker) (gzipReader io.ReadSeeker, err error) {
return &gzipReadSeeker{
reader: reader,
}, nil
func (g *GzipCompressor) Compress(reader io.ReadSeeker) (gzipReader io.ReadSeeker, len int64, err error) {
var buffer bytes.Buffer
gzipWriter := gzip.NewWriter(&buffer)
_, err = io.Copy(gzipWriter, reader)
if err != nil {
return nil, 0, err
}
err = gzipWriter.Close()
if err != nil {
return nil, 0, err
}
return bytes.NewReader(buffer.Bytes()), int64(buffer.Len()), nil
}

// func (g *GzipCompressor) Compress(reader io.ReadSeeker) (gzipReader io.ReadSeeker, err error) {
// return &gzipReadSeeker{
// reader: reader,
// }, nil
// }

func (g *GzipCompressor) ContentEncoding() string {
return compressTypeGzip
}
Expand Down Expand Up @@ -74,13 +89,20 @@ func (crs *gzipReadSeeker) Seek(offset int64, whence int) (int64, error) {

var DefaultGzipCompressor = &GzipCompressor{}

func WrapReader(reader io.ReadSeeker) (io.ReadSeeker, int64, error) {
all, err := io.ReadAll(reader)
if err != nil {
return nil, 0, err
}
return bytes.NewReader(all), int64(len(all)), nil
}

func GetReaderLength(reader io.ReadSeeker) (int64, error) {
// 保存当前的读写位置
originalPos, err := reader.Seek(0, io.SeekCurrent)
if err != nil {
return 0, err
}

// 移动到文件末尾以获取字节长度
length, err := reader.Seek(0, io.SeekEnd)
if err != nil {
Expand All @@ -91,6 +113,5 @@ func GetReaderLength(reader io.ReadSeeker) (int64, error) {
if err != nil {
return 0, err
}

return length, nil
}
4 changes: 2 additions & 2 deletions compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestCompress_gzip(t *testing.T) {
if err != nil {
panic(err)
}
reader, err := DefaultGzipCompressor.Compress(source)
reader, _, err := DefaultGzipCompressor.Compress(source)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -62,7 +62,7 @@ func TestGetLength(t *testing.T) {
}
t.Logf("length %d", length)

seeker, _ := DefaultGzipCompressor.Compress(source)
seeker, _, _ := DefaultGzipCompressor.Compress(source)
targetPath := os.Getenv("target_path")
_, err = os.Stat(targetPath)
if err == nil {
Expand Down
2 changes: 1 addition & 1 deletion oss.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func (ossClient *OSS) Put(ctx context.Context, key string, reader io.ReadSeeker,
return err
}
if l > ossClient.cfg.CompressLimit {
reader, err = ossClient.compressor.Compress(reader)
reader, _, err = ossClient.compressor.Compress(reader)
if err != nil {
return err
}
Expand Down

0 comments on commit 58a5a68

Please sign in to comment.