Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add redis
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston committed Mar 6, 2020
1 parent c0ff5b3 commit 93a25e7
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 9 deletions.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,6 @@ ignore_missing_imports = True

[mypy-jwt.*]
ignore_missing_imports = True

[mypy-txredisapi]
ignore_missing_imports = True
6 changes: 6 additions & 0 deletions synapse/app/homeserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
from synapse.module_api import ModuleApi
from synapse.python_dependencies import check_requirements
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.tcp.protocol import RedisFactory # noqa: F401
from synapse.replication.tcp.resource import ReplicationStreamer # noqa: F401
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.rest import ClientRestResource
from synapse.rest.admin import AdminRestResource
Expand Down Expand Up @@ -282,6 +284,10 @@ def start_listening(self, listeners):
)
for s in services:
reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)

# factory = RedisFactory(self, ReplicationStreamer(self))
# self.get_reactor().connectTCP("redis", 6379, factory)

elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warning(
Expand Down
1 change: 1 addition & 0 deletions synapse/python_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
"Jinja2>=2.9",
"bleach>=1.4.3",
"typing-extensions>=3.7.4",
"txredisapi",
]

CONDITIONAL_REQUIREMENTS = {
Expand Down
12 changes: 8 additions & 4 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from twisted.internet.protocol import ReconnectingClientFactory

from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.tcp.protocol import (
from synapse.replication.tcp.protocol import ( # RedisFactory,
AbstractReplicationClientHandler,
ClientReplicationStreamProtocol,
)
Expand Down Expand Up @@ -107,11 +107,15 @@ def start_replication(self, hs):
using TCP.
"""
client_name = hs.config.worker_name
self.factory = ReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port

self.factory = ReplicationClientFactory(hs, client_name, self)
hs.get_reactor().connectTCP(host, port, self.factory)

# self.factory = RedisFactory(hs, self)
# hs.get_reactor().connectTCP("redis", 6379, self.factory)

def new_connection(self, connection):
self.connection = connection
if connection:
Expand All @@ -122,7 +126,7 @@ def new_connection(self, connection):
def lost_connection(self, connection):
self.connection = None

def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
pass

def federation_ack(self, token):
Expand Down Expand Up @@ -249,7 +253,7 @@ def finished_connecting(self):
"""Called when we have successfully subscribed and caught up to all
streams we're interested in.
"""
logger.info("Finished connecting to server")
logger.debug("Finished connecting to server")

# We don't reset the delay any earlier as otherwise if there is a
# problem during start up we'll end up tight looping connecting to the
Expand Down
133 changes: 128 additions & 5 deletions synapse/replication/tcp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@

from six import iteritems

import txredisapi as redis
from prometheus_client import Counter

from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure

from synapse.logging.context import PreserveLoggingContext
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import (
Expand Down Expand Up @@ -420,6 +422,8 @@ class CommandHandler:
def __init__(self, hs, handler):
self.handler = handler

self.is_master = hs.config.worker.worker_app is None

self.clock = hs.get_clock()

self.streams = {
Expand Down Expand Up @@ -458,11 +462,22 @@ def lost_connection(self, connection):
self.handler.lost_connection(connection)

async def on_USER_SYNC(self, cmd: UserSyncCommand):
if not self.connection:
raise Exception("Not connected")

await self.handler.on_user_sync(
self.connection.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)

async def on_REPLICATE(self, cmd: ReplicateCommand):
# We only want to announce positions by the writer of the streams.
# Currently this is just the master process.
if not self.is_master:
return

if not self.connection:
raise Exception("Not connected")

for stream_name, stream in self.streams.items():
current_token = stream.current_token()
self.connection.send_command(PositionCommand(stream_name, current_token))
Expand All @@ -483,15 +498,14 @@ async def on_SYNC(self, cmd: SyncCommand):
self.handler.on_sync(cmd.data)

async def on_RDATA(self, cmd: RdataCommand):

stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()

try:
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
except Exception:
logger.exception(
"[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
)
logger.exception("[%s] Failed to parse RDATA: %r", stream_name, cmd.row)
raise

if cmd.token is None or stream_name in self.streams_connecting:
Expand Down Expand Up @@ -519,7 +533,7 @@ async def on_POSITION(self, cmd: PositionCommand):
return

# Fetch all updates between then and now.
limited = True
limited = cmd.token != current_token
while limited:
updates, current_token, limited = await stream.get_updates_since(
current_token, cmd.token
Expand Down Expand Up @@ -582,7 +596,7 @@ def lost_connection(self, connection):
raise NotImplementedError()

@abc.abstractmethod
def on_user_sync(
async def on_user_sync(
self, conn_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
):
"""A client has started/stopped syncing on a worker.
Expand Down Expand Up @@ -794,3 +808,112 @@ def transport_kernel_read_buffer_size(protocol, read=True):
inbound_rdata_count = Counter(
"synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
)


class RedisSubscriber(redis.SubscriberProtocol):
def connectionMade(self):
logger.info("MADE CONNECTION")
self.subscribe(self.stream_name)
self.send_command(ReplicateCommand("ALL"))

self.handler.new_connection(self)

def messageReceived(self, pattern, channel, message):
if message.strip() == "":
# Ignore blank lines
return

line = message
cmd_name, rest_of_line = line.split(" ", 1)

cmd_cls = COMMAND_MAP[cmd_name]
try:
cmd = cmd_cls.from_line(rest_of_line)
except Exception as e:
logger.exception(
"[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line
)
self.send_error(
"failed to parse line for %r: %r (%r):" % (cmd_name, e, rest_of_line)
)
return

# Now lets try and call on_<CMD_NAME> function
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
)

async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream.
By default delegates to on_<COMMAND>, which should return an awaitable.
Args:
cmd: received command
"""
# First call any command handlers on this instance. These are for TCP
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)

# Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)

def connectionLost(self, reason):
logger.info("LOST CONNECTION")
self.handler.lost_connection(self)

def send_command(self, cmd):
"""Send a command if connection has been established.
Args:
cmd (Command)
"""
string = "%s %s" % (cmd.NAME, cmd.to_line())
if "\n" in string:
raise Exception("Unexpected newline in command: %r", string)

encoded_string = string.encode("utf-8")

async def _send():
with PreserveLoggingContext():
await self.redis_connection.publish(self.stream_name, encoded_string)

run_as_background_process("send-cmd", _send)

def stream_update(self, stream_name, token, data):
"""Called when a new update is available to stream to clients.
We need to check if the client is interested in the stream or not
"""
self.send_command(RdataCommand(stream_name, token, data))

def send_sync(self, data):
self.send_command(SyncCommand(data))

def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server))


class RedisFactory(redis.SubscriberFactory):

maxDelay = 5
continueTrying = True
protocol = RedisSubscriber

def __init__(self, hs, handler):
super(RedisFactory, self).__init__()

self.handler = CommandHandler(hs, handler)
self.stream_name = hs.hostname

def buildProtocol(self, addr):
p = super(RedisFactory, self).buildProtocol(addr)
p.handler = self.handler
p.redis_connection = redis.lazyConnection("redis")
p.conn_id = random_string(5) # TODO: FIXME
p.stream_name = self.stream_name
return p

0 comments on commit 93a25e7

Please sign in to comment.