diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 27493a53a94a..12a1cfd6d1c4 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -121,16 +121,16 @@ async def on_POSITION(self, cmd: PositionCommand): logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) return - # We're about to go and catch up with the stream, so mark as connecting - # to stop RDATA being handled at the same time by removing stream from - # list of connected streams. We also clear any batched up RDATA from - # before we got the POSITION. - self._streams_connected.discard(cmd.stream_name) - self._pending_batches.clear() - # We protect catching up with a linearizer in case the replication # connection reconnects under us. with await self._position_linearizer.queue(cmd.stream_name): + # We're about to go and catch up with the stream, so mark as connecting + # to stop RDATA being handled at the same time by removing stream from + # list of connected streams. We also clear any batched up RDATA from + # before we got the POSITION. + self._streams_connected.discard(cmd.stream_name) + self._pending_batches.clear() + # Find where we previously streamed up to. current_token = self._replication_data_handler.get_streams_to_replicate().get( cmd.stream_name @@ -158,41 +158,14 @@ async def on_POSITION(self, cmd: PositionCommand): # We've now caught up to position sent to us, notify handler. await self._replication_data_handler.on_position(cmd.stream_name, cmd.token) - self._streams_connected.add(cmd.stream_name) - - # Handle any RDATA that came in while we were catching up. - rows = self._pending_batches.pop(cmd.stream_name, []) - if rows: - # We need to make sure we filter out RDATA rows with a token less - # than what we've caught up to. This is slightly fiddly because of - # "batched" rows which have a `None` token, indicating that they - # have the same token as the next row with a non-None token. - # - # We do this by walking the list backwards, first removing any RDATA - # rows that are part of an uncompeted batch, then taking rows while - # their token is either None or greater than where we've caught up - # to. - uncompleted_batch = [] - unfinished_batch = True - filtered_rows = [] - for row in reversed(rows): - if row.token is not None: - unfinished_batch = False - if cmd.token < row.token: - filtered_rows.append(row) - else: - break - elif unfinished_batch: - uncompleted_batch.append(row) - else: - filtered_rows.append(row) - - filtered_rows.reverse() - uncompleted_batch.reverse() - if uncompleted_batch: - self._pending_batches[cmd.stream_name] = uncompleted_batch - - await self.on_rdata(cmd.stream_name, rows[-1].token, filtered_rows) + # Handle any RDATA that came in while we were catching up. + rows = self._pending_batches.pop(cmd.stream_name, []) + if rows: + await self._replication_data_handler.on_rdata( + cmd.stream_name, rows[-1].token, rows + ) + + self._streams_connected.add(cmd.stream_name) async def on_SYNC(self, cmd: SyncCommand): pass