Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Port synapse.replication.tcp to async/await #6666

Merged
merged 5 commits into from
Jan 16, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/6666.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Port `synapse.replication.tcp` to async/await.
3 changes: 1 addition & 2 deletions synapse/app/admin_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ def build_tcp_replication(self):


class AdminCmdReplicationHandler(ReplicationClientHandler):
@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
async def on_rdata(self, stream_name, token, rows):
pass

def get_streams_to_replicate(self):
Expand Down
5 changes: 2 additions & 3 deletions synapse/app/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,8 @@ def __init__(self, hs):
super(ASReplicationHandler, self).__init__(hs.get_datastore())
self.appservice_handler = hs.get_application_service_handler()

@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
yield super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
async def on_rdata(self, stream_name, token, rows):
await super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)

if stream_name == "events":
max_stream_id = self.store.get_room_max_stream_ordering()
Expand Down
5 changes: 2 additions & 3 deletions synapse/app/federation_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,8 @@ def __init__(self, hs):
super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
self.send_handler = FederationSenderHandler(hs, self)

@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
yield super(FederationSenderReplicationHandler, self).on_rdata(
async def on_rdata(self, stream_name, token, rows):
await super(FederationSenderReplicationHandler, self).on_rdata(
stream_name, token, rows
)
self.send_handler.process_replication_rows(stream_name, token, rows)
Expand Down
5 changes: 2 additions & 3 deletions synapse/app/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,8 @@ def __init__(self, hs):

self.pusher_pool = hs.get_pusherpool()

@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
yield super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
async def on_rdata(self, stream_name, token, rows):
await super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
run_in_background(self.poke_pushers, stream_name, token, rows)

@defer.inlineCallbacks
Expand Down
5 changes: 2 additions & 3 deletions synapse/app/synchrotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,8 @@ def __init__(self, hs):
self.presence_handler = hs.get_presence_handler()
self.notifier = hs.get_notifier()

@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
yield super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
async def on_rdata(self, stream_name, token, rows):
await super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
run_in_background(self.process_and_notify, stream_name, token, rows)

def get_streams_to_replicate(self):
Expand Down
5 changes: 2 additions & 3 deletions synapse/app/user_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,8 @@ def __init__(self, hs):
super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
self.user_directory = hs.get_user_directory_handler()

@defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
yield super(UserDirectoryReplicationHandler, self).on_rdata(
async def on_rdata(self, stream_name, token, rows):
await super(UserDirectoryReplicationHandler, self).on_rdata(
stream_name, token, rows
)
if stream_name == EventsStream.NAME:
Expand Down
4 changes: 3 additions & 1 deletion synapse/federation/send_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ def get_current_token(self):
def federation_ack(self, token):
self._clear_queue_before_pos(token)

def get_replication_rows(self, from_token, to_token, limit, federation_ack=None):
async def get_replication_rows(
self, from_token, to_token, limit, federation_ack=None
):
"""Get rows to be sent over federation between the two tokens

Args:
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def _push_update_local(self, member, typing):
"typing_key", self._latest_room_serial, rooms=[member.room_id]
)

def get_all_typing_updates(self, last_id, current_id):
async def get_all_typing_updates(self, last_id, current_id):
if last_id == current_id:
return []

Expand Down
11 changes: 4 additions & 7 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def start_replication(self, hs):
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self.factory)

def on_rdata(self, stream_name, token, rows):
async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token.

By default this just pokes the slave store. Can be overridden in subclasses to
Expand All @@ -121,20 +121,17 @@ def on_rdata(self, stream_name, token, rows):
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.

Returns:
Deferred|None
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
return self.store.process_replication_rows(stream_name, token, rows)
self.store.process_replication_rows(stream_name, token, rows)

def on_position(self, stream_name, token):
async def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes
the slave store.

Can be overriden in subclasses to handle more.
"""
return self.store.process_replication_rows(stream_name, token, [])
self.store.process_replication_rows(stream_name, token, [])

def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting
Expand Down
76 changes: 34 additions & 42 deletions synapse/replication/tcp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,11 @@
SyncCommand,
UserSyncCommand,
)
from synapse.replication.tcp.streams import STREAMS_MAP
from synapse.types import Collection
from synapse.util import Clock
from synapse.util.stringutils import random_string

from .streams import STREAMS_MAP

connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
)
Expand Down Expand Up @@ -241,19 +240,16 @@ def lineReceived(self, line):
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
)

def handle_command(self, cmd):
async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream.

By default delegates to on_<COMMAND>
By default delegates to on_<COMMAND>, which should return an awaitable.

Args:
cmd (synapse.replication.tcp.commands.Command): received command

