Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use stream.current_token() and remove stream_positions() #7172

Merged
merged 5 commits into from
May 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7172.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use `stream.current_token()` and remove `stream_positions()`.
16 changes: 0 additions & 16 deletions synapse/app/generic_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,12 +412,6 @@ def _reset(self):
# map room IDs to sets of users currently typing
self._room_typing = {}

def stream_positions(self):
# We must update this typing token from the response of the previous
# sync. In particular, the stream id may "reset" back to zero/a low
# value which we *must* use for the next replication request.
return {"typing": self._latest_room_serial}

def process_replication_rows(self, token, rows):
if self._latest_room_serial > token:
# The master has gone backwards. To prevent inconsistent data, just
Expand Down Expand Up @@ -656,13 +650,6 @@ async def on_rdata(self, stream_name, token, rows):
)
await self.process_and_notify(stream_name, token, rows)

def get_streams_to_replicate(self):
args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate()
args.update(self.typing_handler.stream_positions())
if self.send_handler:
args.update(self.send_handler.stream_positions())
return args

async def process_and_notify(self, stream_name, token, rows):
try:
if self.send_handler:
Expand Down Expand Up @@ -797,9 +784,6 @@ def on_start(self):
def wake_destination(self, server: str):
self.federation_sender.wake_destination(server)

def stream_positions(self):
return {"federation": self.federation_position}

async def process_replication_rows(self, stream_name, token, rows):
# The federation stream contains things that we want to send out, e.g.
# presence, typing, etc.
Expand Down
15 changes: 1 addition & 14 deletions synapse/replication/slave/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import logging
from typing import Dict, Optional
from typing import Optional

import six

Expand Down Expand Up @@ -49,19 +49,6 @@ def __init__(self, database: Database, db_conn, hs):

self.hs = hs

def stream_positions(self) -> Dict[str, int]:
"""
Get the current positions of all the streams this store wants to subscribe to

Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
pos = {}
if self._cache_id_gen:
pos["caches"] = self._cache_id_gen.get_current_token()
return pos

def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
Expand Down
8 changes: 0 additions & 8 deletions synapse/replication/slave/storage/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,6 @@ def __init__(self, database: Database, db_conn, hs):
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()

def stream_positions(self):
result = super(SlavedAccountDataStore, self).stream_positions()
position = self._account_data_id_gen.get_current_token()
result["user_account_data"] = position
result["room_account_data"] = position
result["tag_account_data"] = position
return result

def process_replication_rows(self, stream_name, token, rows):
if stream_name == "tag_account_data":
self._account_data_id_gen.advance(token)
Expand Down
5 changes: 0 additions & 5 deletions synapse/replication/slave/storage/deviceinbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,6 @@ def __init__(self, database: Database, db_conn, hs):
expiry_ms=30 * 60 * 1000,
)

def stream_positions(self):
result = super(SlavedDeviceInboxStore, self).stream_positions()
result["to_device"] = self._device_inbox_id_gen.get_current_token()
return result

def process_replication_rows(self, stream_name, token, rows):
if stream_name == "to_device":
self._device_inbox_id_gen.advance(token)
Expand Down
10 changes: 0 additions & 10 deletions synapse/replication/slave/storage/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,6 @@ def __init__(self, database: Database, db_conn, hs):
"DeviceListFederationStreamChangeCache", device_list_max
)

def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
# The user signature stream uses the same stream ID generator as the
# device list stream, so set them both to the device list ID
# generator's current token.
current_token = self._device_list_id_gen.get_current_token()
result[DeviceListsStream.NAME] = current_token
result[UserSignatureStream.NAME] = current_token
return result

def process_replication_rows(self, stream_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
Expand Down
6 changes: 0 additions & 6 deletions synapse/replication/slave/storage/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,6 @@ def get_room_max_stream_ordering(self):
def get_room_min_stream_ordering(self):
return self._backfill_id_gen.get_current_token()

def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token()
result["backfill"] = -self._backfill_id_gen.get_current_token()
return result

def process_replication_rows(self, stream_name, token, rows):
if stream_name == "events":
self._stream_id_gen.advance(token)
Expand Down
5 changes: 0 additions & 5 deletions synapse/replication/slave/storage/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,6 @@ def __init__(self, database: Database, db_conn, hs):
def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token()

def stream_positions(self):
result = super(SlavedGroupServerStore, self).stream_positions()
result["groups"] = self._group_updates_id_gen.get_current_token()
return result

def process_replication_rows(self, stream_name, token, rows):
if stream_name == "groups":
self._group_updates_id_gen.advance(token)
Expand Down
9 changes: 0 additions & 9 deletions synapse/replication/slave/storage/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,6 @@ def __init__(self, database: Database, db_conn, hs):
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()

def stream_positions(self):
result = super(SlavedPresenceStore, self).stream_positions()

if self.hs.config.use_presence:
position = self._presence_id_gen.get_current_token()
result["presence"] = position

return result

def process_replication_rows(self, stream_name, token, rows):
if stream_name == "presence":
self._presence_id_gen.advance(token)
Expand Down
5 changes: 0 additions & 5 deletions synapse/replication/slave/storage/push_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,6 @@ def get_push_rules_stream_token(self):
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()

def stream_positions(self):
result = super(SlavedPushRuleStore, self).stream_positions()
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
return result

def process_replication_rows(self, stream_name, token, rows):
if stream_name == "push_rules":
self._push_rules_stream_id_gen.advance(token)
Expand Down
5 changes: 0 additions & 5 deletions synapse/replication/slave/storage/pushers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ def __init__(self, database: Database, db_conn, hs):
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)

def stream_positions(self):
result = super(SlavedPusherStore, self).stream_positions()
result["pushers"] = self._pushers_id_gen.get_current_token()
return result

def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()

Expand Down
5 changes: 0 additions & 5 deletions synapse/replication/slave/storage/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@ def __init__(self, database: Database, db_conn, hs):
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()

def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions()
result["receipts"] = self._receipts_id_gen.get_current_token()
return result

def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate_many((room_id,))
Expand Down
5 changes: 0 additions & 5 deletions synapse/replication/slave/storage/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ def __init__(self, database: Database, db_conn, hs):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()

def stream_positions(self):
result = super(RoomStore, self).stream_positions()
result["public_rooms"] = self._public_room_id_gen.get_current_token()
return result

def process_replication_rows(self, stream_name, token, rows):
if stream_name == "public_rooms":
self._public_room_id_gen.advance(token)
Expand Down
19 changes: 1 addition & 18 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""

import logging
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING

from twisted.internet.protocol import ReconnectingClientFactory

Expand Down Expand Up @@ -100,23 +100,6 @@ async def on_rdata(self, stream_name: str, token: int, rows: list):
"""
self.store.process_replication_rows(stream_name, token, rows)

