Skip to content

Commit

Permalink
Add ctx.DownloadMedia method
Browse files Browse the repository at this point in the history
  • Loading branch information
celestix committed Jan 26, 2024
1 parent f7ea5b8 commit f75a0aa
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 13 deletions.
17 changes: 9 additions & 8 deletions errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@ var (
)

var (
ErrPeerNotFound = errors.New("peer not found")
ErrNotChat = errors.New("not chat")
ErrNotChannel = errors.New("not channel")
ErrNotUser = errors.New("not user")
ErrTextEmpty = errors.New("text was not provided")
ErrTextInvalid = errors.New("type of text is invalid, provide one from string and []styling.StyledTextOption")
ErrMessageNotExist = errors.New("message not exist")
ErrReplyNotMessage = errors.New("reply header is not a message")
ErrPeerNotFound = errors.New("peer not found")
ErrNotChat = errors.New("not chat")
ErrNotChannel = errors.New("not channel")
ErrNotUser = errors.New("not user")
ErrTextEmpty = errors.New("text was not provided")
ErrTextInvalid = errors.New("type of text is invalid, provide one from string and []styling.StyledTextOption")
ErrMessageNotExist = errors.New("message not exist")
ErrReplyNotMessage = errors.New("reply header is not a message")
ErrUnknownTypeMedia = errors.New("unknown type media")
)
104 changes: 99 additions & 5 deletions ext/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"math/rand"
"strings"
"time"
Expand All @@ -12,6 +13,7 @@ import (
"github.com/celestix/gotgproto/functions"
"github.com/celestix/gotgproto/storage"
"github.com/celestix/gotgproto/types"
"github.com/gotd/td/telegram/downloader"
"github.com/gotd/td/telegram/message"
"github.com/gotd/td/telegram/message/entity"
"github.com/gotd/td/telegram/message/styling"
Expand Down Expand Up @@ -683,14 +685,106 @@ func (ctx *Context) GetUserProfilePhotos(userId int64, opts *tg.PhotosGetUserPho
return p.GetPhotos(), nil
}

func (ctx *Context) ForwardMediaGroup() error {
_, err := ctx.Raw.MessagesForwardMessages(ctx, &tg.MessagesForwardMessagesRequest{})
return err
}

// ExportSessionString returns session of authorized account in the form of string.
// Note: This session string can be used to log back in with the help of gotgproto.
// Check sessionMaker.SessionType for more information about it.
func (ctx *Context) ExportSessionString() (string, error) {
return functions.EncodeSessionToString(ctx.PeerStorage.GetSession())
}

func getInputFileLocation(media tg.MessageMediaClass) (tg.InputFileLocationClass, error) {
switch v := media.(type) {
case *tg.MessageMediaPhoto: // messageMediaPhoto#695150d7
f, ok := v.Photo.AsNotEmpty()
if !ok {
return nil, mtp_errors.ErrUnknownTypeMedia
}
return &tg.InputPhotoFileLocation{
ID: f.ID,
AccessHash: f.AccessHash,
FileReference: f.FileReference,
}, nil
case *tg.MessageMediaDocument: // messageMediaDocument#4cf4d72d
f, ok := v.Document.AsNotEmpty()
if !ok {
return nil, mtp_errors.ErrUnknownTypeMedia
}
return f.AsInputDocumentFileLocation(), nil
case *tg.MessageMediaStory: // messageMediaStory#68cb6283
f, ok := v.Story.(*tg.StoryItem)
if !ok {
return nil, mtp_errors.ErrUnknownTypeMedia
}
return getInputFileLocation(f.Media)
}
return nil, mtp_errors.ErrUnknownTypeMedia
}

// DownloadOutputClass is an interface which is used to download media.
// It can be one from DownloadOutputStream, DownloadOutputPath and DownloadOutputParallel.
type DownloadOutputClass interface {
run(context.Context, *downloader.Builder) (tg.StorageFileTypeClass, error)
}

// DownloadOutputStream is used to download media to an io.Writer.
// It can be used to download media to a file, memory etc.
type DownloadOutputStream struct {
io.Writer
}

func (d DownloadOutputStream) run(ctx context.Context, b *downloader.Builder) (tg.StorageFileTypeClass, error) {
return b.Stream(ctx, d)
}

// DownloadOutputPath is used to download media to a file.
type DownloadOutputPath string

func (d DownloadOutputPath) run(ctx context.Context, b *downloader.Builder) (tg.StorageFileTypeClass, error) {
return b.ToPath(ctx, string(d))
}

// DownloadOutputParallel is used to download a media parallely.
type DownloadOutputParallel struct {
io.WriterAt
}

func (d DownloadOutputParallel) run(ctx context.Context, b *downloader.Builder) (tg.StorageFileTypeClass, error) {
return b.Parallel(ctx, d)
}

// DownloadMediaOpts object contains optional parameters for Context.DownloadMedia.
// If not provided, default values will be used.
type DownloadMediaOpts struct {
// Threads sets downloading goroutines limit.
Threads int
// If verify is true, file hashes will be checked. Verify is true by default for CDN downloads.
Verify *bool

This comment has been minimized.

Copy link
@kazhuravlev

kazhuravlev Mar 20, 2024

Probably will be better to use just bool. It looks like applicable here

This comment has been minimized.

Copy link
@celestix

celestix Mar 20, 2024

Author Owner

Hello @kazhuravlev, Thank you for pointing it out and I agree with you. IIRC it was kept as a ptr because the docs suggested that verify can be true by default in some cases, so it was decided to keep it a ptr and only change value after a null value check, but while looking to the actual declaration of Builder inside gotd I realized that the verify is not being set to true explicitly anywhere in the builder till we use it.

Would you like to create a pull request for the change (since you pointed it out) and contribute?

// PartSize sets chunk size. Must be divisible by 4KB.
PartSize int
}

// DownloadMedia downloads media from the provided MessageMediaClass.
// DownloadOutputClass can be one from DownloadOutputStream, DownloadOutputPath and DownloadOutputParallel.
// DownloadMediaOpts can be used to set optional parameters.
// Returns tg.StorageFileTypeClass and error if any.
func (ctx *Context) DownloadMedia(media tg.MessageMediaClass, downloadOutput DownloadOutputClass, opts *DownloadMediaOpts) (tg.StorageFileTypeClass, error) {
if opts == nil {
opts = &DownloadMediaOpts{}
}
downloader := downloader.NewDownloader()
if opts.PartSize > 0 {
downloader.WithPartSize(opts.PartSize)
}
inputFileLocation, err := getInputFileLocation(media)
if err != nil {
return nil, err
}
d := downloader.Download(ctx.Raw, inputFileLocation)
if opts.Threads > 0 {
d.WithThreads(opts.Threads)
}
if opts.Verify != nil {
d.WithVerify(*opts.Verify)
}
return downloadOutput.run(ctx, d)
}

0 comments on commit f75a0aa

Please sign in to comment.