From c5d0ea7396ccc9ee17bb1f763b7345b52dd32086 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 2 Dec 2022 14:41:37 +0000 Subject: [PATCH 1/5] Split get_users_whose_devices_changed --- synapse/handlers/appservice.py | 4 +- synapse/storage/databases/main/devices.py | 81 ++++++++++++----------- 2 files changed, 45 insertions(+), 40 deletions(-) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 66f5b8d108be..f68027aaed29 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -615,8 +615,8 @@ async def _get_device_list_summary( ) # Fetch the users who have modified their device list since then. - users_with_changed_device_lists = ( - await self.store.get_users_whose_devices_changed(from_key, to_key=new_key) + users_with_changed_device_lists = await self.store.get_all_devices_changed( + from_key, to_key=new_key ) # Filter out any users the application service is not interested in diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 8ba995df3b61..db7a219c2fd5 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -806,11 +806,36 @@ def get_cached_device_list_changes( return self._device_list_stream_cache.get_all_entities_changed(from_key) + @cancellable + async def get_all_devices_changed( + self, + from_key: int, + to_key: int, + ) -> Set[str]: + user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed( + from_key + ) + + if user_ids_to_check is not None: + return await self.get_users_whose_devices_changed( + from_key, user_ids_to_check, to_key + ) + + sql = """ + SELECT DISTINCT user_id FROM device_lists_stream + WHERE ? < stream_id AND stream_id <= ? + """ + + rows = await self.db_pool.execute( + "get_all_devices_changed", None, sql, (from_key, to_key) + ) + return {u for u, in rows} + @cancellable async def get_users_whose_devices_changed( self, from_key: int, - user_ids: Optional[Collection[str]] = None, + user_ids: Collection[str], to_key: Optional[int] = None, ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that @@ -830,52 +855,32 @@ async def get_users_whose_devices_changed( """ # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. - user_ids_to_check: Optional[Collection[str]] - if user_ids is None: - # Get set of all users that have had device list changes since 'from_key' - user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed( - from_key - ) - else: - # The same as above, but filter results to only those users in 'user_ids' - user_ids_to_check = self._device_list_stream_cache.get_entities_changed( - user_ids, from_key - ) + user_ids_to_check = self._device_list_stream_cache.get_entities_changed( + user_ids, from_key + ) # If an empty set was returned, there's nothing to do. - if user_ids_to_check is not None and not user_ids_to_check: + if not user_ids_to_check: return set() - def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]: - stream_id_where_clause = "stream_id > ?" - sql_args = [from_key] - - if to_key: - stream_id_where_clause += " AND stream_id <= ?" - sql_args.append(to_key) + if to_key is None: + to_key = self._device_list_id_gen.get_current_token() - sql = f""" + def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]: + sql = """ SELECT DISTINCT user_id FROM device_lists_stream - WHERE {stream_id_where_clause} + WHERE ? < stream_id AND stream_id <= ? AND %s """ - # If the stream change cache gave us no information, fetch *all* - # users between the stream IDs. - if user_ids_to_check is None: - txn.execute(sql, sql_args) - return {user_id for user_id, in txn} + changes: Set[str] = set() - # Otherwise, fetch changes for the given users. - else: - changes: Set[str] = set() - - # Query device changes with a batch of users at a time - for chunk in batch_iter(user_ids_to_check, 100): - clause, args = make_in_list_sql_clause( - txn.database_engine, "user_id", chunk - ) - txn.execute(sql + " AND " + clause, sql_args + args) - changes.update(user_id for user_id, in txn) + # Query device changes with a batch of users at a time + for chunk in batch_iter(user_ids_to_check, 100): + clause, args = make_in_list_sql_clause( + txn.database_engine, "user_id", chunk + ) + txn.execute(sql % (clause,), [from_key, to_key] + args) + changes.update(user_id for user_id, in txn) return changes From db1936ed259b358f1edac6a89476c643f41674c3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 2 Dec 2022 14:30:41 +0000 Subject: [PATCH 2/5] Better return type for `get_all_entities_changed` This helps prevent people from using it wrong. --- synapse/handlers/presence.py | 12 ++++--- synapse/handlers/sync.py | 6 ++-- synapse/handlers/typing.py | 8 ++--- synapse/storage/databases/main/devices.py | 15 ++++---- synapse/util/caches/stream_change_cache.py | 42 +++++++++++++++++----- tests/util/test_stream_change_cache.py | 20 ++++++----- 6 files changed, 67 insertions(+), 36 deletions(-) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 1799174c2f1a..2af90b25a39c 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1692,10 +1692,12 @@ async def get_new_events( if from_key is not None: # First get all users that have had a presence update - updated_users = stream_change_cache.get_all_entities_changed(from_key) + result = stream_change_cache.get_all_entities_changed(from_key) # Cross-reference users we're interested in with those that have had updates. - if updated_users is not None: + if result.hit: + updated_users = result.entities + # If we have the full list of changes for presence we can # simply check which ones share a room with the user. get_updates_counter.labels("stream").inc() @@ -1767,9 +1769,9 @@ async def _filter_all_presence_updates_for_user( updated_users = None if from_key: # Only return updates since the last sync - updated_users = self.store.presence_stream_cache.get_all_entities_changed( - from_key - ) + result = self.store.presence_stream_cache.get_all_entities_changed(from_key) + if result.hit: + updated_users = result.entities if updated_users is not None: # Get the actual presence update for each change diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c8858b22ddf6..0b395a104d17 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1528,10 +1528,12 @@ async def _generate_sync_entry_for_device_list( # # If we don't have that info cached then we get all the users that # share a room with our user and check if those users have changed. - changed_users = self.store.get_cached_device_list_changes( + cache_result = self.store.get_cached_device_list_changes( since_token.device_list_key ) - if changed_users is not None: + if cache_result.hit: + changed_users = cache_result.entities + result = await self.store.get_rooms_for_users(changed_users) for changed_user_id, entries in result.items(): diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index a0ea71943053..3f656ea4f508 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -420,11 +420,11 @@ async def get_all_typing_updates( if last_id == current_id: return [], current_id, False - changed_rooms: Optional[ - Iterable[str] - ] = self._typing_stream_change_cache.get_all_entities_changed(last_id) + result = self._typing_stream_change_cache.get_all_entities_changed(last_id) - if changed_rooms is None: + if result.hit: + changed_rooms: Iterable[str] = result.entities + else: changed_rooms = self._room_serials rows = [] diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index db7a219c2fd5..e55a3502cdd9 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -58,7 +58,10 @@ from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.caches.stream_change_cache import ( + AllEntitiesChangedResult, + StreamChangeCache, +) from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -799,7 +802,7 @@ async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict] def get_cached_device_list_changes( self, from_key: int, - ) -> Optional[List[str]]: + ) -> AllEntitiesChangedResult: """Get set of users whose devices have changed since `from_key`, or None if that information is not in our cache. """ @@ -812,13 +815,11 @@ async def get_all_devices_changed( from_key: int, to_key: int, ) -> Set[str]: - user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed( - from_key - ) + result = self._device_list_stream_cache.get_all_entities_changed(from_key) - if user_ids_to_check is not None: + if result.hit: return await self.get_users_whose_devices_changed( - from_key, user_ids_to_check, to_key + from_key, result.entities, to_key ) sql = """ diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 666f4b6895bf..b6d6ae74c326 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -16,6 +16,7 @@ import math from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union +import attr from sortedcontainers import SortedDict from synapse.util import caches @@ -26,6 +27,29 @@ EntityType = str +@attr.s(auto_attribs=True, frozen=True, slots=True) +class AllEntitiesChangedResult: + """Return type of `get_all_entities_changed`. + + Callers must check that there was a cache hit, via `result.hit`, before + using the entities in `result.entities`. + + This specifically does *not* implement helpers such as `__bool__` to ensure + that callers do the correct checks. + """ + + _entities: Optional[List[EntityType]] + + @property + def hit(self) -> bool: + return self._entities is not None + + @property + def entities(self) -> List[EntityType]: + assert self._entities is not None + return self._entities + + class StreamChangeCache: """Keeps track of the stream positions of the latest change in a set of entities. @@ -109,19 +133,19 @@ def get_entities_changed( position. Entities unknown to the cache will be returned. If the position is too old it will just return the given list. """ - changed_entities = self.get_all_entities_changed(stream_pos) - if changed_entities is not None: + cache_result = self.get_all_entities_changed(stream_pos) + if cache_result.hit: # We now do an intersection, trying to do so in the most efficient # way possible (some of these sets are *large*). First check in the # given iterable is already set that we can reuse, otherwise we # create a set of the *smallest* of the two iterables and call # `intersection(..)` on it (this can be twice as fast as the reverse). if isinstance(entities, (set, frozenset)): - result = entities.intersection(changed_entities) - elif len(changed_entities) < len(entities): - result = set(changed_entities).intersection(entities) + result = entities.intersection(cache_result.entities) + elif len(cache_result.entities) < len(entities): + result = set(cache_result.entities).intersection(entities) else: - result = set(entities).intersection(changed_entities) + result = set(entities).intersection(cache_result.entities) self.metrics.inc_hits() else: result = set(entities) @@ -144,7 +168,7 @@ def has_any_entity_changed(self, stream_pos: int) -> bool: self.metrics.inc_misses() return True - def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]: + def get_all_entities_changed(self, stream_pos: int) -> AllEntitiesChangedResult: """Returns all entities that have had new things since the given position. If the position is too old it will return None. @@ -153,13 +177,13 @@ def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType] assert type(stream_pos) is int if stream_pos < self._earliest_known_stream_pos: - return None + return AllEntitiesChangedResult(None) changed_entities: List[EntityType] = [] for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)): changed_entities.extend(self._cache[k]) - return changed_entities + return AllEntitiesChangedResult(changed_entities) def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None: """Informs the cache that the entity has been changed at the given diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 9ed01f7e0c97..4689c7512a91 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -70,10 +70,10 @@ def test_entity_has_changed_pops_off_start(self): self.assertTrue("user@foo.com" not in cache._entity_to_key) self.assertEqual( - cache.get_all_entities_changed(2), + cache.get_all_entities_changed(2).entities, ["bar@baz.net", "user@elsewhere.org"], ) - self.assertIsNone(cache.get_all_entities_changed(1)) + self.assertFalse(cache.get_all_entities_changed(1).hit) # If we update an existing entity, it keeps the two existing entities cache.entity_has_changed("bar@baz.net", 5) @@ -81,10 +81,10 @@ def test_entity_has_changed_pops_off_start(self): {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key) ) self.assertEqual( - cache.get_all_entities_changed(2), + cache.get_all_entities_changed(2).entities, ["user@elsewhere.org", "bar@baz.net"], ) - self.assertIsNone(cache.get_all_entities_changed(1)) + self.assertFalse(cache.get_all_entities_changed(1).hit) def test_get_all_entities_changed(self): """ @@ -114,13 +114,15 @@ def test_get_all_entities_changed(self): "bar@baz.net", "user@elsewhere.org", ] - self.assertTrue(r == ok1 or r == ok2) + self.assertTrue(r.entities == ok1 or r.entities == ok2) r = cache.get_all_entities_changed(2) - self.assertTrue(r == ok1[1:] or r == ok2[1:]) + self.assertTrue(r.entities == ok1[1:] or r.entities == ok2[1:]) - self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) - self.assertEqual(cache.get_all_entities_changed(0), None) + self.assertEqual( + cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"] + ) + self.assertFalse(cache.get_all_entities_changed(0).hit) # ... later, things gest more updates cache.entity_has_changed("user@foo.com", 5) @@ -140,7 +142,7 @@ def test_get_all_entities_changed(self): "anotheruser@foo.com", ] r = cache.get_all_entities_changed(3) - self.assertTrue(r == ok1 or r == ok2) + self.assertTrue(r.entities == ok1 or r.entities == ok2) def test_has_any_entity_changed(self): """ From 54c41c978fb1c40830c5fd6e60bb494624362275 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 2 Dec 2022 16:30:00 +0000 Subject: [PATCH 3/5] Newsfile --- changelog.d/14604.bugfix | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/14604.bugfix diff --git a/changelog.d/14604.bugfix b/changelog.d/14604.bugfix new file mode 100644 index 000000000000..149ee99dd716 --- /dev/null +++ b/changelog.d/14604.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where a device list update might not be sent to clients in certain circumstances. From ec7127f7dea81e87f8897f3d84b363847390045e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 5 Dec 2022 10:05:27 +0000 Subject: [PATCH 4/5] Docstrings --- synapse/storage/databases/main/devices.py | 21 +++++++++++++++++++++ synapse/util/caches/stream_change_cache.py | 3 ++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index e55a3502cdd9..f640a9f83c9e 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -815,13 +815,34 @@ async def get_all_devices_changed( from_key: int, to_key: int, ) -> Set[str]: + """Get all users whose devices have changed in the given range. + + Args: + from_key: The minimum device lists stream token to query device list + changes for, exclusive. + to_key: The maximum device lists stream token to query device list + changes for, inclusive. + + Returns: + The set of user_ids whose devices have changed since `from_key` + (exclusive) until `to_key` (inclusive). + """ + result = self._device_list_stream_cache.get_all_entities_changed(from_key) if result.hit: + # We know which users might have changed devices. + if not result.entities: + # If no users then we can return early. + return set() + + # Otherwise we need to filter down the list return await self.get_users_whose_devices_changed( from_key, result.entities, to_key ) + # If the cache didn't tell us anything, we just need to query the full + # range. sql = """ SELECT DISTINCT user_id FROM device_lists_stream WHERE ? < stream_id AND stream_id <= ? diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index b6d6ae74c326..441ae5a8738a 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -172,7 +172,8 @@ def get_all_entities_changed(self, stream_pos: int) -> AllEntitiesChangedResult: """Returns all entities that have had new things since the given position. If the position is too old it will return None. - Returns the entities in the order that they were changed. + Returns a class indicating if we have the requested data cached, and if + so includes the entities in the order they were changed. """ assert type(stream_pos) is int From 4b6cd6a0012304983c7a862222de524fa7c7bba6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 5 Dec 2022 15:00:42 +0000 Subject: [PATCH 5/5] Fix typo --- synapse/storage/databases/main/devices.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index f640a9f83c9e..a5bb4d404e20 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -849,7 +849,11 @@ async def get_all_devices_changed( """ rows = await self.db_pool.execute( - "get_all_devices_changed", None, sql, (from_key, to_key) + "get_all_devices_changed", + None, + sql, + from_key, + to_key, ) return {u for u, in rows}