@@ -5,11 +5,13 @@ import (
55 "crypto/md5"
66 "encoding/hex"
77 "errors"
8+ "fmt"
89 "io"
910 "net/url"
1011 "os"
1112 stdpath "path"
1213 "strconv"
14+ "strings"
1315 "time"
1416
1517 "github.com/OpenListTeam/OpenList/v4/drivers/base"
@@ -31,6 +33,8 @@ type BaiduNetdisk struct {
3133 vipType int // 会员类型,0普通用户(4G/4M)、1普通会员(10G/16M)、2超级会员(20G/32M)
3234}
3335
36+ var ErrUploadIDExpired = errors .New ("uploadid expired" )
37+
3438func (d * BaiduNetdisk ) Config () driver.Config {
3539 return config
3640}
@@ -214,7 +218,6 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
214218
215219 // cal md5 for first 256k data
216220 const SliceSize int64 = 256 * utils .KB
217- // cal md5
218221 blockList := make ([]string , 0 , count )
219222 byteSize := sliceSize
220223 fileMd5H := md5 .New ()
@@ -244,7 +247,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
244247 }
245248 if tmpF != nil {
246249 if written != streamSize {
247- return nil , errs .NewErr (err , "CreateTempFile failed, incoming stream actual size= %d, expect = %d " , written , streamSize )
250+ return nil , errs .NewErr (err , "CreateTempFile failed, size mismatch: %d ! = %d " , written , streamSize )
248251 }
249252 _ , err = tmpF .Seek (0 , io .SeekStart )
250253 if err != nil {
@@ -258,31 +261,14 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
258261 mtime := stream .ModTime ().Unix ()
259262 ctime := stream .CreateTime ().Unix ()
260263
261- // step.1 预上传
262- // 尝试获取之前的进度
264+ // step.1 尝试读取已保存进度
263265 precreateResp , ok := base .GetUploadProgress [* PrecreateResp ](d , d .AccessToken , contentMd5 )
264266 if ! ok {
265- params := map [string ]string {
266- "method" : "precreate" ,
267- }
268- form := map [string ]string {
269- "path" : path ,
270- "size" : strconv .FormatInt (streamSize , 10 ),
271- "isdir" : "0" ,
272- "autoinit" : "1" ,
273- "rtype" : "3" ,
274- "block_list" : blockListStr ,
275- "content-md5" : contentMd5 ,
276- "slice-md5" : sliceMd5 ,
277- }
278- joinTime (form , ctime , mtime )
279-
280- log .Debugf ("[baidu_netdisk] precreate data: %s" , form )
281- _ , err = d .postForm ("/xpan/file" , params , form , & precreateResp )
267+ // 没有进度,走预上传
268+ precreateResp , err = d .precreate (ctx , path , streamSize , blockListStr , contentMd5 , sliceMd5 , ctime , mtime )
282269 if err != nil {
283270 return nil , err
284271 }
285- log .Debugf ("%+v" , precreateResp )
286272 if precreateResp .ReturnType == 2 {
287273 // rapid upload, since got md5 match from baidu server
288274 // 修复时间,具体原因见 Put 方法注释的 **注意**
@@ -291,45 +277,81 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
291277 return fileToObj (precreateResp .File ), nil
292278 }
293279 }
280+
294281 // step.2 上传分片
295- threadG , upCtx := errgroup .NewGroupWithContext (ctx , d .uploadThread ,
296- retry .Attempts (1 ),
297- retry .Delay (time .Second ),
298- retry .DelayType (retry .BackOffDelay ))
299-
300- for i , partseq := range precreateResp .BlockList {
301- if utils .IsCanceled (upCtx ) {
302- break
282+ uploadLoop:
283+ for attempt := 0 ; attempt < 2 ; attempt ++ {
284+ threadG , upCtx := errgroup .NewGroupWithContext (ctx , d .uploadThread ,
285+ retry .Attempts (1 ),
286+ retry .Delay (time .Second ),
287+ retry .DelayType (retry .BackOffDelay ))
288+
289+ cacheReaderAt , okReaderAt := cache .(io.ReaderAt )
290+ if ! okReaderAt {
291+ return nil , fmt .Errorf ("cache object must implement io.ReaderAt interface for upload operations" )
303292 }
304293
305- i , partseq , offset , byteSize := i , partseq , int64 (partseq )* sliceSize , sliceSize
306- if partseq + 1 == count {
307- byteSize = lastBlockSize
308- }
309- threadG .Go (func (ctx context.Context ) error {
310- params := map [string ]string {
311- "method" : "upload" ,
312- "access_token" : d .AccessToken ,
313- "type" : "tmpfile" ,
314- "path" : path ,
315- "uploadid" : precreateResp .Uploadid ,
316- "partseq" : strconv .Itoa (partseq ),
294+ totalParts := len (precreateResp .BlockList )
295+ doneParts := 0 // 新增计数器
296+
297+ for i , partseq := range precreateResp .BlockList {
298+ if utils .IsCanceled (upCtx ) || partseq < 0 {
299+ continue
317300 }
318- err := d . uploadSlice ( ctx , params , stream . GetName (),
319- driver . NewLimitedUploadStream ( ctx , io . NewSectionReader ( cache , offset , byteSize )))
320- if err != nil {
321- return err
301+ i , partseq := i , partseq
302+ offset , size := int64 ( partseq ) * sliceSize , sliceSize
303+ if partseq + 1 == count {
304+ size = lastBlockSize
322305 }
323- up (float64 (threadG .Success ()) * 100 / float64 (len (precreateResp .BlockList )))
324- precreateResp .BlockList [i ] = - 1
325- return nil
326- })
327- }
328- if err = threadG .Wait (); err != nil {
329- // 如果属于用户主动取消,则保存上传进度
306+ threadG .Go (func (ctx context.Context ) error {
307+ params := map [string ]string {
308+ "method" : "upload" ,
309+ "access_token" : d .AccessToken ,
310+ "type" : "tmpfile" ,
311+ "path" : path ,
312+ "uploadid" : precreateResp .Uploadid ,
313+ "partseq" : strconv .Itoa (partseq ),
314+ }
315+ section := io .NewSectionReader (cacheReaderAt , offset , size )
316+ err := d .uploadSlice (ctx , params , stream .GetName (), driver .NewLimitedUploadStream (ctx , section ))
317+ if err != nil {
318+ return err
319+ }
320+ precreateResp .BlockList [i ] = - 1
321+ doneParts ++ // 直接递增计数器
322+ if totalParts > 0 {
323+ up (float64 (doneParts ) * 100.0 / float64 (totalParts ))
324+ }
325+ return nil
326+ })
327+ }
328+
329+ err = threadG .Wait ()
330+ if err == nil {
331+ break uploadLoop
332+ }
333+
334+ // 保存进度(所有错误都会保存)
335+ precreateResp .BlockList = utils .SliceFilter (precreateResp .BlockList , func (s int ) bool { return s >= 0 })
336+ base .SaveUploadProgress (d , precreateResp , d .AccessToken , contentMd5 )
337+
330338 if errors .Is (err , context .Canceled ) {
331- precreateResp .BlockList = utils .SliceFilter (precreateResp .BlockList , func (s int ) bool { return s >= 0 })
339+ return nil , err
340+ }
341+ if errors .Is (err , ErrUploadIDExpired ) {
342+ log .Warn ("[baidu_netdisk] uploadid expired, will restart from scratch" )
343+ // 重新 precreate(所有分片都要重传)
344+ newPre , err2 := d .precreate (ctx , path , streamSize , blockListStr , "" , "" , ctime , mtime )
345+ if err2 != nil {
346+ return nil , err2
347+ }
348+ if newPre .ReturnType == 2 {
349+ return fileToObj (newPre .File ), nil
350+ }
351+ precreateResp = newPre
352+ // 覆盖掉旧的进度
332353 base .SaveUploadProgress (d , precreateResp , d .AccessToken , contentMd5 )
354+ continue uploadLoop
333355 }
334356 return nil , err
335357 }
@@ -343,9 +365,46 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
343365 // 修复时间,具体原因见 Put 方法注释的 **注意**
344366 newFile .Ctime = ctime
345367 newFile .Mtime = mtime
368+ // 上传成功清理进度
369+ base .SaveUploadProgress (d , nil , d .AccessToken , contentMd5 )
346370 return fileToObj (newFile ), nil
347371}
348372
373+ // precreate 执行预上传操作,支持首次上传和 uploadid 过期重试
374+ func (d * BaiduNetdisk ) precreate (ctx context.Context , path string , streamSize int64 , blockListStr , contentMd5 , sliceMd5 string , ctime , mtime int64 ) (* PrecreateResp , error ) {
375+ params := map [string ]string {"method" : "precreate" }
376+ form := map [string ]string {
377+ "path" : path ,
378+ "size" : strconv .FormatInt (streamSize , 10 ),
379+ "isdir" : "0" ,
380+ "autoinit" : "1" ,
381+ "rtype" : "3" ,
382+ "block_list" : blockListStr ,
383+ }
384+
385+ // 只有在首次上传时才包含 content-md5 和 slice-md5
386+ if contentMd5 != "" && sliceMd5 != "" {
387+ form ["content-md5" ] = contentMd5
388+ form ["slice-md5" ] = sliceMd5
389+ }
390+
391+ joinTime (form , ctime , mtime )
392+
393+ var precreateResp PrecreateResp
394+ _ , err := d .postForm ("/xpan/file" , params , form , & precreateResp )
395+ if err != nil {
396+ return nil , err
397+ }
398+
399+ // 修复时间,具体原因见 Put 方法注释的 **注意**
400+ if precreateResp .ReturnType == 2 {
401+ precreateResp .File .Ctime = ctime
402+ precreateResp .File .Mtime = mtime
403+ }
404+
405+ return & precreateResp , nil
406+ }
407+
349408func (d * BaiduNetdisk ) uploadSlice (ctx context.Context , params map [string ]string , fileName string , file io.Reader ) error {
350409 res , err := base .RestyClient .R ().
351410 SetContext (ctx ).
@@ -358,8 +417,16 @@ func (d *BaiduNetdisk) uploadSlice(ctx context.Context, params map[string]string
358417 log .Debugln (res .RawResponse .Status + res .String ())
359418 errCode := utils .Json .Get (res .Body (), "error_code" ).ToInt ()
360419 errNo := utils .Json .Get (res .Body (), "errno" ).ToInt ()
420+ respStr := res .String ()
421+ lower := strings .ToLower (respStr )
422+ // 合并 uploadid 过期检测逻辑
423+ if strings .Contains (lower , "uploadid" ) &&
424+ (strings .Contains (lower , "invalid" ) || strings .Contains (lower , "expired" ) || strings .Contains (lower , "not found" )) {
425+ return ErrUploadIDExpired
426+ }
427+
361428 if errCode != 0 || errNo != 0 {
362- return errs .NewErr (errs .StreamIncomplete , "error in uploading to baidu, will retry. response=%s" , res .String ())
429+ return errs .NewErr (errs .StreamIncomplete , "error uploading to baidu, response=%s" , res .String ())
363430 }
364431 return nil
365432}
0 commit comments