Skip to content

Commit

Permalink
Thread through instance name to replication client. (matrix-org#7369)
Browse files Browse the repository at this point in the history
For in memory streams when fetching updates on workers we need to query the source of the stream, which currently is hard coded to be master. This PR threads through the source instance we received via `POSITION` through to the update function in each stream, which can then be passed to the replication client for in memory streams.
  • Loading branch information
erikjohnston authored and phil-flex committed Jun 16, 2020
1 parent 7a1a656 commit 87ad2ec
Show file tree
Hide file tree
Showing 12 changed files with 101 additions and 41 deletions.
1 change: 1 addition & 0 deletions changelog.d/7369.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Thread through instance name to replication client.
10 changes: 4 additions & 6 deletions synapse/app/generic_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,13 +646,11 @@ def __init__(self, hs):
else:
self.send_handler = None

async def on_rdata(self, stream_name, token, rows):
await super(GenericWorkerReplicationHandler, self).on_rdata(
stream_name, token, rows
)
await self.process_and_notify(stream_name, token, rows)
async def on_rdata(self, stream_name, instance_name, token, rows):
await super().on_rdata(stream_name, instance_name, token, rows)
await self._process_and_notify(stream_name, instance_name, token, rows)

async def process_and_notify(self, stream_name, token, rows):
async def _process_and_notify(self, stream_name, instance_name, token, rows):
try:
if self.send_handler:
await self.send_handler.process_replication_rows(
Expand Down
19 changes: 18 additions & 1 deletion synapse/replication/http/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import abc
import logging
import re
from inspect import signature
from typing import Dict, List, Tuple

from six import raise_from
Expand Down Expand Up @@ -60,6 +61,8 @@ class ReplicationEndpoint(object):
must call `register` to register the path with the HTTP server.
Requests can be sent by calling the client returned by `make_client`.
Requests are sent to master process by default, but can be sent to other
named processes by specifying an `instance_name` keyword argument.
Attributes:
NAME (str): A name for the endpoint, added to the path as well as used
Expand Down Expand Up @@ -91,6 +94,16 @@ def __init__(self, hs):
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
)

# We reserve `instance_name` as a parameter to sending requests, so we
# assert here that sub classes don't try and use the name.
assert (
"instance_name" not in self.PATH_ARGS
), "`instance_name` is a reserved paramater name"
assert (
"instance_name"
not in signature(self.__class__._serialize_payload).parameters
), "`instance_name` is a reserved paramater name"

assert self.METHOD in ("PUT", "POST", "GET")

@abc.abstractmethod
Expand Down Expand Up @@ -135,7 +148,11 @@ def make_client(cls, hs):

@trace(opname="outgoing_replication_request")
@defer.inlineCallbacks
def send_request(**kwargs):
def send_request(instance_name="master", **kwargs):
# Currently we only support sending requests to master process.
if instance_name != "master":
raise Exception("Unknown instance")

data = yield cls._serialize_payload(**kwargs)

url_args = [
Expand Down
4 changes: 3 additions & 1 deletion synapse/replication/http/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
def __init__(self, hs):
super().__init__(hs)

self._instance_name = hs.get_instance_name()

# We pull the streams from the replication steamer (if we try and make
# them ourselves we end up in an import loop).
self.streams = hs.get_replication_streamer().get_streams()
Expand All @@ -67,7 +69,7 @@ async def _handle_request(self, request, stream_name):
upto_token = parse_integer(request, "upto_token", required=True)

updates, upto_token, limited = await stream.get_updates_since(
from_token, upto_token
self._instance_name, from_token, upto_token
)

return (
Expand Down
12 changes: 7 additions & 5 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,19 @@ class ReplicationDataHandler:
def __init__(self, store: BaseSlavedStore):
self.store = store

async def on_rdata(self, stream_name: str, token: int, rows: list):
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
handle more.
Args:
stream_name (str): name of the replication stream for this batch of rows
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
stream_name: name of the replication stream for this batch of rows
instance_name: the instance that wrote the rows.
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
self.store.process_replication_rows(stream_name, token, rows)

Expand Down
20 changes: 15 additions & 5 deletions synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,19 +278,24 @@ async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
# Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, [])
rows.append(row)
await self.on_rdata(stream_name, cmd.token, rows)
await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)

async def on_rdata(self, stream_name: str, token: int, rows: list):
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
):
"""Called to handle a batch of replication data with a given stream token.
Args:
stream_name: name of the replication stream for this batch of rows
instance_name: the instance that wrote the rows.
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
await self._replication_data_handler.on_rdata(stream_name, token, rows)
await self._replication_data_handler.on_rdata(
stream_name, instance_name, token, rows
)

async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
Expand Down Expand Up @@ -325,7 +330,9 @@ async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
updates,
current_token,
missing_updates,
) = await stream.get_updates_since(current_token, cmd.token)
) = await stream.get_updates_since(
cmd.instance_name, current_token, cmd.token
)

# TODO: add some tests for this

Expand All @@ -334,7 +341,10 @@ async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):

for token, rows in _batch_updates(updates):
await self.on_rdata(
cmd.stream_name, token, [stream.parse_row(row) for row in rows],
cmd.stream_name,
cmd.instance_name,
token,
[stream.parse_row(row) for row in rows],
)

# We've now caught up to position sent to us, notify handler.
Expand Down
50 changes: 37 additions & 13 deletions synapse/replication/tcp/streams/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import logging
from collections import namedtuple
from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple
from typing import Any, Awaitable, Callable, List, Optional, Tuple

import attr

Expand Down Expand Up @@ -53,6 +53,7 @@
#
# The arguments are:
#
# * instance_name: the writer of the stream
# * from_token: the previous stream token: the starting point for fetching the
# updates
# * to_token: the new stream token: the point to get updates up to
Expand All @@ -62,7 +63,7 @@
# If there are more updates available, it should set `limited` in the result, and
# it will be called again to get the next batch.
#
UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]


class Stream(object):
Expand Down Expand Up @@ -93,6 +94,7 @@ def parse_row(cls, row: StreamRow):

def __init__(
self,
local_instance_name: str,
current_token_function: Callable[[], Token],
update_function: UpdateFunction,
):
Expand All @@ -108,9 +110,11 @@ def __init__(
stream tokens. See the UpdateFunction type definition for more info.
Args:
local_instance_name: The instance name of the current process
current_token_function: callback to get the current token, as above
update_function: callback go get stream updates, as above
"""
self.local_instance_name = local_instance_name
self.current_token = current_token_function
self.update_function = update_function

Expand All @@ -135,14 +139,14 @@ async def get_updates(self) -> StreamUpdateResult:
"""
current_token = self.current_token()
updates, current_token, limited = await self.get_updates_since(
self.last_token, current_token
self.local_instance_name, self.last_token, current_token
)
self.last_token = current_token

return updates, current_token, limited

async def get_updates_since(
self, from_token: Token, upto_token: Token
self, instance_name: str, from_token: Token, upto_token: Token
) -> StreamUpdateResult:
"""Like get_updates except allows specifying from when we should
stream updates
Expand All @@ -160,19 +164,19 @@ async def get_updates_since(
return [], upto_token, False

updates, upto_token, limited = await self.update_function(
from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
)
return updates, upto_token, limited


def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]]
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> UpdateFunction:
"""Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class
"""

