diff --git a/api/r0/download.go b/api/r0/download.go index e4313857..504a943f 100644 --- a/api/r0/download.go +++ b/api/r0/download.go @@ -1,17 +1,19 @@ package r0 import ( - "github.com/getsentry/sentry-go" "io" "net/http" "strconv" + "github.com/getsentry/sentry-go" "github.com/gorilla/mux" "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/api" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/download_controller" + "github.com/turt2live/matrix-media-repo/storage" + "github.com/turt2live/matrix-media-repo/storage/datastore" ) type DownloadMediaResponse struct { @@ -22,6 +24,11 @@ type DownloadMediaResponse struct { TargetDisposition string } +type Redirect struct { + Status int + URL string +} + func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { params := mux.Vars(r) @@ -70,31 +77,53 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserI "allowRemote": downloadRemote, }) - streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, false, asyncWaitMs, rctx) + db := storage.GetDatabase().GetMediaStore(rctx) + dbMedia, err := db.Get(server, mediaId) if err != nil { - if err == common.ErrMediaNotFound { - return api.NotFoundError() - } else if err == common.ErrMediaTooLarge { - return api.RequestTooLarge() - } else if err == common.ErrMediaQuarantined { - return api.NotFoundError() // We lie for security - } else if err == common.ErrNotYetUploaded { - return api.NotYetUploaded() - } - rctx.Log.Error("Unexpected error locating media: " + err.Error()) - sentry.CaptureException(err) - return api.InternalServerError("Unexpected Error") + return handleDownloadError(rctx, err) } - if filename == "" { - filename = streamedMedia.UploadName + if datastore.ShouldRedirectDownload(rctx, dbMedia.DatastoreId) { + media, err := download_controller.GetMediaURL(server, mediaId, filename, downloadRemote, false, asyncWaitMs, rctx) + if err != nil { + return handleDownloadError(rctx, err) + } + + return &Redirect{ + Status: http.StatusTemporaryRedirect, + URL: media.URL, + } + } else { + streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, false, asyncWaitMs, rctx) + if err != nil { + return handleDownloadError(rctx, err) + } + + if filename == "" { + filename = streamedMedia.UploadName + } + + return &DownloadMediaResponse{ + ContentType: streamedMedia.ContentType, + Filename: filename, + SizeBytes: streamedMedia.SizeBytes, + Data: streamedMedia.Stream, + TargetDisposition: targetDisposition, + } } +} - return &DownloadMediaResponse{ - ContentType: streamedMedia.ContentType, - Filename: filename, - SizeBytes: streamedMedia.SizeBytes, - Data: streamedMedia.Stream, - TargetDisposition: targetDisposition, +func handleDownloadError(ctx rcontext.RequestContext, err error) interface{} { + switch err { + case common.ErrMediaNotFound, common.ErrMediaQuarantined: + return api.NotFoundError() + case common.ErrMediaTooLarge: + return api.RequestTooLarge() // this does *not* seem like the right status code + case common.ErrNotYetUploaded: + return api.NotYetUploaded() + default: + ctx.Log.Warn("error looking up media for download: ", err) + sentry.CaptureException(err) + return api.InternalServerError("Unexpected Error") } } diff --git a/api/webserver/route_handler.go b/api/webserver/route_handler.go index 63acdd20..7509fd4a 100644 --- a/api/webserver/route_handler.go +++ b/api/webserver/route_handler.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/getsentry/sentry-go" "io" "io/ioutil" "math" @@ -18,6 +17,7 @@ import ( "strings" "github.com/alioygur/is" + "github.com/getsentry/sentry-go" "github.com/prometheus/client_golang/prometheus" "github.com/sebest/xff" "github.com/sirupsen/logrus" @@ -374,6 +374,15 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Content-Security-Policy", "") // We're serving HTML, so take away the CSP io.Copy(w, bytes.NewBuffer([]byte(result.HTML))) return + case *r0.Redirect: + metrics.HttpResponses.With(prometheus.Labels{ + "host": r.Host, + "action": h.action, + "method": r.Method, + "statusCode": strconv.Itoa(result.Status), + }).Inc() + http.Redirect(w, r, result.URL, result.Status) + return default: break } diff --git a/controllers/download_controller/download_controller.go b/controllers/download_controller/download_controller.go index e3d6c92e..90333694 100644 --- a/controllers/download_controller/download_controller.go +++ b/controllers/download_controller/download_controller.go @@ -5,12 +5,12 @@ import ( "database/sql" "errors" "fmt" - "github.com/getsentry/sentry-go" "io" "io/ioutil" "time" "github.com/disintegration/imaging" + "github.com/getsentry/sentry-go" "github.com/patrickmn/go-cache" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/globals" @@ -26,6 +26,41 @@ import ( var localCache = cache.New(30*time.Second, 60*time.Second) +// GetMediaURL is like GetMedia but it returns a pre-signed S3 download URL +func GetMediaURL(origin string, mediaId string, filename string, downloadRemote bool, blockForMedia bool, asyncWaitMs *int, ctx rcontext.RequestContext) (*types.MinimalMedia, error) { + db := storage.GetDatabase().GetMediaStore(ctx) + + ctx.Log.Info("Getting media record from database") + dbMedia, err := db.Get(origin, mediaId) + if err != nil { + return nil, err + } + + media, err := waitForUpload(dbMedia, asyncWaitMs, ctx) + if err != nil { + return nil, err + } + + if filename == "" { + filename = media.UploadName + } + + downloadURL, err := datastore.GetDownloadURL(ctx, media.DatastoreId, media.Location, filename) + if err != nil { + return nil, err + } + + return &types.MinimalMedia{ + Origin: media.Origin, + MediaId: media.MediaId, + ContentType: media.ContentType, + UploadName: media.UploadName, + SizeBytes: media.SizeBytes, + KnownMedia: media, + URL: downloadURL, + }, nil +} + func GetMedia(origin string, mediaId string, downloadRemote bool, blockForMedia bool, asyncWaitMs *int, ctx rcontext.RequestContext) (*types.MinimalMedia, error) { cacheKey := fmt.Sprintf("%s/%s?r=%t&b=%t", origin, mediaId, downloadRemote, blockForMedia) v, _, err := globals.DefaultRequestGroup.Do(cacheKey, func() (interface{}, error) { diff --git a/controllers/upload_controller/upload_controller.go b/controllers/upload_controller/upload_controller.go index 731b6889..36834f68 100644 --- a/controllers/upload_controller/upload_controller.go +++ b/controllers/upload_controller/upload_controller.go @@ -202,7 +202,6 @@ func UploadMedia(contents io.ReadCloser, contentLength int64, contentType string ctx.Log.Warn("Unexpected error trying to notify cache about media: " + err.Error()) } } - } return m, err } diff --git a/storage/datastore/datastore.go b/storage/datastore/datastore.go index 26a9f35e..db2948e7 100644 --- a/storage/datastore/datastore.go +++ b/storage/datastore/datastore.go @@ -2,12 +2,12 @@ package datastore import ( "fmt" - "github.com/getsentry/sentry-go" - "github.com/turt2live/matrix-media-repo/common" "io" + "github.com/getsentry/sentry-go" "github.com/pkg/errors" "github.com/sirupsen/logrus" + "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/storage" @@ -56,6 +56,24 @@ func DownloadStream(ctx rcontext.RequestContext, datastoreId string, location st return ref.DownloadFile(location) } +func GetDownloadURL(ctx rcontext.RequestContext, datastoreId string, location string, filename string) (string, error) { + ref, err := LocateDatastore(ctx, datastoreId) + if err != nil { + return "", err + } + + return ref.GetDownloadURL(location, filename) +} + +func ShouldRedirectDownload(ctx rcontext.RequestContext, datastoreId string) bool { + ref, err := LocateDatastore(ctx, datastoreId) + if err != nil { + return false + } + + return ref.ShouldRedirectDownload() +} + func GetDatastoreConfig(ds *types.Datastore) (config.DatastoreConfig, error) { for _, dsConf := range config.UniqueDatastores() { if dsConf.Type == ds.Type && GetUriForDatastore(dsConf) == ds.Uri { diff --git a/storage/datastore/datastore_ref.go b/storage/datastore/datastore_ref.go index f3b8ee88..a831a5e7 100644 --- a/storage/datastore/datastore_ref.go +++ b/storage/datastore/datastore_ref.go @@ -5,6 +5,7 @@ import ( "io" "os" "path" + "strconv" "github.com/sirupsen/logrus" config2 "github.com/turt2live/matrix-media-repo/common/config" @@ -88,6 +89,19 @@ func (d *DatastoreRef) DownloadFile(location string) (io.ReadCloser, error) { } } +func (d *DatastoreRef) GetDownloadURL(location string, filename string) (string, error) { + if d.Type != "s3" { + logrus.Error("attempting to get an download URL but datasource is of type ", d.Type) + return "", errors.New("download URLs unsupported for non-s3 datastores") + } + + s3, err := ds_s3.GetOrCreateS3Datastore(d.DatastoreId, d.config) + if err != nil { + return "", err + } + return s3.GetDownloadURL(location, filename) +} + func (d *DatastoreRef) ObjectExists(location string) bool { if d.Type == "file" { ok, err := util.FileExists(path.Join(d.Uri, location)) @@ -128,3 +142,13 @@ func (d *DatastoreRef) OverwriteObject(location string, stream io.ReadCloser, ct return errors.New("unknown datastore type") } } + +func (d *DatastoreRef) ShouldRedirectDownload() bool { + if d.Type != "s3" { + return false + } + + redirectDownloads, _ := strconv.ParseBool(d.config.Options["redirectDownloads"]) + return redirectDownloads + +} diff --git a/storage/datastore/ds_s3/s3_store.go b/storage/datastore/ds_s3/s3_store.go index 530eb99d..0c6c77eb 100644 --- a/storage/datastore/ds_s3/s3_store.go +++ b/storage/datastore/ds_s3/s3_store.go @@ -4,9 +4,11 @@ import ( "fmt" "io" "io/ioutil" + "net/url" "os" "strconv" "strings" + "time" "github.com/minio/minio-go/v6" "github.com/pkg/errors" @@ -230,6 +232,20 @@ func (s *s3Datastore) DownloadObject(location string) (io.ReadCloser, error) { return s.client.GetObject(s.bucket, location, minio.GetObjectOptions{}) } +func (s *s3Datastore) GetDownloadURL(location string, filename string) (string, error) { + logrus.Info("getting pre-signed download URL for object from bucket ", s.bucket, ": ", location) + + reqParams := make(url.Values) + reqParams.Set("response-content-disposition", fmt.Sprintf("attachment; filename=\"%s\"", filename)) + + u, err := s.client.PresignedGetObject(s.bucket, location, time.Minute*5, reqParams) + if err != nil { + return "", err + } + + return u.String(), nil +} + func (s *s3Datastore) ObjectExists(location string) bool { stat, err := s.client.StatObject(s.bucket, location, minio.StatObjectOptions{}) if err != nil { diff --git a/types/media.go b/types/media.go index 492f0b51..a69de8b4 100644 --- a/types/media.go +++ b/types/media.go @@ -40,6 +40,7 @@ type MinimalMedia struct { ContentType string SizeBytes int64 KnownMedia *Media + URL string } type MinimalMediaMetadata struct {