Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Convert events worker database to async/await.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Aug 12, 2020
1 parent fbe930d commit ebc1916
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 131 deletions.
1 change: 1 addition & 0 deletions changelog.d/8071.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
2 changes: 1 addition & 1 deletion synapse/spam_checker_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ def get_state_events_in_room(self, room_id: str, types: tuple) -> defer.Deferred
state_ids = yield self._store.get_filtered_current_state_ids(
room_id=room_id, state_filter=StateFilter.from_types(types)
)
state = yield self._store.get_events(state_ids.values())
state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
return state.values()
2 changes: 1 addition & 1 deletion synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def get_events(self, event_ids, allow_rejected=False):
allow_rejected (bool): If True return rejected events.
Returns:
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event.
"""

return self.store.get_events(
Expand Down
30 changes: 14 additions & 16 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
def get_auth_chain(self, event_ids, include_given=False):
async def get_auth_chain(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events.
Args:
Expand All @@ -40,9 +40,10 @@ def get_auth_chain(self, event_ids, include_given=False):
Returns:
list of events
"""
return self.get_auth_chain_ids(
event_ids = await self.get_auth_chain_ids(
event_ids, include_given=include_given
).addCallback(self.get_events_as_list)
)
return await self.get_events_as_list(event_ids)

def get_auth_chain_ids(
self,
Expand Down Expand Up @@ -472,7 +473,7 @@ def get_forward_extremeties_for_room_txn(txn):
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)