def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.

Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
args = self.store.stream_positions()
user_account_data = args.pop("user_account_data", None)
room_account_data = args.pop("room_account_data", None)
if user_account_data:
args["account_data"] = user_account_data
elif room_account_data:
args["account_data"] = room_account_data
return args

async def on_position(self, stream_name: str, token: int):
self.store.process_replication_rows(stream_name, token, [])

Expand Down
10 changes: 1 addition & 9 deletions synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,15 +303,7 @@ async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
self._pending_batches.pop(cmd.stream_name, [])

# Find where we previously streamed up to.
current_token = self._replication_data_handler.get_streams_to_replicate().get(
cmd.stream_name
)
if current_token is None:
logger.warning(
"Got POSITION for stream we're not subscribed to: %s",
cmd.stream_name,
)
return
current_token = stream.current_token()

# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
Expand Down
30 changes: 10 additions & 20 deletions tests/replication/tcp/streams/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,23 @@
# limitations under the License.

import logging
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, List, Optional, Tuple

import attr

from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel

from synapse.app.generic_worker import GenericWorkerServer
from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
GenericWorkerServer,
)
from synapse.http.site import SynapseRequest
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
from synapse.util import Clock

from tests import unittest
Expand Down Expand Up @@ -77,7 +79,7 @@ def prepare(self, reactor, clock, hs):
self._server_transport = None

def _build_replication_data_handler(self):
return TestReplicationDataHandler(self.worker_hs.get_datastore())
return TestReplicationDataHandler(self.worker_hs)

def reconnect(self):
if self._client_transport:
Expand Down Expand Up @@ -172,32 +174,20 @@ def assert_request_is_get_repl_stream_updates(
self.assertEqual(request.method, b"GET")


class TestReplicationDataHandler(ReplicationDataHandler):
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""

def __init__(self, store: BaseSlavedStore):
super().__init__(store)

# streams to subscribe to: map from stream id to position
self.stream_positions = {} # type: Dict[str, int]
def __init__(self, hs: HomeServer):
super().__init__(hs)

# list of received (stream_name, token, row) tuples
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]

def get_streams_to_replicate(self):
return self.stream_positions

async def on_rdata(self, stream_name, token, rows):
await super().on_rdata(stream_name, token, rows)
for r in rows:
self.received_rdata_rows.append((stream_name, token, r))

if (
stream_name in self.stream_positions
and token > self.stream_positions[stream_name]
):
self.stream_positions[stream_name] = token


@attr.s()
class OneShotRequestFactory:
Expand Down
24 changes: 16 additions & 8 deletions tests/replication/tcp/streams/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def prepare(self, reactor, clock, hs):
self.user_tok = self.login("u1", "pass")

self.reconnect()
self.test_handler.stream_positions["events"] = 0

self.room_id = self.helper.create_room_as(tok=self.user_tok)
self.test_handler.received_rdata_rows.clear()
Expand Down Expand Up @@ -80,8 +79,12 @@ def test_update_function_event_row_limit(self):
self.reconnect()
self.replicate()

# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
# we should have received all the expected rows in the right order (as
# well as various cache invalidation updates which we ignore)
received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
]

for event in events:
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
Expand Down Expand Up @@ -184,7 +187,8 @@ def test_update_function_huge_state_change(self):
self.reconnect()
self.replicate()

# now we should have received all the expected rows in the right order.
# we should have received all the expected rows in the right order (as
# well as various cache invalidation updates which we ignore)
#
# we expect:
#
Expand All @@ -193,7 +197,9 @@ def test_update_function_huge_state_change(self):
# of the states that got reverted.
# - two rows for state2

received_rows = self.test_handler.received_rdata_rows
received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
]

# first check the first two rows, which should be state1

Expand Down Expand Up @@ -334,9 +340,11 @@ def test_update_function_state_row_limit(self):
self.reconnect()
self.replicate()

# we should have received all the expected rows in the right order

received_rows = self.test_handler.received_rdata_rows
# we should have received all the expected rows in the right order (as
# well as various cache invalidation updates which we ignore)
received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
]
self.assertGreaterEqual(len(received_rows), len(events))
for i in range(NUM_USERS):
# for each user, we expect the PL event row, followed by state rows for
Expand Down
3 changes: 0 additions & 3 deletions tests/replication/tcp/streams/test_receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ def _build_replication_data_handler(self):
def test_receipt(self):
self.reconnect()

# make the client subscribe to the receipts stream
self.test_handler.stream_positions.update({"receipts": 0})

# tell the master to send a new receipt
self.get_success(
self.hs.get_datastore().insert_receipt(
Expand Down
Loading