Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Type annotations in synapse.databases.main.devices (#13025)
Browse files Browse the repository at this point in the history
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
  • Loading branch information
David Robertson and clokep committed Jun 15, 2022
1 parent 0d1d3e0 commit 97e9fbe
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 21 deletions.
1 change: 1 addition & 0 deletions changelog.d/13025.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type annotations to `synapse.storage.databases.main.devices`.
1 change: 0 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ exclude = (?x)
^(
|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/schema/

|tests/api/test_auth.py
Expand Down
3 changes: 1 addition & 2 deletions synapse/replication/slave/storage/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.devices import DeviceWorkerStore
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore

if TYPE_CHECKING:
from synapse.server import HomeServer


class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
class SlavedDeviceStore(DeviceWorkerStore, BaseSlavedStore):
def __init__(
self,
database: DatabasePool,
Expand Down
1 change: 1 addition & 0 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def __init__(
self._min_stream_order_on_start = self.get_room_min_stream_ordering()

def get_device_stream_token(self) -> int:
# TODO: shouldn't this be moved to `DeviceWorkerStore`?
return self._device_list_id_gen.get_current_token()

async def get_users(self) -> List[JsonDict]:
Expand Down
51 changes: 33 additions & 18 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
cast,
)

from typing_extensions import Literal

from synapse.api.constants import EduTypes
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
Expand All @@ -44,6 +46,8 @@
LoggingTransaction,
make_tuple_comparison_clause,
)
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.storage.types import Cursor
from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
Expand All @@ -65,7 +69,7 @@
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"


class DeviceWorkerStore(SQLBaseStore):
class DeviceWorkerStore(EndToEndKeyWorkerStore):
def __init__(
self,
database: DatabasePool,
Expand All @@ -74,7 +78,9 @@ def __init__(
):
super().__init__(database, db_conn, hs)

device_list_max = self._device_list_id_gen.get_current_token()
# Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
# StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).
device_list_max = self._device_list_id_gen.get_current_token() # type: ignore[attr-defined]
device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
db_conn,
"device_lists_stream",
Expand Down Expand Up @@ -339,8 +345,9 @@ async def get_device_updates_by_remote(
# following this stream later.
last_processed_stream_id = from_stream_id

query_map = {}
cross_signing_keys_by_user = {}
# A map of (user ID, device ID) to (stream ID, context).
query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]] = {}
cross_signing_keys_by_user: Dict[str, Dict[str, object]] = {}
for user_id, device_id, update_stream_id, update_context in updates:
# Calculate the remaining length budget.
# Note that, for now, each entry in `cross_signing_keys_by_user`
Expand Down Expand Up @@ -596,7 +603,7 @@ def _mark_as_sent_devices_by_remote_txn(
txn=txn,
table="device_lists_outbound_last_success",
key_names=("destination", "user_id"),
key_values=((destination, user_id) for user_id, _ in rows),
key_values=[(destination, user_id) for user_id, _ in rows],
value_names=("stream_id",),
value_values=((stream_id,) for _, stream_id in rows),
)
Expand All @@ -621,7 +628,9 @@ async def add_user_signature_change_to_streams(
The new stream ID.
"""

async with self._device_list_id_gen.get_next() as stream_id:
# TODO: this looks like it's _writing_. Should this be on DeviceStore rather
# than DeviceWorkerStore?
async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
Expand Down Expand Up @@ -686,7 +695,7 @@ async def get_user_devices_from_cache(
} - users_needing_resync
user_ids_not_in_cache = user_ids - user_ids_in_cache

results = {}
results: Dict[str, Dict[str, JsonDict]] = {}
for user_id, device_id in query_list:
if user_id not in user_ids_in_cache:
continue
Expand Down Expand Up @@ -727,7 +736,7 @@ async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]
def get_cached_device_list_changes(
self,
from_key: int,
) -> Optional[Set[str]]:
) -> Optional[List[str]]:
"""Get set of users whose devices have changed since `from_key`, or None
if that information is not in our cache.
"""
Expand All @@ -737,7 +746,7 @@ def get_cached_device_list_changes(
async def get_users_whose_devices_changed(
self,
from_key: int,
user_ids: Optional[Iterable[str]] = None,
user_ids: Optional[Collection[str]] = None,
to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
Expand All @@ -757,6 +766,7 @@ async def get_users_whose_devices_changed(
"""
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
user_ids_to_check: Optional[Collection[str]]
if user_ids is None:
# Get set of all users that have had device list changes since 'from_key'
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
Expand All @@ -772,7 +782,7 @@ async def get_users_whose_devices_changed(
return set()

def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
changes = set()
changes: Set[str] = set()

stream_id_where_clause = "stream_id > ?"
sql_args = [from_key]
Expand All @@ -788,6 +798,9 @@ def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
"""

# Query device changes with a batch of users at a time
# Assertion for mypy's benefit; see also
# https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
assert user_ids_to_check is not None
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
Expand Down Expand Up @@ -854,7 +867,9 @@ async def get_all_device_list_changes_for_remotes(
if last_id == current_id:
return [], current_id, False

def _get_all_device_list_changes_for_remotes(txn):
def _get_all_device_list_changes_for_remotes(
txn: Cursor,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries.
sql = """
Expand Down Expand Up @@ -913,7 +928,7 @@ async def get_device_list_last_stream_id_for_remotes(
desc="get_device_list_last_stream_id_for_remotes",
)

results = {user_id: None for user_id in user_ids}
results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
results.update({row["user_id"]: row["stream_id"] for row in rows})

return results
Expand Down Expand Up @@ -1337,9 +1352,9 @@ def __init__(

# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = LruCache(
cache_name="device_id_exists", max_size=10000
)
self.device_id_exists_cache: LruCache[
Tuple[str, str], Literal[True]
] = LruCache(cache_name="device_id_exists", max_size=10000)

async def store_device(
self,
Expand Down Expand Up @@ -1651,7 +1666,7 @@ def add_device_changes_txn(
context,
)

async with self._device_list_id_gen.get_next_mult(
async with self._device_list_id_gen.get_next_mult( # type: ignore[attr-defined]
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
Expand Down Expand Up @@ -1704,7 +1719,7 @@ def _add_device_outbound_poke_to_stream_txn(
device_ids: Iterable[str],
hosts: Collection[str],
stream_ids: List[int],
context: Dict[str, str],
context: Optional[Dict[str, str]],
) -> None:
for host in hosts:
txn.call_after(
Expand Down Expand Up @@ -1875,7 +1890,7 @@ def add_device_list_outbound_pokes_txn(
[],
)

async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: # type: ignore[attr-defined]
return await self.db_pool.runInteraction(
"add_device_list_outbound_pokes",
add_device_list_outbound_pokes_txn,
Expand Down

0 comments on commit 97e9fbe

Please sign in to comment.