Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert calls of async database methods to async #8166

Merged
merged 6 commits into from
Aug 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8166.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
16 changes: 9 additions & 7 deletions synapse/federation/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

import logging

from synapse.federation.units import Transaction
from synapse.logging.utils import log_function
from synapse.types import JsonDict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,15 +51,15 @@ def have_responded(self, origin, transaction):
return self.store.get_received_txn_response(transaction.transaction_id, origin)

@log_function
def set_response(self, origin, transaction, code, response):
async def set_response(
self, origin: str, transaction: Transaction, code: int, response: JsonDict
) -> None:
""" Persist how we responded to a transaction.

Returns:
Deferred
"""
if not transaction.transaction_id:
transaction_id = transaction.transaction_id # type: ignore
if not transaction_id:
raise RuntimeError("Cannot persist a transaction with no transaction_id")

return self.store.set_received_txn_response(
transaction.transaction_id, origin, code, response
await self.store.set_received_txn_response(
transaction_id, origin, code, response
)
4 changes: 1 addition & 3 deletions synapse/federation/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ def __init__(self, transaction_id=None, pdus=[], **kwargs):
if "edus" in kwargs and not kwargs["edus"]:
del kwargs["edus"]

super(Transaction, self).__init__(
transaction_id=transaction_id, pdus=pdus, **kwargs
)
super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs)

@staticmethod
def create_new(pdus, **kwargs):
Expand Down
6 changes: 2 additions & 4 deletions synapse/storage/databases/main/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,14 @@ async def get_appservice_state(self, service):
return result.get("state")
return None

