From 1df84505812743482b843569f3d30bd0398f9497 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Mar 2022 11:17:59 -0400 Subject: [PATCH 1/2] Add type hints to the main state store. --- mypy.ini | 1 - synapse/storage/databases/main/state.py | 34 ++++++++++++++++--------- synapse/util/caches/__init__.py | 6 +++-- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/mypy.ini b/mypy.ini index 24d4ba15d452..a9f13d1d7377 100644 --- a/mypy.ini +++ b/mypy.ini @@ -46,7 +46,6 @@ exclude = (?x) |synapse/storage/databases/main/receipts.py |synapse/storage/databases/main/roommember.py |synapse/storage/databases/main/search.py - |synapse/storage/databases/main/state.py |synapse/storage/schema/ |tests/api/test_auth.py diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 417aef1dbcf3..d74324cc29cc 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -12,9 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections.abc import logging -from typing import TYPE_CHECKING, Iterable, Optional, Set +from typing import TYPE_CHECKING, Dict, Iterable, Optional, Set, Tuple + +from frozendict import frozendict from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError @@ -29,7 +30,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.state import StateFilter -from synapse.types import StateMap +from synapse.types import JsonDict, JsonMapping, StateMap from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList @@ -132,7 +133,7 @@ def get_room_version_id_txn(self, txn: LoggingTransaction, room_id: str) -> str: return room_version - async def get_room_predecessor(self, room_id: str) -> Optional[dict]: + async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]: """Get the predecessor of an upgraded room if it exists. Otherwise return None. @@ -158,9 +159,10 @@ async def get_room_predecessor(self, room_id: str) -> Optional[dict]: predecessor = create_event.content.get("predecessor", None) # Ensure the key is a dictionary - if not isinstance(predecessor, collections.abc.Mapping): + if not isinstance(predecessor, (dict, frozendict)): return None + # The keys must be strings since the data is JSON. return predecessor async def get_create_event_for_room(self, room_id: str) -> EventBase: @@ -241,7 +243,9 @@ async def get_filtered_current_state_ids( # We delegate to the cached version return await self.get_current_state_ids(room_id) - def _get_filtered_current_state_ids_txn(txn): + def _get_filtered_current_state_ids_txn( + txn: LoggingTransaction, + ) -> StateMap[str]: results = {} sql = """ SELECT type, state_key, event_id FROM current_state_events @@ -281,11 +285,11 @@ async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: event_id = state.get((EventTypes.CanonicalAlias, "")) if not event_id: - return + return None event = await self.get_event(event_id, allow_none=True) if not event: - return + return None return event.content.get("canonical_alias") @@ -304,7 +308,9 @@ async def _get_state_group_for_event(self, event_id: str) -> Optional[int]: list_name="event_ids", num_args=1, ) - async def _get_state_group_for_events(self, event_ids): + async def _get_state_group_for_events( + self, event_ids: Iterable[str] + ) -> Dict[str, int]: """Returns mapping event_id -> state_group""" rows = await self.db_pool.simple_select_many_batch( table="event_to_state_groups", @@ -375,7 +381,9 @@ def __init__( self._background_remove_left_rooms, ) - async def _background_remove_left_rooms(self, progress, batch_size): + async def _background_remove_left_rooms( + self, progress: JsonDict, batch_size: int + ) -> int: """Background update to delete rows from `current_state_events` and `event_forward_extremities` tables of rooms that the server is no longer joined to. @@ -383,7 +391,9 @@ async def _background_remove_left_rooms(self, progress, batch_size): last_room_id = progress.get("last_room_id", "") - def _background_remove_left_rooms_txn(txn): + def _background_remove_left_rooms_txn( + txn: LoggingTransaction, + ) -> Tuple[bool, Set[str]]: # get a batch of room ids to consider sql = """ SELECT DISTINCT room_id FROM current_state_events @@ -515,7 +525,7 @@ def _background_remove_left_rooms_txn(txn): ) for user_id in potentially_left_users - joined_users: - await self.mark_remote_user_device_list_as_unsubscribed(user_id) + await self.mark_remote_user_device_list_as_unsubscribed(user_id) # type: ignore[attr-defined] return batch_size diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 1cbc180eda72..42f6abb5e1ad 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -17,7 +17,7 @@ import typing from enum import Enum, auto from sys import intern -from typing import Any, Callable, Dict, List, Optional, Sized +from typing import Any, Callable, Dict, List, Optional, Sized, TypeVar import attr from prometheus_client.core import Gauge @@ -195,8 +195,10 @@ def register_cache( ) } +T = TypeVar("T", Optional[str], str) -def intern_string(string: Optional[str]) -> Optional[str]: + +def intern_string(string: T) -> T: """Takes a (potentially) unicode string and interns it if it's ascii""" if string is None: return None From 0d46fcf165c19c35406ad8c143546337a4a4354b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Mar 2022 11:20:42 -0400 Subject: [PATCH 2/2] Newsfragment --- changelog.d/12267.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/12267.misc diff --git a/changelog.d/12267.misc b/changelog.d/12267.misc new file mode 100644 index 000000000000..e43844d44ab6 --- /dev/null +++ b/changelog.d/12267.misc @@ -0,0 +1 @@ +Add missing type hints for storage.