diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index bdda2ae4331c..66df4d28d623 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -26,6 +26,7 @@ from twisted.web.resource import Resource from synapse.api.errors import ( + Codes, FederationDeniedError, HttpResponseException, NotFoundError, @@ -173,6 +174,77 @@ async def create_media_id(self, auth_user: UserID) -> Tuple[str, int]: ) return f"mxc://{self.server_name}/{media_id}", unused_expires_at + async def verify_can_upload(self, media_id: str, auth_user: UserID) -> None: + """Verify that the media ID can be uploaded to by the given user. This + function checks that: + + * the media ID exists + * the media ID does not already have content + * the user uploading is the same as the one who created the media ID + * the media ID has not expired + + Args: + media_id: The media ID to verify + auth_user: The user_id of the uploader + """ + media = await self.store.get_local_media(media_id) + if media is None: + raise SynapseError(404, "Unknow media ID", errcode=Codes.NOT_FOUND) + + if media["user_id"] != str(auth_user): + raise SynapseError( + 403, + "Only the creator of the media ID can upload to it", + errcode=Codes.FORBIDDEN, + ) + + if media.get("media_length") is not None: + raise SynapseError( + 409, + "Media ID already has content", + errcode="FI.MAU.MSC2246_CANNOT_OVERWRITE_MEDIA", + ) + + if media.get("unused_expires_at", 0) < self.clock.time_msec(): + raise SynapseError( + 409, + "Media ID has expired", + errcode="FI.MAU.MSC2246_CANNOT_OVERWRITE_MEDIA", + ) + + async def update_content( + self, + media_id: str, + media_type: str, + upload_name: Optional[str], + content: IO, + content_length: int, + auth_user: UserID, + ) -> None: + """Update the content of the given media ID. + + Args: + media_id: The media ID to replace. + media_type: The content type of the file. + upload_name: The name of the file, if provided. + content: A file like object that is the content to store + content_length: The length of the content + auth_user: The user_id of the uploader + """ + file_info = FileInfo(server_name=None, file_id=media_id) + fname = await self.media_storage.store_file(content, file_info) + logger.info("Stored local media in file %r", fname) + + await self.store.update_local_media( + media_id=media_id, + media_type=media_type, + upload_name=upload_name, + media_length=content_length, + user_id=auth_user, + ) + + await self._generate_thumbnails(None, media_id, media_id, media_type) + async def create_content( self, media_type: str, @@ -991,6 +1063,7 @@ def __init__(self, hs: "HomeServer"): media_repo = hs.get_media_repository() self.putChild(b"create", CreateResource(hs, media_repo)) + self.putChild(b"upload", UploadResource(hs, media_repo, True)) class VersionedMediaRepositoryResource(Resource): @@ -1046,7 +1119,7 @@ def __init__(self, hs: "HomeServer", version: MediaVersion): super().__init__() media_repo = hs.get_media_repository() - self.putChild(b"upload", UploadResource(hs, media_repo)) + self.putChild(b"upload", UploadResource(hs, media_repo, False)) self.putChild(b"download", DownloadResource(hs, media_repo)) self.putChild( b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage) diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index dccada05a41a..bea63e521285 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import IO, TYPE_CHECKING, Dict, List, Optional +from typing import IO, TYPE_CHECKING, Dict, List, Optional, Tuple from synapse.api.errors import Codes, SynapseError from synapse.http.server import DirectServeJsonResource, respond_with_json @@ -22,32 +22,40 @@ from synapse.http.site import SynapseRequest from synapse.rest.media.v1.media_storage import SpamMediaException +from ._base import parse_media_id + if TYPE_CHECKING: from synapse.rest.media.v1.media_repository import MediaRepository from synapse.server import HomeServer logger = logging.getLogger(__name__) +# The name of the lock to use when uploading media. +_UPLOAD_MEDIA_LOCK_NAME = "upload_media" + class UploadResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): + def __init__( + self, + hs: "HomeServer", + media_repo: "MediaRepository", + enable_async_uploads: bool, + ): super().__init__() + self.enable_async_uploads = enable_async_uploads self.media_repo = media_repo self.filepaths = media_repo.filepaths self.store = hs.get_datastores().main - self.clock = hs.get_clock() self.server_name = hs.hostname self.auth = hs.get_auth() self.max_upload_size = hs.config.media.max_upload_size - async def _async_render_OPTIONS(self, request: SynapseRequest) -> None: - respond_with_json(request, 200, {}, send_cors=True) - - async def _async_render_POST(self, request: SynapseRequest) -> None: - requester = await self.auth.get_user_by_req(request) + def _get_file_metadata( + self, request: SynapseRequest + ) -> Tuple[int, Optional[str], str]: raw_content_length = request.getHeader("Content-Length") if raw_content_length is None: raise SynapseError(msg="Request must specify a Content-Length", code=400) @@ -90,6 +98,15 @@ async def _async_render_POST(self, request: SynapseRequest) -> None: # disposition = headers.getRawHeaders(b"Content-Disposition")[0] # TODO(markjh): parse content-dispostion + return content_length, upload_name, media_type + + async def _async_render_OPTIONS(self, request: SynapseRequest) -> None: + respond_with_json(request, 200, {}, send_cors=True) + + async def _async_render_POST(self, request: SynapseRequest) -> None: + requester = await self.auth.get_user_by_req(request) + content_length, upload_name, media_type = self._get_file_metadata(request) + try: content: IO = request.content # type: ignore content_uri = await self.media_repo.create_content( @@ -103,3 +120,51 @@ async def _async_render_POST(self, request: SynapseRequest) -> None: logger.info("Uploaded content with URI %r", content_uri) respond_with_json(request, 200, {"content_uri": content_uri}, send_cors=True) + + async def _async_render_PUT(self, request: SynapseRequest) -> None: + if not self.enable_async_uploads: + raise SynapseError( + 405, + "Asynchronous uploads are not enabled on this homeserver", + errcode=Codes.UNRECOGNIZED, + ) + + requester = await self.auth.get_user_by_req(request) + server_name, media_id, _ = parse_media_id(request) + + if server_name != self.server_name: + raise SynapseError( + 404, + "Non-local server name specified", + errcode=Codes.NOT_FOUND, + ) + + lock = await self.store.try_acquire_lock(_UPLOAD_MEDIA_LOCK_NAME, media_id) + if not lock: + raise SynapseError( + 409, + "Media ID is is locked and cannot be uploaded to", + errcode="FI.MAU.MSC2246_CANNOT_OVERWRITE_MEDIA", + ) + + async with lock: + await self.media_repo.verify_can_upload(media_id, requester.user) + content_length, upload_name, media_type = self._get_file_metadata(request) + + try: + content: IO = request.content # type: ignore + await self.media_repo.update_content( + media_id, + media_type, + upload_name, + content, + content_length, + requester.user, + ) + except SpamMediaException: + # For uploading of media we want to respond with a 400, instead of + # the default 404, as that would just be confusing. + raise SynapseError(400, "Bad content") + + logger.info("Uploaded content to URI %r", media_id) + respond_with_json(request, 200, {}, send_cors=True) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 13869a1fbedb..62fb04874b43 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -175,6 +175,7 @@ async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: "quarantined_by", "url_cache", "safe_from_quarantine", + "user_id", ), allow_none=True, desc="get_local_media", @@ -348,6 +349,30 @@ async def store_local_media( desc="store_local_media", ) + async def update_local_media( + self, + media_id: str, + media_type: str, + upload_name: Optional[str], + media_length: int, + user_id: UserID, + url_cache: Optional[str] = None, + ) -> None: + await self.db_pool.simple_update_one( + "local_media_repository", + keyvalues={ + "user_id": user_id.to_string(), + "media_id": media_id, + }, + updatevalues={ + "media_type": media_type, + "upload_name": upload_name, + "media_length": media_length, + "url_cache": url_cache, + }, + desc="update_local_media", + ) + async def mark_local_media_as_safe(self, media_id: str, safe: bool = True) -> None: """Mark a local media as safe or unsafe from quarantining.""" await self.db_pool.simple_update_one(