Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert misc database code to async #8087

Merged
merged 3 commits into from
Aug 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8087.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
14 changes: 5 additions & 9 deletions synapse/storage/background_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

from canonicaljson import json

from twisted.internet import defer

from synapse.metrics.background_process_metrics import run_as_background_process

from . import engines
Expand Down Expand Up @@ -308,9 +306,8 @@ def register_noop_background_update(self, update_name):
update_name (str): Name of update
"""

@defer.inlineCallbacks
def noop_update(progress, batch_size):
yield self._end_background_update(update_name)
async def noop_update(progress, batch_size):
await self._end_background_update(update_name)
return 1

self.register_background_update_handler(update_name, noop_update)
Expand Down Expand Up @@ -409,12 +406,11 @@ def create_index_sqlite(conn):
else:
runner = create_index_sqlite

@defer.inlineCallbacks
def updater(progress, batch_size):
async def updater(progress, batch_size):
if runner is not None:
logger.info("Adding index %s to %s", index_name, table)
yield self.db_pool.runWithConnection(runner)
yield self._end_background_update(update_name)
await self.db_pool.runWithConnection(runner)
await self._end_background_update(update_name)
return 1

self.register_background_update_handler(update_name, updater)
Expand Down
5 changes: 2 additions & 3 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,10 +671,9 @@ def get_device_list_last_stream_id_for_remote(self, user_id: str):
@cachedList(
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
inlineCallbacks=True,
)
def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
rows = yield self.db_pool.simple_select_many_batch(
async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
Expand Down
9 changes: 4 additions & 5 deletions synapse/storage/databases/main/event_push_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.descriptors import cached

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,18 +86,17 @@ def __init__(self, database: DatabasePool, db_conn, hs):
self._rotate_delay = 3
self._rotate_count = 10000

@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user(
@cached(num_args=3, tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):
ret = yield self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
room_id,
user_id,
last_read_event_id,
)
return ret

def _get_unread_counts_by_receipt_txn(
self, txn, room_id, user_id, last_read_event_id
Expand Down
9 changes: 3 additions & 6 deletions synapse/storage/databases/main/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,10 @@ def _get_presence_for_user(self, user_id):
raise NotImplementedError()

@cachedList(
cached_method_name="_get_presence_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
)
def get_presence_for_users(self, user_ids):
rows = yield self.db_pool.simple_select_many_batch(
async def get_presence_for_users(self, user_ids):
rows = await self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
iterable=user_ids,
Expand Down
16 changes: 6 additions & 10 deletions synapse/storage/databases/main/push_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,15 @@ def have_push_rules_changed_txn(txn):
)

@cachedList(
cached_method_name="get_push_rules_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
)
def bulk_get_push_rules(self, user_ids):
async def bulk_get_push_rules(self, user_ids):
if not user_ids:
return {}

results = {user_id: [] for user_id in user_ids}

rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="push_rules",
column="user_name",
iterable=user_ids,
Expand All @@ -194,7 +191,7 @@ def bulk_get_push_rules(self, user_ids):
for row in rows:
results.setdefault(row["user_name"], []).append(row)

enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)

for user_id, rules in results.items():
use_new_defaults = user_id in self._users_new_default_push_rules
Expand Down Expand Up @@ -260,15 +257,14 @@ def copy_push_rules_from_room_to_room_for_user(
cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
)
def bulk_get_push_rules_enabled(self, user_ids):
async def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids:
return {}

results = {user_id: {} for user_id in user_ids}

rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="push_rules_enable",
column="user_name",
iterable=user_ids,
Expand Down
9 changes: 3 additions & 6 deletions synapse/storage/databases/main/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,10 @@ def get_if_user_has_pusher(self, user_id):
raise NotImplementedError()

@cachedList(
cached_method_name="get_if_user_has_pusher",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
)
def get_if_users_have_pushers(self, user_ids):
rows = yield self.db_pool.simple_select_many_batch(
async def get_if_users_have_pushers(self, user_ids):
rows = await self.db_pool.simple_select_many_batch(
table="pushers",
column="user_name",
iterable=user_ids,
Expand Down
5 changes: 2 additions & 3 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,8 @@ def f(txn):
cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids",
num_args=3,
inlineCallbacks=True,
)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
return {}

Expand Down Expand Up @@ -243,7 +242,7 @@ def f(txn):

return self.db_pool.cursor_to_dict(txn)

txn_results = yield self.db_pool.runInteraction(
txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)

Expand Down
17 changes: 6 additions & 11 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import logging
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set

from twisted.internet import defer

from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
Expand Down Expand Up @@ -92,8 +90,7 @@ def __init__(self, database: DatabasePool, db_conn, hs):
lambda: self._known_servers_count,
)

@defer.inlineCallbacks
def _count_known_servers(self):
async def _count_known_servers(self):
"""
Count the servers that this server knows about.

Expand Down Expand Up @@ -121,7 +118,7 @@ def _transact(txn):
txn.execute(query)
return list(txn)[0][0]

count = yield self.db_pool.runInteraction("get_known_servers", _transact)
count = await self.db_pool.runInteraction("get_known_servers", _transact)

# We always know about ourselves, even if we have nothing in
# room_memberships (for example, the server is new).
Expand Down Expand Up @@ -589,23 +586,21 @@ def _get_joined_profile_from_event_id(self, event_id):
raise NotImplementedError()

@cachedList(
cached_method_name="_get_joined_profile_from_event_id",
list_name="event_ids",
inlineCallbacks=True,
cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
)
def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.

Args:
event_ids: The member event IDs to lookup

Returns:
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
to `user_id` and ProfileInfo (or None if not join event).
"""

rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
Expand Down
5 changes: 2 additions & 3 deletions synapse/storage/databases/main/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,11 @@ def _get_state_group_for_event(self, event_id):
cached_method_name="_get_state_group_for_event",
list_name="event_ids",
num_args=1,
inlineCallbacks=True,
)
def _get_state_group_for_events(self, event_ids):
async def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
iterable=event_ids,
Expand Down
13 changes: 5 additions & 8 deletions synapse/storage/databases/main/user_erasure_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,23 @@ def is_user_erased(self, user_id):
desc="is_user_erased",
).addCallback(operator.truth)

@cachedList(
cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
)
def are_users_erased(self, user_ids):
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
async def are_users_erased(self, user_ids):
"""
Checks which users in a list have requested erasure

Args:
user_ids (iterable[str]): full user id to check

Returns:
Deferred[dict[str, bool]]:
dict[str, bool]:
for each user, whether the user has requested erasure.
"""
# this serves the dual purpose of (a) making sure we can do len and
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))

rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="erased_users",
column="user_id",
iterable=user_ids,
Expand All @@ -65,8 +63,7 @@ def are_users_erased(self, user_ids):
)
erased_users = {row["user_id"] for row in rows}

res = {u: u in erased_users for u in user_ids}
return res
return {u: u in erased_users for u in user_ids}


class UserErasureStore(UserErasureWorkerStore):
Expand Down