Skip to content

Commit 9e852ba

Browse files
hcrgmjyxjjjCopilotjenfonro
authored
fix(baidu_netdisk): improve upload experience (#1562)
* fix(baidu_netdisk): improve upload experience * fix(typo): URL should be uppercase, apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: ShenLin <773933146@qq.com> * fix(typo): URL should be uppercase, apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: ShenLin <773933146@qq.com> * fix(baidu_netdisk): use "UploadAPI" as a fallback when using dynamic upload api * fix(baidu_netdisk): all uploads share the same upload url cache * fix(drivers/baidu_netdisk): defer uploadUrlMu unlock * update driver.go to main --------- Signed-off-by: ShenLin <773933146@qq.com> Signed-off-by: jenfonro <799170122@qq.com> Co-authored-by: ShenLin <773933146@qq.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: jenfonro <799170122@qq.com>
1 parent 174eae8 commit 9e852ba

File tree

4 files changed

+162
-20
lines changed

4 files changed

+162
-20
lines changed

drivers/baidu_netdisk/driver.go

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
stdpath "path"
1313
"strconv"
1414
"strings"
15+
"sync"
1516
"time"
1617

1718
"github.com/OpenListTeam/OpenList/v4/drivers/base"
@@ -20,8 +21,10 @@ import (
2021
"github.com/OpenListTeam/OpenList/v4/internal/errs"
2122
"github.com/OpenListTeam/OpenList/v4/internal/model"
2223
"github.com/OpenListTeam/OpenList/v4/pkg/errgroup"
24+
"github.com/OpenListTeam/OpenList/v4/pkg/singleflight"
2325
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
2426
"github.com/avast/retry-go"
27+
"github.com/go-resty/resty/v2"
2528
log "github.com/sirupsen/logrus"
2629
)
2730

@@ -31,6 +34,12 @@ type BaiduNetdisk struct {
3134

3235
uploadThread int
3336
vipType int // 会员类型,0普通用户(4G/4M)、1普通会员(10G/16M)、2超级会员(20G/32M)
37+
38+
upClient *resty.Client // 上传文件使用的http客户端
39+
uploadUrlG singleflight.Group[string]
40+
uploadUrlMu sync.RWMutex
41+
uploadUrl string // 上传域名
42+
uploadUrlUpdateTime time.Time // 上传域名上次更新时间
3443
}
3544

3645
var ErrUploadIDExpired = errors.New("uploadid expired")
@@ -44,19 +53,26 @@ func (d *BaiduNetdisk) GetAddition() driver.Additional {
4453
}
4554

4655
func (d *BaiduNetdisk) Init(ctx context.Context) error {
56+
d.upClient = base.NewRestyClient().
57+
SetTimeout(UPLOAD_TIMEOUT).
58+
SetRetryCount(UPLOAD_RETRY_COUNT).
59+
SetRetryWaitTime(UPLOAD_RETRY_WAIT_TIME).
60+
SetRetryMaxWaitTime(UPLOAD_RETRY_MAX_WAIT_TIME)
4761
d.uploadThread, _ = strconv.Atoi(d.UploadThread)
48-
if d.uploadThread < 1 || d.uploadThread > 32 {
49-
d.uploadThread, d.UploadThread = 3, "3"
62+
if d.uploadThread < 1 {
63+
d.uploadThread, d.UploadThread = 1, "1"
64+
} else if d.uploadThread > 32 {
65+
d.uploadThread, d.UploadThread = 32, "32"
5066
}
5167

5268
if _, err := url.Parse(d.UploadAPI); d.UploadAPI == "" || err != nil {
53-
d.UploadAPI = "https://d.pcs.baidu.com"
69+
d.UploadAPI = UPLOAD_FALLBACK_API
5470
}
5571

5672
res, err := d.get("/xpan/nas", map[string]string{
5773
"method": "uinfo",
5874
}, nil)
59-
log.Debugf("[baidu] get uinfo: %s", string(res))
75+
log.Debugf("[baidu_netdisk] get uinfo: %s", string(res))
6076
if err != nil {
6177
return err
6278
}
@@ -183,6 +199,11 @@ func (d *BaiduNetdisk) PutRapid(ctx context.Context, dstDir model.Obj, stream mo
183199
// **注意**: 截至 2024/04/20 百度云盘 api 接口返回的时间永远是当前时间,而不是文件时间。
184200
// 而实际上云盘存储的时间是文件时间,所以此处需要覆盖时间,保证缓存与云盘的数据一致
185201
func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
202+
// 百度网盘不允许上传空文件
203+
if stream.GetSize() < 1 {
204+
return nil, ErrBaiduEmptyFilesNotAllowed
205+
}
206+
186207
// rapid upload
187208
if newObj, err := d.PutRapid(ctx, dstDir, stream); err == nil {
188209
return newObj, nil
@@ -281,6 +302,9 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
281302
// step.2 上传分片
282303
uploadLoop:
283304
for attempt := 0; attempt < 2; attempt++ {
305+
// 获取上传域名
306+
uploadUrl := d.getUploadUrl(path, precreateResp.Uploadid)
307+
// 并发上传
284308
threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread,
285309
retry.Attempts(1),
286310
retry.Delay(time.Second),
@@ -292,7 +316,6 @@ uploadLoop:
292316
}
293317

294318
totalParts := len(precreateResp.BlockList)
295-
doneParts := 0 // 新增计数器
296319

297320
for i, partseq := range precreateResp.BlockList {
298321
if utils.IsCanceled(upCtx) || partseq < 0 {
@@ -313,15 +336,15 @@ uploadLoop:
313336
"partseq": strconv.Itoa(partseq),
314337
}
315338
section := io.NewSectionReader(cacheReaderAt, offset, size)
316-
err := d.uploadSlice(ctx, params, stream.GetName(), driver.NewLimitedUploadStream(ctx, section))
339+
err := d.uploadSlice(ctx, uploadUrl, params, stream.GetName(), driver.NewLimitedUploadStream(ctx, section))
317340
if err != nil {
318341
return err
319342
}
320343
precreateResp.BlockList[i] = -1
321-
doneParts++ // 直接递增计数器
322-
if totalParts > 0 {
323-
up(float64(doneParts) * 100.0 / float64(totalParts))
324-
}
344+
// 当前goroutine还没退出,+1才是真正成功的数量
345+
success := threadG.Success() + 1
346+
progress := float64(success) * 100 / float64(totalParts)
347+
up(progress)
325348
return nil
326349
})
327350
}
@@ -405,12 +428,12 @@ func (d *BaiduNetdisk) precreate(ctx context.Context, path string, streamSize in
405428
return &precreateResp, nil
406429
}
407430

408-
func (d *BaiduNetdisk) uploadSlice(ctx context.Context, params map[string]string, fileName string, file io.Reader) error {
409-
res, err := base.RestyClient.R().
431+
func (d *BaiduNetdisk) uploadSlice(ctx context.Context, uploadUrl string, params map[string]string, fileName string, file io.Reader) error {
432+
res, err := d.upClient.R().
410433
SetContext(ctx).
411434
SetQueryParams(params).
412435
SetFileReader("file", fileName, file).
413-
Post(d.UploadAPI + "/rest/2.0/pcs/superfile2")
436+
Post(uploadUrl + "/rest/2.0/pcs/superfile2")
414437
if err != nil {
415438
return err
416439
}

drivers/baidu_netdisk/meta.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package baidu_netdisk
33
import (
44
"github.com/OpenListTeam/OpenList/v4/internal/driver"
55
"github.com/OpenListTeam/OpenList/v4/internal/op"
6+
"time"
67
)
78

89
type Addition struct {
@@ -19,11 +20,21 @@ type Addition struct {
1920
RefreshToken string `json:"refresh_token" required:"true"`
2021
UploadThread string `json:"upload_thread" default:"3" help:"1<=thread<=32"`
2122
UploadAPI string `json:"upload_api" default:"https://d.pcs.baidu.com"`
23+
UseDynamicUploadAPI bool `json:"use_dynamic_upload_api" default:"true" help:"dynamically get upload api domain, when enabled, the 'Upload API' setting will be used as a fallback if failed to get"`
2224
CustomUploadPartSize int64 `json:"custom_upload_part_size" type:"number" default:"0" help:"0 for auto"`
2325
LowBandwithUploadMode bool `json:"low_bandwith_upload_mode" default:"false"`
2426
OnlyListVideoFile bool `json:"only_list_video_file" default:"false"`
2527
}
2628

29+
const (
30+
UPLOAD_FALLBACK_API = "https://d.pcs.baidu.com" // 备用上传地址
31+
UPLOAD_URL_EXPIRE_TIME = time.Minute * 60 // 上传地址有效期(分钟)
32+
UPLOAD_TIMEOUT = time.Minute * 30 // 上传请求超时时间
33+
UPLOAD_RETRY_COUNT = 3
34+
UPLOAD_RETRY_WAIT_TIME = time.Second * 1
35+
UPLOAD_RETRY_MAX_WAIT_TIME = time.Second * 5
36+
)
37+
2738
var config = driver.Config{
2839
Name: "BaiduNetdisk",
2940
DefaultRoot: "/",

drivers/baidu_netdisk/types.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package baidu_netdisk
22

33
import (
4+
"errors"
45
"path"
56
"strconv"
67
"time"
@@ -9,6 +10,10 @@ import (
910
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
1011
)
1112

13+
var (
14+
ErrBaiduEmptyFilesNotAllowed = errors.New("empty files are not allowed by baidu netdisk")
15+
)
16+
1217
type TokenErrResp struct {
1318
ErrorDescription string `json:"error_description"`
1419
Error string `json:"error"`
@@ -190,6 +195,30 @@ type PrecreateResp struct {
190195
File File `json:"info"`
191196
}
192197

198+
type UploadServerResp struct {
199+
BakServer []any `json:"bak_server"`
200+
BakServers []struct {
201+
Server string `json:"server"`
202+
} `json:"bak_servers"`
203+
ClientIP string `json:"client_ip"`
204+
ErrorCode int `json:"error_code"`
205+
ErrorMsg string `json:"error_msg"`
206+
Expire int `json:"expire"`
207+
Host string `json:"host"`
208+
Newno string `json:"newno"`
209+
QuicServer []any `json:"quic_server"`
210+
QuicServers []struct {
211+
Server string `json:"server"`
212+
} `json:"quic_servers"`
213+
RequestID int64 `json:"request_id"`
214+
Server []any `json:"server"`
215+
ServerTime int `json:"server_time"`
216+
Servers []struct {
217+
Server string `json:"server"`
218+
} `json:"servers"`
219+
Sl int `json:"sl"`
220+
}
221+
193222
type QuotaResp struct {
194223
Errno int `json:"errno"`
195224
RequestId int64 `json:"request_id"`

drivers/baidu_netdisk/util.go

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ func (d *BaiduNetdisk) request(furl string, method string, callback base.ReqCall
115115
errno := utils.Json.Get(res.Body(), "errno").ToInt()
116116
if errno != 0 {
117117
if utils.SliceContains([]int{111, -6}, errno) {
118-
log.Info("refreshing baidu_netdisk token.")
118+
log.Info("[baidu_netdisk] refreshing baidu_netdisk token.")
119119
err2 := d.refreshToken()
120120
if err2 != nil {
121121
return retry.Unrecoverable(err2)
@@ -326,28 +326,28 @@ func (d *BaiduNetdisk) getSliceSize(filesize int64) int64 {
326326
// 非会员固定为 4MB
327327
if d.vipType == 0 {
328328
if d.CustomUploadPartSize != 0 {
329-
log.Warnf("CustomUploadPartSize is not supported for non-vip user, use DefaultSliceSize")
329+
log.Warnf("[baidu_netdisk] CustomUploadPartSize is not supported for non-vip user, use DefaultSliceSize")
330330
}
331331
if filesize > MaxSliceNum*DefaultSliceSize {
332-
log.Warnf("File size(%d) is too large, may cause upload failure", filesize)
332+
log.Warnf("[baidu_netdisk] File size(%d) is too large, may cause upload failure", filesize)
333333
}
334334

335335
return DefaultSliceSize
336336
}
337337

338338
if d.CustomUploadPartSize != 0 {
339339
if d.CustomUploadPartSize < DefaultSliceSize {
340-
log.Warnf("CustomUploadPartSize(%d) is less than DefaultSliceSize(%d), use DefaultSliceSize", d.CustomUploadPartSize, DefaultSliceSize)
340+
log.Warnf("[baidu_netdisk] CustomUploadPartSize(%d) is less than DefaultSliceSize(%d), use DefaultSliceSize", d.CustomUploadPartSize, DefaultSliceSize)
341341
return DefaultSliceSize
342342
}
343343

344344
if d.vipType == 1 && d.CustomUploadPartSize > VipSliceSize {
345-
log.Warnf("CustomUploadPartSize(%d) is greater than VipSliceSize(%d), use VipSliceSize", d.CustomUploadPartSize, VipSliceSize)
345+
log.Warnf("[baidu_netdisk] CustomUploadPartSize(%d) is greater than VipSliceSize(%d), use VipSliceSize", d.CustomUploadPartSize, VipSliceSize)
346346
return VipSliceSize
347347
}
348348

349349
if d.vipType == 2 && d.CustomUploadPartSize > SVipSliceSize {
350-
log.Warnf("CustomUploadPartSize(%d) is greater than SVipSliceSize(%d), use SVipSliceSize", d.CustomUploadPartSize, SVipSliceSize)
350+
log.Warnf("[baidu_netdisk] CustomUploadPartSize(%d) is greater than SVipSliceSize(%d), use SVipSliceSize", d.CustomUploadPartSize, SVipSliceSize)
351351
return SVipSliceSize
352352
}
353353

@@ -377,7 +377,7 @@ func (d *BaiduNetdisk) getSliceSize(filesize int64) int64 {
377377
}
378378

379379
if filesize > MaxSliceNum*maxSliceSize {
380-
log.Warnf("File size(%d) is too large, may cause upload failure", filesize)
380+
log.Warnf("[baidu_netdisk] File size(%d) is too large, may cause upload failure", filesize)
381381
}
382382

383383
return maxSliceSize
@@ -394,6 +394,85 @@ func (d *BaiduNetdisk) quota(ctx context.Context) (model.DiskUsage, error) {
394394
return driver.DiskUsageFromUsedAndTotal(resp.Used, resp.Total), nil
395395
}
396396

397+
// getUploadUrl 从开放平台获取上传域名/地址,并发请求会被合并,结果会被缓存1h。
398+
// 如果获取失败,则返回 Upload API设置项。
399+
func (d *BaiduNetdisk) getUploadUrl(path, uploadId string) string {
400+
if !d.UseDynamicUploadAPI {
401+
return d.UploadAPI
402+
}
403+
getCachedUrlFunc := func() string {
404+
d.uploadUrlMu.RLock()
405+
defer d.uploadUrlMu.RUnlock()
406+
if d.uploadUrl != "" && time.Since(d.uploadUrlUpdateTime) < UPLOAD_URL_EXPIRE_TIME {
407+
uploadUrl := d.uploadUrl
408+
return uploadUrl
409+
}
410+
return ""
411+
}
412+
// 检查地址缓存
413+
if uploadUrl := getCachedUrlFunc(); uploadUrl != "" {
414+
return uploadUrl
415+
}
416+
417+
uploadUrlGetFunc := func() (string, error) {
418+
// 双重检查缓存
419+
if uploadUrl := getCachedUrlFunc(); uploadUrl != "" {
420+
return uploadUrl, nil
421+
}
422+
423+
uploadUrl, err := d.requestForUploadUrl(path, uploadId)
424+
if err != nil {
425+
return "", err
426+
}
427+
428+
d.uploadUrlMu.Lock()
429+
defer d.uploadUrlMu.Unlock()
430+
d.uploadUrl = uploadUrl
431+
d.uploadUrlUpdateTime = time.Now()
432+
return uploadUrl, nil
433+
}
434+
435+
uploadUrl, err, _ := d.uploadUrlG.Do("", uploadUrlGetFunc)
436+
if err != nil {
437+
fallback := d.UploadAPI
438+
log.Warnf("[baidu_netdisk] get upload URL failed (%v), will use fallback URL: %s", err, fallback)
439+
return fallback
440+
}
441+
return uploadUrl
442+
}
443+
444+
// requestForUploadUrl 请求获取上传地址。
445+
// 实测此接口不需要认证,传method和upload_version就行,不过还是按文档规范调用。
446+
// https://pan.baidu.com/union/doc/Mlvw5hfnr
447+
func (d *BaiduNetdisk) requestForUploadUrl(path, uploadId string) (string, error) {
448+
params := map[string]string{
449+
"method": "locateupload",
450+
"appid": "250528",
451+
"path": path,
452+
"uploadid": uploadId,
453+
"upload_version": "2.0",
454+
}
455+
apiUrl := "https://d.pcs.baidu.com/rest/2.0/pcs/file"
456+
var resp UploadServerResp
457+
_, err := d.request(apiUrl, http.MethodGet, func(req *resty.Request) {
458+
req.SetQueryParams(params)
459+
}, &resp)
460+
if err != nil {
461+
return "", err
462+
}
463+
// 应该是https开头的一个地址
464+
var uploadUrl string
465+
if len(resp.Servers) > 0 {
466+
uploadUrl = resp.Servers[0].Server
467+
} else if len(resp.BakServers) > 0 {
468+
uploadUrl = resp.BakServers[0].Server
469+
}
470+
if uploadUrl == "" {
471+
return "", errors.New("upload URL is empty")
472+
}
473+
return uploadUrl, nil
474+
}
475+
397476
// func encodeURIComponent(str string) string {
398477
// r := url.QueryEscape(str)
399478
// r = strings.ReplaceAll(r, "+", "%20")

0 commit comments

Comments
 (0)