Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion internal/model/obj.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ type FileStreamer interface {
// for a non-seekable Stream, if Read is called, this function won't work.
// caches the full Stream and writes it to writer (if provided, even if the stream is already cached).
CacheFullAndWriter(up *UpdateProgress, writer io.Writer) (File, error)
SetTmpFile(file File)
// if the Stream is not a File and is not cached, returns nil.
GetFile() File
}
Expand Down
126 changes: 73 additions & 53 deletions internal/stream/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ type FileStream struct {
ForceStreamUpload bool
Exist model.Obj //the file existed in the destination, we can reuse some info since we wil overwrite it
utils.Closers

tmpFile model.File //if present, tmpFile has full content, it will be deleted at last
peekBuff *buffer.Reader
size int64
oriReader io.Reader // the original reader, used for caching
Expand All @@ -39,12 +37,6 @@ func (f *FileStream) GetSize() int64 {
if f.size > 0 {
return f.size
}
if file, ok := f.tmpFile.(*os.File); ok {
info, err := file.Stat()
if err == nil {
return info.Size()
}
}
return f.Obj.GetSize()
}

Expand All @@ -71,14 +63,13 @@ func (f *FileStream) Close() error {
if errors.Is(err1, os.ErrClosed) {
err1 = nil
}
if file, ok := f.tmpFile.(*os.File); ok {
if file, ok := f.Reader.(*os.File); ok {
err2 = os.RemoveAll(file.Name())
if err2 != nil {
err2 = errs.NewErr(err2, "failed to remove tmpFile [%s]", file.Name())
} else {
f.tmpFile = nil
}
}
f.Reader = nil

return errors.Join(err1, err2)
}
Expand All @@ -94,50 +85,50 @@ func (f *FileStream) SetExist(obj model.Obj) {
// It's not thread-safe!
func (f *FileStream) CacheFullAndWriter(up *model.UpdateProgress, writer io.Writer) (model.File, error) {
if cache := f.GetFile(); cache != nil {
_, err := cache.Seek(0, io.SeekStart)
if err != nil {
return nil, err
}
if writer == nil {
return cache, nil
}
_, err := cache.Seek(0, io.SeekStart)
if err == nil {
reader := f.Reader
if up != nil {
cacheProgress := model.UpdateProgressWithRange(*up, 0, 50)
*up = model.UpdateProgressWithRange(*up, 50, 100)
reader = &ReaderUpdatingProgress{
Reader: &SimpleReaderWithSize{
Reader: reader,
Size: f.GetSize(),
},
UpdateProgress: cacheProgress,
}
}
_, err = utils.CopyWithBuffer(writer, reader)
if err == nil {
_, err = cache.Seek(0, io.SeekStart)
reader := f.Reader
if up != nil {
cacheProgress := model.UpdateProgressWithRange(*up, 0, 50)
*up = model.UpdateProgressWithRange(*up, 50, 100)
reader = &ReaderUpdatingProgress{
Reader: &SimpleReaderWithSize{
Reader: reader,
Size: f.GetSize(),
},
UpdateProgress: cacheProgress,
}
}
_, err = utils.CopyWithBuffer(writer, reader)
if err == nil {
_, err = cache.Seek(0, io.SeekStart)
}
if err != nil {
return nil, err
}
return cache, nil
}

reader := f.Reader
if up != nil {
cacheProgress := model.UpdateProgressWithRange(*up, 0, 50)
*up = model.UpdateProgressWithRange(*up, 50, 100)
reader = &ReaderUpdatingProgress{
Reader: &SimpleReaderWithSize{
Reader: reader,
Size: f.GetSize(),
},
UpdateProgress: cacheProgress,
if f.peekBuff != nil {
f.peekBuff.Seek(0, io.SeekStart)
if writer != nil {
_, err := utils.CopyWithBuffer(writer, f.peekBuff)
if err != nil {
return nil, err
}
f.peekBuff.Seek(0, io.SeekStart)
}
reader = f.oriReader
}
if writer != nil {
reader = io.TeeReader(reader, writer)
}

if f.GetSize() < 0 {
if f.peekBuff == nil {
f.peekBuff = &buffer.Reader{}
Expand Down Expand Up @@ -174,7 +165,6 @@ func (f *FileStream) CacheFullAndWriter(up *model.UpdateProgress, writer io.Writ
}
}
}

tmpF, err := utils.CreateTempFile(reader, 0)
if err != nil {
return nil, err
Expand All @@ -191,14 +181,33 @@ func (f *FileStream) CacheFullAndWriter(up *model.UpdateProgress, writer io.Writ
return peekF, nil
}

f.Reader = reader
if up != nil {
cacheProgress := model.UpdateProgressWithRange(*up, 0, 50)
*up = model.UpdateProgressWithRange(*up, 50, 100)
size := f.GetSize()
if f.peekBuff != nil {
peekSize := f.peekBuff.Size()
cacheProgress(float64(peekSize) / float64(size) * 100)
size -= peekSize
}
reader = &ReaderUpdatingProgress{
Reader: &SimpleReaderWithSize{
Reader: reader,
Size: size,
},
UpdateProgress: cacheProgress,
}
}

if f.peekBuff != nil {
f.oriReader = reader
} else {
f.Reader = reader
}
return f.cache(f.GetSize())
}

func (f *FileStream) GetFile() model.File {
if f.tmpFile != nil {
return f.tmpFile
}
if file, ok := f.Reader.(model.File); ok {
return file
}
Expand Down Expand Up @@ -234,12 +243,29 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {

func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
if maxCacheSize > int64(conf.MaxBufferLimit) {
tmpF, err := utils.CreateTempFile(f.Reader, f.GetSize())
size := f.GetSize()
reader := f.Reader
if f.peekBuff != nil {
size -= f.peekBuff.Size()
reader = f.oriReader
}
tmpF, err := utils.CreateTempFile(reader, size)
if err != nil {
return nil, err
}
if f.peekBuff != nil {
f.Add(utils.CloseFunc(func() error {
return errors.Join(tmpF.Close(), os.RemoveAll(tmpF.Name()))
}))
peekF, err := buffer.NewPeekFile(f.peekBuff, tmpF)
if err != nil {
return nil, err
}
f.Reader = peekF
return peekF, nil
}

f.Add(tmpF)
f.tmpFile = tmpF
f.Reader = tmpF
return tmpF, nil
}
Expand All @@ -248,7 +274,7 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
f.peekBuff = &buffer.Reader{}
f.oriReader = f.Reader
}
bufSize := maxCacheSize - int64(f.peekBuff.Size())
bufSize := maxCacheSize - f.peekBuff.Size()
var buf []byte
if conf.MmapThreshold > 0 && bufSize >= int64(conf.MmapThreshold) {
m, err := mmap.Alloc(int(bufSize))
Expand All @@ -267,7 +293,7 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", bufSize, n, err)
}
f.peekBuff.Append(buf)
if int64(f.peekBuff.Size()) >= f.GetSize() {
if f.peekBuff.Size() >= f.GetSize() {
f.Reader = f.peekBuff
f.oriReader = nil
} else {
Expand All @@ -276,12 +302,6 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
return f.peekBuff, nil
}

func (f *FileStream) SetTmpFile(file model.File) {
f.AddIfCloser(file)
f.tmpFile = file
f.Reader = file
}

var _ model.FileStreamer = (*SeekableStream)(nil)
var _ model.FileStreamer = (*FileStream)(nil)

Expand Down
49 changes: 35 additions & 14 deletions internal/stream/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@ import (
"io"
"testing"

"github.com/OpenListTeam/OpenList/v4/internal/conf"
"github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/pkg/http_range"
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
)

func TestFileStream_RangeRead(t *testing.T) {
conf.MaxBufferLimit = 16 * 1024 * 1024
type args struct {
httpRange http_range.Range
}
Expand Down Expand Up @@ -73,16 +72,38 @@ func TestFileStream_RangeRead(t *testing.T) {
}
})
}
t.Run("after", func(t *testing.T) {
if f.GetFile() == nil {
t.Error("not cached")
}
buf2 := make([]byte, len(buf))
if _, err := io.ReadFull(f, buf2); err != nil {
t.Errorf("FileStream.Read() error = %v", err)
}
if !bytes.Equal(buf, buf2) {
t.Errorf("FileStream.Read() = %s, want %s", buf2, buf)
}
})
if f.GetFile() == nil {
t.Error("not cached")
}
buf2 := make([]byte, len(buf))
if _, err := io.ReadFull(f, buf2); err != nil {
t.Errorf("FileStream.Read() error = %v", err)
}
if !bytes.Equal(buf, buf2) {
t.Errorf("FileStream.Read() = %s, want %s", buf2, buf)
}
}

func TestFileStream_With_PreHash(t *testing.T) {
buf := []byte("github.com/OpenListTeam/OpenList")
f := &FileStream{
Obj: &model.Object{
Size: int64(len(buf)),
},
Reader: io.NopCloser(bytes.NewReader(buf)),
}

const hashSize int64 = 20
reader, _ := f.RangeRead(http_range.Range{Start: 0, Length: hashSize})
preHash, _ := utils.HashReader(utils.SHA1, reader)
if preHash == "" {
t.Error("preHash is empty")
}
tmpF, fullHash, _ := CacheFullAndHash(f, nil, utils.SHA1)
fmt.Println(fullHash)
fileFullHash, _ := utils.HashFile(utils.SHA1, tmpF)
fmt.Println(fileFullHash)
if fullHash != fileFullHash {
t.Errorf("fullHash and fileFullHash should match: fullHash=%s fileFullHash=%s", fullHash, fileFullHash)
}
}