Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Convert simple_update* and simple_select* to async (#8173)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Aug 27, 2020
1 parent a466b67 commit 4a739c7
Show file tree
Hide file tree
Showing 19 changed files with 164 additions and 133 deletions.
1 change: 1 addition & 0 deletions changelog.d/8173.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
6 changes: 2 additions & 4 deletions synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
create_requester,
)
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer, maybe_awaitable
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client

Expand Down Expand Up @@ -1329,9 +1329,7 @@ async def shutdown_room(
ratelimit=False,
)

aliases_for_room = await maybe_awaitable(
self.store.get_aliases_for_room(room_id)
)
aliases_for_room = await self.store.get_aliases_for_room(room_id)

await self.store.update_aliases_for_room(
room_id, new_room_id, requester_user_id
Expand Down
29 changes: 15 additions & 14 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,13 +1132,13 @@ def simple_select_onecol_txn(

return [r[0] for r in txn]

def simple_select_onecol(
async def simple_select_onecol(
self,
table: str,
keyvalues: Optional[Dict[str, Any]],
retcol: str,
desc: str = "simple_select_onecol",
) -> defer.Deferred:
) -> List[Any]:
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
Expand All @@ -1148,19 +1148,19 @@ def simple_select_onecol(
retcol: column whos value we wish to retrieve.
Returns:
Deferred: Results in a list
Results in a list
"""
return self.runInteraction(
return await self.runInteraction(
desc, self.simple_select_onecol_txn, table, keyvalues, retcol
)

def simple_select_list(
async def simple_select_list(
self,
table: str,
keyvalues: Optional[Dict[str, Any]],
retcols: Iterable[str],
desc: str = "simple_select_list",
) -> defer.Deferred:
) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Expand All @@ -1170,10 +1170,11 @@ def simple_select_list(
column names and values to select the rows with, or None to not
apply a WHERE clause.
retcols: the names of the columns to return
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
A list of dictionaries.
"""
return self.runInteraction(
return await self.runInteraction(
desc, self.simple_select_list_txn, table, keyvalues, retcols
)

Expand Down Expand Up @@ -1299,14 +1300,14 @@ def simple_select_many_txn(
txn.execute(sql, values)
return cls.cursor_to_dict(txn)

def simple_update(
async def simple_update(
self,
table: str,
keyvalues: Dict[str, Any],
updatevalues: Dict[str, Any],
desc: str,
) -> defer.Deferred:
return self.runInteraction(
) -> int:
return await self.runInteraction(
desc, self.simple_update_txn, table, keyvalues, updatevalues
)

Expand All @@ -1332,13 +1333,13 @@ def simple_update_txn(

return txn.rowcount

def simple_update_one(
async def simple_update_one(
self,
table: str,
keyvalues: Dict[str, Any],
updatevalues: Dict[str, Any],
desc: str = "simple_update_one",
) -> defer.Deferred:
) -> None:
"""Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values.
Expand All @@ -1347,7 +1348,7 @@ def simple_update_one(
keyvalues: dict of column names and values to select the row with
updatevalues: dict giving column names and values to update
"""
return self.runInteraction(
await self.runInteraction(
desc, self.simple_update_one_txn, table, keyvalues, updatevalues
)

Expand Down
8 changes: 4 additions & 4 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import calendar
import logging
import time
from typing import Any, Dict, List

from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig
Expand Down Expand Up @@ -476,14 +477,13 @@ def _generate_user_daily_visits(txn):
"generate_user_daily_visits", _generate_user_daily_visits
)

def get_users(self):
async def get_users(self) -> List[Dict[str, Any]]:
"""Function to retrieve a list of users in users table.
Args:
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
A list of dictionaries representing users.
"""
return self.db_pool.simple_select_list(
return await self.db_pool.simple_select_list(
table="users",
keyvalues={},
retcols=[
Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/databases/main/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

from collections import namedtuple
from typing import Iterable, Optional
from typing import Iterable, List, Optional

from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
Expand Down Expand Up @@ -68,8 +68,8 @@ async def get_room_alias_creator(self, room_alias: str) -> str:
)

@cached(max_entries=5000)
def get_aliases_for_room(self, room_id):
return self.db_pool.simple_select_onecol(
async def get_aliases_for_room(self, room_id: str) -> List[str]:
return await self.db_pool.simple_select_onecol(
"room_aliases",
{"room_id": room_id},
"room_alias",
Expand Down
26 changes: 16 additions & 10 deletions synapse/storage/databases/main/e2e_room_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
from synapse.storage._base import SQLBaseStore, db_to_json
Expand Down Expand Up @@ -368,18 +370,22 @@ def _create_e2e_room_keys_version_txn(txn):
)

@trace
def update_e2e_room_keys_version(
self, user_id, version, info=None, version_etag=None
):
async def update_e2e_room_keys_version(
self,
user_id: str,
version: str,
info: Optional[dict] = None,
version_etag: Optional[int] = None,
) -> None:
"""Update a given backup version
Args:
user_id(str): the user whose backup version we're updating
version(str): the version ID of the backup version we're updating
info (dict): the new backup version info to store. If None, then
the backup version info is not updated
version_etag (Optional[int]): etag of the keys in the backup. If
None, then the etag is not updated
user_id: the user whose backup version we're updating
version: the version ID of the backup version we're updating
info: the new backup version info to store. If None, then the backup
version info is not updated.
version_etag: etag of the keys in the backup. If None, then the etag
is not updated.
"""
updatevalues = {}

Expand All @@ -389,7 +395,7 @@ def update_e2e_room_keys_version(
updatevalues["etag"] = version_etag

if updatevalues:
return self.db_pool.simple_update(
await self.db_pool.simple_update(
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": version},
updatevalues=updatevalues,
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ def _get_rooms_with_many_extremities_txn(txn):
)

@cached(max_entries=5000, iterable=True)
def get_latest_event_ids_in_room(self, room_id):
return self.db_pool.simple_select_onecol(
async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]:
return await self.db_pool.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
Expand Down
55 changes: 32 additions & 23 deletions synapse/storage/databases/main/group_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,24 +44,26 @@ async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
desc="get_group",
)

def get_users_in_group(self, group_id, include_private=False):
async def get_users_in_group(
self, group_id: str, include_private: bool = False
) -> List[Dict[str, Any]]:
# TODO: Pagination

keyvalues = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True

return self.db_pool.simple_select_list(
return await self.db_pool.simple_select_list(
table="group_users",
keyvalues=keyvalues,
retcols=("user_id", "is_public", "is_admin"),
desc="get_users_in_group",
)

def get_invited_users_in_group(self, group_id):
async def get_invited_users_in_group(self, group_id: str) -> List[str]:
# TODO: Pagination

return self.db_pool.simple_select_onecol(
return await self.db_pool.simple_select_onecol(
table="group_invites",
keyvalues={"group_id": group_id},
retcol="user_id",
Expand Down Expand Up @@ -265,15 +267,14 @@ async def get_group_role(self, group_id, role_id):

return role

def get_local_groups_for_room(self, room_id):
async def get_local_groups_for_room(self, room_id: str) -> List[str]:
"""Get all of the local group that contain a given room
Args:
room_id (str): The ID of a room
room_id: The ID of a room
Returns:
Deferred[list[str]]: A twisted.Deferred containing a list of group ids
containing this room
A list of group ids containing this room
"""
return self.db_pool.simple_select_onecol(
return await self.db_pool.simple_select_onecol(
table="group_rooms",
keyvalues={"room_id": room_id},
retcol="group_id",
Expand Down Expand Up @@ -422,10 +423,10 @@ def _get_users_membership_in_group_txn(txn):
"get_users_membership_info_in_group", _get_users_membership_in_group_txn
)

def get_publicised_groups_for_user(self, user_id):
async def get_publicised_groups_for_user(self, user_id: str) -> List[str]:
"""Get all groups a user is publicising
"""
return self.db_pool.simple_select_onecol(
return await self.db_pool.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
retcol="group_id",
Expand Down Expand Up @@ -466,8 +467,8 @@ async def get_remote_attestation(self, group_id, user_id):

return None

def get_joined_groups(self, user_id):
return self.db_pool.simple_select_onecol(
async def get_joined_groups(self, user_id: str) -> List[str]:
return await self.db_pool.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join"},
retcol="group_id",
Expand Down Expand Up @@ -585,14 +586,14 @@ def _get_all_groups_changes_txn(txn):


class GroupServerStore(GroupServerWorkerStore):
def set_group_join_policy(self, group_id, join_policy):
async def set_group_join_policy(self, group_id: str, join_policy: str) -> None:
"""Set the join policy of a group.
join_policy can be one of:
* "invite"
* "open"
"""
return self.db_pool.simple_update_one(
await self.db_pool.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
updatevalues={"join_policy": join_policy},
Expand Down Expand Up @@ -1050,8 +1051,10 @@ def add_room_to_group(self, group_id, room_id, is_public):
desc="add_room_to_group",
)

def update_room_in_group_visibility(self, group_id, room_id, is_public):
return self.db_pool.simple_update(
async def update_room_in_group_visibility(
self, group_id: str, room_id: str, is_public: bool
) -> int:
return await self.db_pool.simple_update(
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
updatevalues={"is_public": is_public},
Expand All @@ -1076,10 +1079,12 @@ def _remove_room_from_group_txn(txn):
"remove_room_from_group", _remove_room_from_group_txn
)

def update_group_publicity(self, group_id, user_id, publicise):
async def update_group_publicity(
self, group_id: str, user_id: str, publicise: bool
) -> None:
"""Update whether the user is publicising their membership of the group
"""
return self.db_pool.simple_update_one(
await self.db_pool.simple_update_one(
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"is_publicised": publicise},
Expand Down Expand Up @@ -1218,20 +1223,24 @@ async def update_group_profile(self, group_id, profile):
desc="update_group_profile",
)

def update_attestation_renewal(self, group_id, user_id, attestation):
async def update_attestation_renewal(
self, group_id: str, user_id: str, attestation: dict
) -> None:
"""Update an attestation that we have renewed
"""
return self.db_pool.simple_update_one(
await self.db_pool.simple_update_one(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
desc="update_attestation_renewal",
)

def update_remote_attestion(self, group_id, user_id, attestation):
async def update_remote_attestion(
self, group_id: str, user_id: str, attestation: dict
) -> None:
"""Update an attestation that a remote has renewed
"""
return self.db_pool.simple_update_one(
await self.db_pool.simple_update_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={
Expand Down
Loading

0 comments on commit 4a739c7

Please sign in to comment.