Returns:
Deferred
cmd: received command
"""
handler = getattr(self, "on_%s" % (cmd.NAME,))
return handler(cmd)
await handler(cmd)

def close(self):
logger.warning("[%s] Closing connection", self.id())
Expand Down Expand Up @@ -326,10 +322,10 @@ def _send_pending_commands(self):
for cmd in pending:
self.send_command(cmd)

def on_PING(self, line):
async def on_PING(self, line):
self.received_ping = True

def on_ERROR(self, cmd):
async def on_ERROR(self, cmd):
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)

def pauseProducing(self):
Expand Down Expand Up @@ -376,7 +372,7 @@ def connectionLost(self, reason):

self.on_connection_closed()

def on_connection_closed(self):
async def on_connection_closed(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one isn't called the same way

logger.info("[%s] Connection was closed", self.id())

self.state = ConnectionStates.CLOSED
Expand Down Expand Up @@ -429,16 +425,16 @@ def connectionMade(self):
BaseReplicationStreamProtocol.connectionMade(self)
self.streamer.new_connection(self)

def on_NAME(self, cmd):
async def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data

def on_USER_SYNC(self, cmd):
return self.streamer.on_user_sync(
async def on_USER_SYNC(self, cmd):
await self.streamer.on_user_sync(
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)

def on_REPLICATE(self, cmd):
async def on_REPLICATE(self, cmd):
stream_name = cmd.stream_name
token = cmd.token

Expand All @@ -449,23 +445,23 @@ def on_REPLICATE(self, cmd):
for stream in iterkeys(self.streamer.streams_by_name)
]

return make_deferred_yieldable(
await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
else:
return self.subscribe_to_stream(stream_name, token)
await self.subscribe_to_stream(stream_name, token)

def on_FEDERATION_ACK(self, cmd):
return self.streamer.federation_ack(cmd.token)
async def on_FEDERATION_ACK(self, cmd):
self.streamer.federation_ack(cmd.token)

def on_REMOVE_PUSHER(self, cmd):
return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
async def on_REMOVE_PUSHER(self, cmd):
await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)

def on_INVALIDATE_CACHE(self, cmd):
return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
async def on_INVALIDATE_CACHE(self, cmd):
self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)

def on_USER_IP(self, cmd):
return self.streamer.on_user_ip(
async def on_USER_IP(self, cmd):
self.streamer.on_user_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
Expand All @@ -474,8 +470,7 @@ def on_USER_IP(self, cmd):
cmd.last_seen,
)

@defer.inlineCallbacks
def subscribe_to_stream(self, stream_name, token):
async def subscribe_to_stream(self, stream_name, token):
"""Subscribe the remote to a stream.

This invloves checking if they've missed anything and sending those
Expand All @@ -487,7 +482,7 @@ def subscribe_to_stream(self, stream_name, token):

try:
# Get missing updates
updates, current_token = yield self.streamer.get_stream_updates(
updates, current_token = await self.streamer.get_stream_updates(
stream_name, token
)

Expand Down Expand Up @@ -572,27 +567,24 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
"""

@abc.abstractmethod
def on_rdata(self, stream_name, token, rows):
async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token.

Args:
stream_name (str): name of the replication stream for this batch of rows
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.

Returns:
Deferred|None
"""
raise NotImplementedError()

@abc.abstractmethod
def on_position(self, stream_name, token):
async def on_position(self, stream_name, token):
"""Called when we get new position data."""
raise NotImplementedError()

@abc.abstractmethod
def on_sync(self, data):
async def on_sync(self, data):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the implementations of this don't seem to be async (and it's not awaited where it is called).

"""Called when get a new SYNC command."""
raise NotImplementedError()

Expand Down Expand Up @@ -676,12 +668,12 @@ def connectionMade(self):
if not self.streams_connecting:
self.handler.finished_connecting()

def on_SERVER(self, cmd):
async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")

def on_RDATA(self, cmd):
async def on_RDATA(self, cmd):
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()

Expand All @@ -701,19 +693,19 @@ def on_RDATA(self, cmd):
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
return self.handler.on_rdata(stream_name, cmd.token, rows)
await self.handler.on_rdata(stream_name, cmd.token, rows)

def on_POSITION(self, cmd):
async def on_POSITION(self, cmd):
# When we get a `POSITION` command it means we've finished getting
# missing updates for the given stream, and are now up to date.
self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting:
self.handler.finished_connecting()

return self.handler.on_position(cmd.stream_name, cmd.token)
await self.handler.on_position(cmd.stream_name, cmd.token)

def on_SYNC(self, cmd):
return self.handler.on_sync(cmd.data)
async def on_SYNC(self, cmd):
self.handler.on_sync(cmd.data)

def replicate(self, stream_name, token):
"""Send the subscription request to the server
Expand Down
Loading