Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 36 additions & 13 deletions drivers/baidu_netdisk/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
stdpath "path"
"strconv"
"strings"
"sync"
"time"

"github.com/OpenListTeam/OpenList/v4/drivers/base"
Expand All @@ -20,8 +21,10 @@ import (
"github.com/OpenListTeam/OpenList/v4/internal/errs"
"github.com/OpenListTeam/OpenList/v4/internal/model"
"github.com/OpenListTeam/OpenList/v4/pkg/errgroup"
"github.com/OpenListTeam/OpenList/v4/pkg/singleflight"
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
"github.com/avast/retry-go"
"github.com/go-resty/resty/v2"
log "github.com/sirupsen/logrus"
)

Expand All @@ -31,6 +34,12 @@ type BaiduNetdisk struct {

uploadThread int
vipType int // 会员类型,0普通用户(4G/4M)、1普通会员(10G/16M)、2超级会员(20G/32M)

upClient *resty.Client // 上传文件使用的http客户端
uploadUrlG singleflight.Group[string]
uploadUrlMu sync.RWMutex
uploadUrl string // 上传域名
uploadUrlUpdateTime time.Time // 上传域名上次更新时间
}

var ErrUploadIDExpired = errors.New("uploadid expired")
Expand All @@ -44,19 +53,26 @@ func (d *BaiduNetdisk) GetAddition() driver.Additional {
}

