From 46ed200c2ba27aac812fd6476f250857db80e95f Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 9 Jun 2023 17:20:25 -0600 Subject: [PATCH] media/create: enforce limit on number of pending uploads Signed-off-by: Sumner Evans --- synapse/media/media_repository.py | 17 +++++++++++++ synapse/rest/media/create_resource.py | 14 +++++++++++ .../databases/main/media_repository.py | 25 +++++++++++++++++++ 3 files changed, 56 insertions(+) diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 0682709596eb..d281b089ce9e 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -197,6 +197,23 @@ async def create_media_id(self, auth_user: UserID) -> Tuple[str, int]: ) return f"mxc://{self.server_name}/{media_id}", unused_expires_at + async def reached_pending_media_limit( + self, auth_user: UserID, limit: int + ) -> Tuple[bool, int]: + """Check if the user is over the limit for pending media uploads. + Args: + auth_user: The user_id of the uploader + limit: The maximum number of pending media uploads a user is allowed to have + Returns: + A tuple with a boolean and an integer indicating whether the user has too + many pending media uploads and the timestamp at which the first pending + media will expire, respectively. + """ + pending, first_expiration_ts = await self.store.count_pending_media( + user_id=auth_user + ) + return pending >= limit, first_expiration_ts + async def verify_can_upload(self, media_id: str, auth_user: UserID) -> None: """Verify that the media ID can be uploaded to by the given user. This function checks that: diff --git a/synapse/rest/media/create_resource.py b/synapse/rest/media/create_resource.py index a6a2120c5c8a..1e9ae60579d4 100644 --- a/synapse/rest/media/create_resource.py +++ b/synapse/rest/media/create_resource.py @@ -36,6 +36,7 @@ def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): self.media_repo = media_repo self.clock = hs.get_clock() self.auth = hs.get_auth() + self.max_pending_media_uploads = hs.config.media.max_pending_media_uploads # A rate limiter for creating new media IDs. self._create_media_rate_limiter = Ratelimiter( @@ -62,6 +63,19 @@ async def _async_render_POST(self, request: SynapseRequest) -> None: retry_after_ms=int(1000 * (time_allowed - time_now_s)) ) + ( + reached_pending_limit, + first_expiration_ts, + ) = await self.media_repo.reached_pending_media_limit( + requester.user, self.max_pending_media_uploads + ) + if reached_pending_limit: + raise LimitExceededError( + msg="You have too many uploads pending. Please finish uploading your " + "existing files.", + retry_after_ms=first_expiration_ts - self.clock.time_msec(), + ) + content_uri, unused_expires_at = await self.media_repo.create_media_id( requester.user ) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 42c50762db9b..83d88d8dcf65 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -27,6 +27,7 @@ ) from synapse.api.constants import Direction +from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -405,6 +406,30 @@ async def mark_local_media_as_safe(self, media_id: str, safe: bool = True) -> No desc="mark_local_media_as_safe", ) + async def count_pending_media(self, user_id: UserID) -> Tuple[int, int]: + """Count the number of pending media for a user. + Returns: + A tuple of two integers: the total pending media requests and the earliest + expiration timestamp. + """ + + def get_pending_media_txn(txn: LoggingTransaction) -> Tuple[int, int]: + sql = ( + "SELECT COUNT(*), MIN(unused_expires_at)" + " FROM local_media_repository" + " WHERE user_id = ?" + " AND quarantined_by IS NULL" + " AND unused_expires_at > ?" + " AND media_length IS NULL" + ) + txn.execute(sql, (user_id.to_string(), self._clock.time_msec())) + row = txn.fetchone() + if not row: + raise StoreError(404, "Failed to count pending media for user") + return row[0], row[1] or 0 + + return await self.db_pool.runInteraction("get_url_cache", get_pending_media_txn) + async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]: """Get the media_id and ts for a cached URL as of the given timestamp Returns: