Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Move server command handling out of TCP protocol #7187

Merged
merged 16 commits into from
Apr 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7187.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Move server command handling out of TCP protocol.
177 changes: 159 additions & 18 deletions synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,23 @@

from prometheus_client import Counter

from synapse.metrics import LaterGauge
from synapse.replication.tcp.client import ReplicationClientFactory
from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
Command,
FederationAckCommand,
InvalidateCacheCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
RemovePusherCommand,
ReplicateCommand,
SyncCommand,
UserIpCommand,
UserSyncCommand,
)
from synapse.replication.tcp.protocol import AbstractConnection
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
from synapse.util.async_helpers import Linearizer

Expand All @@ -42,6 +46,13 @@
inbound_rdata_count = Counter(
"synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
)
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
invalidate_cache_counter = Counter(
"synapse_replication_tcp_resource_invalidate_cache", ""
)
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")


class ReplicationCommandHandler:
Expand All @@ -52,6 +63,10 @@ class ReplicationCommandHandler:
def __init__(self, hs):
self._replication_data_handler = hs.get_replication_data_handler()
self._presence_handler = hs.get_presence_handler()
self._store = hs.get_datastore()
self._notifier = hs.get_notifier()
self._clock = hs.get_clock()
self._instance_id = hs.get_instance_id()

# Set of streams that we've caught up with.
self._streams_connected = set() # type: Set[str]
Expand All @@ -69,8 +84,26 @@ def __init__(self, hs):
# The factory used to create connections.
self._factory = None # type: Optional[ReplicationClientFactory]

# The current connection. None if we are currently (re)connecting
self._connection = None
# The currently connected connections.
self._connections = [] # type: List[AbstractConnection]

LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
[],
lambda: len(self._connections),
)

self._is_master = hs.config.worker_app is None

self._federation_sender = None
if self._is_master and not hs.config.send_federation:
self._federation_sender = hs.get_federation_sender()

self._server_notices_sender = None
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()
self._notifier.add_remote_server_up_callback(self.send_remote_server_up)

def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
Expand All @@ -82,6 +115,70 @@ def start_replication(self, hs):
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory)

async def on_REPLICATE(self, cmd: ReplicateCommand):
# We only want to announce positions by the writer of the streams.
# Currently this is just the master process.
if not self._is_master:
return

for stream_name, stream in self._streams.items():
current_token = stream.current_token()
self.send_command(PositionCommand(stream_name, current_token))

async def on_USER_SYNC(self, cmd: UserSyncCommand):
user_sync_counter.inc()

if self._is_master:
await self._presence_handler.update_external_syncs_row(
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)

async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand):
if self._is_master:
await self._presence_handler.update_external_syncs_clear(cmd.instance_id)

async def on_FEDERATION_ACK(self, cmd: FederationAckCommand):
federation_ack_counter.inc()

if self._federation_sender:
self._federation_sender.federation_ack(cmd.token)

async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand):
remove_pusher_counter.inc()

if self._is_master:
await self._store.delete_pusher_by_app_id_pushkey_user_id(
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
)

self._notifier.on_new_replication_data()

async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand):
invalidate_cache_counter.inc()

if self._is_master:
# We invalidate the cache locally, but then also stream that to other
# workers.
await self._store.invalidate_cache_and_stream(
cmd.cache_func, tuple(cmd.keys)
)

async def on_USER_IP(self, cmd: UserIpCommand):
user_ip_cache_counter.inc()

if self._is_master:
await self._store.insert_client_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
cmd.user_agent,
cmd.device_id,
cmd.last_seen,
)

if self._server_notices_sender:
await self._server_notices_sender.on_user_ip(cmd.user_id)

async def on_RDATA(self, cmd: RdataCommand):
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
Expand Down Expand Up @@ -174,36 +271,73 @@ async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)

if self._is_master:
self._notifier.notify_remote_server_up(cmd.data)

def get_currently_syncing_users(self):
"""Get the list of currently syncing users (if any). This is called
when a connection has been established and we need to send the
currently syncing users.
"""
return self._presence_handler.get_currently_syncing_users()

def update_connection(self, connection):
"""Called when a connection has been established (or lost with None).
def new_connection(self, connection: AbstractConnection):
"""Called when we have a new connection.
"""
self._connection = connection
self._connections.append(connection)

# If we are connected to replication as a client (rather than a server)
# we need to reset the reconnection delay on the client factory (which
# is used to do exponential back off when the connection drops).
#
# Ideally we would reset the delay when we've "fully established" the
# connection (for some definition thereof) to stop us from tightlooping
# on reconnection if something fails after this point and we drop the
# connection. Unfortunately, we don't really have a better definition of
# "fully established" than the connection being established.
if self._factory:
self._factory.resetDelay()

# Tell the server if we have any users currently syncing (should only
# happen on synchrotrons)
currently_syncing = self.get_currently_syncing_users()
now = self._clock.time_msec()
for user_id in currently_syncing:
connection.send_command(
UserSyncCommand(self._instance_id, user_id, True, now)
)

def finished_connecting(self):
"""Called when we have successfully subscribed and caught up to all
streams we're interested in.
def lost_connection(self, connection: AbstractConnection):
"""Called when a connection is closed/lost.
"""
logger.info("Finished connecting to server")
try:
self._connections.remove(connection)
except ValueError:
pass

# We don't reset the delay any earlier as otherwise if there is a
# problem during start up we'll end up tight looping connecting to the
# server.
if self._factory:
self._factory.resetDelay()
def connected(self) -> bool:
"""Do we have any replication connections open?

Is used by e.g. `ReplicationStreamer` to no-op if nothing is connected.
"""
return bool(self._connections)

def send_command(self, cmd: Command):
"""Send a command to master (when we get establish a connection if we
don't have one already.)
"""Send a command to all connected connections.
"""
if self._connection:
self._connection.send_command(cmd)
if self._connections:
for connection in self._connections:
try:
connection.send_command(cmd)
except Exception:
# We probably want to catch some types of exceptions here
# and log them as warnings (e.g. connection gone), but I
# can't find what those exception types they would be.
logger.exception(
"Failed to write command %s to connection %s",
cmd.NAME,
connection,
)
else:
logger.warning("Dropping command as not connected: %r", cmd.NAME)

Expand Down Expand Up @@ -250,3 +384,10 @@ def send_user_ip(

def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server))

def stream_update(self, stream_name: str, token: str, data: Any):
"""Called when a new update is available to stream to clients.

We need to check if the client is interested in the stream or not
"""
self.send_command(RdataCommand(stream_name, token, data))
Loading