Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add tests for StreamIdGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
David Robertson committed Nov 16, 2022
1 parent 618e4ab commit fa5c919
Showing 1 changed file with 143 additions and 2 deletions.
145 changes: 143 additions & 2 deletions tests/storage/test_id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,156 @@
from twisted.test.proto_helpers import MemoryReactor

from synapse.server import HomeServer
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.util import Clock

from tests.unittest import HomeserverTestCase
from tests.utils import USE_POSTGRES_FOR_TESTS


class StreamIdGeneratorTestCase(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool

self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))

def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute(
"""
CREATE TABLE foobar (
stream_id BIGINT NOT NULL,
data TEXT
);
"""
)
txn.execute("INSERT INTO foobar VALUES (123, 'hello world');")

def _create_id_generator(self) -> StreamIdGenerator:
def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
return StreamIdGenerator(
db_conn=conn,
table="foobar",
column="stream_id",
)

return self.get_success_or_raise(self.db_pool.runWithConnection(_create))

def test_initial_value(self) -> None:
"""Check that we read the current token from the DB."""
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_current_token(), 123)

def test_single_gen_next(self) -> None:
"""Check that we correctly increment the current token from the DB."""
id_gen = self._create_id_generator()

async def test_gen_next() -> None:
async with id_gen.get_next() as next_id:
# We haven't persisted `next_id` yet; current token is still 123
self.assertEqual(id_gen.get_current_token(), 123)
# But we did learn what the next value is
self.assertEqual(next_id, 124)

# Once the context manager closes we assume that the `next_id` has been
# written to the DB.
self.assertEqual(id_gen.get_current_token(), 124)

self.get_success(test_gen_next())

def test_multiple_gen_nexts(self) -> None:
"""Check that we handle overlapping calls to gen_next sensibly."""
id_gen = self._create_id_generator()

async def test_gen_next() -> None:
ctx1 = id_gen.get_next()
ctx2 = id_gen.get_next()
ctx3 = id_gen.get_next()

# Request three new stream IDs.
self.assertEqual(await ctx1.__aenter__(), 124)
self.assertEqual(await ctx2.__aenter__(), 125)
self.assertEqual(await ctx3.__aenter__(), 126)

# None are persisted: current token unchanged.
self.assertEqual(id_gen.get_current_token(), 123)

# Persist each in turn.
await ctx1.__aexit__(None, None, None)
self.assertEqual(id_gen.get_current_token(), 124)
await ctx2.__aexit__(None, None, None)
self.assertEqual(id_gen.get_current_token(), 125)
await ctx3.__aexit__(None, None, None)
self.assertEqual(id_gen.get_current_token(), 126)

self.get_success(test_gen_next())

def test_multiple_gen_nexts_closed_in_different_order(self) -> None:
"""Check that we handle overlapping calls to gen_next, even when their IDs
created and persisted in different orders."""
id_gen = self._create_id_generator()

async def test_gen_next() -> None:
ctx1 = id_gen.get_next()
ctx2 = id_gen.get_next()
ctx3 = id_gen.get_next()

# Request three new stream IDs.
self.assertEqual(await ctx1.__aenter__(), 124)
self.assertEqual(await ctx2.__aenter__(), 125)
self.assertEqual(await ctx3.__aenter__(), 126)

# None are persisted: current token unchanged.
self.assertEqual(id_gen.get_current_token(), 123)

# Persist them in a different order, starting with 126 from ctx3.
await ctx3.__aexit__(None, None, None)
# We haven't persisted 124 from ctx1 yet---current token is still 123.
self.assertEqual(id_gen.get_current_token(), 123)

# Now persist 124 from ctx1.
await ctx1.__aexit__(None, None, None)
# Current token is then 124, waiting for 125 to be persisted.
self.assertEqual(id_gen.get_current_token(), 124)

# Finally persist 125 from ctx2.
await ctx2.__aexit__(None, None, None)
# Current token is then 126 (skipping over 125).
self.assertEqual(id_gen.get_current_token(), 126)

self.get_success(test_gen_next())

def test_gen_next_while_still_waiting_for_persistence(self) -> None:
"""Check that we handle overlapping calls to gen_next."""
id_gen = self._create_id_generator()

async def test_gen_next() -> None:
ctx1 = id_gen.get_next()
ctx2 = id_gen.get_next()
ctx3 = id_gen.get_next()

# Request two new stream IDs.
self.assertEqual(await ctx1.__aenter__(), 124)
self.assertEqual(await ctx2.__aenter__(), 125)

# Persist ctx2 first.
await ctx2.__aexit__(None, None, None)
# Still waiting on ctx1's ID to be persisted.
self.assertEqual(id_gen.get_current_token(), 123)

# Now request a third stream ID. It should be 126 (the smallest ID that
# we've not yet handed out.)
self.assertEqual(await ctx3.__aenter__(), 126)

self.get_success(test_gen_next())


class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
Expand Down

0 comments on commit fa5c919

Please sign in to comment.