From 8cb13731e270413809148197809ea8eb69f79c0b Mon Sep 17 00:00:00 2001 From: Yufan Sheng Date: Wed, 16 Nov 2022 02:47:31 +0800 Subject: [PATCH] feat: add file download for telegram. --- internal/driver/aliyun.go | 36 +++++++---- internal/driver/aliyun/share.go | 10 +-- internal/driver/common.go | 6 +- internal/driver/lanzou.go | 12 +++- internal/driver/lanzou/common.go | 8 ++- internal/driver/telecom.go | 21 +++---- internal/fetcher/fetcher.go | 13 +--- internal/fetcher/sanqiu.go | 15 ++--- internal/fetcher/service.go | 18 +----- internal/fetcher/talebook.go | 11 +++- internal/fetcher/telegram.go | 105 ++++++++++++++++++++++++++++--- internal/telegram/common.go | 58 ++++------------- internal/telegram/download.go | 100 +++++++++++++++++++++++++++++ internal/telegram/proxy.go | 4 +- 14 files changed, 288 insertions(+), 129 deletions(-) create mode 100644 internal/telegram/download.go diff --git a/internal/driver/aliyun.go b/internal/driver/aliyun.go index 36fc4ec..b7c2793 100644 --- a/internal/driver/aliyun.go +++ b/internal/driver/aliyun.go @@ -44,35 +44,45 @@ func (a *aliyunDriver) Source() Source { } func (a *aliyunDriver) Resolve(shareLink string, passcode string) ([]Share, error) { - c := a.client - token, err := c.ShareToken(shareLink, passcode) + token, err := a.client.ShareToken(shareLink, passcode) if err != nil { return nil, err } - shareFiles, err := c.Share(shareLink, token.ShareToken) + + files, err := a.client.Share(shareLink, token.ShareToken) if err != nil { return nil, err } + var shares []Share - for index := range shareFiles { - item := shareFiles[index] - shares = append(shares, Share{ + for index := range files { + item := files[index] + share := Share{ FileName: item.Name, + Size: int64(item.Size), URL: item.FileID, - Properties: map[string]string{ + Properties: map[string]any{ "shareToken": token.ShareToken, "shareID": shareLink, + "fileID": item.FileID, }, - }) + } + + shares = append(shares, share) } + return shares, nil } -func (a *aliyunDriver) Download(share Share) (io.ReadCloser, int64, error) { - c := a.client - url, err := c.DownloadURL(share.Properties["shareToken"], share.Properties["shareID"], share.URL) +func (a *aliyunDriver) Download(share Share) (io.ReadCloser, error) { + shareToken := share.Properties["shareToken"].(string) + shareID := share.Properties["shareID"].(string) + fileID := share.Properties["fileID"].(string) + + url, err := a.client.DownloadURL(shareToken, shareID, fileID) if err != nil { - return nil, 0, err + return nil, err } - return c.DownloadFile(url) + + return a.client.DownloadFile(url) } diff --git a/internal/driver/aliyun/share.go b/internal/driver/aliyun/share.go index fbf6534..db42107 100644 --- a/internal/driver/aliyun/share.go +++ b/internal/driver/aliyun/share.go @@ -117,7 +117,7 @@ func (ali *Aliyun) ShareToken(shareID string, sharePwd string) (*ShareTokenResp, return resp.Result().(*ShareTokenResp), nil } -func (ali *Aliyun) DownloadURL(shareToken string, shareID string, fileID string) (string, error) { +func (ali *Aliyun) DownloadURL(shareToken, shareID, fileID string) (string, error) { token, err := ali.AuthToken() if err != nil { return "", err @@ -143,13 +143,15 @@ func (ali *Aliyun) DownloadURL(shareToken string, shareID string, fileID string) return res.DownloadURL, nil } -func (ali *Aliyun) DownloadFile(downloadURL string) (io.ReadCloser, int64, error) { +func (ali *Aliyun) DownloadFile(downloadURL string) (io.ReadCloser, error) { log.Debugf("Start to download file from aliyun drive: %s", downloadURL) resp, err := ali.client.R(). SetDoNotParseResponse(true). Get(downloadURL) - response := resp.RawResponse + if err != nil { + return nil, err + } - return response.Body, response.ContentLength, err + return resp.RawBody(), err } diff --git a/internal/driver/common.go b/internal/driver/common.go index 724baf5..80c3a46 100644 --- a/internal/driver/common.go +++ b/internal/driver/common.go @@ -15,10 +15,12 @@ type ( Share struct { // FileName is a file name with the file extension. FileName string + // Size is the file size in bytes. + Size int64 // URL is the downloadable url for this file. URL string // Properties could be some metadata, such as the token for this downloadable share. - Properties map[string]string + Properties map[string]any } // Driver is used to resolve the links from a Source. @@ -30,7 +32,7 @@ type ( Resolve(shareLink string, passcode string) ([]Share, error) // Download the given link. - Download(share Share) (io.ReadCloser, int64, error) + Download(share Share) (io.ReadCloser, error) } ) diff --git a/internal/driver/lanzou.go b/internal/driver/lanzou.go index 1a40415..b86be2a 100644 --- a/internal/driver/lanzou.go +++ b/internal/driver/lanzou.go @@ -1,6 +1,7 @@ package driver import ( + "errors" "fmt" "io" @@ -14,7 +15,7 @@ func newLanzouDriver(c *client.Config, _ map[string]string) (Driver, error) { return nil, err } - return &lanzouDriver{driver: drive}, nil + return &lanzouDriver{driver: drive}, errors.New("we don't support lanzou currently") } type lanzouDriver struct { @@ -34,9 +35,14 @@ func (l *lanzouDriver) Resolve(shareLink string, passcode string) ([]Share, erro return nil, fmt.Errorf("parsed faild: %v", resp.Msg) } - return []Share{{FileName: resp.Data.Name, URL: resp.Data.URL, Properties: nil}}, err + share := Share{ + FileName: resp.Data.Name, + URL: resp.Data.URL, + } + + return []Share{share}, err } -func (l *lanzouDriver) Download(share Share) (io.ReadCloser, int64, error) { +func (l *lanzouDriver) Download(share Share) (io.ReadCloser, error) { return l.driver.DownloadFile(share.URL) } diff --git a/internal/driver/lanzou/common.go b/internal/driver/lanzou/common.go index 95b807d..4e82a17 100644 --- a/internal/driver/lanzou/common.go +++ b/internal/driver/lanzou/common.go @@ -37,13 +37,15 @@ func NewDrive(config *client.Config) (*Drive, error) { return &Drive{client: cl}, nil } -func (l *Drive) DownloadFile(downloadURL string) (io.ReadCloser, int64, error) { +func (l *Drive) DownloadFile(downloadURL string) (io.ReadCloser, error) { log.Debugf("Start to download file from aliyun drive: %s", downloadURL) resp, err := l.client.R(). SetDoNotParseResponse(true). Get(downloadURL) - response := resp.RawResponse + if err != nil { + return nil, err + } - return response.Body, response.ContentLength, err + return resp.RawBody(), err } diff --git a/internal/driver/telecom.go b/internal/driver/telecom.go index d6c4809..4ccbfda 100644 --- a/internal/driver/telecom.go +++ b/internal/driver/telecom.go @@ -42,11 +42,11 @@ func (t *telecomDriver) Resolve(shareLink string, passcode string) ([]Share, err for _, file := range files { shares = append(shares, Share{ FileName: file.Name, - Properties: map[string]string{ - "fileID": strconv.FormatInt(file.ID, 10), - "shareID": strconv.FormatInt(info.ShareID, 10), + Size: file.Size, + Properties: map[string]any{ "shareCode": code, - "fileSize": strconv.FormatInt(file.Size, 10), + "shareID": strconv.FormatInt(info.ShareID, 10), + "fileID": strconv.FormatInt(file.ID, 10), }, }) } @@ -54,20 +54,19 @@ func (t *telecomDriver) Resolve(shareLink string, passcode string) ([]Share, err return shares, nil } -func (t *telecomDriver) Download(share Share) (io.ReadCloser, int64, error) { - shareCode := share.Properties["shareCode"] - shareID := share.Properties["shareID"] - fileID := share.Properties["fileID"] - size, _ := strconv.ParseInt(share.Properties["fileSize"], 10, 64) +func (t *telecomDriver) Download(share Share) (io.ReadCloser, error) { + shareCode := share.Properties["shareCode"].(string) + shareID := share.Properties["shareID"].(string) + fileID := share.Properties["fileID"].(string) // Resolve the link. url, err := t.client.DownloadURL(shareCode, shareID, fileID) if err != nil { - return nil, 0, err + return nil, err } // Download the file. file, err := t.client.DownloadFile(url) - return file, size, err + return file, err } diff --git a/internal/fetcher/fetcher.go b/internal/fetcher/fetcher.go index 2765ff2..497d8a0 100644 --- a/internal/fetcher/fetcher.go +++ b/internal/fetcher/fetcher.go @@ -124,12 +124,6 @@ func (f *commonFetcher) startDownload() { // downloadFile in a thread. func (f *commonFetcher) downloadFile(bookID int64, format Format, share driver.Share) error { - file, err := f.service.fetch(bookID, format, share) - if err != nil { - return err - } - defer func() { _ = file.content.Close() }() - // Rename if it was required. prefix := strconv.FormatInt(bookID, 10) if f.Rename { @@ -160,12 +154,11 @@ func (f *commonFetcher) downloadFile(bookID int64, format Format, share driver.S defer func() { _ = writer.Close() }() // Add download progress. - bar := log.NewProgressBar(bookID, f.progress.Size(), share.FileName, file.size) + bar := log.NewProgressBar(bookID, f.progress.Size(), share.FileName, share.Size) defer func() { _ = bar.Close() }() - // Write file content - _, err = io.Copy(io.MultiWriter(writer, bar), file.content) - if err != nil { + // Write file content. + if err := f.service.fetch(bookID, format, share, io.MultiWriter(writer, bar)); err != nil { return err } diff --git a/internal/fetcher/sanqiu.go b/internal/fetcher/sanqiu.go index 38be247..876835e 100644 --- a/internal/fetcher/sanqiu.go +++ b/internal/fetcher/sanqiu.go @@ -2,6 +2,7 @@ package fetcher import ( "errors" + "io" "strconv" "github.com/bookstairs/bookhunter/internal/client" @@ -110,14 +111,14 @@ func (s *sanqiuService) formats(id int64) (map[Format]driver.Share, error) { return map[Format]driver.Share{}, nil } -func (s *sanqiuService) fetch(id int64, format Format, share driver.Share) (*fetch, error) { - content, size, err := s.driver.Download(share) +func (s *sanqiuService) fetch(_ int64, _ Format, share driver.Share, writer io.Writer) error { + content, err := s.driver.Download(share) if err != nil { - return nil, err + return err } + defer func() { _ = content.Close() }() - return &fetch{ - content: content, - size: size, - }, nil + // Save the download content info files. + _, err = io.Copy(writer, content) + return err } diff --git a/internal/fetcher/service.go b/internal/fetcher/service.go index 516ffdb..789c21f 100644 --- a/internal/fetcher/service.go +++ b/internal/fetcher/service.go @@ -5,8 +5,6 @@ import ( "fmt" "io" - "github.com/go-resty/resty/v2" - "github.com/bookstairs/bookhunter/internal/driver" ) @@ -19,21 +17,7 @@ type service interface { formats(int64) (map[Format]driver.Share, error) // fetch the given book ID. - fetch(int64, Format, driver.Share) (*fetch, error) -} - -// fetch is the result which can be created from a resty.Response -type fetch struct { - content io.ReadCloser - size int64 -} - -// createFetch will try to create a fetch. Remember to create the response with the `SetDoNotParseResponse` option. -func createFetch(resp *resty.Response) *fetch { - return &fetch{ - content: resp.RawBody(), - size: resp.RawResponse.ContentLength, - } + fetch(int64, Format, driver.Share, io.Writer) error } // newService is the endpoint for creating all the supported download service. diff --git a/internal/fetcher/talebook.go b/internal/fetcher/talebook.go index 358463b..c800674 100644 --- a/internal/fetcher/talebook.go +++ b/internal/fetcher/talebook.go @@ -3,6 +3,7 @@ package fetcher import ( "errors" "fmt" + "io" "net/http" "strconv" "strings" @@ -132,13 +133,17 @@ func (t *talebookService) formats(id int64) (map[Format]driver.Share, error) { } } -func (t *talebookService) fetch(_ int64, _ Format, share driver.Share) (*fetch, error) { +func (t *talebookService) fetch(_ int64, _ Format, share driver.Share, writer io.Writer) error { resp, err := t.client.R(). SetDoNotParseResponse(true). Get(share.URL) if err != nil { - return nil, err + return err } + body := resp.RawBody() + defer func() { _ = body.Close() }() - return createFetch(resp), nil + // Save the download content info files. + _, err = io.Copy(writer, body) + return err } diff --git a/internal/fetcher/telegram.go b/internal/fetcher/telegram.go index 6672531..6c74556 100644 --- a/internal/fetcher/telegram.go +++ b/internal/fetcher/telegram.go @@ -1,25 +1,116 @@ package fetcher -import "github.com/bookstairs/bookhunter/internal/driver" +import ( + "io" + "os" + "path/filepath" + "strconv" + "time" + + "github.com/gotd/contrib/middleware/floodwait" + "github.com/gotd/contrib/middleware/ratelimit" + "github.com/gotd/td/session" + client "github.com/gotd/td/telegram" + "github.com/gotd/td/telegram/dcs" + "github.com/gotd/td/tg" + "golang.org/x/time/rate" + + "github.com/bookstairs/bookhunter/internal/driver" + "github.com/bookstairs/bookhunter/internal/telegram" +) type telegramService struct { - config *Config + config *Config + telegram *telegram.Telegram + info *telegram.ChannelInfo } func newTelegramService(config *Config) (service, error) { + // Create the session file. + path, err := config.ConfigPath() + if err != nil { + return nil, err + } + sessionPath := filepath.Join(path, "session.json") + if refresh, _ := strconv.ParseBool(config.Property("reLogin")); refresh { + _ = os.Remove(sessionPath) + } + + channelID := config.Property("channelID") + mobile := config.Property("mobile") + appID, _ := strconv.ParseInt(config.Property("appID"), 10, 64) + appHash := config.Property("appHash") + + // Create the http proxy dial. + dialFunc, err := telegram.CreateProxy(config.Proxy) + if err != nil { + return nil, err + } + + // Create the backend telegram client. + c := client.NewClient( + int(appID), + appHash, + client.Options{ + Resolver: dcs.Plain(dcs.PlainOptions{Dial: dialFunc}), + SessionStorage: &session.FileStorage{Path: sessionPath}, + Middlewares: []client.Middleware{ + floodwait.NewSimpleWaiter().WithMaxRetries(uint(3)), + ratelimit.New(rate.Every(time.Minute), config.RateLimit), + }, + }, + ) + + tel, err := telegram.New(channelID, mobile, appID, appHash, c) + if err != nil { + return nil, err + } + return &telegramService{ - config: config, + config: config, + telegram: tel, }, nil } func (s *telegramService) size() (int64, error) { - panic("implement me") + info, err := s.telegram.ChannelInfo() + if err != nil { + return 0, err + } + s.info = info + + return info.LastMsgID, nil } func (s *telegramService) formats(id int64) (map[Format]driver.Share, error) { - panic("implement me") + files, err := s.telegram.ParseMessage(s.info, id) + if err != nil { + return nil, err + } + + res := make(map[Format]driver.Share) + for _, file := range files { + res[Format(file.Format)] = driver.Share{ + FileName: file.Name, + Size: file.Size, + Properties: map[string]any{ + "fileID": file.ID, + "document": file.Document, + }, + } + } + + return res, nil } -func (s *telegramService) fetch(id int64, format Format, share driver.Share) (*fetch, error) { - panic("implement me") +func (s *telegramService) fetch(_ int64, f Format, share driver.Share, writer io.Writer) error { + file := &telegram.File{ + ID: share.Properties["fileID"].(int64), + Name: share.FileName, + Format: string(f), + Size: share.Size, + Document: share.Properties["document"].(*tg.InputDocumentFileLocation), + } + + return s.telegram.DownloadFile(file, writer) } diff --git a/internal/telegram/common.go b/internal/telegram/common.go index 4cfb3cf..ebf1488 100644 --- a/internal/telegram/common.go +++ b/internal/telegram/common.go @@ -2,19 +2,9 @@ package telegram import ( "context" - "os" - "path/filepath" - "strconv" - "time" - "github.com/gotd/contrib/middleware/floodwait" - "github.com/gotd/contrib/middleware/ratelimit" - "github.com/gotd/td/session" "github.com/gotd/td/telegram" - "github.com/gotd/td/telegram/dcs" - "golang.org/x/time/rate" - - "github.com/bookstairs/bookhunter/internal/fetcher" + "github.com/gotd/td/tg" ) type ( @@ -31,45 +21,19 @@ type ( AccessHash int64 LastMsgID int64 } -) - -// New will create a telegram client. -func New(config *fetcher.Config) (*Telegram, error) { - // Create the session file. - path, err := config.ConfigPath() - if err != nil { - return nil, err - } - sessionPath := filepath.Join(path, "session.json") - if refresh, _ := strconv.ParseBool(config.Property("reLogin")); refresh { - _ = os.Remove(sessionPath) - } - - channelID := config.Property("channelID") - mobile := config.Property("mobile") - appID, _ := strconv.ParseInt(config.Property("appID"), 10, 64) - appHash := config.Property("appHash") - // Create the http proxy dial. - dialFunc, err := createProxy(config.Proxy) - if err != nil { - return nil, err + // File is the file info from the telegram channel. + File struct { + ID int64 + Name string + Format string + Size int64 + Document *tg.InputDocumentFileLocation } +) - // Create the backend telegram client. - client := telegram.NewClient( - int(appID), - appHash, - telegram.Options{ - Resolver: dcs.Plain(dcs.PlainOptions{Dial: dialFunc}), - SessionStorage: &session.FileStorage{Path: sessionPath}, - Middlewares: []telegram.Middleware{ - floodwait.NewSimpleWaiter().WithMaxRetries(uint(3)), - ratelimit.New(rate.Every(time.Minute), config.RateLimit), - }, - }, - ) - +// New will create a telegram client. +func New(channelID string, mobile string, appID int64, appHash string, client *telegram.Client) (*Telegram, error) { t := &Telegram{ channelID: channelID, mobile: mobile, diff --git a/internal/telegram/download.go b/internal/telegram/download.go new file mode 100644 index 0000000..237948d --- /dev/null +++ b/internal/telegram/download.go @@ -0,0 +1,100 @@ +package telegram + +import ( + "context" + "io" + "math" + + "github.com/gotd/td/telegram" + "github.com/gotd/td/telegram/downloader" + "github.com/gotd/td/tg" + + "github.com/bookstairs/bookhunter/internal/log" + "github.com/bookstairs/bookhunter/internal/naming" +) + +func (t *Telegram) DownloadFile(file *File, writer io.Writer) error { + return t.execute(func(ctx context.Context, client *telegram.Client) error { + tool := downloader.NewDownloader() + thread := int(math.Ceil(float64(file.Size) / (512 * 1024))) + _, err := tool.Download(client.API(), file.Document).WithThreads(thread).Stream(ctx, writer) + + return err + }) +} + +// ParseMessage will parse the given message id. +func (t *Telegram) ParseMessage(info *ChannelInfo, msgID int64) ([]File, error) { + var files []File + err := t.execute(func(ctx context.Context, client *telegram.Client) error { + // This API is translated from official C++ client. + api := client.API() + history, err := api.MessagesSearch(ctx, &tg.MessagesSearchRequest{ + Peer: &tg.InputPeerChannel{ + ChannelID: info.ID, + AccessHash: info.AccessHash, + }, + Filter: &tg.InputMessagesFilterEmpty{}, + Q: "", + OffsetID: int(msgID), + Limit: 1, + AddOffset: -1, + }) + if err != nil { + return err + } + + messages := history.(*tg.MessagesChannelMessages) + for i := len(messages.Messages) - 1; i >= 0; i-- { + message := messages.Messages[i] + + if file, ok := parseFile(message); ok { + files = append(files, *file) + } else { + log.Warnf("[%d/%d] No downloadable files found.", message.GetID(), info.LastMsgID) + continue + } + } + + return nil + }) + + return files, err +} + +func parseFile(message tg.MessageClass) (*File, bool) { + if message == nil { + return nil, false + } + msg, ok := message.(*tg.Message) + if !ok { + return nil, false + } + if msg.Media == nil { + return nil, false + } + s, ok := msg.Media.(*tg.MessageMediaDocument) + if !ok { + return nil, false + } + document := s.Document.(*tg.Document) + fileName := "" + for _, attribute := range document.Attributes { + x, ok := attribute.(*tg.DocumentAttributeFilename) + if ok { + fileName = x.FileName + } + } + if fileName == "" { + return nil, false + } + format, _ := naming.Extension(fileName) + + return &File{ + ID: int64(msg.ID), + Name: fileName, + Format: format, + Size: document.Size, + Document: document.AsInputDocumentFileLocation(), + }, true +} diff --git a/internal/telegram/proxy.go b/internal/telegram/proxy.go index df1e600..b133bff 100644 --- a/internal/telegram/proxy.go +++ b/internal/telegram/proxy.go @@ -13,9 +13,9 @@ import ( // This file is used to manually create a proxy with the arguments and system environment. -// createProxy is used to create a dcs.DialFunc for the telegram to send request. +// CreateProxy is used to create a dcs.DialFunc for the telegram to send request. // We don't support MTProxy now. -func createProxy(proxyURL string) (dcs.DialFunc, error) { +func CreateProxy(proxyURL string) (dcs.DialFunc, error) { if proxyURL != "" { log.Debugf("Try to manually create the proxy through %s", proxyURL)