Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert appservice, group server, profile and more databases to async #8066

Merged
merged 5 commits into from
Aug 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8066.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
34 changes: 13 additions & 21 deletions synapse/storage/databases/main/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

from canonicaljson import json

from twisted.internet import defer

from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
from synapse.storage._base import SQLBaseStore, db_to_json
Expand Down Expand Up @@ -124,17 +122,15 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
class ApplicationServiceTransactionWorkerStore(
ApplicationServiceWorkerStore, EventsWorkerStore
):
@defer.inlineCallbacks
def get_appservices_by_state(self, state):
async def get_appservices_by_state(self, state):
"""Get a list of application services based on their state.

Args:
state(ApplicationServiceState): The state to filter on.
Returns:
A Deferred which resolves to a list of ApplicationServices, which
may be empty.
A list of ApplicationServices, which may be empty.
"""
results = yield self.db_pool.simple_select_list(
results = await self.db_pool.simple_select_list(
"application_services_state", {"state": state}, ["as_id"]
)
# NB: This assumes this class is linked with ApplicationServiceStore
Expand All @@ -147,16 +143,15 @@ def get_appservices_by_state(self, state):
services.append(service)
return services

@defer.inlineCallbacks
def get_appservice_state(self, service):
async def get_appservice_state(self, service):
"""Get the application service state.

Args:
service(ApplicationService): The service whose state to set.
Returns:
A Deferred which resolves to ApplicationServiceState.
An ApplicationServiceState.
"""
result = yield self.db_pool.simple_select_one(
result = await self.db_pool.simple_select_one(
"application_services_state",
{"as_id": service.id},
["state"],
Expand Down Expand Up @@ -270,16 +265,14 @@ def _complete_appservice_txn(txn):
"complete_appservice_txn", _complete_appservice_txn
)

@defer.inlineCallbacks
def get_oldest_unsent_txn(self, service):
async def get_oldest_unsent_txn(self, service):
"""Get the oldest transaction which has not been sent for this
service.

Args:
service(ApplicationService): The app service to get the oldest txn.
Returns:
A Deferred which resolves to an AppServiceTransaction or
None.
An AppServiceTransaction or None.
"""

def _get_oldest_unsent_txn(txn):
Expand All @@ -298,7 +291,7 @@ def _get_oldest_unsent_txn(txn):

return entry

entry = yield self.db_pool.runInteraction(
entry = await self.db_pool.runInteraction(
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
)

Expand All @@ -307,7 +300,7 @@ def _get_oldest_unsent_txn(txn):

event_ids = db_to_json(entry["event_ids"])

events = yield self.get_events_as_list(event_ids)
events = await self.get_events_as_list(event_ids)

return AppServiceTransaction(service=service, id=entry["txn_id"], events=events)

Expand All @@ -332,8 +325,7 @@ def set_appservice_last_pos_txn(txn):
"set_appservice_last_pos", set_appservice_last_pos_txn
)

@defer.inlineCallbacks
def get_new_events_for_appservice(self, current_id, limit):
async def get_new_events_for_appservice(self, current_id, limit):
"""Get all new evnets"""

def get_new_events_for_appservice_txn(txn):
Expand All @@ -357,11 +349,11 @@ def get_new_events_for_appservice_txn(txn):

return upper_bound, [row[1] for row in rows]

upper_bound, event_ids = yield self.db_pool.runInteraction(
upper_bound, event_ids = await self.db_pool.runInteraction(
"get_new_events_for_appservice", get_new_events_for_appservice_txn
)

events = yield self.get_events_as_list(event_ids)
events = await self.get_events_as_list(event_ids)

return upper_bound, events

Expand Down
8 changes: 4 additions & 4 deletions synapse/storage/databases/main/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@

from synapse.api.errors import Codes, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.descriptors import cached


class FilteringStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2)
def get_user_filter(self, user_localpart, filter_id):
@cached(num_args=2)
async def get_user_filter(self, user_localpart, filter_id):
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
# with a coherent error message rather than 500 M_UNKNOWN.
try:
int(filter_id)
except ValueError:
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)

def_json = yield self.db_pool.simple_select_one_onecol(
def_json = await self.db_pool.simple_select_one_onecol(
table="user_filters",
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
retcol="filter_json",
Expand Down
86 changes: 39 additions & 47 deletions synapse/storage/databases/main/group_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple

from twisted.internet import defer
from typing import List, Optional, Tuple

from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.types import JsonDict
from synapse.util import json_encoder

# The category ID for the "default" category. We don't store as null in the
Expand Down Expand Up @@ -210,9 +209,8 @@ def _get_rooms_for_summary_txn(txn):
"get_rooms_for_summary", _get_rooms_for_summary_txn
)

@defer.inlineCallbacks
def get_group_categories(self, group_id):
rows = yield self.db_pool.simple_select_list(
async def get_group_categories(self, group_id):
rows = await self.db_pool.simple_select_list(
table="group_room_categories",
keyvalues={"group_id": group_id},
retcols=("category_id", "is_public", "profile"),
Expand All @@ -227,9 +225,8 @@ def get_group_categories(self, group_id):
for row in rows
}

@defer.inlineCallbacks
def get_group_category(self, group_id, category_id):
category = yield self.db_pool.simple_select_one(
async def get_group_category(self, group_id, category_id):
category = await self.db_pool.simple_select_one(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
retcols=("is_public", "profile"),
Expand All @@ -240,9 +237,8 @@ def get_group_category(self, group_id, category_id):

return category

@defer.inlineCallbacks
def get_group_roles(self, group_id):
rows = yield self.db_pool.simple_select_list(
async def get_group_roles(self, group_id):
rows = await self.db_pool.simple_select_list(
table="group_roles",
keyvalues={"group_id": group_id},
retcols=("role_id", "is_public", "profile"),
Expand All @@ -257,9 +253,8 @@ def get_group_roles(self, group_id):
for row in rows
}

@defer.inlineCallbacks
def get_group_role(self, group_id, role_id):
role = yield self.db_pool.simple_select_one(
async def get_group_role(self, group_id, role_id):
role = await self.db_pool.simple_select_one(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
retcols=("is_public", "profile"),
Expand Down Expand Up @@ -448,12 +443,11 @@ def _get_attestations_need_renewals_txn(txn):
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
)

@defer.inlineCallbacks
def get_remote_attestation(self, group_id, user_id):
async def get_remote_attestation(self, group_id, user_id):
"""Get the attestation that proves the remote agrees that the user is
in the group.
"""
row = yield self.db_pool.simple_select_one(
row = await self.db_pool.simple_select_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
retcols=("valid_until_ms", "attestation_json"),
Expand Down Expand Up @@ -499,13 +493,13 @@ def _get_all_groups_for_user_txn(txn):
"get_all_groups_for_user", _get_all_groups_for_user_txn
)

def get_groups_changes_for_user(self, user_id, from_token, to_token):
async def get_groups_changes_for_user(self, user_id, from_token, to_token):
from_token = int(from_token)
has_changed = self._group_updates_stream_cache.has_entity_changed(
user_id, from_token
)
if not has_changed:
return defer.succeed([])
return []

def _get_groups_changes_for_user_txn(txn):
sql = """
Expand All @@ -525,7 +519,7 @@ def _get_groups_changes_for_user_txn(txn):
for group_id, membership, gtype, content_json in txn
]

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_groups_changes_for_user", _get_groups_changes_for_user_txn
)

Expand Down Expand Up @@ -1087,31 +1081,31 @@ def update_group_publicity(self, group_id, user_id, publicise):
desc="update_group_publicity",
)

@defer.inlineCallbacks
def register_user_group_membership(
async def register_user_group_membership(
self,
group_id,
user_id,
membership,
is_admin=False,
content={},
local_attestation=None,
remote_attestation=None,
is_publicised=False,
):
group_id: str,
user_id: str,
membership: str,
is_admin: bool = False,
content: JsonDict = {},
local_attestation: Optional[dict] = None,
remote_attestation: Optional[dict] = None,
is_publicised: bool = False,
) -> int:
"""Registers that a local user is a member of a (local or remote) group.

Args:
group_id (str)
user_id (str)
membership (str)
is_admin (bool)
content (dict): Content of the membership, e.g. includes the inviter
group_id: The group the member is being added to.
user_id: THe user ID to add to the group.
membership: The type of group membership.
is_admin: Whether the user should be added as a group admin.
content: Content of the membership, e.g. includes the inviter
if the user has been invited.
local_attestation (dict): If remote group then store the fact that we
local_attestation: If remote group then store the fact that we
have given out an attestation, else None.
remote_attestation (dict): If remote group then store the remote
remote_attestation: If remote group then store the remote
attestation from the group, else None.
is_publicised: Whether this should be publicised.
"""

def _register_user_group_membership_txn(txn, next_id):
Expand Down Expand Up @@ -1188,18 +1182,17 @@ def _register_user_group_membership_txn(txn, next_id):
return next_id

with self._group_updates_id_gen.get_next() as next_id:
res = yield self.db_pool.runInteraction(
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
next_id,
)
return res

@defer.inlineCallbacks
def create_group(
async def create_group(
self, group_id, user_id, name, avatar_url, short_description, long_description
):
yield self.db_pool.simple_insert(
) -> None:
await self.db_pool.simple_insert(
table="groups",
values={
"group_id": group_id,
Expand All @@ -1212,9 +1205,8 @@ def create_group(
desc="create_group",
)

@defer.inlineCallbacks
def update_group_profile(self, group_id, profile):
yield self.db_pool.simple_update_one(
async def update_group_profile(self, group_id, profile):
await self.db_pool.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
updatevalues=profile,
Expand Down
7 changes: 2 additions & 5 deletions synapse/storage/databases/main/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,20 @@

from typing import List, Tuple

from twisted.internet import defer

from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.presence import UserPresenceState
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter


class PresenceStore(SQLBaseStore):
@defer.inlineCallbacks
def update_presence(self, presence_states):
async def update_presence(self, presence_states):
stream_ordering_manager = self._presence_id_gen.get_next_mult(
len(presence_states)
)

with stream_ordering_manager as stream_orderings:
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"update_presence",
self._update_presence_txn,
stream_orderings,
Expand Down
Loading