Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Allow moving account data and receipts streams off master #9104

Merged
merged 14 commits into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions synapse/replication/slave/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,13 @@ def __init__(self, database: DatabasePool, db_conn, hs):
database,
stream_name="caches",
instance_name=hs.get_instance_name(),
table="cache_invalidation_stream_by_instance",
instance_column="instance_name",
id_column="stream_id",
tables=[
(
"cache_invalidation_stream_by_instance",
"instance_name",
"stream_id",
)
],
sequence_name="cache_invalidation_stream_seq",
writers=[],
) # type: Optional[MultiWriterIdGenerator]
Expand Down
10 changes: 7 additions & 3 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,13 @@ def __init__(self, database: DatabasePool, db_conn, hs):
database,
stream_name="caches",
instance_name=hs.get_instance_name(),
table="cache_invalidation_stream_by_instance",
instance_column="instance_name",
id_column="stream_id",
tables=[
(
"cache_invalidation_stream_by_instance",
"instance_name",
"stream_id",
)
],
sequence_name="cache_invalidation_stream_seq",
writers=[],
)
Expand Down
4 changes: 1 addition & 3 deletions synapse/storage/databases/main/deviceinbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def __init__(self, database: DatabasePool, db_conn, hs):
db=database,
stream_name="to_device",
instance_name=self._instance_name,
table="device_inbox",
instance_column="instance_name",
id_column="stream_id",
tables=[("device_inbox", "instance_name", "stream_id")],
sequence_name="device_inbox_sequence",
writers=hs.config.worker.writers.to_device,
)
Expand Down
8 changes: 2 additions & 6 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,7 @@ def __init__(self, database: DatabasePool, db_conn, hs):
db=database,
stream_name="events",
instance_name=hs.get_instance_name(),
table="events",
instance_column="instance_name",
id_column="stream_ordering",
tables=[("events", "instance_name", "stream_ordering")],
sequence_name="events_stream_seq",
writers=hs.config.worker.writers.events,
)
Expand All @@ -107,9 +105,7 @@ def __init__(self, database: DatabasePool, db_conn, hs):
db=database,
stream_name="backfill",
instance_name=hs.get_instance_name(),
table="events",
instance_column="instance_name",
id_column="stream_ordering",
tables=[("events", "instance_name", "stream_ordering")],
sequence_name="events_backfill_stream_seq",
positive=False,
writers=hs.config.worker.writers.events,
Expand Down
84 changes: 48 additions & 36 deletions synapse/storage/util/id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import threading
from collections import deque
from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Union
from typing import Dict, List, Optional, Set, Tuple, Union

