Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add functions to MultiWriterIdGen used by events stream (#8164)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston authored Aug 25, 2020
1 parent 5099bd6 commit eba98fb
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 3 deletions.
1 change: 1 addition & 0 deletions changelog.d/8164.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add functions to `MultiWriterIdGen` used by events stream.
103 changes: 101 additions & 2 deletions synapse/storage/util/id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
# limitations under the License.

import contextlib
import heapq
import threading
from collections import deque
from typing import Dict, Set
from typing import Dict, List, Set

from typing_extensions import Deque

Expand Down Expand Up @@ -210,6 +211,23 @@ def __init__(
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]

# We track the max position where we know everything before has been
# persisted. This is done by a) looking at the min across all instances
# and b) noting that if we have seen a run of persisted positions
# without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
#
# Note: There is no guarentee that the IDs generated by the sequence
# will be gapless; gaps can form when e.g. a transaction was rolled
# back. This means that sometimes we won't be able to skip forward the
# position even though everything has been persisted. However, since
# gaps should be relatively rare it's still worth doing the book keeping
# that allows us to skip forwards when there are gapless runs of
# positions.
self._persisted_upto_position = (
min(self._current_positions.values()) if self._current_positions else 0
)
self._known_persisted_positions = [] # type: List[int]

self._sequence_gen = PostgresSequenceGenerator(sequence_name)

def _load_current_ids(
Expand All @@ -234,9 +252,12 @@ def _load_current_ids(

return current_positions

def _load_next_id_txn(self, txn):
def _load_next_id_txn(self, txn) -> int:
return self._sequence_gen.get_next_id_txn(txn)

def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n)

async def get_next(self):
"""
Usage:
Expand All @@ -262,6 +283,34 @@ def manager():

return manager()

async def get_next_mult(self, n: int):
"""
Usage:
with await stream_id_gen.get_next_mult(5) as stream_ids:
# ... persist events ...
"""
next_ids = await self._db.runInteraction(
"_load_next_mult_id", self._load_next_mult_id_txn, n
)

# Assert the fetched ID is actually greater than any ID we've already
# seen. If not, then the sequence and table have got out of sync
# somehow.
assert max(self.get_positions().values(), default=0) < min(next_ids)

with self._lock:
self._unfinished_ids.update(next_ids)

@contextlib.contextmanager
def manager():
try:
yield next_ids
finally:
for i in next_ids:
self._mark_id_as_finished(i)

return manager()

def get_next_txn(self, txn: LoggingTransaction):
"""
Usage:
Expand Down Expand Up @@ -326,3 +375,53 @@ def advance(self, instance_name: str, new_id: int):
self._current_positions[instance_name] = max(
new_id, self._current_positions.get(instance_name, 0)
)

self._add_persisted_position(new_id)

def get_persisted_upto_position(self) -> int:
"""Get the max position where all previous positions have been
persisted.
Note: In the worst case scenario this will be equal to the minimum
position across writers. This means that the returned position here can
lag if one writer doesn't write very often.
"""

with self._lock:
return self._persisted_upto_position

def _add_persisted_position(self, new_id: int):
"""Record that we have persisted a position.
This is used to keep the `_current_positions` up to date.
"""

# We require that the lock is locked by caller
assert self._lock.locked()

heapq.heappush(self._known_persisted_positions, new_id)

# We move the current min position up if the minimum current positions
# of all instances is higher (since by definition all positions less
# that that have been persisted).
min_curr = min(self._current_positions.values())
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)

# We now iterate through the seen positions, discarding those that are
# less than the current min positions, and incrementing the min position
# if its exactly one greater.
#
# This is also where we discard items from `_known_persisted_positions`
# (to ensure the list doesn't infinitely grow).
while self._known_persisted_positions:
if self._known_persisted_positions[0] <= self._persisted_upto_position:
heapq.heappop(self._known_persisted_positions)
elif (
self._known_persisted_positions[0] == self._persisted_upto_position + 1
):
heapq.heappop(self._known_persisted_positions)
self._persisted_upto_position += 1
else:
# There was a gap in seen positions, so there is nothing more to
# do.
break
8 changes: 7 additions & 1 deletion synapse/storage/util/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import abc
import threading
from typing import Callable, Optional
from typing import Callable, List, Optional

from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.types import Cursor
Expand All @@ -39,6 +39,12 @@ def get_next_id_txn(self, txn: Cursor) -> int:
txn.execute("SELECT nextval(?)", (self._sequence_name,))
return txn.fetchone()[0]

def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
txn.execute(
"SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
)
return [i for (i,) in txn]


GetFirstCallbackType = Callable[[Cursor], int]

Expand Down
36 changes: 36 additions & 0 deletions tests/storage/test_id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,39 @@ def _get_next_txn(txn):

self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)

def test_get_persisted_upto_position(self):
"""Test that `get_persisted_upto_position` correctly tracks updates to
positions.
"""

self._insert_rows("first", 3)
self._insert_rows("second", 5)

id_gen = self._create_id_generator("first")

# Min is 3 and there is a gap between 5, so we expect it to be 3.
self.assertEqual(id_gen.get_persisted_upto_position(), 3)

# We advance "first" straight to 6. Min is now 5 but there is no gap so
# we expect it to be 6
id_gen.advance("first", 6)
self.assertEqual(id_gen.get_persisted_upto_position(), 6)

# No gap, so we expect 7.
id_gen.advance("second", 7)
self.assertEqual(id_gen.get_persisted_upto_position(), 7)

# We haven't seen 8 yet, so we expect 7 still.
id_gen.advance("second", 9)
self.assertEqual(id_gen.get_persisted_upto_position(), 7)

# Now that we've seen 7, 8 and 9 we can got straight to 9.
id_gen.advance("first", 8)
self.assertEqual(id_gen.get_persisted_upto_position(), 9)

# Jump forward with gaps. The minimum is 11, even though we haven't seen
# 10 we know that everything before 11 must be persisted.
id_gen.advance("first", 11)
id_gen.advance("second", 15)
self.assertEqual(id_gen.get_persisted_upto_position(), 11)

0 comments on commit eba98fb

Please sign in to comment.