Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add ability to wait for replication streams #7542

Merged
merged 9 commits into from
May 22, 2020
1 change: 1 addition & 0 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(self, hs):
self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
self.http_client = hs.get_simple_http_client()
self._replication = hs.get_replication_data_handler()

self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client(
hs
Expand Down
5 changes: 1 addition & 4 deletions synapse/replication/http/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@ def __init__(self, hs):
super().__init__(hs)

self._instance_name = hs.get_instance_name()

# We pull the streams from the replication handler (if we try and make
# them ourselves we end up in an import loop).
self.streams = hs.get_tcp_replication().get_streams()
self.streams = hs.get_replication_streams()

@staticmethod
def _serialize_payload(stream_name, from_token, upto_token):
Expand Down
76 changes: 74 additions & 2 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,23 @@
# limitations under the License.
"""A replication client for use by synapse workers.
"""

import heapq
import logging
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Dict, List, Tuple

from twisted.internet.defer import Deferred
from twisted.internet.protocol import ReconnectingClientFactory

from synapse.api.constants import EventTypes
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamEventRow,
EventsStreamRow,
)
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -92,6 +96,16 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.pusher_pool = hs.get_pusherpool()
self.notifier = hs.get_notifier()
self._reactor = hs.get_reactor()
self._clock = hs.get_clock()
self._streams = hs.get_replication_streams()
self._instance_name = hs.get_instance_name()

# Map from stream to list of deferreds waiting for stream to particular
# position. The lists are sorted by the stream position.
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
self._streams_to_waiters = (
{}
) # type: Dict[str, List[Tuple[int, Deferred[None]]]]

async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
Expand Down Expand Up @@ -131,8 +145,66 @@ async def on_rdata(

await self.pusher_pool.on_new_notifications(token, token)

# Notify any waiting deferreds
waiting_list = self._streams_to_waiters.get(stream_name, [])

# Index of first item with a position after the current token, i.e we
# have called all deferreds before this index. If not overwritten by
# loop below means either a) no items in list so no-op or b) all items
# in list were called and so the list should be cleared. Setting it to
# `len(list)` works for both cases.
index_of_first_deferred_not_called = len(waiting_list)

for idx, (position, deferred) in enumerate(waiting_list):
if position <= token:
try:
with PreserveLoggingContext():
deferred.callback(None)
except Exception:
# The deferred has been cancelled or timed out.
pass
else:
index_of_first_deferred_not_called = idx
break
anoadragon453 marked this conversation as resolved.
Show resolved Hide resolved

# (This maintains the order so no need to resort)
waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
anoadragon453 marked this conversation as resolved.
Show resolved Hide resolved

async def on_position(self, stream_name: str, instance_name: str, token: int):
self.store.process_replication_rows(stream_name, instance_name, token, [])

def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""

async def wait_for_stream_position(
self, instance_name: str, stream_name: str, position: int
):
"""Wait until this instance has received updates up to and including
the given stream position.
"""

if instance_name == self._instance_name:
# We don't get told about updates written by this process, and
# anyway in that case we don't need to wait.
return

current_position = self._streams[stream_name].current_token(self._instance_name)
if position <= current_position:
# We're already past the position
return

# Create a new deferred that times out after N seconds, as we don't want
# to wedge here forever.
deferred = Deferred()
deferred = timeout_deferred(deferred, 30, self._reactor)
anoadragon453 marked this conversation as resolved.
Show resolved Hide resolved

waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
heapq.heappush(waiting_list, (position, deferred))
anoadragon453 marked this conversation as resolved.
Show resolved Hide resolved

# We measure here to get in flight counts and average waiting time.
with Measure(self._clock, "repl.wait_for_stream_position"):
logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
await make_deferred_yieldable(deferred)
logger.info(
"Finished waiting for repl stream %r to reach %s", stream_name, position
)
5 changes: 5 additions & 0 deletions synapse/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.replication.tcp.streams import STREAMS_MAP
from synapse.rest.media.v1.media_repository import (
MediaRepository,
MediaRepositoryResource,
Expand Down Expand Up @@ -210,6 +211,7 @@ def build_DEPENDENCY(self)
"storage",
"replication_streamer",
"replication_data_handler",
"replication_streams",
]

REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
Expand Down Expand Up @@ -583,6 +585,9 @@ def build_replication_streamer(self) -> ReplicationStreamer:
def build_replication_data_handler(self):
return ReplicationDataHandler(self)

def build_replication_streams(self):
return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()}

def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

Expand Down
5 changes: 5 additions & 0 deletions synapse/server.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Dict

import twisted.internet

import synapse.api.auth
Expand Down Expand Up @@ -28,6 +30,7 @@ import synapse.server_notices.server_notices_sender
import synapse.state
import synapse.storage
from synapse.events.builder import EventBuilderFactory
from synapse.replication.tcp.streams import Stream

class HomeServer(object):
@property
Expand Down Expand Up @@ -136,3 +139,5 @@ class HomeServer(object):
pass
def get_pusherpool(self) -> synapse.push.pusherpool.PusherPool:
pass
def get_replication_streams(self) -> Dict[str, Stream]:
pass
5 changes: 4 additions & 1 deletion tests/handlers/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ def make_homeserver(self, reactor, clock):
reactor.pump((1000,))

hs = self.setup_test_homeserver(
notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring
notifier=Mock(),
http_client=mock_federation_client,
keyring=mock_keyring,
replication_streams={},
)

hs.datastores = datastores
Expand Down