Skip to content
This repository has been archived by the owner on Oct 14, 2022. It is now read-only.

Media references table/API #3

Draft
wants to merge 2 commits into
base: beeper
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions api/r0/references.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package r0

import (
"encoding/json"
"net/http"

"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/api"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/controllers/upload_controller"
)

type AddMediaReferenceBody struct {
RoomID string `json:"room_id"`
}

func AddMediaReference(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} {
params := mux.Vars(r)
server := params["server"]
mediaID := params["mediaId"]
rctx = rctx.LogWithFields(logrus.Fields{
"server": server,
"mediaId": mediaID,
})

defer r.Body.Close()
body := AddMediaReferenceBody{}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
return api.BadRequest("error parsing request body as json")
}

rctx.Log.Info("media referenced in room ", body.RoomID)

if err := upload_controller.AddMediaReference(server, mediaID, body.RoomID, rctx); err != nil {
rctx.Log.Error("error storing room reference for media upload: ", err)
return api.InternalServerError("unexpected error")
}
return api.EmptyResponse{}
}
17 changes: 16 additions & 1 deletion api/r0/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ func CreateMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInf
return api.InternalServerError("Unexpected Error")
}

roomID := r.URL.Query().Get("room_id")
if roomID != "" {
if err = upload_controller.AddMediaReference(media.Origin, media.MediaId, roomID, rctx); err != nil {
rctx.Log.Error("error storing room reference for media upload: ", err)
return api.InternalServerError("Unexpected Error")
}
}

return &MediaCreatedResponse{
ContentUri: media.MxcUri(),
UnusedExpiresAt: time.Now().Unix() + int64(rctx.Config.Features.MSC2246Async.AsyncUploadExpirySecs),
Expand Down Expand Up @@ -210,7 +218,6 @@ func UploadMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInf
}

contentLength := upload_controller.EstimateContentLength(r.ContentLength, r.Header.Get("Content-Length"))

media, err := upload_controller.UploadMedia(r.Body, contentLength, contentType, filename, user.UserId, r.Host, mediaId, rctx)
if err != nil {
io.Copy(ioutil.Discard, r.Body) // Ditch the entire request
Expand Down Expand Up @@ -241,6 +248,14 @@ func UploadMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInf
}
}

roomID := r.URL.Query().Get("room_id")
if roomID != "" {
if err = upload_controller.AddMediaReference(media.Origin, media.MediaId, roomID, rctx); err != nil {
rctx.Log.Error("error storing room reference for media upload: ", err)
return api.InternalServerError("unexpected error")
}
}

return &MediaUploadedResponse{
ContentUri: media.MxcUri(),
}
Expand Down
2 changes: 1 addition & 1 deletion api/unstable/local_copy.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package unstable

