From 1faf04f0a52487c163a97f32988b9d56154b5cd3 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 1 Sep 2022 12:40:52 +0100 Subject: [PATCH] Pydantic for experimental `account_status` endpoint --- synapse/handlers/account.py | 4 ++-- synapse/rest/client/account.py | 20 ++++++++++++-------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py index c05a14304c1e..69d7d3180e90 100644 --- a/synapse/handlers/account.py +++ b/synapse/handlers/account.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, List, Tuple +from typing import TYPE_CHECKING, Dict, List, Sequence, Tuple from synapse.api.errors import Codes, SynapseError from synapse.types import JsonDict, UserID @@ -33,7 +33,7 @@ def __init__(self, hs: "HomeServer"): async def get_account_statuses( self, - user_ids: List[str], + user_ids: Sequence[str], allow_remote: bool, ) -> Tuple[JsonDict, List[str]]: """Get account statuses for a list of user IDs. diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 8e858eb1da4b..36716122081a 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -15,10 +15,10 @@ # limitations under the License. import logging import random -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Optional, Sequence, Tuple from urllib.parse import urlparse -from pydantic import StrictBool, StrictStr, constr +from pydantic import StrictBool, StrictStr, conlist, constr from twisted.web.server import Request @@ -842,17 +842,21 @@ def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() self._account_handler = hs.get_account_handler() + class PostBody(RequestBodyModel): + # TODO: we could validate that each user id is an mxid here, and/or parse it + # as a UserID + if TYPE_CHECKING: + user_ids: Sequence[StrictStr] + else: + user_ids: conlist(item_type=StrictStr, min_items=1) + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await self._auth.get_user_by_req(request) - body = parse_json_object_from_request(request) - if "user_ids" not in body: - raise SynapseError( - 400, "Required parameter 'user_ids' is missing", Codes.MISSING_PARAM - ) + body = parse_and_validate_json_object_from_request(request, self.PostBody) statuses, failures = await self._account_handler.get_account_statuses( - body["user_ids"], + body.user_ids, allow_remote=True, )