async def update_function(from_token, upto_token, limit):
async def update_function(instance_name, from_token, upto_token, limit):
rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows]
limited = False
Expand All @@ -193,10 +197,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
client = ReplicationGetStreamUpdates.make_client(hs)

async def update_function(
from_token: int, upto_token: int, limit: int
instance_name: str, from_token: int, upto_token: int, limit: int
) -> StreamUpdateResult:
result = await client(
stream_name=stream_name, from_token=from_token, upto_token=upto_token,
instance_name=instance_name,
stream_name=stream_name,
from_token=from_token,
upto_token=upto_token,
)
return result["updates"], result["upto_token"], result["limited"]

Expand Down Expand Up @@ -226,6 +233,7 @@ class BackfillStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_current_backfill_token,
db_query_to_update_function(store.get_all_new_backfill_event_rows),
)
Expand Down Expand Up @@ -261,7 +269,9 @@ def __init__(self, hs):
# Query master process
update_function = make_http_update_function(hs, self.NAME)

super().__init__(store.get_current_presence_token, update_function)
super().__init__(
hs.get_instance_name(), store.get_current_presence_token, update_function
)


class TypingStream(Stream):
Expand All @@ -284,7 +294,9 @@ def __init__(self, hs):
# Query master process
update_function = make_http_update_function(hs, self.NAME)

super().__init__(typing_handler.get_current_token, update_function)
super().__init__(
hs.get_instance_name(), typing_handler.get_current_token, update_function
)


class ReceiptsStream(Stream):
Expand All @@ -305,6 +317,7 @@ class ReceiptsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_max_receipt_stream_id,
db_query_to_update_function(store.get_all_updated_receipts),
)
Expand All @@ -322,14 +335,16 @@ class PushRulesStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
super(PushRulesStream, self).__init__(
self._current_token, self._update_function
hs.get_instance_name(), self._current_token, self._update_function
)

def _current_token(self) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token

async def _update_function(self, from_token: Token, to_token: Token, limit: int):
async def _update_function(
self, instance_name: str, from_token: Token, to_token: Token, limit: int
):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)

limited = False
Expand All @@ -356,6 +371,7 @@ def __init__(self, hs):
store = hs.get_datastore()

super().__init__(
hs.get_instance_name(),
store.get_pushers_stream_token,
db_query_to_update_function(store.get_all_updated_pushers_rows),
)
Expand Down Expand Up @@ -387,6 +403,7 @@ class CachesStreamRow:
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_cache_stream_token,
db_query_to_update_function(store.get_all_updated_caches),
)
Expand All @@ -412,6 +429,7 @@ class PublicRoomsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_current_public_room_stream_id,
db_query_to_update_function(store.get_all_new_public_rooms),
)
Expand All @@ -432,6 +450,7 @@ class DeviceListsStreamRow:
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_device_stream_token,
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
)
Expand All @@ -449,6 +468,7 @@ class ToDeviceStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_to_device_stream_token,
db_query_to_update_function(store.get_all_new_device_messages),
)
Expand All @@ -468,6 +488,7 @@ class TagAccountDataStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_max_account_data_stream_id,
db_query_to_update_function(store.get_all_updated_tags),
)
Expand All @@ -487,6 +508,7 @@ class AccountDataStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
self.store.get_max_account_data_stream_id,
db_query_to_update_function(self._update_function),
)
Expand Down Expand Up @@ -517,6 +539,7 @@ class GroupServerStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_group_stream_token,
db_query_to_update_function(store.get_all_groups_changes),
)
Expand All @@ -534,6 +557,7 @@ class UserSignatureStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_device_stream_token,
db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes
Expand Down
Loading

0 comments on commit 87ad2ec

Please sign in to comment.