def set_appservice_state(self, service, state):
async def set_appservice_state(self, service, state) -> None:
"""Set the application service state.

Args:
service(ApplicationService): The service whose state to set.
state(ApplicationServiceState): The connectivity state to apply.
Returns:
An Awaitable which resolves when the state was set successfully.
"""
return self.db_pool.simple_upsert(
await self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}
)

Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,11 +716,11 @@ async def get_user_ids_requiring_device_list_resync(

return {row["user_id"] for row in rows}

def mark_remote_user_device_cache_as_stale(self, user_id: str):
async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None:
"""Records that the server has reason to believe the cache of the devices
for the remote users is out of date.
"""
return self.db_pool.simple_upsert(
await self.db_pool.simple_upsert(
table="device_lists_remote_resync",
keyvalues={"user_id": user_id},
values={},
Expand Down
30 changes: 22 additions & 8 deletions synapse/storage/databases/main/group_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,13 @@ def remove_room_from_summary(self, group_id, room_id, category_id):
desc="remove_room_from_summary",
)

def upsert_group_category(self, group_id, category_id, profile, is_public):
async def upsert_group_category(
self,
group_id: str,
category_id: str,
profile: Optional[JsonDict],
is_public: Optional[bool],
) -> None:
"""Add/update room category for group
"""
insertion_values = {}
Expand All @@ -757,7 +763,7 @@ def upsert_group_category(self, group_id, category_id, profile, is_public):
else:
update_values["is_public"] = is_public

return self.db_pool.simple_upsert(
await self.db_pool.simple_upsert(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
values=update_values,
Expand All @@ -772,7 +778,13 @@ def remove_group_category(self, group_id, category_id):
desc="remove_group_category",
)

def upsert_group_role(self, group_id, role_id, profile, is_public):
async def upsert_group_role(
self,
group_id: str,
role_id: str,
profile: Optional[JsonDict],
is_public: Optional[bool],
) -> None:
"""Add/remove user role
"""
insertion_values = {}
Expand All @@ -788,7 +800,7 @@ def upsert_group_role(self, group_id, role_id, profile, is_public):
else:
update_values["is_public"] = is_public

return self.db_pool.simple_upsert(
await self.db_pool.simple_upsert(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
values=update_values,
Expand Down Expand Up @@ -937,10 +949,10 @@ def remove_user_from_summary(self, group_id, user_id, role_id):
desc="remove_user_from_summary",
)

def add_group_invite(self, group_id, user_id):
async def add_group_invite(self, group_id: str, user_id: str) -> None:
"""Record that the group server has invited a user
"""
return self.db_pool.simple_insert(
await self.db_pool.simple_insert(
table="group_invites",
values={"group_id": group_id, "user_id": user_id},
desc="add_group_invite",
Expand Down Expand Up @@ -1043,8 +1055,10 @@ def _remove_user_from_group_txn(txn):
"remove_user_from_group", _remove_user_from_group_txn
)

def add_room_to_group(self, group_id, room_id, is_public):
return self.db_pool.simple_insert(
async def add_room_to_group(
self, group_id: str, room_id: str, is_public: bool
) -> None:
await self.db_pool.simple_insert(
table="group_rooms",
values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
desc="add_room_to_group",
Expand Down
26 changes: 16 additions & 10 deletions synapse/storage/databases/main/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,22 +140,28 @@ async def store_server_verify_keys(
for i in invalidations:
invalidate((i,))

def store_server_keys_json(
self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
):
async def store_server_keys_json(
self,
server_name: str,
key_id: str,
from_server: str,
ts_now_ms: int,
ts_expires_ms: int,
key_json_bytes: bytes,
) -> None:
"""Stores the JSON bytes for a set of keys from a server
The JSON should be signed by the originating server, the intermediate
server, and by this server. Updates the value for the
(server_name, key_id, from_server) triplet if one already existed.
Args:
server_name (str): The name of the server.
key_id (str): The identifer of the key this JSON is for.
from_server (str): The server this JSON was fetched from.
ts_now_ms (int): The time now in milliseconds.
ts_valid_until_ms (int): The time when this json stops being valid.
key_json (bytes): The encoded JSON.
server_name: The name of the server.
key_id: The identifer of the key this JSON is for.
from_server: The server this JSON was fetched from.
ts_now_ms: The time now in milliseconds.
ts_valid_until_ms: The time when this json stops being valid.
key_json_bytes: The encoded JSON.
"""
return self.db_pool.simple_upsert(
await self.db_pool.simple_upsert(
table="server_keys_json",
keyvalues={
"server_name": server_name,
Expand Down
22 changes: 11 additions & 11 deletions synapse/storage/databases/main/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
desc="get_local_media",
)

def store_local_media(
async def store_local_media(
self,
media_id,
media_type,
Expand All @@ -69,8 +69,8 @@ def store_local_media(
media_length,
user_id,
url_cache=None,
):
return self.db_pool.simple_insert(
) -> None:
await self.db_pool.simple_insert(
"local_media_repository",
{
"media_id": media_id,
Expand Down Expand Up @@ -141,10 +141,10 @@ def get_url_cache_txn(txn):

return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)

def store_url_cache(
async def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
):
return self.db_pool.simple_insert(
await self.db_pool.simple_insert(
"local_media_repository_url_cache",
{
"url": url,
Expand Down Expand Up @@ -172,7 +172,7 @@ def get_local_media_thumbnails(self, media_id):
desc="get_local_media_thumbnails",
)

def store_local_thumbnail(
async def store_local_thumbnail(
self,
media_id,
thumbnail_width,
Expand All @@ -181,7 +181,7 @@ def store_local_thumbnail(
thumbnail_method,
thumbnail_length,
):
return self.db_pool.simple_insert(
await self.db_pool.simple_insert(
"local_media_repository_thumbnails",
{
"media_id": media_id,
Expand Down Expand Up @@ -212,7 +212,7 @@ async def get_cached_remote_media(
desc="get_cached_remote_media",
)

def store_cached_remote_media(
async def store_cached_remote_media(
self,
origin,
media_id,
Expand All @@ -222,7 +222,7 @@ def store_cached_remote_media(
upload_name,
filesystem_id,
):
return self.db_pool.simple_insert(
await self.db_pool.simple_insert(
"remote_media_cache",
{
"media_origin": origin,
Expand Down Expand Up @@ -286,7 +286,7 @@ def get_remote_media_thumbnails(self, origin, media_id):
desc="get_remote_media_thumbnails",
)

def store_remote_media_thumbnail(
async def store_remote_media_thumbnail(
self,
origin,
media_id,
Expand All @@ -297,7 +297,7 @@ def store_remote_media_thumbnail(
thumbnail_method,
thumbnail_length,
):
return self.db_pool.simple_insert(
await self.db_pool.simple_insert(
"remote_media_cache_thumbnails",
{
"media_origin": origin,
Expand Down
6 changes: 4 additions & 2 deletions synapse/storage/databases/main/openid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@


class OpenIdStore(SQLBaseStore):
def insert_open_id_token(self, token, ts_valid_until_ms, user_id):
return self.db_pool.simple_insert(
async def insert_open_id_token(
self, token: str, ts_valid_until_ms: int, user_id: str
) -> None:
await self.db_pool.simple_insert(
table="open_id_tokens",
values={
"token": token,
Expand Down
10 changes: 6 additions & 4 deletions synapse/storage/databases/main/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ async def get_from_remote_profile_cache(
desc="get_from_remote_profile_cache",
)

def create_profile(self, user_localpart):
return self.db_pool.simple_insert(
async def create_profile(self, user_localpart: str) -> None:
await self.db_pool.simple_insert(
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
)

Expand All @@ -89,13 +89,15 @@ def set_profile_avatar_url(self, user_localpart, new_avatar_url):


class ProfileStore(ProfileWorkerStore):
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
async def add_remote_profile_cache(
self, user_id: str, displayname: str, avatar_url: str
) -> None:
"""Ensure we are caching the remote user's profiles.

This should only be called when `is_subscribed_remote_profile_for_user`
would return true for the user.
"""
return self.db_pool.simple_upsert(
await self.db_pool.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
Expand Down
29 changes: 14 additions & 15 deletions synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import logging
import re
from typing import Any, Awaitable, Dict, List, Optional
from typing import Any, Dict, List, Optional

from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
Expand Down Expand Up @@ -549,23 +549,22 @@ def user_delete_threepids(self, user_id: str):
desc="user_delete_threepids",
)

def add_user_bound_threepid(self, user_id, medium, address, id_server):
async def add_user_bound_threepid(
self, user_id: str, medium: str, address: str, id_server: str
):
"""The server proxied a bind request to the given identity server on
behalf of the given user. We need to remember this in case the user
asks us to unbind the threepid.

Args:
user_id (str)
medium (str)
address (str)
id_server (str)

Returns:
Awaitable
user_id
medium
address
id_server
"""
# We need to use an upsert, in case they user had already bound the
# threepid
return self.db_pool.simple_upsert(
await self.db_pool.simple_upsert(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
Expand Down Expand Up @@ -1081,17 +1080,17 @@ def _register_user(

self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))

def record_user_external_id(
async def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
) -> Awaitable:
) -> None:
"""Record a mapping from an external user id to a mxid

Args:
auth_provider: identifier for the remote auth provider
external_id: id on that system
user_id: complete mxid that it is mapped to
"""
return self.db_pool.simple_insert(
await self.db_pool.simple_insert(
table="user_external_ids",
values={
"auth_provider": auth_provider,
Expand Down Expand Up @@ -1235,12 +1234,12 @@ async def is_guest(self, user_id: str) -> bool:

return res if res else False

def add_user_pending_deactivation(self, user_id):
async def add_user_pending_deactivation(self, user_id: str) -> None:
"""
Adds a user to the table of users who need to be parted from all the rooms they're
in
"""
return self.db_pool.simple_insert(
await self.db_pool.simple_insert(
"users_pending_deactivation",
values={"user_id": user_id},
desc="add_user_pending_deactivation",
Expand Down
Loading