Skip to content

Commit 2844797

Browse files
authored
fix(baidu_netdisk): support resuming uploads when an error occurs (#1279)
support resuming uploads when an error occurs
1 parent 9f4e439 commit 2844797

File tree

1 file changed

+122
-55
lines changed

1 file changed

+122
-55
lines changed

drivers/baidu_netdisk/driver.go

Lines changed: 122 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ import (
55
"crypto/md5"
66
"encoding/hex"
77
"errors"
8+
"fmt"
89
"io"
910
"net/url"
1011
"os"
1112
stdpath "path"
1213
"strconv"
14+
"strings"
1315
"time"
1416

1517
"github.com/OpenListTeam/OpenList/v4/drivers/base"
@@ -31,6 +33,8 @@ type BaiduNetdisk struct {
3133
vipType int // 会员类型,0普通用户(4G/4M)、1普通会员(10G/16M)、2超级会员(20G/32M)
3234
}
3335

36+
var ErrUploadIDExpired = errors.New("uploadid expired")
37+
3438
func (d *BaiduNetdisk) Config() driver.Config {
3539
return config
3640
}
@@ -214,7 +218,6 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
214218

215219
// cal md5 for first 256k data
216220
const SliceSize int64 = 256 * utils.KB
217-
// cal md5
218221
blockList := make([]string, 0, count)
219222
byteSize := sliceSize
220223
fileMd5H := md5.New()
@@ -244,7 +247,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
244247
}
245248
if tmpF != nil {
246249
if written != streamSize {
247-
return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %d, expect = %d ", written, streamSize)
250+
return nil, errs.NewErr(err, "CreateTempFile failed, size mismatch: %d != %d ", written, streamSize)
248251
}
249252
_, err = tmpF.Seek(0, io.SeekStart)
250253
if err != nil {
@@ -258,31 +261,14 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
258261
mtime := stream.ModTime().Unix()
259262
ctime := stream.CreateTime().Unix()
260263

261-
// step.1 预上传
262-
// 尝试获取之前的进度
264+
// step.1 尝试读取已保存进度
263265
precreateResp, ok := base.GetUploadProgress[*PrecreateResp](d, d.AccessToken, contentMd5)
264266
if !ok {
265-
params := map[string]string{
266-
"method": "precreate",
267-
}
268-
form := map[string]string{
269-
"path": path,
270-
"size": strconv.FormatInt(streamSize, 10),
271-
"isdir": "0",
272-
"autoinit": "1",
273-
"rtype": "3",
274-
"block_list": blockListStr,
275-
"content-md5": contentMd5,
276-
"slice-md5": sliceMd5,
277-
}
278-
joinTime(form, ctime, mtime)
279-
280-
log.Debugf("[baidu_netdisk] precreate data: %s", form)
281-
_, err = d.postForm("/xpan/file", params, form, &precreateResp)
267+
// 没有进度,走预上传
268+
precreateResp, err = d.precreate(ctx, path, streamSize, blockListStr, contentMd5, sliceMd5, ctime, mtime)
282269
if err != nil {
283270
return nil, err
284271
}
285-
log.Debugf("%+v", precreateResp)
286272
if precreateResp.ReturnType == 2 {
287273
// rapid upload, since got md5 match from baidu server
288274
// 修复时间,具体原因见 Put 方法注释的 **注意**
@@ -291,45 +277,81 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
291277
return fileToObj(precreateResp.File), nil
292278
}
293279
}
280+
294281
// step.2 上传分片
295-
threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread,
296-
retry.Attempts(1),
297-
retry.Delay(time.Second),
298-
retry.DelayType(retry.BackOffDelay))
299-
300-
for i, partseq := range precreateResp.BlockList {
301-
if utils.IsCanceled(upCtx) {
302-
break
282+
uploadLoop:
283+
for attempt := 0; attempt < 2; attempt++ {
284+
threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread,
285+
retry.Attempts(1),
286+
retry.Delay(time.Second),
287+
retry.DelayType(retry.BackOffDelay))
288+
289+
cacheReaderAt, okReaderAt := cache.(io.ReaderAt)
290+
if !okReaderAt {
291+
return nil, fmt.Errorf("cache object must implement io.ReaderAt interface for upload operations")
303292
}
304293

305-
i, partseq, offset, byteSize := i, partseq, int64(partseq)*sliceSize, sliceSize
306-
if partseq+1 == count {
307-
byteSize = lastBlockSize
308-
}
309-
threadG.Go(func(ctx context.Context) error {
310-
params := map[string]string{
311-
"method": "upload",
312-
"access_token": d.AccessToken,
313-
"type": "tmpfile",
314-
"path": path,
315-
"uploadid": precreateResp.Uploadid,
316-
"partseq": strconv.Itoa(partseq),
294+
totalParts := len(precreateResp.BlockList)
295+
doneParts := 0 // 新增计数器
296+
297+
for i, partseq := range precreateResp.BlockList {
298+
if utils.IsCanceled(upCtx) || partseq < 0 {
299+
continue
317300
}
318-
err := d.uploadSlice(ctx, params, stream.GetName(),
319-
driver.NewLimitedUploadStream(ctx, io.NewSectionReader(cache, offset, byteSize)))
320-
if err != nil {
321-
return err
301+
i, partseq := i, partseq
302+
offset, size := int64(partseq)*sliceSize, sliceSize
303+
if partseq+1 == count {
304+
size = lastBlockSize
322305
}
323-
up(float64(threadG.Success()) * 100 / float64(len(precreateResp.BlockList)))
324-
precreateResp.BlockList[i] = -1
325-
return nil
326-
})
327-
}
328-
if err = threadG.Wait(); err != nil {
329-
// 如果属于用户主动取消,则保存上传进度
306+
threadG.Go(func(ctx context.Context) error {
307+
params := map[string]string{
308+
"method": "upload",
309+
"access_token": d.AccessToken,
310+
"type": "tmpfile",
311+
"path": path,
312+
"uploadid": precreateResp.Uploadid,
313+
"partseq": strconv.Itoa(partseq),
314+
}
315+
section := io.NewSectionReader(cacheReaderAt, offset, size)
316+
err := d.uploadSlice(ctx, params, stream.GetName(), driver.NewLimitedUploadStream(ctx, section))
317+
if err != nil {
318+
return err
319+
}
320+
precreateResp.BlockList[i] = -1
321+
doneParts++ // 直接递增计数器
322+
if totalParts > 0 {
323+
up(float64(doneParts) * 100.0 / float64(totalParts))
324+
}
325+
return nil
326+
})
327+
}
328+
329+
err = threadG.Wait()
330+
if err == nil {
331+
break uploadLoop
332+
}
333+
334+
// 保存进度(所有错误都会保存)
335+
precreateResp.BlockList = utils.SliceFilter(precreateResp.BlockList, func(s int) bool { return s >= 0 })
336+
base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5)
337+
330338
if errors.Is(err, context.Canceled) {
331-
precreateResp.BlockList = utils.SliceFilter(precreateResp.BlockList, func(s int) bool { return s >= 0 })
339+
return nil, err
340+
}
341+
if errors.Is(err, ErrUploadIDExpired) {
342+
log.Warn("[baidu_netdisk] uploadid expired, will restart from scratch")
343+
// 重新 precreate(所有分片都要重传)
344+
newPre, err2 := d.precreate(ctx, path, streamSize, blockListStr, "", "", ctime, mtime)
345+
if err2 != nil {
346+
return nil, err2
347+
}
348+
if newPre.ReturnType == 2 {
349+
return fileToObj(newPre.File), nil
350+
}
351+
precreateResp = newPre
352+
// 覆盖掉旧的进度
332353
base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5)
354+
continue uploadLoop
333355
}
334356
return nil, err
335357
}
@@ -343,9 +365,46 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
343365
// 修复时间,具体原因见 Put 方法注释的 **注意**
344366
newFile.Ctime = ctime
345367
newFile.Mtime = mtime
368+
// 上传成功清理进度
369+
base.SaveUploadProgress(d, nil, d.AccessToken, contentMd5)
346370
return fileToObj(newFile), nil
347371
}
348372

373+
// precreate 执行预上传操作,支持首次上传和 uploadid 过期重试
374+
func (d *BaiduNetdisk) precreate(ctx context.Context, path string, streamSize int64, blockListStr, contentMd5, sliceMd5 string, ctime, mtime int64) (*PrecreateResp, error) {
375+
params := map[string]string{"method": "precreate"}
376+
form := map[string]string{
377+
"path": path,
378+
"size": strconv.FormatInt(streamSize, 10),
379+
"isdir": "0",
380+
"autoinit": "1",
381+
"rtype": "3",
382+
"block_list": blockListStr,
383+
}
384+
385+
// 只有在首次上传时才包含 content-md5 和 slice-md5
386+
if contentMd5 != "" && sliceMd5 != "" {
387+
form["content-md5"] = contentMd5
388+
form["slice-md5"] = sliceMd5
389+
}
390+
391+
joinTime(form, ctime, mtime)
392+
393+
var precreateResp PrecreateResp
394+
_, err := d.postForm("/xpan/file", params, form, &precreateResp)
395+
if err != nil {
396+
return nil, err
397+
}
398+
399+
// 修复时间,具体原因见 Put 方法注释的 **注意**
400+
if precreateResp.ReturnType == 2 {
401+
precreateResp.File.Ctime = ctime
402+
precreateResp.File.Mtime = mtime
403+
}
404+
405+
return &precreateResp, nil
406+
}
407+
349408
func (d *BaiduNetdisk) uploadSlice(ctx context.Context, params map[string]string, fileName string, file io.Reader) error {
350409
res, err := base.RestyClient.R().
351410
SetContext(ctx).
@@ -358,8 +417,16 @@ func (d *BaiduNetdisk) uploadSlice(ctx context.Context, params map[string]string
358417
log.Debugln(res.RawResponse.Status + res.String())
359418
errCode := utils.Json.Get(res.Body(), "error_code").ToInt()
360419
errNo := utils.Json.Get(res.Body(), "errno").ToInt()
420+
respStr := res.String()
421+
lower := strings.ToLower(respStr)
422+
// 合并 uploadid 过期检测逻辑
423+
if strings.Contains(lower, "uploadid") &&
424+
(strings.Contains(lower, "invalid") || strings.Contains(lower, "expired") || strings.Contains(lower, "not found")) {
425+
return ErrUploadIDExpired
426+
}
427+
361428
if errCode != 0 || errNo != 0 {
362-
return errs.NewErr(errs.StreamIncomplete, "error in uploading to baidu, will retry. response=%s", res.String())
429+
return errs.NewErr(errs.StreamIncomplete, "error uploading to baidu, response=%s", res.String())
363430
}
364431
return nil
365432
}

0 commit comments

Comments
 (0)