From 0573e8dcde43f83977972de92488100628d9a9cf Mon Sep 17 00:00:00 2001 From: Nick Barrett Date: Thu, 22 Dec 2022 16:20:03 +0000 Subject: [PATCH] Update all stream IDs after processing replication rows This creates a new store method, `process_replication_position` that is called after `process_replication_rows`. By moving stream ID advances here this guarantees any relevant cache invalidations will have been applied before the stream is advanced. This avoids race conditions where Python switches between threads mid way through processing the `process_replication_rows` method where stream IDs may be advanced before caches are invalidated due to class resolution ordering. See this comment/issue for further discussion: https://github.com/matrix-org/synapse/issues/14158#issuecomment-1344048703 --- synapse/replication/tcp/client.py | 1 + synapse/storage/_base.py | 16 +++++++++++++++- synapse/storage/databases/main/account_data.py | 14 ++++++++++---- synapse/storage/databases/main/cache.py | 11 ++++++++--- synapse/storage/databases/main/deviceinbox.py | 7 +++++++ synapse/storage/databases/main/devices.py | 11 +++++++++-- synapse/storage/databases/main/events_worker.py | 15 ++++++++++----- synapse/storage/databases/main/presence.py | 8 +++++++- synapse/storage/databases/main/push_rule.py | 7 +++++++ synapse/storage/databases/main/pusher.py | 6 +++--- synapse/storage/databases/main/receipts.py | 7 +++++++ synapse/storage/databases/main/tags.py | 8 +++++++- 12 files changed, 91 insertions(+), 20 deletions(-) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 658d89210d31..3698e16b46e8 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -152,6 +152,7 @@ async def on_rdata( rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ self.store.process_replication_rows(stream_name, instance_name, token, rows) + self.store.process_replication_position(stream_name, instance_name, token) if self.send_handler: await self.send_handler.process_replication_rows(stream_name, token, rows) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 69abf6fa8741..c021c374f5d6 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -57,7 +57,21 @@ def process_replication_rows( # noqa: B027 (no-op by design) token: int, rows: Iterable[Any], ) -> None: - pass + """ + Used by storage classes to invalidate caches based on incoming replication data. + """ + + def process_replication_position( # noqa: B027 (no-op by design) + self, + stream_name: str, + instance_name: str, + token: int, + ) -> None: + """ + Used by storage classes to advance ID generators based on incoming replication data. This + is called after process_replication_rows such that caches are invalidated before any token + positions advance. + """ def _invalidate_state_caches( self, room_id: str, members_changed: Collection[str] diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 07908c41d9ca..66472c0a439a 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -411,10 +411,7 @@ def process_replication_rows( token: int, rows: Iterable[Any], ) -> None: - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - elif stream_name == AccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) + if stream_name == AccountDataStream.NAME: for row in rows: if not row.room_id: self.get_global_account_data_by_type_for_user.invalidate( @@ -429,6 +426,15 @@ def process_replication_rows( super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + elif stream_name == AccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + async def add_account_data_to_room( self, user_id: str, room_id: str, account_data_type: str, content: JsonDict ) -> int: diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index a58668a38036..2179a8bf5922 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -164,9 +164,6 @@ def process_replication_rows( backfilled=True, ) elif stream_name == CachesStream.NAME: - if self._cache_id_gen: - self._cache_id_gen.advance(instance_name, token) - for row in rows: if row.cache_func == CURRENT_STATE_CACHE_NAME: if row.keys is None: @@ -182,6 +179,14 @@ def process_replication_rows( super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == CachesStream.NAME: + if self._cache_id_gen: + self._cache_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: data = row.data diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 48a54d9cb86b..713be91c5dd8 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -157,6 +157,13 @@ def process_replication_rows( ) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == ToDeviceStream.NAME: + self._device_inbox_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def get_to_device_stream_token(self) -> int: return self._device_inbox_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index a5bb4d404e20..db877e3f1374 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -162,14 +162,21 @@ def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] ) -> None: if stream_name == DeviceListsStream.NAME: - self._device_list_id_gen.advance(instance_name, token) self._invalidate_caches_for_devices(token, rows) elif stream_name == UserSignatureStream.NAME: - self._device_list_id_gen.advance(instance_name, token) for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == DeviceListsStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + elif stream_name == UserSignatureStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _invalidate_caches_for_devices( self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] ) -> None: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 761b15a8150a..d150fa8a943d 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -388,11 +388,7 @@ def process_replication_rows( token: int, rows: Iterable[Any], ) -> None: - if stream_name == EventsStream.NAME: - self._stream_id_gen.advance(instance_name, token) - elif stream_name == BackfillStream.NAME: - self._backfill_id_gen.advance(instance_name, -token) - elif stream_name == UnPartialStatedEventStream.NAME: + if stream_name == UnPartialStatedEventStream.NAME: for row in rows: assert isinstance(row, UnPartialStatedEventStreamRow) @@ -405,6 +401,15 @@ def process_replication_rows( super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == EventsStream.NAME: + self._stream_id_gen.advance(instance_name, token) + elif stream_name == BackfillStream.NAME: + self._backfill_id_gen.advance(instance_name, -token) + super().process_replication_position(stream_name, instance_name, token) + async def have_censored_event(self, event_id: str) -> bool: """Check if an event has been censored, i.e. if the content of the event has been erased from the database due to a redaction. diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 9769a18a9d0c..7b60815043a6 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -439,8 +439,14 @@ def process_replication_rows( rows: Iterable[Any], ) -> None: if stream_name == PresenceStream.NAME: - self._presence_id_gen.advance(instance_name, token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) return super().process_replication_rows(stream_name, instance_name, token, rows) + + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == PresenceStream.NAME: + self._presence_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index d4c64c46ad44..d4e4b777da95 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -154,6 +154,13 @@ def process_replication_rows( self.push_rules_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == PushRulesStream.NAME: + self._push_rules_stream_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + @cached(max_entries=5000) async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules: rows = await self.db_pool.simple_select_list( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 40fd781a6ab1..7f24a3b6ec5e 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -111,12 +111,12 @@ def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]: def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() - def process_replication_rows( - self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + def process_replication_position( + self, stream_name: str, instance_name: str, token: int ) -> None: if stream_name == PushersStream.NAME: self._pushers_id_gen.advance(instance_name, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + super().process_replication_position(stream_name, instance_name, token) async def get_pushers_by_app_id_and_pushkey( self, app_id: str, pushkey: str diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index e06725f69c31..86f5bce5f08d 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -588,6 +588,13 @@ def process_replication_rows( return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == ReceiptsStream.NAME: + self._receipts_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _insert_linearized_receipt_txn( self, txn: LoggingTransaction, diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index b0f5de67a30d..e23c927e02f2 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -300,13 +300,19 @@ def process_replication_rows( rows: Iterable[Any], ) -> None: if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) for row in rows: self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed(row.user_id, token) super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + class TagsStore(TagsWorkerStore): pass