Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add unit test for event persister sharding #8433

Merged
merged 7 commits into from
Oct 2, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,9 @@ def start_replication(self, hs):
using TCP.
"""
if hs.config.redis.redis_enabled:
import txredisapi

from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory,
lazyConnection,
)

logger.info(
Expand All @@ -271,7 +270,8 @@ def start_replication(self, hs):
# connection after SUBSCRIBE is called).

# First create the connection for sending commands.
outbound_redis_connection = txredisapi.lazyConnection(
outbound_redis_connection = lazyConnection(
reactor=hs.get_reactor(),
host=hs.config.redis_host,
port=hs.config.redis_port,
password=hs.config.redis.redis_password,
Expand Down
41 changes: 41 additions & 0 deletions synapse/replication/tcp/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,44 @@ def buildProtocol(self, addr):
p.password = self.password

return p


def lazyConnection(
reactor,
host="localhost",
port=6379,
dbid=None,
reconnect=True,
charset="utf-8",
password=None,
connectTimeout=None,
replyTimeout=None,
convertNumbers=True,
):
"""Equivalent to `txredisapi.lazyConnection`, except allows specifying a
reactor.
Comment on lines +245 to +246
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""

isLazy = True
poolsize = 1

uuid = "%s:%d" % (host, port)
factory = txredisapi.RedisFactory(
uuid,
dbid,
poolsize,
isLazy,
txredisapi.ConnectionHandler,
charset,
password,
replyTimeout,
convertNumbers,
)
factory.continueTrying = reconnect
for x in range(poolsize):
reactor.connectTCP(host, port, factory, connectTimeout)

if isLazy:
return factory.handler
else:
return factory.deferred
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a copy and paste from the function in txredisapi fwiw, but with reactor given as a parameter.

224 changes: 203 additions & 21 deletions tests/replication/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Any, Callable, List, Optional, Tuple

import attr
import hiredis

from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel

Expand All @@ -27,7 +28,7 @@
GenericWorkerServer,
)
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource, streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
Expand Down Expand Up @@ -197,19 +198,37 @@ def setUp(self):
self.server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = self.hs.get_replication_streamer()

# Fake in memory Redis server that servers can connect to.
self._redis_server = FakeRedisPubSubServer()

store = self.hs.get_datastore()
self.database_pool = store.db_pool

self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["localhost"] = "127.0.0.1"

# A map from a HS instance to the associated HTTP Site to use for
# handling inbound HTTP requests to that instance.
self._hs_to_site = {self.hs: self.site}

if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
"localhost", 6379, self.connect_any_redis_attempts,
)

self._worker_hs_to_resource = {}
self.hs.get_tcp_replication().start_replication(self.hs)

# When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't
# manually have to go and explicitly set it up each time (plus sometimes
# it is impossible to write the handling explicitly in the tests).
#
# This sets registers the master replication listener:
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
self.reactor.add_tcp_client_callback(
"1.2.3.4", 8765, self._handle_http_replication_attempt
"1.2.3.4",
8765,
lambda: self._handle_http_replication_attempt(self.hs, 8765),
)

def create_test_json_resource(self):
Expand Down Expand Up @@ -253,28 +272,63 @@ def make_worker_hs(
**kwargs
)

# If the instance is in the `instance_map` config then workers may try
# and send HTTP requests to it, so we register it with
# `_handle_http_replication_attempt` like we do with the master HS.
instance_name = worker_hs.get_instance_name()
instance_loc = worker_hs.config.worker.instance_map.get(instance_name)
if instance_loc:
# Ensure the host is one that has a fake DNS entry.
if instance_loc.host not in self.reactor.lookups:
raise Exception(
"Host does not have an IP for instance_map[%r].host = %r"
% (instance_name, instance_loc.host,)
)

self.reactor.add_tcp_client_callback(
self.reactor.lookups[instance_loc.host],
instance_loc.port,
lambda: self._handle_http_replication_attempt(
worker_hs, instance_loc.port
),
)

store = worker_hs.get_datastore()
store.db_pool._db_pool = self.database_pool._db_pool

repl_handler = ReplicationCommandHandler(worker_hs)
client = ClientReplicationStreamProtocol(
worker_hs, "client", "test", self.clock, repl_handler,
)
server = self.server_factory.buildProtocol(None)
# Set up TCP replication between master and the new worker if we don't
# have Redis support enabled.
if not worker_hs.config.redis_enabled:
repl_handler = ReplicationCommandHandler(worker_hs)
client = ClientReplicationStreamProtocol(
worker_hs, "client", "test", self.clock, repl_handler,
)
server = self.server_factory.buildProtocol(None)

client_transport = FakeTransport(server, self.reactor)
client.makeConnection(client_transport)
client_transport = FakeTransport(server, self.reactor)
client.makeConnection(client_transport)

server_transport = FakeTransport(client, self.reactor)
server.makeConnection(server_transport)
server_transport = FakeTransport(client, self.reactor)
server.makeConnection(server_transport)

# Set up a resource for the worker
resource = ReplicationRestResource(self.hs)
resource = ReplicationRestResource(worker_hs)

