diff --git a/controllers/upload_controller/upload_controller.go b/controllers/upload_controller/upload_controller.go index 2a0d7997..731b6889 100644 --- a/controllers/upload_controller/upload_controller.go +++ b/controllers/upload_controller/upload_controller.go @@ -2,26 +2,23 @@ package upload_controller import ( "database/sql" - "fmt" - "github.com/getsentry/sentry-go" "io" "io/ioutil" "strconv" "time" + "github.com/getsentry/sentry-go" "github.com/patrickmn/go-cache" "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/internal_cache" - "github.com/turt2live/matrix-media-repo/plugins" "github.com/turt2live/matrix-media-repo/storage" "github.com/turt2live/matrix-media-repo/storage/datastore" "github.com/turt2live/matrix-media-repo/types" "github.com/turt2live/matrix-media-repo/util" "github.com/turt2live/matrix-media-repo/util/cleanup" - "github.com/turt2live/matrix-media-repo/util/util_byte_seeker" ) const NoApplicableUploadUser = "" @@ -159,21 +156,14 @@ func UploadMedia(contents io.ReadCloser, contentLength int64, contentType string data = contents } - dataBytes, err := ioutil.ReadAll(data) - if err != nil { - return nil, err - } - var mediaId string - var ds *datastore.DatastoreRef if asyncMediaId == "" { - media, newDs, err := CreateMedia(origin, ctx) + media, _, err := CreateMedia(origin, ctx) if err != nil { return nil, err } mediaId = media.MediaId - ds = newDs } else { db := storage.GetDatabase().GetMediaStore(ctx) @@ -199,42 +189,20 @@ func UploadMedia(contents io.ReadCloser, contentLength int64, contentType string } mediaId = asyncMediaId - ds, err = datastore.LocateDatastore(ctx, media.DatastoreId) - if err != nil { - return nil, err - } - } - - var existingFile *AlreadyUploadedFile = nil - if ds.Type == "ipfs" { - // Do the upload now so we can pick the media ID to point to IPFS - info, err := ds.UploadFile(util_byte_seeker.NewByteSeeker(dataBytes), contentLength, ctx) - if err != nil { - return nil, err - } - existingFile = &AlreadyUploadedFile{ - DS: ds, - ObjectInfo: info, - } - mediaId = fmt.Sprintf("ipfs:%s", info.Location[len("ipfs/"):]) } - m, err := StoreDirect(existingFile, util_byte_seeker.NewByteSeeker(dataBytes), contentLength, contentType, filename, userId, origin, mediaId, common.KindLocalMedia, ctx, asyncMediaId == "") + m, err := StoreDirect(nil, data, contentLength, contentType, filename, userId, origin, mediaId, common.KindLocalMedia, ctx, asyncMediaId == "") if err != nil { return m, err } if m != nil { util.NotifyUpload(origin, mediaId) - - cache := internal_cache.Get() - if err := cache.UploadMedia(m.Sha256Hash, util_byte_seeker.NewByteSeeker(dataBytes), ctx); err != nil { - ctx.Log.Warn("Unexpected error trying to cache media: " + err.Error()) - } if asyncMediaId != "" { - if err := cache.NotifyUpload(origin, mediaId, ctx); err != nil { + if err := internal_cache.Get().NotifyUpload(origin, mediaId, ctx); err != nil { ctx.Log.Warn("Unexpected error trying to notify cache about media: " + err.Error()) } } + } return m, err } @@ -246,23 +214,10 @@ func trackUploadAsLastAccess(ctx rcontext.RequestContext, media *types.Media) { } } -func checkSpam(contents []byte, filename string, contentType string, userId string, origin string, mediaId string) error { - spam, err := plugins.CheckForSpam(contents, filename, contentType, userId, origin, mediaId) - if err != nil { - logrus.Warn("Error checking spam - assuming not spam: " + err.Error()) - sentry.CaptureException(err) - return nil - } - if spam { - return common.ErrMediaQuarantined - } - return nil -} - func StoreDirect(f *AlreadyUploadedFile, contents io.ReadCloser, expectedSize int64, contentType string, filename string, userId string, origin string, mediaId string, kind string, ctx rcontext.RequestContext, filterUserDuplicates bool) (ret *types.Media, err error) { var ds *datastore.DatastoreRef var info *types.ObjectInfo - var contentBytes []byte + if f == nil { dsPicked, err := datastore.PickDatastore(kind, ctx) if err != nil { @@ -270,12 +225,7 @@ func StoreDirect(f *AlreadyUploadedFile, contents io.ReadCloser, expectedSize in } ds = dsPicked - contentBytes, err = ioutil.ReadAll(contents) - if err != nil { - return nil, err - } - - fInfo, err := ds.UploadFile(util.BytesToStream(contentBytes), expectedSize, ctx) + fInfo, err := ds.UploadFile(contents, expectedSize, ctx) if err != nil { return nil, err } @@ -283,16 +233,6 @@ func StoreDirect(f *AlreadyUploadedFile, contents io.ReadCloser, expectedSize in } else { ds = f.DS info = f.ObjectInfo - - // download the contents for antispam - contents, err = ds.DownloadFile(info.Location) - if err != nil { - return nil, err - } - contentBytes, err = ioutil.ReadAll(contents) - if err != nil { - return nil, err - } } defer func() { // always delete temp object if we return an error @@ -328,11 +268,6 @@ func StoreDirect(f *AlreadyUploadedFile, contents io.ReadCloser, expectedSize in } } - err = checkSpam(contentBytes, filename, contentType, userId, origin, mediaId) - if err != nil { - return nil, err - } - // We'll use the location from the first record record := records[0] if record.Quarantined { @@ -414,11 +349,6 @@ func StoreDirect(f *AlreadyUploadedFile, contents io.ReadCloser, expectedSize in return nil, errors.New("file has no contents") } - err = checkSpam(contentBytes, filename, contentType, userId, origin, mediaId) - if err != nil { - return nil, err - } - // Check if we have reserved the metadata already, validate uploader media, err := db.Get(origin, mediaId) if err == sql.ErrNoRows { diff --git a/storage/datastore/ds_s3/s3_store.go b/storage/datastore/ds_s3/s3_store.go index 26b19629..530eb99d 100644 --- a/storage/datastore/ds_s3/s3_store.go +++ b/storage/datastore/ds_s3/s3_store.go @@ -21,12 +21,12 @@ import ( var stores = make(map[string]*s3Datastore) type s3Datastore struct { - conf config.DatastoreConfig - dsId string - client *minio.Client - bucket string - region string - tempPath string + conf config.DatastoreConfig + dsId string + client *minio.Client + bucket string + region string + tempPath string storageClass string prefixLength int } @@ -78,12 +78,12 @@ func GetOrCreateS3Datastore(dsId string, conf config.DatastoreConfig) (*s3Datast } s3ds := &s3Datastore{ - conf: conf, - dsId: dsId, - client: s3client, - bucket: bucket, - region: region, - tempPath: tempPath, + conf: conf, + dsId: dsId, + client: s3client, + bucket: bucket, + region: region, + tempPath: tempPath, storageClass: storageClass, prefixLength: prefixLength, }