func (d *BaiduNetdisk) Init(ctx context.Context) error {
d.upClient = base.NewRestyClient().
SetTimeout(UPLOAD_TIMEOUT).
SetRetryCount(UPLOAD_RETRY_COUNT).
SetRetryWaitTime(UPLOAD_RETRY_WAIT_TIME).
SetRetryMaxWaitTime(UPLOAD_RETRY_MAX_WAIT_TIME)
d.uploadThread, _ = strconv.Atoi(d.UploadThread)
if d.uploadThread < 1 || d.uploadThread > 32 {
d.uploadThread, d.UploadThread = 3, "3"
if d.uploadThread < 1 {
d.uploadThread, d.UploadThread = 1, "1"
} else if d.uploadThread > 32 {
d.uploadThread, d.UploadThread = 32, "32"
}

if _, err := url.Parse(d.UploadAPI); d.UploadAPI == "" || err != nil {
d.UploadAPI = "https://d.pcs.baidu.com"
d.UploadAPI = UPLOAD_FALLBACK_API
}

res, err := d.get("/xpan/nas", map[string]string{
"method": "uinfo",
}, nil)
log.Debugf("[baidu] get uinfo: %s", string(res))
log.Debugf("[baidu_netdisk] get uinfo: %s", string(res))
if err != nil {
return err
}
Expand Down Expand Up @@ -183,6 +199,11 @@ func (d *BaiduNetdisk) PutRapid(ctx context.Context, dstDir model.Obj, stream mo
// **注意**: 截至 2024/04/20 百度云盘 api 接口返回的时间永远是当前时间,而不是文件时间。
// 而实际上云盘存储的时间是文件时间,所以此处需要覆盖时间,保证缓存与云盘的数据一致
func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
// 百度网盘不允许上传空文件
if stream.GetSize() < 1 {
return nil, ErrBaiduEmptyFilesNotAllowed
}

// rapid upload
if newObj, err := d.PutRapid(ctx, dstDir, stream); err == nil {
return newObj, nil
Expand Down Expand Up @@ -281,6 +302,9 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
// step.2 上传分片
uploadLoop:
for attempt := 0; attempt < 2; attempt++ {
// 获取上传域名
uploadUrl := d.getUploadUrl(path, precreateResp.Uploadid)
// 并发上传
threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread,
retry.Attempts(1),
retry.Delay(time.Second),
Expand All @@ -292,7 +316,6 @@ uploadLoop:
}

totalParts := len(precreateResp.BlockList)
doneParts := 0 // 新增计数器

for i, partseq := range precreateResp.BlockList {
if utils.IsCanceled(upCtx) || partseq < 0 {
Expand All @@ -313,15 +336,15 @@ uploadLoop:
"partseq": strconv.Itoa(partseq),
}
section := io.NewSectionReader(cacheReaderAt, offset, size)
err := d.uploadSlice(ctx, params, stream.GetName(), driver.NewLimitedUploadStream(ctx, section))
err := d.uploadSlice(ctx, uploadUrl, params, stream.GetName(), driver.NewLimitedUploadStream(ctx, section))
if err != nil {
return err
}
precreateResp.BlockList[i] = -1
doneParts++ // 直接递增计数器
if totalParts > 0 {
up(float64(doneParts) * 100.0 / float64(totalParts))
}
// 当前goroutine还没退出,+1才是真正成功的数量
success := threadG.Success() + 1
progress := float64(success) * 100 / float64(totalParts)
up(progress)
return nil
})
}
Expand Down Expand Up @@ -405,12 +428,12 @@ func (d *BaiduNetdisk) precreate(ctx context.Context, path string, streamSize in
return &precreateResp, nil
}

func (d *BaiduNetdisk) uploadSlice(ctx context.Context, params map[string]string, fileName string, file io.Reader) error {
res, err := base.RestyClient.R().
func (d *BaiduNetdisk) uploadSlice(ctx context.Context, uploadUrl string, params map[string]string, fileName string, file io.Reader) error {
res, err := d.upClient.R().
SetContext(ctx).
SetQueryParams(params).
SetFileReader("file", fileName, file).
Post(d.UploadAPI + "/rest/2.0/pcs/superfile2")
Post(uploadUrl + "/rest/2.0/pcs/superfile2")
if err != nil {
return err
}
Expand Down
11 changes: 11 additions & 0 deletions drivers/baidu_netdisk/meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package baidu_netdisk
import (
"github.com/OpenListTeam/OpenList/v4/internal/driver"
"github.com/OpenListTeam/OpenList/v4/internal/op"
"time"
)

type Addition struct {
Expand All @@ -19,11 +20,21 @@ type Addition struct {
RefreshToken string `json:"refresh_token" required:"true"`
UploadThread string `json:"upload_thread" default:"3" help:"1<=thread<=32"`
UploadAPI string `json:"upload_api" default:"https://d.pcs.baidu.com"`
UseDynamicUploadAPI bool `json:"use_dynamic_upload_api" default:"true" help:"dynamically get upload api domain, when enabled, the 'Upload API' setting will be used as a fallback if failed to get"`
CustomUploadPartSize int64 `json:"custom_upload_part_size" type:"number" default:"0" help:"0 for auto"`
LowBandwithUploadMode bool `json:"low_bandwith_upload_mode" default:"false"`
OnlyListVideoFile bool `json:"only_list_video_file" default:"false"`
}

const (
UPLOAD_FALLBACK_API = "https://d.pcs.baidu.com" // 备用上传地址
UPLOAD_URL_EXPIRE_TIME = time.Minute * 60 // 上传地址有效期(分钟)
UPLOAD_TIMEOUT = time.Minute * 30 // 上传请求超时时间
UPLOAD_RETRY_COUNT = 3
UPLOAD_RETRY_WAIT_TIME = time.Second * 1
UPLOAD_RETRY_MAX_WAIT_TIME = time.Second * 5
)

var config = driver.Config{
Name: "BaiduNetdisk",
DefaultRoot: "/",
Expand Down
29 changes: 29 additions & 0 deletions drivers/baidu_netdisk/types.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package baidu_netdisk

import (
"errors"
"path"
"strconv"
"time"
Expand All @@ -9,6 +10,10 @@ import (
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
)

var (
ErrBaiduEmptyFilesNotAllowed = errors.New("empty files are not allowed by baidu netdisk")
)

type TokenErrResp struct {
ErrorDescription string `json:"error_description"`
Error string `json:"error"`
Expand Down Expand Up @@ -190,6 +195,30 @@ type PrecreateResp struct {
File File `json:"info"`
}

type UploadServerResp struct {
BakServer []any `json:"bak_server"`
BakServers []struct {
Server string `json:"server"`
} `json:"bak_servers"`
ClientIP string `json:"client_ip"`
ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"`
Expire int `json:"expire"`
Host string `json:"host"`
Newno string `json:"newno"`
QuicServer []any `json:"quic_server"`
QuicServers []struct {
Server string `json:"server"`
} `json:"quic_servers"`
RequestID int64 `json:"request_id"`
Server []any `json:"server"`
ServerTime int `json:"server_time"`
Servers []struct {
Server string `json:"server"`
} `json:"servers"`
Sl int `json:"sl"`
}

type QuotaResp struct {
Errno int `json:"errno"`
RequestId int64 `json:"request_id"`
Expand Down
93 changes: 86 additions & 7 deletions drivers/baidu_netdisk/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func (d *BaiduNetdisk) request(furl string, method string, callback base.ReqCall
errno := utils.Json.Get(res.Body(), "errno").ToInt()
if errno != 0 {
if utils.SliceContains([]int{111, -6}, errno) {
log.Info("refreshing baidu_netdisk token.")
log.Info("[baidu_netdisk] refreshing baidu_netdisk token.")
err2 := d.refreshToken()
if err2 != nil {
return retry.Unrecoverable(err2)
Expand Down Expand Up @@ -326,28 +326,28 @@ func (d *BaiduNetdisk) getSliceSize(filesize int64) int64 {
// 非会员固定为 4MB
if d.vipType == 0 {
if d.CustomUploadPartSize != 0 {
log.Warnf("CustomUploadPartSize is not supported for non-vip user, use DefaultSliceSize")
log.Warnf("[baidu_netdisk] CustomUploadPartSize is not supported for non-vip user, use DefaultSliceSize")
}
if filesize > MaxSliceNum*DefaultSliceSize {
log.Warnf("File size(%d) is too large, may cause upload failure", filesize)
log.Warnf("[baidu_netdisk] File size(%d) is too large, may cause upload failure", filesize)
}

return DefaultSliceSize
}

if d.CustomUploadPartSize != 0 {
if d.CustomUploadPartSize < DefaultSliceSize {
log.Warnf("CustomUploadPartSize(%d) is less than DefaultSliceSize(%d), use DefaultSliceSize", d.CustomUploadPartSize, DefaultSliceSize)
log.Warnf("[baidu_netdisk] CustomUploadPartSize(%d) is less than DefaultSliceSize(%d), use DefaultSliceSize", d.CustomUploadPartSize, DefaultSliceSize)
return DefaultSliceSize
}

if d.vipType == 1 && d.CustomUploadPartSize > VipSliceSize {
log.Warnf("CustomUploadPartSize(%d) is greater than VipSliceSize(%d), use VipSliceSize", d.CustomUploadPartSize, VipSliceSize)
log.Warnf("[baidu_netdisk] CustomUploadPartSize(%d) is greater than VipSliceSize(%d), use VipSliceSize", d.CustomUploadPartSize, VipSliceSize)
return VipSliceSize
}

if d.vipType == 2 && d.CustomUploadPartSize > SVipSliceSize {
log.Warnf("CustomUploadPartSize(%d) is greater than SVipSliceSize(%d), use SVipSliceSize", d.CustomUploadPartSize, SVipSliceSize)
log.Warnf("[baidu_netdisk] CustomUploadPartSize(%d) is greater than SVipSliceSize(%d), use SVipSliceSize", d.CustomUploadPartSize, SVipSliceSize)
return SVipSliceSize
}

Expand Down Expand Up @@ -377,7 +377,7 @@ func (d *BaiduNetdisk) getSliceSize(filesize int64) int64 {
}

if filesize > MaxSliceNum*maxSliceSize {
log.Warnf("File size(%d) is too large, may cause upload failure", filesize)
log.Warnf("[baidu_netdisk] File size(%d) is too large, may cause upload failure", filesize)
}

return maxSliceSize
Expand All @@ -394,6 +394,85 @@ func (d *BaiduNetdisk) quota(ctx context.Context) (model.DiskUsage, error) {
return driver.DiskUsageFromUsedAndTotal(resp.Used, resp.Total), nil
}

// getUploadUrl 从开放平台获取上传域名/地址,并发请求会被合并,结果会被缓存1h。
// 如果获取失败,则返回 Upload API设置项。
func (d *BaiduNetdisk) getUploadUrl(path, uploadId string) string {
if !d.UseDynamicUploadAPI {
return d.UploadAPI
}
getCachedUrlFunc := func() string {
d.uploadUrlMu.RLock()
defer d.uploadUrlMu.RUnlock()
if d.uploadUrl != "" && time.Since(d.uploadUrlUpdateTime) < UPLOAD_URL_EXPIRE_TIME {
uploadUrl := d.uploadUrl
return uploadUrl
}
return ""
}
// 检查地址缓存
if uploadUrl := getCachedUrlFunc(); uploadUrl != "" {
return uploadUrl
}

uploadUrlGetFunc := func() (string, error) {
// 双重检查缓存
if uploadUrl := getCachedUrlFunc(); uploadUrl != "" {
return uploadUrl, nil
}

uploadUrl, err := d.requestForUploadUrl(path, uploadId)
if err != nil {
return "", err
}

d.uploadUrlMu.Lock()
defer d.uploadUrlMu.Unlock()
d.uploadUrl = uploadUrl
d.uploadUrlUpdateTime = time.Now()
return uploadUrl, nil
}

uploadUrl, err, _ := d.uploadUrlG.Do("", uploadUrlGetFunc)
if err != nil {
fallback := d.UploadAPI
log.Warnf("[baidu_netdisk] get upload URL failed (%v), will use fallback URL: %s", err, fallback)
return fallback
}
return uploadUrl
}

// requestForUploadUrl 请求获取上传地址。
// 实测此接口不需要认证,传method和upload_version就行,不过还是按文档规范调用。
// https://pan.baidu.com/union/doc/Mlvw5hfnr
func (d *BaiduNetdisk) requestForUploadUrl(path, uploadId string) (string, error) {
params := map[string]string{
"method": "locateupload",
"appid": "250528",
"path": path,
"uploadid": uploadId,
"upload_version": "2.0",
}
apiUrl := "https://d.pcs.baidu.com/rest/2.0/pcs/file"
var resp UploadServerResp
_, err := d.request(apiUrl, http.MethodGet, func(req *resty.Request) {
req.SetQueryParams(params)
}, &resp)
if err != nil {
return "", err
}
// 应该是https开头的一个地址
var uploadUrl string
if len(resp.Servers) > 0 {
uploadUrl = resp.Servers[0].Server
} else if len(resp.BakServers) > 0 {
uploadUrl = resp.BakServers[0].Server
}
if uploadUrl == "" {
return "", errors.New("upload URL is empty")
}
return uploadUrl, nil
}

// func encodeURIComponent(str string) string {
// r := url.QueryEscape(str)
// r = strings.ReplaceAll(r, "+", "%20")
Expand Down