import (
"github.com/getsentry/sentry-go"
"net/http"
"strconv"

"github.com/getsentry/sentry-go"
"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/api"
Expand Down
2 changes: 2 additions & 0 deletions api/webserver/webserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ func Init() *sync.WaitGroup {
logoutAllHandler := handler{api.AccessTokenRequiredRoute(r0.LogoutAll), "logout_all", counter, false}
getMediaAttrsHandler := handler{api.AccessTokenRequiredRoute(custom.GetAttributes), "get_media_attributes", counter, false}
setMediaAttrsHandler := handler{api.AccessTokenRequiredRoute(custom.SetAttributes), "set_media_attributes", counter, false}
addMediaReferenceHandler := handler{api.AccessTokenRequiredRoute(r0.AddMediaReference), "add_media_reference", counter, false}

routes := make([]definedRoute, 0)
// r0 is typically clients and v1 is typically servers. v1 is deprecated.
Expand Down Expand Up @@ -148,6 +149,7 @@ func Init() *sync.WaitGroup {
routes = append(routes, definedRoute{"/_matrix/media/" + version + "/admin/import/{importId:[a-zA-Z0-9.:\\-_]+}/close", route{"POST", stopImportHandler}})
routes = append(routes, definedRoute{"/_matrix/media/" + version + "/admin/media/{server:[a-zA-Z0-9.:\\-_]+}/{mediaId:[^/]+}/attributes", route{"GET", getMediaAttrsHandler}})
routes = append(routes, definedRoute{"/_matrix/media/" + version + "/admin/media/{server:[a-zA-Z0-9.:\\-_]+}/{mediaId:[^/]+}/attributes/set", route{"POST", setMediaAttrsHandler}})
routes = append(routes, definedRoute{"/_matrix/media/" + version + "/reference/{server:[a-zA-Z0-9.:\\-_]+}/{mediaId:[^/]+}", route{"POST", addMediaReferenceHandler}})

// Routes that we should handle but aren't in the media namespace (synapse compat)
routes = append(routes, definedRoute{"/_matrix/client/" + version + "/admin/purge_media_cache", route{"POST", purgeRemote}})
Expand Down
1 change: 1 addition & 0 deletions common/config/models_domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type UrlPreviewsConfig struct {
OEmbed bool `yaml:"oEmbed"`
Proxy string `yaml:"proxy"`
MetricsDomains []string `yaml:"metricsDomains"`
DisableTunny bool `yaml:"disableTunny"`
}

type IdenticonsConfig struct {
Expand Down
24 changes: 20 additions & 4 deletions controllers/preview_controller/preview_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/turt2live/matrix-media-repo/storage/stores"
"github.com/turt2live/matrix-media-repo/types"
"github.com/turt2live/matrix-media-repo/util"
"github.com/turt2live/matrix-media-repo/util/resource_handler"
)

func GetPreview(urlStr string, onHost string, forUserId string, atTs int64, languageHeader string, ctx rcontext.RequestContext) (*types.UrlPreview, error) {
Expand Down Expand Up @@ -79,10 +80,25 @@ func GetPreview(urlStr string, onHost string, forUserId string, atTs int64, lang
break
}

previewChan := getResourceHandler().GeneratePreview(urlToPreview, forUserId, onHost, languageHeader, ctx.Config.UrlPreviews.OEmbed)
defer close(previewChan)

result := <-previewChan
var result *urlPreviewResponse
if ctx.Config.UrlPreviews.DisableTunny {
ctx.Log.Info("running preview generation in http request goroutine")
result = urlPreviewWorkFn(&resource_handler.WorkRequest{
Id: fmt.Sprintf("preview_%s", urlToPreview.UrlString),
Metadata: &urlPreviewRequest{
urlPayload: urlToPreview,
forUserId: forUserId,
onHost: onHost,
languageHeader: languageHeader,
allowOEmbed: ctx.Config.UrlPreviews.OEmbed,
},
})
} else {
ctx.Log.Info("scheduling preview generation with Tunny")
previewChan := getResourceHandler().GeneratePreview(urlToPreview, forUserId, onHost, languageHeader, ctx.Config.UrlPreviews.OEmbed)
defer close(previewChan)
result = <-previewChan
}
return result.preview, result.err
})

Expand Down
2 changes: 1 addition & 1 deletion controllers/preview_controller/preview_resource_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ package preview_controller

import (
"fmt"
"github.com/getsentry/sentry-go"
"sync"

"github.com/disintegration/imaging"
"github.com/getsentry/sentry-go"
"github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/common/config"
Expand Down
5 changes: 5 additions & 0 deletions controllers/upload_controller/upload_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,8 @@ func StoreDirect(f *AlreadyUploadedFile, contents io.ReadCloser, expectedSize in
trackUploadAsLastAccess(ctx, media)
return media, nil
}

func AddMediaReference(origin string, mediaID string, roomID string, ctx rcontext.RequestContext) error {
db := storage.GetDatabase().GetMediaStore(ctx)
return db.InsertMediaReference(origin, mediaID, roomID)
}
2 changes: 2 additions & 0 deletions migrations/19_create_media_references_table_down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
DROP INDEX media_references_index;
DROP TABLE media_references;
7 changes: 7 additions & 0 deletions migrations/19_create_media_references_table_up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
CREATE TABLE IF NOT EXISTS media_references (
media_id TEXT NOT NULL,
origin TEXT NOT NULL,
room_id TEXT NOT NULL,
FOREIGN KEY (media_id, origin) REFERENCES media (media_id, origin)
);
CREATE UNIQUE INDEX IF NOT EXISTS media_references_index ON media_references (media_id, origin, room_id);
15 changes: 15 additions & 0 deletions storage/stores/media_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ const selectMediaByUserBefore = "SELECT origin, media_id, upload_name, content_t
const selectMediaByDomainBefore = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE origin = $1 AND creation_ts <= $2"
const selectMediaByLocation = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE datastore_id = $1 AND location = $2"
const selectIfQuarantined = "SELECT 1 FROM media WHERE sha256_hash = $1 AND quarantined = $2 LIMIT 1;"
const insertMediaReference = "INSERT INTO media_references (origin, media_id, room_id) values ($1, $2, $3)"

var dsCacheByPath = sync.Map{} // [string] => Datastore
var dsCacheById = sync.Map{} // [string] => Datastore
Expand Down Expand Up @@ -63,6 +64,7 @@ type mediaStoreStatements struct {
selectMediaByDomainBefore *sql.Stmt
selectMediaByLocation *sql.Stmt
selectIfQuarantined *sql.Stmt
insertMediaReference *sql.Stmt
}

type MediaStoreFactory struct {
Expand Down Expand Up @@ -154,6 +156,9 @@ func InitMediaStore(sqlDb *sql.DB) (*MediaStoreFactory, error) {
if store.stmts.selectIfQuarantined, err = store.sqlDb.Prepare(selectIfQuarantined); err != nil {
return nil, err
}
if store.stmts.insertMediaReference, err = store.sqlDb.Prepare(insertMediaReference); err != nil {
return nil, err
}

return &store, nil
}
Expand Down Expand Up @@ -740,3 +745,13 @@ func (s *MediaStore) IsQuarantined(sha256hash string) (bool, error) {
}
return true, nil
}

func (s *MediaStore) InsertMediaReference(origin string, mediaID string, roomID string) error {
_, err := s.statements.insertMediaReference.ExecContext(
s.ctx,
origin,
mediaID,
roomID,
)
return err
}