Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert simple_update* and simple_select* to async #8173

Merged
merged 5 commits into from
Aug 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8173.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
6 changes: 2 additions & 4 deletions synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
create_requester,
)
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer, maybe_awaitable
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client

Expand Down Expand Up @@ -1329,9 +1329,7 @@ async def shutdown_room(
ratelimit=False,
)

aliases_for_room = await maybe_awaitable(
self.store.get_aliases_for_room(room_id)
)
aliases_for_room = await self.store.get_aliases_for_room(room_id)

await self.store.update_aliases_for_room(
room_id, new_room_id, requester_user_id
Expand Down
29 changes: 15 additions & 14 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,13 +1132,13 @@ def simple_select_onecol_txn(

return [r[0] for r in txn]

def simple_select_onecol(
async def simple_select_onecol(
self,
table: str,
keyvalues: Optional[Dict[str, Any]],
retcol: str,
desc: str = "simple_select_onecol",
) -> defer.Deferred:
) -> List[Any]:
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
Expand All @@ -1148,19 +1148,19 @@ def simple_select_onecol(
retcol: column whos value we wish to retrieve.
Returns:
Deferred: Results in a list
Results in a list
"""
return self.runInteraction(
return await self.runInteraction(
desc, self.simple_select_onecol_txn, table, keyvalues, retcol
)

def simple_select_list(
async def simple_select_list(
self,
table: str,
keyvalues: Optional[Dict[str, Any]],
retcols: Iterable[str],
desc: str = "simple_select_list",
) -> defer.Deferred:
) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Expand All @@ -1170,10 +1170,11 @@ def simple_select_list(
column names and values to select the rows with, or None to not
apply a WHERE clause.
retcols: the names of the columns to return
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
A list of dictionaries.
"""
return self.runInteraction(
return await self.runInteraction(
desc, self.simple_select_list_txn, table, keyvalues, retcols
)

Expand Down Expand Up @@ -1299,14 +1300,14 @@ def simple_select_many_txn(
txn.execute(sql, values)
return cls.cursor_to_dict(txn)

def simple_update(
async def simple_update(
self,
table: str,
keyvalues: Dict[str, Any],
updatevalues: Dict[str, Any],
desc: str,
) -> defer.Deferred:
return self.runInteraction(
) -> int:
return await self.runInteraction(
desc, self.simple_update_txn, table, keyvalues, updatevalues
)

Expand All @@ -1332,13 +1333,13 @@ def simple_update_txn(

return txn.rowcount

def simple_update_one(
async def simple_update_one(
self,
table: str,
keyvalues: Dict[str, Any],
updatevalues: Dict[str, Any],
desc: str = "simple_update_one",
) -> defer.Deferred:
) -> None:
"""Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values.
Expand All @@ -1347,7 +1348,7 @@ def simple_update_one(
keyvalues: dict of column names and values to select the row with
updatevalues: dict giving column names and values to update
"""
return self.runInteraction(
await self.runInteraction(
desc, self.simple_update_one_txn, table, keyvalues, updatevalues
)

Expand Down
8 changes: 4 additions & 4 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import calendar
import logging
import time
from typing import Any, Dict, List

from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig
Expand Down Expand Up @@ -476,14 +477,13 @@ def _generate_user_daily_visits(txn):
"generate_user_daily_visits", _generate_user_daily_visits
)

def get_users(self):
async def get_users(self) -> List[Dict[str, Any]]:
"""Function to retrieve a list of users in users table.
Args:
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
A list of dictionaries representing users.
"""
return self.db_pool.simple_select_list(
return await self.db_pool.simple_select_list(
table="users",
keyvalues={},
retcols=[
Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/databases/main/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

from collections import namedtuple
from typing import Iterable, Optional
from typing import Iterable, List, Optional

from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
Expand Down Expand Up @@ -68,8 +68,8 @@ async def get_room_alias_creator(self, room_alias: str) -> str:
)

@cached(max_entries=5000)
def get_aliases_for_room(self, room_id):
return self.db_pool.simple_select_onecol(
async def get_aliases_for_room(self, room_id: str) -> List[str]:
return await self.db_pool.simple_select_onecol(
"room_aliases",
{"room_id": room_id},
"room_alias",
Expand Down
26 changes: 16 additions & 10 deletions synapse/storage/databases/main/e2e_room_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
from synapse.storage._base import SQLBaseStore, db_to_json
Expand Down Expand Up @@ -368,18 +370,22 @@ def _create_e2e_room_keys_version_txn(txn):
)

@trace
def update_e2e_room_keys_version(
self, user_id, version, info=None, version_etag=None
):
async def update_e2e_room_keys_version(
self,
user_id: str,
version: str,
info: Optional[dict] = None,
version_etag: Optional[int] = None,
) -> None:
"""Update a given backup version
Args:
user_id(str): the user whose backup version we're updating
version(str): the version ID of the backup version we're updating
info (dict): the new backup version info to store. If None, then
the backup version info is not updated
version_etag (Optional[int]): etag of the keys in the backup. If
None, then the etag is not updated
user_id: the user whose backup version we're updating
version: the version ID of the backup version we're updating
info: the new backup version info to store. If None, then the backup
version info is not updated.
version_etag: etag of the keys in the backup. If None, then the etag
is not updated.
"""
updatevalues = {}

Expand All @@ -389,7 +395,7 @@ def update_e2e_room_keys_version(
updatevalues["etag"] = version_etag

if updatevalues:
return self.db_pool.simple_update(
await self.db_pool.simple_update(
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": version},
updatevalues=updatevalues,
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ def _get_rooms_with_many_extremities_txn(txn):
)

@cached(max_entries=5000, iterable=True)
def get_latest_event_ids_in_room(self, room_id):
return self.db_pool.simple_select_onecol(
async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]:
return await self.db_pool.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
Expand Down
55 changes: 32 additions & 23 deletions synapse/storage/databases/main/group_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,24 +44,26 @@ async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
desc="get_group",
)

def get_users_in_group(self, group_id, include_private=False):
async def get_users_in_group(
self, group_id: str, include_private: bool = False
) -> List[Dict[str, Any]]:
# TODO: Pagination

keyvalues = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True

return self.db_pool.simple_select_list(
return await self.db_pool.simple_select_list(
table="group_users",
keyvalues=keyvalues,
retcols=("user_id", "is_public", "is_admin"),
desc="get_users_in_group",
)

def get_invited_users_in_group(self, group_id):
async def get_invited_users_in_group(self, group_id: str) -> List[str]:
# TODO: Pagination

return self.db_pool.simple_select_onecol(
return await self.db_pool.simple_select_onecol(
table="group_invites",
keyvalues={"group_id": group_id},
retcol="user_id",
Expand Down Expand Up @@ -265,15 +267,14 @@ async def get_group_role(self, group_id, role_id):

return role

def get_local_groups_for_room(self, room_id):
async def get_local_groups_for_room(self, room_id: str) -> List[str]:
"""Get all of the local group that contain a given room
Args:
room_id (str): The ID of a room
room_id: The ID of a room
Returns:
Deferred[list[str]]: A twisted.Deferred containing a list of group ids
containing this room
A list of group ids containing this room
"""
return self.db_pool.simple_select_onecol(
return await self.db_pool.simple_select_onecol(
table="group_rooms",
keyvalues={"room_id": room_id},
retcol="group_id",
Expand Down Expand Up @@ -422,10 +423,10 @@ def _get_users_membership_in_group_txn(txn):
"get_users_membership_info_in_group", _get_users_membership_in_group_txn
)

def get_publicised_groups_for_user(self, user_id):
async def get_publicised_groups_for_user(self, user_id: str) -> List[str]:
"""Get all groups a user is publicising
"""
return self.db_pool.simple_select_onecol(
return await self.db_pool.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
retcol="group_id",
Expand Down Expand Up @@ -466,8 +467,8 @@ async def get_remote_attestation(self, group_id, user_id):

return None

def get_joined_groups(self, user_id):
return self.db_pool.simple_select_onecol(
async def get_joined_groups(self, user_id: str) -> List[str]:
return await self.db_pool.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join"},
retcol="group_id",
Expand Down Expand Up @@ -585,14 +586,14 @@ def _get_all_groups_changes_txn(txn):


class GroupServerStore(GroupServerWorkerStore):
def set_group_join_policy(self, group_id, join_policy):
async def set_group_join_policy(self, group_id: str, join_policy: str) -> None:
"""Set the join policy of a group.
join_policy can be one of:
* "invite"
* "open"
"""
return self.db_pool.simple_update_one(
await self.db_pool.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
updatevalues={"join_policy": join_policy},
Expand Down Expand Up @@ -1050,8 +1051,10 @@ def add_room_to_group(self, group_id, room_id, is_public):
desc="add_room_to_group",
)

def update_room_in_group_visibility(self, group_id, room_id, is_public):
return self.db_pool.simple_update(
async def update_room_in_group_visibility(
self, group_id: str, room_id: str, is_public: bool
) -> int:
return await self.db_pool.simple_update(
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
updatevalues={"is_public": is_public},
Expand All @@ -1076,10 +1079,12 @@ def _remove_room_from_group_txn(txn):
"remove_room_from_group", _remove_room_from_group_txn
)

def update_group_publicity(self, group_id, user_id, publicise):
async def update_group_publicity(
self, group_id: str, user_id: str, publicise: bool
) -> None:
"""Update whether the user is publicising their membership of the group
"""
return self.db_pool.simple_update_one(
await self.db_pool.simple_update_one(
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"is_publicised": publicise},
Expand Down Expand Up @@ -1218,20 +1223,24 @@ async def update_group_profile(self, group_id, profile):
desc="update_group_profile",
)

def update_attestation_renewal(self, group_id, user_id, attestation):
async def update_attestation_renewal(
self, group_id: str, user_id: str, attestation: dict
) -> None:
"""Update an attestation that we have renewed
"""
return self.db_pool.simple_update_one(
await self.db_pool.simple_update_one(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
desc="update_attestation_renewal",
)

def update_remote_attestion(self, group_id, user_id, attestation):
async def update_remote_attestion(
self, group_id: str, user_id: str, attestation: dict
) -> None:
"""Update an attestation that a remote has renewed
"""
return self.db_pool.simple_update_one(
await self.db_pool.simple_update_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={
Expand Down
Loading