diff --git a/changelog.d/10080.misc b/changelog.d/10080.misc new file mode 100644 index 000000000000..9adb0fbd02d3 --- /dev/null +++ b/changelog.d/10080.misc @@ -0,0 +1 @@ +Add type hints to the federation servlets. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index ace30aa45078..86562cd04f28 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -129,7 +129,7 @@ def __init__(self, hs: "HomeServer"): # come in waves. self._state_resp_cache = ResponseCache( hs.get_clock(), "state_resp", timeout_ms=30000 - ) # type: ResponseCache[Tuple[str, str]] + ) # type: ResponseCache[Tuple[str, Optional[str]]] self._state_ids_resp_cache = ResponseCache( hs.get_clock(), "state_ids_resp", timeout_ms=30000 ) # type: ResponseCache[Tuple[str, str]] @@ -406,7 +406,7 @@ async def _process_edu(edu_dict): ) async def on_room_state_request( - self, origin: str, room_id: str, event_id: str + self, origin: str, room_id: str, event_id: Optional[str] ) -> Tuple[int, Dict[str, Any]]: origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) @@ -463,7 +463,7 @@ async def _on_state_ids_request_compute(self, room_id, event_id): return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} async def _on_context_state_request_compute( - self, room_id: str, event_id: str + self, room_id: str, event_id: Optional[str] ) -> Dict[str, list]: if event_id: pdus = await self.handler.get_state_for_pdu( diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 5756fcb551ab..4bc7d2015b73 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -28,6 +28,7 @@ FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX, ) +from synapse.handlers.groups_local import GroupsLocalHandler from synapse.http.server import HttpServer, JsonResource from synapse.http.servlet import ( parse_boolean_from_args, @@ -275,10 +276,17 @@ class BaseFederationServlet: RATELIMIT = True # Whether to rate limit requests or not - def __init__(self, handler, authenticator, ratelimiter, server_name): - self.handler = handler + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + self.hs = hs self.authenticator = authenticator self.ratelimiter = ratelimiter + self.server_name = server_name def _wrap(self, func): authenticator = self.authenticator @@ -375,17 +383,30 @@ def register(self, server): ) -class FederationSendServlet(BaseFederationServlet): +class BaseFederationServerServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a federation server handler. + + See BaseFederationServlet for more information. + """ + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_federation_server() + + +class FederationSendServlet(BaseFederationServerServlet): PATH = "/send/(?P[^/]*)/?" # We ratelimit manually in the handler as we queue up the requests and we # don't want to fill up the ratelimiter with blocked requests. RATELIMIT = False - def __init__(self, handler, server_name, **kwargs): - super().__init__(handler, server_name=server_name, **kwargs) - self.server_name = server_name - # This is when someone is trying to send us a bunch of data. async def on_PUT(self, origin, content, query, transaction_id): """Called on PUT /send// @@ -434,7 +455,7 @@ async def on_PUT(self, origin, content, query, transaction_id): return code, response -class FederationEventServlet(BaseFederationServlet): +class FederationEventServlet(BaseFederationServerServlet): PATH = "/event/(?P[^/]*)/?" # This is when someone asks for a data item for a given server data_id pair. @@ -442,7 +463,7 @@ async def on_GET(self, origin, content, query, event_id): return await self.handler.on_pdu_request(origin, event_id) -class FederationStateV1Servlet(BaseFederationServlet): +class FederationStateV1Servlet(BaseFederationServerServlet): PATH = "/state/(?P[^/]*)/?" # This is when someone asks for all data for a given room. @@ -454,7 +475,7 @@ async def on_GET(self, origin, content, query, room_id): ) -class FederationStateIdsServlet(BaseFederationServlet): +class FederationStateIdsServlet(BaseFederationServerServlet): PATH = "/state_ids/(?P[^/]*)/?" async def on_GET(self, origin, content, query, room_id): @@ -465,7 +486,7 @@ async def on_GET(self, origin, content, query, room_id): ) -class FederationBackfillServlet(BaseFederationServlet): +class FederationBackfillServlet(BaseFederationServerServlet): PATH = "/backfill/(?P[^/]*)/?" async def on_GET(self, origin, content, query, room_id): @@ -478,7 +499,7 @@ async def on_GET(self, origin, content, query, room_id): return await self.handler.on_backfill_request(origin, room_id, versions, limit) -class FederationQueryServlet(BaseFederationServlet): +class FederationQueryServlet(BaseFederationServerServlet): PATH = "/query/(?P[^/]*)" # This is when we receive a server-server Query @@ -488,7 +509,7 @@ async def on_GET(self, origin, content, query, query_type): return await self.handler.on_query_request(query_type, args) -class FederationMakeJoinServlet(BaseFederationServlet): +class FederationMakeJoinServlet(BaseFederationServerServlet): PATH = "/make_join/(?P[^/]*)/(?P[^/]*)" async def on_GET(self, origin, _content, query, room_id, user_id): @@ -518,7 +539,7 @@ async def on_GET(self, origin, _content, query, room_id, user_id): return 200, content -class FederationMakeLeaveServlet(BaseFederationServlet): +class FederationMakeLeaveServlet(BaseFederationServerServlet): PATH = "/make_leave/(?P[^/]*)/(?P[^/]*)" async def on_GET(self, origin, content, query, room_id, user_id): @@ -526,7 +547,7 @@ async def on_GET(self, origin, content, query, room_id, user_id): return 200, content -class FederationV1SendLeaveServlet(BaseFederationServlet): +class FederationV1SendLeaveServlet(BaseFederationServerServlet): PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): @@ -534,7 +555,7 @@ async def on_PUT(self, origin, content, query, room_id, event_id): return 200, (200, content) -class FederationV2SendLeaveServlet(BaseFederationServlet): +class FederationV2SendLeaveServlet(BaseFederationServerServlet): PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" PREFIX = FEDERATION_V2_PREFIX @@ -544,14 +565,14 @@ async def on_PUT(self, origin, content, query, room_id, event_id): return 200, content -class FederationEventAuthServlet(BaseFederationServlet): +class FederationEventAuthServlet(BaseFederationServerServlet): PATH = "/event_auth/(?P[^/]*)/(?P[^/]*)" async def on_GET(self, origin, content, query, room_id, event_id): return await self.handler.on_event_auth(origin, room_id, event_id) -class FederationV1SendJoinServlet(BaseFederationServlet): +class FederationV1SendJoinServlet(BaseFederationServerServlet): PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): @@ -561,7 +582,7 @@ async def on_PUT(self, origin, content, query, room_id, event_id): return 200, (200, content) -class FederationV2SendJoinServlet(BaseFederationServlet): +class FederationV2SendJoinServlet(BaseFederationServerServlet): PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" PREFIX = FEDERATION_V2_PREFIX @@ -573,7 +594,7 @@ async def on_PUT(self, origin, content, query, room_id, event_id): return 200, content -class FederationV1InviteServlet(BaseFederationServlet): +class FederationV1InviteServlet(BaseFederationServerServlet): PATH = "/invite/(?P[^/]*)/(?P[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): @@ -590,7 +611,7 @@ async def on_PUT(self, origin, content, query, room_id, event_id): return 200, (200, content) -class FederationV2InviteServlet(BaseFederationServlet): +class FederationV2InviteServlet(BaseFederationServerServlet): PATH = "/invite/(?P[^/]*)/(?P[^/]*)" PREFIX = FEDERATION_V2_PREFIX @@ -614,7 +635,7 @@ async def on_PUT(self, origin, content, query, room_id, event_id): return 200, content -class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet): +class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet): PATH = "/exchange_third_party_invite/(?P[^/]*)" async def on_PUT(self, origin, content, query, room_id): @@ -622,21 +643,21 @@ async def on_PUT(self, origin, content, query, room_id): return 200, {} -class FederationClientKeysQueryServlet(BaseFederationServlet): +class FederationClientKeysQueryServlet(BaseFederationServerServlet): PATH = "/user/keys/query" async def on_POST(self, origin, content, query): return await self.handler.on_query_client_keys(origin, content) -class FederationUserDevicesQueryServlet(BaseFederationServlet): +class FederationUserDevicesQueryServlet(BaseFederationServerServlet): PATH = "/user/devices/(?P[^/]*)" async def on_GET(self, origin, content, query, user_id): return await self.handler.on_query_user_devices(origin, user_id) -class FederationClientKeysClaimServlet(BaseFederationServlet): +class FederationClientKeysClaimServlet(BaseFederationServerServlet): PATH = "/user/keys/claim" async def on_POST(self, origin, content, query): @@ -644,7 +665,7 @@ async def on_POST(self, origin, content, query): return 200, response -class FederationGetMissingEventsServlet(BaseFederationServlet): +class FederationGetMissingEventsServlet(BaseFederationServerServlet): # TODO(paul): Why does this path alone end with "/?" optional? PATH = "/get_missing_events/(?P[^/]*)/?" @@ -664,7 +685,7 @@ async def on_POST(self, origin, content, query, room_id): return 200, content -class On3pidBindServlet(BaseFederationServlet): +class On3pidBindServlet(BaseFederationServerServlet): PATH = "/3pid/onbind" REQUIRE_AUTH = False @@ -694,7 +715,7 @@ async def on_POST(self, origin, content, query): return 200, {} -class OpenIdUserInfo(BaseFederationServlet): +class OpenIdUserInfo(BaseFederationServerServlet): """ Exchange a bearer token for information about a user. @@ -770,8 +791,16 @@ class PublicRoomList(BaseFederationServlet): PATH = "/publicRooms" - def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access): - super().__init__(handler, authenticator, ratelimiter, server_name) + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + allow_access: bool, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_room_list_handler() self.allow_access = allow_access async def on_GET(self, origin, content, query): @@ -856,7 +885,24 @@ async def on_GET(self, origin, content, query): ) -class FederationGroupsProfileServlet(BaseFederationServlet): +class BaseGroupsServerServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a groups server handler. + + See BaseFederationServlet for more information. + """ + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_server_handler() + + +class FederationGroupsProfileServlet(BaseGroupsServerServlet): """Get/set the basic profile of a group on behalf of a user""" PATH = "/groups/(?P[^/]*)/profile" @@ -882,7 +928,7 @@ async def on_POST(self, origin, content, query, group_id): return 200, new_content -class FederationGroupsSummaryServlet(BaseFederationServlet): +class FederationGroupsSummaryServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/summary" async def on_GET(self, origin, content, query, group_id): @@ -895,7 +941,7 @@ async def on_GET(self, origin, content, query, group_id): return 200, new_content -class FederationGroupsRoomsServlet(BaseFederationServlet): +class FederationGroupsRoomsServlet(BaseGroupsServerServlet): """Get the rooms in a group on behalf of a user""" PATH = "/groups/(?P[^/]*)/rooms" @@ -910,7 +956,7 @@ async def on_GET(self, origin, content, query, group_id): return 200, new_content -class FederationGroupsAddRoomsServlet(BaseFederationServlet): +class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet): """Add/remove room from group""" PATH = "/groups/(?P[^/]*)/room/(?P[^/]*)" @@ -938,7 +984,7 @@ async def on_DELETE(self, origin, content, query, group_id, room_id): return 200, new_content -class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): +class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet): """Update room config in group""" PATH = ( @@ -958,7 +1004,7 @@ async def on_POST(self, origin, content, query, group_id, room_id, config_key): return 200, result -class FederationGroupsUsersServlet(BaseFederationServlet): +class FederationGroupsUsersServlet(BaseGroupsServerServlet): """Get the users in a group on behalf of a user""" PATH = "/groups/(?P[^/]*)/users" @@ -973,7 +1019,7 @@ async def on_GET(self, origin, content, query, group_id): return 200, new_content -class FederationGroupsInvitedUsersServlet(BaseFederationServlet): +class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet): """Get the users that have been invited to a group""" PATH = "/groups/(?P[^/]*)/invited_users" @@ -990,7 +1036,7 @@ async def on_GET(self, origin, content, query, group_id): return 200, new_content -class FederationGroupsInviteServlet(BaseFederationServlet): +class FederationGroupsInviteServlet(BaseGroupsServerServlet): """Ask a group server to invite someone to the group""" PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/invite" @@ -1007,7 +1053,7 @@ async def on_POST(self, origin, content, query, group_id, user_id): return 200, new_content -class FederationGroupsAcceptInviteServlet(BaseFederationServlet): +class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet): """Accept an invitation from the group server""" PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/accept_invite" @@ -1021,7 +1067,7 @@ async def on_POST(self, origin, content, query, group_id, user_id): return 200, new_content -class FederationGroupsJoinServlet(BaseFederationServlet): +class FederationGroupsJoinServlet(BaseGroupsServerServlet): """Attempt to join a group""" PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/join" @@ -1035,7 +1081,7 @@ async def on_POST(self, origin, content, query, group_id, user_id): return 200, new_content -class FederationGroupsRemoveUserServlet(BaseFederationServlet): +class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet): """Leave or kick a user from the group""" PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/remove" @@ -1052,7 +1098,24 @@ async def on_POST(self, origin, content, query, group_id, user_id): return 200, new_content -class FederationGroupsLocalInviteServlet(BaseFederationServlet): +class BaseGroupsLocalServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a groups local handler. + + See BaseFederationServlet for more information. + """ + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_local_handler() + + +class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet): """A group server has invited a local user""" PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/invite" @@ -1061,12 +1124,16 @@ async def on_POST(self, origin, content, query, group_id, user_id): if get_domain_from_id(group_id) != origin: raise SynapseError(403, "group_id doesn't match origin") + assert isinstance( + self.handler, GroupsLocalHandler + ), "Workers cannot handle group invites." + new_content = await self.handler.on_invite(group_id, user_id, content) return 200, new_content -class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): +class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet): """A group server has removed a local user""" PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/remove" @@ -1075,6 +1142,10 @@ async def on_POST(self, origin, content, query, group_id, user_id): if get_domain_from_id(group_id) != origin: raise SynapseError(403, "user_id doesn't match origin") + assert isinstance( + self.handler, GroupsLocalHandler + ), "Workers cannot handle group removals." + new_content = await self.handler.user_removed_from_group( group_id, user_id, content ) @@ -1087,6 +1158,16 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): PATH = "/groups/(?P[^/]*)/renew_attestation/(?P[^/]*)" + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_attestation_renewer() + async def on_POST(self, origin, content, query, group_id, user_id): # We don't need to check auth here as we check the attestation signatures @@ -1097,7 +1178,7 @@ async def on_POST(self, origin, content, query, group_id, user_id): return 200, new_content -class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): +class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet): """Add/remove a room from the group summary, with optional category. Matches both: @@ -1154,7 +1235,7 @@ async def on_DELETE(self, origin, content, query, group_id, category_id, room_id return 200, resp -class FederationGroupsCategoriesServlet(BaseFederationServlet): +class FederationGroupsCategoriesServlet(BaseGroupsServerServlet): """Get all categories for a group""" PATH = "/groups/(?P[^/]*)/categories/?" @@ -1169,7 +1250,7 @@ async def on_GET(self, origin, content, query, group_id): return 200, resp -class FederationGroupsCategoryServlet(BaseFederationServlet): +class FederationGroupsCategoryServlet(BaseGroupsServerServlet): """Add/remove/get a category in a group""" PATH = "/groups/(?P[^/]*)/categories/(?P[^/]+)" @@ -1222,7 +1303,7 @@ async def on_DELETE(self, origin, content, query, group_id, category_id): return 200, resp -class FederationGroupsRolesServlet(BaseFederationServlet): +class FederationGroupsRolesServlet(BaseGroupsServerServlet): """Get roles in a group""" PATH = "/groups/(?P[^/]*)/roles/?" @@ -1237,7 +1318,7 @@ async def on_GET(self, origin, content, query, group_id): return 200, resp -class FederationGroupsRoleServlet(BaseFederationServlet): +class FederationGroupsRoleServlet(BaseGroupsServerServlet): """Add/remove/get a role in a group""" PATH = "/groups/(?P[^/]*)/roles/(?P[^/]+)" @@ -1290,7 +1371,7 @@ async def on_DELETE(self, origin, content, query, group_id, role_id): return 200, resp -class FederationGroupsSummaryUsersServlet(BaseFederationServlet): +class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet): """Add/remove a user from the group summary, with optional role. Matches both: @@ -1345,7 +1426,7 @@ async def on_DELETE(self, origin, content, query, group_id, role_id, user_id): return 200, resp -class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): +class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet): """Get roles in a group""" PATH = "/get_groups_publicised" @@ -1358,7 +1439,7 @@ async def on_POST(self, origin, content, query): return 200, resp -class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): +class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet): """Sets whether a group is joinable without an invite or knock""" PATH = "/groups/(?P[^/]*)/settings/m.join_policy" @@ -1379,6 +1460,16 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" PATH = "/spaces/(?P[^/]*)" + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_space_summary_handler() + async def on_GET( self, origin: str, @@ -1444,16 +1535,25 @@ class RoomComplexityServlet(BaseFederationServlet): PATH = "/rooms/(?P[^/]*)/complexity" PREFIX = FEDERATION_UNSTABLE_PREFIX - async def on_GET(self, origin, content, query, room_id): - - store = self.handler.hs.get_datastore() + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self._store = self.hs.get_datastore() - is_public = await store.is_room_world_readable_or_publicly_joinable(room_id) + async def on_GET(self, origin, content, query, room_id): + is_public = await self._store.is_room_world_readable_or_publicly_joinable( + room_id + ) if not is_public: raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM) - complexity = await store.get_room_complexity(room_id) + complexity = await self._store.get_room_complexity(room_id) return 200, complexity @@ -1482,6 +1582,7 @@ async def on_GET(self, origin, content, query, room_id): On3pidBindServlet, FederationVersionServlet, RoomComplexityServlet, + FederationSpaceSummaryServlet, ) # type: Tuple[Type[BaseFederationServlet], ...] OPENID_SERVLET_CLASSES = ( @@ -1559,23 +1660,16 @@ def register_servlets( if "federation" in servlet_groups: for servletclass in FEDERATION_SERVLET_CLASSES: servletclass( - handler=hs.get_federation_server(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, ).register(resource) - FederationSpaceSummaryServlet( - handler=hs.get_space_summary_handler(), - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) - if "openid" in servlet_groups: for servletclass in OPENID_SERVLET_CLASSES: servletclass( - handler=hs.get_federation_server(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, @@ -1584,7 +1678,7 @@ def register_servlets( if "room_list" in servlet_groups: for servletclass in ROOM_LIST_CLASSES: servletclass( - handler=hs.get_room_list_handler(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, @@ -1594,7 +1688,7 @@ def register_servlets( if "group_server" in servlet_groups: for servletclass in GROUP_SERVER_SERVLET_CLASSES: servletclass( - handler=hs.get_groups_server_handler(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, @@ -1603,7 +1697,7 @@ def register_servlets( if "group_local" in servlet_groups: for servletclass in GROUP_LOCAL_SERVLET_CLASSES: servletclass( - handler=hs.get_groups_local_handler(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, @@ -1612,7 +1706,7 @@ def register_servlets( if "group_attestation" in servlet_groups: for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES: servletclass( - handler=hs.get_groups_attestation_renewer(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 141c9c044400..0a26088d3215 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -44,7 +44,7 @@ def __init__(self, hs: "HomeServer"): self.enable_room_list_search = hs.config.enable_room_list_search self.response_cache = ResponseCache( hs.get_clock(), "room_list" - ) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]] + ) # type: ResponseCache[Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]] self.remote_response_cache = ResponseCache( hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000 ) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]] @@ -54,7 +54,7 @@ async def get_local_public_room_list( limit: Optional[int] = None, since_token: Optional[str] = None, search_filter: Optional[dict] = None, - network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID, + network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID, from_federation: bool = False, ) -> JsonDict: """Generate a local public room list. @@ -111,7 +111,7 @@ async def _get_public_room_list( limit: Optional[int] = None, since_token: Optional[str] = None, search_filter: Optional[dict] = None, - network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID, + network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID, from_federation: bool = False, ) -> JsonDict: """Generate a public room list. diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index d61563d39b6c..72e2ec78db41 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -295,6 +295,30 @@ def parse_strings_from_args( return default +@overload +def parse_string_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[str] = None, + required: Literal[True] = True, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> str: + ... + + +@overload +def parse_string_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[str] = None, + required: bool = False, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> Optional[str]: + ... + + def parse_string_from_args( args: Dict[bytes, List[bytes]], name: str,