Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: gzip read seeker causes oss put fail #1

Merged
merged 3 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ func (a *S3) GetAndDecompress(ctx context.Context, key string) (string, error) {
return "", errors.New("GetAndDecompress only supports snappy for now, got " + *compressor)
}

rawBytes, err := ioutil.ReadAll(body)
rawBytes, err := io.ReadAll(body)
if err != nil {
return "", err
}
Expand All @@ -224,7 +224,7 @@ func (a *S3) GetAndDecompress(ctx context.Context, key string) (string, error) {
if err != nil {
if errors.Is(err, snappy.ErrCorrupt) {
reader := snappy.NewReader(bytes.NewReader(rawBytes))
data, err := ioutil.ReadAll(reader)
data, err := io.ReadAll(reader)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -258,12 +258,10 @@ func (a *S3) Put(ctx context.Context, key string, reader io.ReadSeeker, meta map
if err != nil {
return err
}

putOptions := DefaultPutOptions()
for _, opt := range options {
opt(putOptions)
}

input := &s3.PutObjectInput{
Body: reader,
Bucket: aws.String(bucketName),
Expand All @@ -283,19 +281,20 @@ func (a *S3) Put(ctx context.Context, key string, reader io.ReadSeeker, meta map
if putOptions.expires != nil {
input.Expires = putOptions.expires
}

if a.compressor != nil {
l, err := GetReaderLength(input.Body)
wrapReader, l, err := WrapReader(input.Body)
if err != nil {
return err
}
if l > a.cfg.CompressLimit {
input.Body, err = a.compressor.Compress(input.Body)
input.Body, _, err = a.compressor.Compress(wrapReader)
if err != nil {
return err
}
encoding := a.compressor.ContentEncoding()
input.ContentEncoding = &encoding
} else {
input.Body = wrapReader
}
}

Expand Down
57 changes: 30 additions & 27 deletions aws_test.go

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptrace"
"strings"
Expand Down Expand Up @@ -111,6 +112,7 @@ func newStorage(name string, cfg *BucketConfig, logger *elog.Component) (Client,
}
if cfg.Debug {
config.LogLevel = aws.LogLevel(aws.LogDebugWithHTTPBody | aws.LogDebugWithSigning)
slog.Default().Enabled(context.Background(), slog.LevelDebug)
}

config.HTTPClient = &http.Client{
Expand Down
35 changes: 28 additions & 7 deletions compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,34 @@ func Register(comp Compressor) {
}

type Compressor interface {
Compress(reader io.ReadSeeker) (gzipReader io.ReadSeeker, err error)
// Compress(reader io.ReadSeeker) (gzipReader io.ReadSeeker, err error)
Compress(reader io.ReadSeeker) (gzipReader io.ReadSeeker, len int64, err error)
ContentEncoding() string
}

type GzipCompressor struct {
}

func (g *GzipCompressor) Compress(reader io.ReadSeeker) (gzipReader io.ReadSeeker, err error) {
return &gzipReadSeeker{
reader: reader,
}, nil
func (g *GzipCompressor) Compress(reader io.ReadSeeker) (gzipReader io.ReadSeeker, len int64, err error) {
var buffer bytes.Buffer
gzipWriter := gzip.NewWriter(&buffer)
_, err = io.Copy(gzipWriter, reader)
if err != nil {
return nil, 0, err
}
err = gzipWriter.Close()
if err != nil {
return nil, 0, err
}
return bytes.NewReader(buffer.Bytes()), int64(buffer.Len()), nil
}

// func (g *GzipCompressor) Compress(reader io.ReadSeeker) (gzipReader io.ReadSeeker, err error) {
// return &gzipReadSeeker{
// reader: reader,
// }, nil
// }

func (g *GzipCompressor) ContentEncoding() string {
return compressTypeGzip
}
Expand Down Expand Up @@ -74,13 +89,20 @@ func (crs *gzipReadSeeker) Seek(offset int64, whence int) (int64, error) {

var DefaultGzipCompressor = &GzipCompressor{}

func WrapReader(reader io.ReadSeeker) (io.ReadSeeker, int64, error) {
all, err := io.ReadAll(reader)
if err != nil {
return nil, 0, err
}
return bytes.NewReader(all), int64(len(all)), nil
}

func GetReaderLength(reader io.ReadSeeker) (int64, error) {
// 保存当前的读写位置
originalPos, err := reader.Seek(0, io.SeekCurrent)
if err != nil {
return 0, err
}

// 移动到文件末尾以获取字节长度
length, err := reader.Seek(0, io.SeekEnd)
if err != nil {
Expand All @@ -91,6 +113,5 @@ func GetReaderLength(reader io.ReadSeeker) (int64, error) {
if err != nil {
return 0, err
}

return length, nil
}
4 changes: 2 additions & 2 deletions compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestCompress_gzip(t *testing.T) {
if err != nil {
panic(err)
}
reader, err := DefaultGzipCompressor.Compress(source)
reader, _, err := DefaultGzipCompressor.Compress(source)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -62,7 +62,7 @@ func TestGetLength(t *testing.T) {
}
t.Logf("length %d", length)

seeker, _ := DefaultGzipCompressor.Compress(source)
seeker, _, _ := DefaultGzipCompressor.Compress(source)
targetPath := os.Getenv("target_path")
_, err = os.Stat(targetPath)
if err == nil {
Expand Down
2 changes: 1 addition & 1 deletion oss.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func (ossClient *OSS) Put(ctx context.Context, key string, reader io.ReadSeeker,
return err
}
if l > ossClient.cfg.CompressLimit {
reader, err = ossClient.compressor.Compress(reader)
reader, _, err = ossClient.compressor.Compress(reader)
if err != nil {
return err
}
Expand Down