import attr
from typing_extensions import Deque
Expand Down Expand Up @@ -186,11 +186,12 @@ class MultiWriterIdGenerator:
Args:
db_conn
db
stream_name: A name for the stream.
stream_name: A name for the stream, for use in the `stream_positions`
table. (Does not need to be the same as the replication stream name)
instance_name: The name of this instance.
table: Database table associated with stream.
instance_column: Column that stores the row's writer's instance name
id_column: Column that stores the stream ID.
tables: List of tables associated with the stream. Tuple of table
name, column name that stores the writer's instance name, and
column name that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new
IDs.
writers: A list of known writers to use to populate current positions
Expand All @@ -206,9 +207,7 @@ def __init__(
db: DatabasePool,
stream_name: str,
instance_name: str,
table: str,
instance_column: str,
id_column: str,
tables: List[Tuple[str, str, str]],
sequence_name: str,
writers: List[str],
positive: bool = True,
Expand Down Expand Up @@ -260,15 +259,16 @@ def __init__(
self._sequence_gen = PostgresSequenceGenerator(sequence_name)

# We check that the table and sequence haven't diverged.
self._sequence_gen.check_consistency(
db_conn, table=table, id_column=id_column, positive=positive
)
for table, _, id_column in tables:
self._sequence_gen.check_consistency(
db_conn, table=table, id_column=id_column, positive=positive
)

# This goes and fills out the above state from the database.
self._load_current_ids(db_conn, table, instance_column, id_column)
self._load_current_ids(db_conn, tables)

def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
self, db_conn, tables: List[Tuple[str, str, str]],
):
cur = db_conn.cursor(txn_name="_load_current_ids")

Expand Down Expand Up @@ -306,17 +306,22 @@ def _load_current_ids(
# We add a GREATEST here to ensure that the result is always
# positive. (This can be a problem for e.g. backfill streams where
# the server has never backfilled).
sql = """
SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
FROM %(table)s
""" % {
"id": id_column,
"table": table,
"agg": "MAX" if self._positive else "-MIN",
}
cur.execute(sql)
(stream_id,) = cur.fetchone()
self._persisted_upto_position = stream_id
max_stream_id = 1
for table, _, id_column in tables:
sql = """
SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
FROM %(table)s
""" % {
"id": id_column,
"table": table,
"agg": "MAX" if self._positive else "-MIN",
}
cur.execute(sql)
(stream_id,) = cur.fetchone()

max_stream_id = max(max_stream_id, stream_id)

self._persisted_upto_position = max_stream_id
else:
# If we have a min_stream_id then we pull out everything greater
# than it from the DB so that we can prefill
Expand All @@ -329,21 +334,28 @@ def _load_current_ids(
# stream positions table before restart (or the stream position
# table otherwise got out of date).

sql = """
SELECT %(instance)s, %(id)s FROM %(table)s
WHERE ? %(cmp)s %(id)s
""" % {
"id": id_column,
"table": table,
"instance": instance_column,
"cmp": "<=" if self._positive else ">=",
}
cur.execute(sql, (min_stream_id * self._return_factor,))

self._persisted_upto_position = min_stream_id

rows = []
for table, instance_column, id_column in tables:
sql = """
SELECT %(instance)s, %(id)s FROM %(table)s
WHERE ? %(cmp)s %(id)s
""" % {
"id": id_column,
"table": table,
"instance": instance_column,
"cmp": "<=" if self._positive else ">=",
}
cur.execute(sql, (min_stream_id * self._return_factor,))

rows.extend(cur)

# Sort so that we handle rows in order for each instance.
rows.sort()

with self._lock:
for (instance, stream_id,) in cur:
for (instance, stream_id,) in rows:
stream_id = self._return_factor * stream_id
self._add_persisted_position(stream_id)

Expand Down
108 changes: 102 additions & 6 deletions tests/storage/test_id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def _create(conn):
self.db_pool,
stream_name="test_stream",
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
id_column="stream_id",
tables=[("foobar", "instance_name", "stream_id")],
sequence_name="foobar_seq",
writers=writers,
)
Expand Down Expand Up @@ -487,9 +485,7 @@ def _create(conn):
self.db_pool,
stream_name="test_stream",
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
id_column="stream_id",
tables=[("foobar", "instance_name", "stream_id")],
sequence_name="foobar_seq",
writers=writers,
positive=False,
Expand Down Expand Up @@ -579,3 +575,103 @@ async def _get_next_async2():
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)


class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"

def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.db_pool = self.store.db_pool # type: DatabasePool

self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))

def _setup_db(self, txn):
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
CREATE TABLE foobar1 (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)

txn.execute(
"""
CREATE TABLE foobar2 (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)

def _create_id_generator(
self, instance_name="master", writers=["master"]
) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
conn,
self.db_pool,
stream_name="test_stream",
instance_name=instance_name,
tables=[
("foobar1", "instance_name", "stream_id"),
("foobar2", "instance_name", "stream_id"),
],
sequence_name="foobar_seq",
writers=writers,
)

return self.get_success_or_raise(self.db_pool.runWithConnection(_create))

def _insert_rows(
self, table: str, instance_name: str, number: int, update_stream_table=True
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
):
"""Insert N rows as the given instance, inserting with stream IDs pulled
from the postgres sequence.
"""

def _insert(txn):
for _ in range(number):
txn.execute(
"INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
(instance_name,),
)
if update_stream_table:
txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
""",
(instance_name,),
)

self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))

def test_load_existing_stream(self):
"""Test creating ID gens with multiple tables that have rows from after
the position in `stream_positions` table.
"""
self._insert_rows("foobar1", "first", 3)
self._insert_rows("foobar2", "second", 3)
self._insert_rows("foobar2", "second", 1, update_stream_table=False)

first_id_gen = self._create_id_generator("first", writers=["first", "second"])
second_id_gen = self._create_id_generator("second", writers=["first", "second"])

# The first ID gen will notice that it can advance its token to 7 as it
# has no in progress writes...
self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)

# ... but the second ID gen doesn't know that.
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)