diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index b5e40da5337e..af7408d8abbb 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -326,7 +326,7 @@ async def wait_for_stream_position( # anyway in that case we don't need to wait. return - current_position = self._streams[stream_name].current_token(self._instance_name) + current_position = self._streams[stream_name].current_token(instance_name) if position <= current_position: # We're already past the position return @@ -345,7 +345,12 @@ async def wait_for_stream_position( # We measure here to get in flight counts and average waiting time. with Measure(self._clock, "repl.wait_for_stream_position"): - logger.info("Waiting for repl stream %r to reach %s", stream_name, position) + logger.info( + "Waiting for repl stream %r to reach %s. Current position: %s", + stream_name, + position, + current_position, + ) await make_deferred_yieldable(deferred) logger.info( "Finished waiting for repl stream %r to reach %s", stream_name, position diff --git a/synapse/replication/tcp/streams/partial_state.py b/synapse/replication/tcp/streams/partial_state.py index b5a2ae74b685..84df5e8590ca 100644 --- a/synapse/replication/tcp/streams/partial_state.py +++ b/synapse/replication/tcp/streams/partial_state.py @@ -71,6 +71,6 @@ def __init__(self, hs: "HomeServer"): super().__init__( hs.get_instance_name(), # TODO(faster_joins, multiple writers): we need to account for instance names - current_token_without_instance(store.get_un_partial_stated_events_token), + store.get_un_partial_stated_events_token, store.get_un_partial_stated_events_from_stream, ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index d150fa8a943d..2e8ecac6e8ec 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -314,11 +314,12 @@ def get_chain_id_txn(txn: Cursor) -> int: db_conn, "un_partial_stated_event_stream", "stream_id" ) - def get_un_partial_stated_events_token(self) -> int: - # TODO(faster_joins, multiple writers): This is inappropriate if there are multiple - # writers because workers that don't write often will hold all - # readers up. - return self._un_partial_stated_events_stream_id_gen.get_current_token() + def get_un_partial_stated_events_token(self, instance_name: str) -> int: + return ( + self._un_partial_stated_events_stream_id_gen.get_current_token_for_writer( + instance_name + ) + ) async def get_un_partial_stated_events_from_stream( self, instance_name: str, last_id: int, current_id: int, limit: int @@ -408,6 +409,11 @@ def process_replication_position( self._stream_id_gen.advance(instance_name, token) elif stream_name == BackfillStream.NAME: self._backfill_id_gen.advance(instance_name, -token) + elif stream_name == UnPartialStatedEventStream.NAME: + logger.info( + "Advancing %s token to %s", UnPartialStatedEventStream.NAME, token + ) + self._un_partial_stated_events_stream_id_gen.advance(instance_name, token) super().process_replication_position(stream_name, instance_name, token) async def have_censored_event(self, event_id: str) -> bool: