Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Move server command handling out of TCP protocol #7187

Merged
merged 16 commits into from
Apr 7, 2020
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
UserIpCommand,
UserSyncCommand,
)
from synapse.replication.tcp.protocol import AbstractConnection
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
from synapse.util.async_helpers import Linearizer

Expand Down Expand Up @@ -64,6 +65,8 @@ def __init__(self, hs):
self._presence_handler = hs.get_presence_handler()
self._store = hs.get_datastore()
self._notifier = hs.get_notifier()
self._clock = hs.get_clock()
self._instance_id = hs.get_instance_id()

# Set of streams that we've caught up with.
self._streams_connected = set() # type: Set[str]
Expand All @@ -82,7 +85,7 @@ def __init__(self, hs):
self._factory = None # type: Optional[ReplicationClientFactory]

# The currently connected connections.
self._connections = [] # type: List[Any]
self._connections = [] # type: List[AbstractConnection]

LaterGauge(
"synapse_replication_tcp_resource_total_connections",
Expand Down Expand Up @@ -118,9 +121,6 @@ async def on_REPLICATE(self, cmd: ReplicateCommand):
if not self._is_master:
return

if not self._connections:
raise Exception("Not connected")

for stream_name, stream in self._streams.items():
current_token = stream.current_token()
self.send_command(PositionCommand(stream_name, current_token))
Expand Down Expand Up @@ -281,19 +281,33 @@ def get_currently_syncing_users(self):
"""
return self._presence_handler.get_currently_syncing_users()

def new_connection(self, connection):
def new_connection(self, connection: AbstractConnection):
"""Called when we have a new connection.
"""
self._connections.append(connection)

# If we're using a ReplicationClientFactory then we reset the connection
# delay now. We don't reset the delay any earlier as otherwise if there
# is a problem during start up we'll end up tight looping connecting to
# the server.
# If we are connected to replication as a client (rather than a server)
# we need to reset the reconnection delay on the client factory (which
# is used to do exponential back off when the connection drops).
#
# Ideally we would reset the delay when we've "fully established" the
# connection (for some definition thereof) to stop us from tightlooping
# on reconnection if something fails after this point and we drop the
# connection. Unfortunately, we don't really have a better definition of
# "fully established" than the connection being established.
if self._factory:
self._factory.resetDelay()

def lost_connection(self, connection):
# Tell the server if we have any users currently syncing (should only
# happen on synchrotrons)
currently_syncing = self.get_currently_syncing_users()
now = self._clock.time_msec()
for user_id in currently_syncing:
connection.send_command(
UserSyncCommand(self._instance_id, user_id, True, now)
)

def lost_connection(self, connection: AbstractConnection):
"""Called when a connection is closed/lost.
"""
try:
Expand All @@ -304,17 +318,26 @@ def lost_connection(self, connection):
def connected(self) -> bool:
"""Do we have any replication connections open?

Used to no-op if nothing is connected.
Is used by e.g. `ReplicationStreamer` to no-op if nothing is connected.
"""
return bool(self._connections)

def send_command(self, cmd: Command):
"""Send a command to master (when we get establish a connection if we
don't have one already.)
"""Send a command to all connected connections.
"""
if self._connections:
for connection in self._connections:
connection.send_command(cmd)
try:
connection.send_command(cmd)
except Exception:
# We probably want to catch some types of exceptions here
# and log them as warnings (e.g. connection gone), but I
# can't find what those exception types they would be.
logger.exception(
"Failed to write command %s to connection %s",
cmd.NAME,
connection,
)
else:
logger.warning("Dropping command as not connected: %r", cmd.NAME)

Expand Down
48 changes: 29 additions & 19 deletions synapse/replication/tcp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
> ERROR server stopping
* connection closed by server *
"""
import abc
import fcntl
import logging
import struct
Expand All @@ -71,7 +72,6 @@
PingCommand,
ReplicateCommand,
ServerCommand,
UserSyncCommand,
)
from synapse.types import Collection
from synapse.util import Clock
Expand Down Expand Up @@ -132,7 +132,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):

def __init__(self, clock: Clock, handler: "ReplicationCommandHandler"):
self.clock = clock
self.handler = handler
self.command_handler = handler

self.last_received_command = self.clock.time_msec()
self.last_sent_command = 0
Expand Down Expand Up @@ -172,7 +172,7 @@ def connectionMade(self):
# can time us out.
self.send_command(PingCommand(self.clock.time_msec()))

self.handler.new_connection(self)
self.command_handler.new_connection(self)

def send_ping(self):
"""Periodically sends a ping and checks if we should close the connection
Expand Down Expand Up @@ -242,7 +242,10 @@ def lineReceived(self, line):
async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream.

By default delegates to on_<COMMAND>, which should return an awaitable.
First calls `self.on_<COMMAND>` if it exists, then calls
`self.command_handler.on_<COMMAND>` if it exists. This allows for
protocol level handling of commands (e.g. PINGs), before delegating to
the handler.
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved

Args:
cmd: received command
Expand All @@ -257,7 +260,7 @@ async def handle_command(self, cmd: Command):
handled = True

# Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
Expand Down Expand Up @@ -392,7 +395,7 @@ def on_connection_closed(self):
self.state = ConnectionStates.CLOSED
self.pending_commands = []

self.handler.lost_connection(self)
self.command_handler.lost_connection(self)

if self.transport:
self.transport.unregisterProducer()
Expand Down Expand Up @@ -423,13 +426,13 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
def __init__(
self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler"
):
BaseReplicationStreamProtocol.__init__(self, clock, handler) # Old style class
super().__init__(clock, handler)

self.server_name = server_name

def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
BaseReplicationStreamProtocol.connectionMade(self)
super().connectionMade()

async def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
Expand All @@ -448,27 +451,18 @@ def __init__(
clock: Clock,
command_handler: "ReplicationCommandHandler",
):
BaseReplicationStreamProtocol.__init__(self, clock, command_handler)

self.instance_id = hs.get_instance_id()
super().__init__(clock, command_handler)

self.client_name = client_name
self.server_name = server_name

def connectionMade(self):
self.send_command(NameCommand(self.client_name))
BaseReplicationStreamProtocol.connectionMade(self)
super().connectionMade()

# Once we've connected subscribe to the necessary streams
self.replicate()

# Tell the server if we have any users currently syncing (should only
# happen on synchrotrons)
currently_syncing = self.handler.get_currently_syncing_users()
now = self.clock.time_msec()
for user_id in currently_syncing:
self.send_command(UserSyncCommand(self.instance_id, user_id, True, now))
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved

async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
Expand All @@ -482,6 +476,22 @@ def replicate(self):
self.send_command(ReplicateCommand())


class AbstractConnection(abc.ABC):
"""An interface for replication connections.
"""

@abc.abstractmethod
def send_command(self, cmd: Command):
"""Send the command down the connection
"""
pass


# This tells python that `BaseReplicationStreamProtocol` implements the
# interface.
AbstractConnection.register(BaseReplicationStreamProtocol)


# The following simply registers metrics for the replication connections

pending_commands = LaterGauge(
Expand Down
13 changes: 7 additions & 6 deletions synapse/replication/tcp/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,17 @@ class ReplicationStreamProtocolFactory(Factory):
"""

def __init__(self, hs):
self.handler = hs.get_tcp_replication()
self.command_handler = hs.get_tcp_replication()
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
self.hs = hs

# Ensure the replication streamer is started if we register a
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how does this "if we register a replication server endpoint" relate to this bit of code?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, you'd always do something like listen_tcp("localhost", 80, ReplicationStreamProtocolFactory(hs)) or whatever. Its not a fantastic place to put it, but not sure where else would be a good place

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The subtlety here is that we only really want to start this if we configure a replication listener. I guess we could put this in HomeServer.setup and then check if we have a replication listener? But not sure that is a better place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok so what you're saying is:

The fact that we have got here implies that we are trying to register a replication server endpoint, so we should make sure that we start the replication streamer so that we have something to send to the clients when they connect.

This isn't terribly elegant, but we don't really want to start a replication streamer unless we have a replication server endpoint, which makes it awkward to put this call anywhere else.

which is fine, but please can you make the comment say so? :)

# replication server endpoint.
hs.get_replication_streamer()

def buildProtocol(self, addr):
return ServerReplicationStreamProtocol(
self.server_name, self.clock, self.handler
self.server_name, self.clock, self.command_handler
)


Expand Down Expand Up @@ -94,7 +93,7 @@ def __init__(self, hs):
self.is_looping = False
self.pending_updates = False

self.client = hs.get_tcp_replication()
self.command_handler = hs.get_tcp_replication()

def get_streams(self) -> Dict[str, Stream]:
"""Get a mapp from stream name to stream instance.
Expand All @@ -108,7 +107,7 @@ def on_notifier_poke(self):
This should get called each time new data is available, even if it
is currently being executed, so that nothing gets missed
"""
if not self.client.connected():
if not self.command_handler.connected():
# Don't bother if nothing is listening. We still need to advance
# the stream tokens otherwise they'll fall beihind forever
for stream in self.streams:
Expand Down Expand Up @@ -183,7 +182,9 @@ async def _run_notifier_loop(self):

for token, row in batched_updates:
try:
self.client.stream_update(stream.NAME, token, row)
self.command_handler.stream_update(
stream.NAME, token, row
)
except Exception:
logger.exception("Failed to replicate")

Expand Down