for servlet in self.servlets:
servlet(worker_hs, resource)

self._worker_hs_to_resource[worker_hs] = resource
self._hs_to_site[worker_hs] = SynapseSite(
logger_name="synapse.access.http.fake",
site_tag="{}-{}".format(
worker_hs.config.server.server_name, worker_hs.get_instance_name()
),
config=worker_hs.config.server.listeners[0],
resource=resource,
server_version_string="1",
)

if worker_hs.config.redis.redis_enabled:
worker_hs.get_tcp_replication().start_replication(worker_hs)

return worker_hs

Expand All @@ -285,7 +339,7 @@ def _get_worker_hs_config(self) -> dict:
return config

def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
render(request, self._hs_to_site[worker_hs].resource, self.reactor)

def replicate(self):
"""Tell the master side of replication that something has happened, and then
Expand All @@ -294,9 +348,9 @@ def replicate(self):
self.streamer.on_notifier_poke()
self.pump()

def _handle_http_replication_attempt(self):
"""Handles a connection attempt to the master replication HTTP
listener.
def _handle_http_replication_attempt(self, hs, repl_port):
"""Handles a connection attempt to the given HS replication HTTP
listener on the given port.
"""

# We should have at least one outbound connection attempt, where the
Expand All @@ -305,7 +359,7 @@ def _handle_http_replication_attempt(self):
self.assertGreaterEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 8765)
self.assertEqual(port, repl_port)

# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
Expand All @@ -315,7 +369,7 @@ def _handle_http_replication_attempt(self):
# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor)
channel.requestFactory = request_factory
channel.site = self.site
channel.site = self._hs_to_site[hs]

# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
Expand All @@ -333,6 +387,32 @@ def _handle_http_replication_attempt(self):
# inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection.

def connect_any_redis_attempts(self):
"""If redis is enabled we need to deal with workers connecting to a
redis server. We don't want to use a real Redis server so we use a
fake one.
"""
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, "localhost")
self.assertEqual(port, 6379)

client_protocol = client_factory.buildProtocol(None)
server_protocol = self._redis_server.buildProtocol(None)

client_to_server_transport = FakeTransport(
server_protocol, self.reactor, client_protocol
)
client_protocol.makeConnection(client_to_server_transport)

server_to_client_transport = FakeTransport(
client_protocol, self.reactor, server_protocol
)
server_protocol.makeConnection(server_to_client_transport)

return client_to_server_transport, server_to_client_transport


class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
Expand Down Expand Up @@ -467,3 +547,105 @@ def _run_once(self):
pass

self.stopProducing()


class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub.
"""

def __init__(self):
self._subscribers = set()

def add_subscriber(self, conn):
"""A connection has called SUBSCRIBE
"""
self._subscribers.add(conn)

def remove_subscriber(self, conn):
"""A connection has called UNSUBSCRIBE
"""
self._subscribers.discard(conn)

def publish(self, conn, channel, msg) -> int:
"""A connection want to publish a message to subscribers.
"""
for sub in self._subscribers:
sub.send(["message", channel, msg])

return len(self._subscribers)

def buildProtocol(self, addr):
return FakeRedisPubSubProtocol(self)


class FakeRedisPubSubProtocol(Protocol):
"""A connection from a client talking to the fake Redis server.
"""

def __init__(self, server: FakeRedisPubSubServer):
self._server = server
self._reader = hiredis.Reader()

def dataReceived(self, data):
self._reader.feed(data)

# We might get multiple messages in one packet.
while True:
msg = self._reader.gets()

if msg is False:
# No more messages.
return

if not isinstance(msg, list):
# Inbound commands should always be a list
raise Exception("Expected redis list")

self.handle_command(msg[0], *msg[1:])

def handle_command(self, command, *args):
"""Received a Redis command from the client.
"""

# We currently only support pub/sub.
if command == b"PUBLISH":
channel, message = args
num_subscribers = self._server.publish(self, channel, message)
self.send(num_subscribers)
elif command == b"SUBSCRIBE":
(channel,) = args
self._server.add_subscriber(self)
self.send(["subscribe", channel, 1])
else:
raise Exception("Unknown command")

def send(self, msg):
"""Send a message back to the client.
"""
raw = self.encode(msg).encode("utf-8")

self.transport.write(raw)
self.transport.flush()

def encode(self, obj):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like something we should be able to pull from the txredis library or something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You would have thought, but neither of the two redis libraries we use expose it :/

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's unfortunate. 😢

"""Encode an object to its Redis format.

Supports: strings/bytes, integers and list/tuples.
"""

if isinstance(obj, bytes):
# We assume bytes are just unicode strings.
obj = obj.decode("utf-8")

if isinstance(obj, str):
return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
if isinstance(obj, int):
return ":{val}\r\n".format(val=obj)
if isinstance(obj, (list, tuple)):
items = "".join(self.encode(a) for a in obj)
return "*{len}\r\n{items}".format(len=len(obj), items=items)

raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)

def connectionList(self, reason):
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
self._server.remove_subscriber(self)
Loading