def get_backfill_events(self, room_id, event_list, limit):
async def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit`
Expand All @@ -482,17 +483,15 @@ def get_backfill_events(self, room_id, event_list, limit):
event_list (list)
limit (int)
"""
return (
self.db_pool.runInteraction(
"get_backfill_events",
self._get_backfill_events,
room_id,
event_list,
limit,
)
.addCallback(self.get_events_as_list)
.addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
event_ids = await self.db_pool.runInteraction(
"get_backfill_events",
self._get_backfill_events,
room_id,
event_list,
limit,
)
events = await self.get_events_as_list(event_ids)
return sorted(events, key=lambda e: -e.depth)

def _get_backfill_events(self, txn, room_id, event_list, limit):
logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
Expand Down Expand Up @@ -553,8 +552,7 @@ async def get_missing_events(self, room_id, earliest_events, latest_events, limi
latest_events,
limit,
)
events = await self.get_events_as_list(ids)
return events
return await self.get_events_as_list(ids)

def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):

Expand Down
81 changes: 36 additions & 45 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
import threading
from collections import namedtuple
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

from constantly import NamedConstant, Names

Expand All @@ -32,7 +32,7 @@
EventFormatVersions,
RoomVersions,
)
from synapse.events import make_event_from_dict
from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
Expand All @@ -43,7 +43,7 @@
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
from synapse.util.caches.descriptors import Cache, cached
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure

Expand Down Expand Up @@ -173,16 +173,15 @@ def _get_approximate_received_ts_txn(txn):
"get_approximate_received_ts", _get_approximate_received_ts_txn
)

@defer.inlineCallbacks
def get_event(
async def get_event(
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
allow_none: bool = False,
check_room_id: Optional[str] = None,
):
) -> Optional[EventBase]:
"""Get an event from the database by event_id.
Args:
Expand All @@ -207,12 +206,12 @@ def get_event(
If there is a mismatch, behave as per allow_none.
Returns:
Deferred[EventBase|None]
The event, or None if the event was not found.
"""
if not isinstance(event_id, str):
raise TypeError("Invalid event event_id %r" % (event_id,))

events = yield self.get_events_as_list(
events = await self.get_events_as_list(
[event_id],
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
Expand All @@ -230,14 +229,13 @@ def get_event(

return event

@defer.inlineCallbacks
def get_events(
async def get_events(
self,
event_ids: List[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
):
) -> Dict[str, EventBase]:
"""Get events from the database
Args:
Expand All @@ -256,9 +254,9 @@ def get_events(
omits rejeted events from the response.
Returns:
Deferred : Dict from event_id to event.
A mapping from event_id to event.
"""
events = yield self.get_events_as_list(
events = await self.get_events_as_list(
event_ids,
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
Expand All @@ -267,8 +265,7 @@ def get_events(

return {e.event_id: e for e in events}

@defer.inlineCallbacks
def get_events_as_list(
async def get_events_as_list(
self,
event_ids: List[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
Expand All @@ -295,7 +292,7 @@ def get_events_as_list(
omits rejected events from the response.
Returns:
Deferred[list[EventBase]]: List of events fetched from the database. The
list[EventBase]: List of events fetched from the database. The
events are in the same order as `event_ids` arg.
Note that the returned list may be smaller than the list of event
Expand All @@ -306,7 +303,7 @@ def get_events_as_list(
return []

# there may be duplicates so we cast the list to a set
event_entry_map = yield self._get_events_from_cache_or_db(
event_entry_map = await self._get_events_from_cache_or_db(
set(event_ids), allow_rejected=allow_rejected
)

Expand Down Expand Up @@ -341,7 +338,7 @@ def get_events_as_list(
continue

redacted_event_id = entry.event.redacts
event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
event_map = await self._get_events_from_cache_or_db([redacted_event_id])
original_event_entry = event_map.get(redacted_event_id)
if not original_event_entry:
# we don't have the redacted event (or it was rejected).
Expand Down Expand Up @@ -407,7 +404,7 @@ def get_events_as_list(

if get_prev_content:
if "replaces_state" in event.unsigned:
prev = yield self.get_event(
prev = await self.get_event(
event.unsigned["replaces_state"],
get_prev_content=False,
allow_none=True,
Expand All @@ -419,8 +416,7 @@ def get_events_as_list(

return events

@defer.inlineCallbacks
def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
Expand All @@ -435,7 +431,7 @@ def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
rejected events are omitted from the response.
Returns:
Deferred[Dict[str, _EventCacheEntry]]:
Dict[str, _EventCacheEntry]:
map from event id to result
"""
event_entry_map = self._get_events_from_cache(
Expand All @@ -453,7 +449,7 @@ def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
# the events have been redacted, and if so pulling the redaction event out
# of the database to check it.
#
missing_events = yield self._get_events_from_db(
missing_events = await self._get_events_from_db(
missing_events_ids, allow_rejected=allow_rejected
)

Expand Down Expand Up @@ -561,8 +557,7 @@ def fire(evs, exc):
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e)

@defer.inlineCallbacks
def _get_events_from_db(self, event_ids, allow_rejected=False):
async def _get_events_from_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the database.
Returned events will be added to the cache for future lookups.
Expand All @@ -576,15 +571,15 @@ def _get_events_from_db(self, event_ids, allow_rejected=False):
rejected events are omitted from the response.
Returns:
Deferred[Dict[str, _EventCacheEntry]]:
Dict[str, _EventCacheEntry]:
map from event id to result. May return extra events which
weren't asked for.
"""
fetched_events = {}
events_to_fetch = event_ids

while events_to_fetch:
row_map = yield self._enqueue_events(events_to_fetch)
row_map = await self._enqueue_events(events_to_fetch)

# we need to recursively fetch any redactions of those events
redaction_ids = set()
Expand Down Expand Up @@ -686,8 +681,7 @@ def _get_events_from_db(self, event_ids, allow_rejected=False):

return result_map

@defer.inlineCallbacks
def _enqueue_events(self, events):
async def _enqueue_events(self, events):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
Expand All @@ -696,7 +690,7 @@ def _enqueue_events(self, events):
events (Iterable[str]): events to be fetched.
Returns:
Deferred[Dict[str, Dict]]: map from event id to row data from the database.
Dict[str, Dict]: map from event id to row data from the database.
May contain events that weren't requested.
"""

Expand All @@ -719,7 +713,7 @@ def _enqueue_events(self, events):

logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
row_map = yield events_d
row_map = await events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))

return row_map
Expand Down Expand Up @@ -878,12 +872,11 @@ def _maybe_redact_event_row(self, original_ev, redactions, event_map):
# no valid redaction found for this event
return None

@defer.inlineCallbacks
def have_events_in_timeline(self, event_ids):
async def have_events_in_timeline(self, event_ids):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
Expand All @@ -894,15 +887,14 @@ def have_events_in_timeline(self, event_ids):

return {r["event_id"] for r in rows}

@defer.inlineCallbacks
def have_seen_events(self, event_ids):
async def have_seen_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
Args:
event_ids (iterable[str]):
Returns:
Deferred[set[str]]: The events we have already seen.
set[str]: The events we have already seen.
"""
results = set()

Expand All @@ -918,7 +910,7 @@ def have_seen_events_txn(txn, chunk):
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"have_seen_events", have_seen_events_txn, chunk
)
return results
Expand Down Expand Up @@ -978,8 +970,7 @@ def get_current_state_event_counts(self, room_id):
room_id,
)

@defer.inlineCallbacks
def get_room_complexity(self, room_id):
async def get_room_complexity(self, room_id):
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
Expand All @@ -990,9 +981,9 @@ def get_room_complexity(self, room_id):
room_id (str)
Returns:
Deferred[dict[str:int]] of complexity version to complexity.
dict[str:int] of complexity version to complexity.
"""
state_events = yield self.get_current_state_event_counts(room_id)
state_events = await self.get_current_state_event_counts(room_id)

# Call this one "v1", so we can introduce new ones as we want to develop
# it.
Expand Down Expand Up @@ -1320,9 +1311,9 @@ async def is_event_after(self, event_id1, event_id2):
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)

@cachedInlineCallbacks(max_entries=5000)
def get_event_ordering(self, event_id):
res = yield self.db_pool.simple_select_one(
@cached(max_entries=5000)
async def get_event_ordering(self, event_id):
res = await self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
Expand Down
Loading

0 comments on commit ebc1916

Please sign in to comment.