Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Separate get_current_token into two. #8113

Merged
merged 3 commits into from
Aug 19, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8113.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Separate `get_current_token` into two since there are two different use cases for it.
8 changes: 8 additions & 0 deletions synapse/replication/slave/storage/_slaved_id_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,11 @@ def get_current_token(self):
int
"""
return self._current

def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.

For streams with single writers this is equivalent to
`get_current_token`.
"""
return self.get_current_token()
2 changes: 1 addition & 1 deletion synapse/replication/tcp/streams/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_cache_stream_token,
store.get_cache_stream_token_for_writer,
store.get_all_updated_caches,
)

Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ def _send_invalidation_to_replication(
},
)

def get_cache_stream_token(self, instance_name):
def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
if self._cache_id_gen:
return self._cache_id_gen.get_current_token(instance_name)
return self._cache_id_gen.get_current_token_for_writer(instance_name)
else:
return 0
39 changes: 30 additions & 9 deletions synapse/storage/util/id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ def get_current_token(self):

return self._current

def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.

For streams with single writers this is equivalent to
`get_current_token`.
"""
return self.get_current_token()


class ChainedIdGenerator(object):
"""Used to generate new stream ids where the stream must be kept in sync
Expand Down Expand Up @@ -216,6 +224,14 @@ def advance(self, token: int):
"Attempted to advance token on source for table %r", self._table
)

def get_current_token_for_writer(self, instance_name: str) -> Tuple[int, int]:
"""Returns the position of the given writer.

For streams with single writers this is equivalent to
`get_current_token`.
"""
return self.get_current_token()


class MultiWriterIdGenerator:
"""An ID generator that tracks a stream that can have multiple writers.
Expand Down Expand Up @@ -298,7 +314,7 @@ async def get_next(self):
# Assert the fetched ID is actually greater than what we currently
# believe the ID to be. If not, then the sequence and table have got
# out of sync somehow.
assert self.get_current_token() < next_id
assert self.get_current_token_for_writer(self._instance_name) < next_id

with self._lock:
self._unfinished_ids.add(next_id)
Expand Down Expand Up @@ -344,16 +360,21 @@ def _mark_id_as_finished(self, next_id: int):
curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, next_id)

def get_current_token(self, instance_name: str = None) -> int:
"""Gets the current position of a named writer (defaults to current
instance).

Returns 0 if we don't have a position for the named writer (likely due
to it being a new writer).
def get_current_token(self) -> int:
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""

if instance_name is None:
instance_name = self._instance_name
# Currently we don't support this operation, as it's not obvious how to
# condense the stream positions of multiple writers into a single int.
raise NotImplementedError()

def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.

For streams with single writers this is equivalent to
`get_current_token`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this a copy-and-paste error or did you still want this comment here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops

"""

with self._lock:
return self._current_positions.get(instance_name, 0)
Expand Down
16 changes: 8 additions & 8 deletions tests/storage/test_id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_single_instance(self):
id_gen = self._create_id_generator()

self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token("master"), 7)
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)

# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
Expand All @@ -98,12 +98,12 @@ async def _get_next_async():
self.assertEqual(stream_id, 8)

self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token("master"), 7)
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)

self.get_success(_get_next_async())

self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token("master"), 8)
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)

def test_multi_instance(self):
"""Test that reads and writes from multiple processes are handled
Expand All @@ -116,8 +116,8 @@ def test_multi_instance(self):
second_id_gen = self._create_id_generator("second")

self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_current_token("first"), 3)
self.assertEqual(first_id_gen.get_current_token("second"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)

# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
Expand Down Expand Up @@ -166,7 +166,7 @@ def test_get_next_txn(self):
id_gen = self._create_id_generator()

self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token("master"), 7)
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)

# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
Expand All @@ -176,9 +176,9 @@ def _get_next_txn(txn):
self.assertEqual(stream_id, 8)

self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token("master"), 7)
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)

self.get_success(self.db_pool.runInteraction("test", _get_next_txn))

self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token("master"), 8)
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)