Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert additional database methods to async (select list, search, insert_many, delete_*) #8168

Merged
merged 12 commits into from
Aug 27, 2020
1 change: 1 addition & 0 deletions changelog.d/8168.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
9 changes: 5 additions & 4 deletions synapse/storage/background_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,21 +414,22 @@ async def updater(progress, batch_size):

self.register_background_update_handler(update_name, updater)

def _end_background_update(self, update_name):
async def _end_background_update(self, update_name: str) -> None:
"""Removes a completed background update task from the queue.

Args:
update_name(str): The name of the completed task to remove
update_name:: The name of the completed task to remove

Returns:
A deferred that completes once the task is removed.
None, completes once the task is removed.
"""
if update_name != self._current_background_update:
raise Exception(
"Cannot end background update %s which isn't currently running"
% update_name
)
self._current_background_update = None
return self.db_pool.simple_delete_one(
await self.db_pool.simple_delete_one(
"background_updates", keyvalues={"update_name": update_name}
)

Expand Down
99 changes: 37 additions & 62 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,13 @@ def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
results = [dict(zip(col_headers, row)) for row in cursor]
return results

def execute(self, desc: str, decoder: Callable, query: str, *args: Any):
async def execute(
self,
desc: str,
decoder: Optional[Callable[[Cursor], R]],
query: str,
*args: Any
) -> R:
"""Runs a single query for a result set.

Args:
Expand All @@ -614,7 +620,7 @@ def execute(self, desc: str, decoder: Callable, query: str, *args: Any):
query - The query string to execute
*args - Query args.
Returns:
Deferred which results to the result of decoder(results)
The result of decoder(results)
"""

def interaction(txn):
Expand All @@ -624,7 +630,7 @@ def interaction(txn):
else:
return txn.fetchall()

return self.runInteraction(desc, interaction)
return await self.runInteraction(desc, interaction)

# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
Expand Down Expand Up @@ -673,15 +679,30 @@ def simple_insert_txn(

txn.execute(sql, vals)

def simple_insert_many(
async def simple_insert_many(
self, table: str, values: List[Dict[str, Any]], desc: str
) -> defer.Deferred:
return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
) -> None:
"""Executes an INSERT query on the named table.

Args:
table: string giving the table name
values: dict of new column names and values for them
desc: string giving a description of the transaction
"""
await self.runInteraction(desc, self.simple_insert_many_txn, table, values)

@staticmethod
def simple_insert_many_txn(
txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
) -> None:
"""Executes an INSERT query on the named table.

Args:
txn: The transaction to use.
table: string giving the table name
values: dict of new column names and values for them
desc: string giving a description of the transaction
"""
if not values:
return

Expand Down Expand Up @@ -1397,17 +1418,17 @@ def simple_select_one_txn(

return dict(zip(retcols, row))

def simple_delete_one(
async def simple_delete_one(
self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
) -> defer.Deferred:
) -> None:
"""Executes a DELETE query on the named table, expecting to delete a
single row.

Args:
table: string giving the table name
keyvalues: dict of column names and values to select the row with
"""
return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)

@staticmethod
def simple_delete_one_txn(
Expand Down Expand Up @@ -1446,15 +1467,15 @@ def simple_delete_txn(
txn.execute(sql, list(keyvalues.values()))
return txn.rowcount

def simple_delete_many(
async def simple_delete_many(
self,
table: str,
column: str,
iterable: Iterable[Any],
keyvalues: Dict[str, Any],
desc: str,
) -> defer.Deferred:
return self.runInteraction(
) -> int:
return await self.runInteraction(
desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
)

Expand Down Expand Up @@ -1537,52 +1558,6 @@ def get_cache_dict(

return cache, min_val

def simple_select_list_paginate(
self,
table: str,
orderby: str,
start: int,
limit: int,
retcols: Iterable[str],
filters: Optional[Dict[str, Any]] = None,
keyvalues: Optional[Dict[str, Any]] = None,
order_direction: str = "ASC",
desc: str = "simple_select_list_paginate",
) -> defer.Deferred:
"""
Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts.

Args:
table: the table name
orderby: Column to order the results by.
start: Index to begin the query at.
limit: Number of results to return.
retcols: the names of the columns to return
filters:
column names and values to filter the rows with, or None to not
apply a WHERE ? LIKE ? clause.
keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
order_direction: Whether the results should be ordered "ASC" or "DESC".
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
return self.runInteraction(
desc,
self.simple_select_list_paginate_txn,
table,
orderby,
start,
limit,
retcols,
filters=filters,
keyvalues=keyvalues,
order_direction=order_direction,
)

@classmethod
def simple_select_list_paginate_txn(
cls,
Expand Down Expand Up @@ -1647,14 +1622,14 @@ def simple_select_list_paginate_txn(

return cls.cursor_to_dict(txn)

def simple_search_list(
async def simple_search_list(
self,
table: str,
term: Optional[str],
col: str,
retcols: Iterable[str],
desc="simple_search_list",
):
) -> Optional[List[Dict[str, Any]]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.

Expand All @@ -1665,10 +1640,10 @@ def simple_search_list(
retcols: the names of the columns to return

Returns:
defer.Deferred: resolves to list[dict[str, Any]] or None
A list of dictionaries or None.
"""

return self.runInteraction(
return await self.runInteraction(
desc, self.simple_search_list_txn, table, term, col, retcols
)

Expand Down
12 changes: 6 additions & 6 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import calendar
import logging
import time
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig
Expand Down Expand Up @@ -559,17 +559,17 @@ def get_users_paginate_txn(txn):
"get_users_paginate_txn", get_users_paginate_txn
)

def search_users(self, term):
async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]:
"""Function to search users list for one or more users with
the matched term.

Args:
term (str): search term
col (str): column to query term should be matched to
term: search term

Returns:
defer.Deferred: resolves to list[dict[str, Any]]
A list of dictionaries or None.
"""
return self.db_pool.simple_search_list(
return await self.db_pool.simple_search_list(
table="users",
term=term,
col="name",
Expand Down
15 changes: 10 additions & 5 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple

from canonicaljson import encode_canonical_json

Expand All @@ -27,6 +27,9 @@
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter

if TYPE_CHECKING:
from synapse.handlers.e2e_keys import SignatureListItem


class EndToEndKeyWorkerStore(SQLBaseStore):
@trace
Expand Down Expand Up @@ -730,14 +733,16 @@ async def set_e2e_cross_signing_key(self, user_id, key_type, key):
stream_id,
)

def store_e2e_cross_signing_signatures(self, user_id, signatures):
async def store_e2e_cross_signing_signatures(
self, user_id: str, signatures: "Iterable[SignatureListItem]"
) -> None:
"""Stores cross-signing signatures.

Args:
user_id (str): the user who made the signatures
signatures (iterable[SignatureListItem]): signatures to add
user_id: the user who made the signatures
signatures: signatures to add
"""
return self.db_pool.simple_insert_many(
await self.db_pool.simple_insert_many(
"e2e_cross_signing_signatures",
[
{
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,14 @@ def store_remote_media_thumbnail(
desc="store_remote_media_thumbnail",
)

def get_remote_media_before(self, before_ts):
async def get_remote_media_before(self, before_ts):
sql = (
"SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache"
" WHERE last_access_ts < ?"
)

return self.db_pool.execute(
return await self.db_pool.execute(
"get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
)

Expand Down
5 changes: 2 additions & 3 deletions tests/storage/test_background_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,12 @@ def update(progress, count):

# second step: complete the update
# we should now get run with a much bigger number of items to update
@defer.inlineCallbacks
def update(progress, count):
async def update(progress, count):
self.assertEqual(progress, {"my_key": 2})
self.assertAlmostEqual(
count, target_background_update_duration_ms / duration_ms, places=0,
)
yield self.updates._end_background_update("test_update")
await self.updates._end_background_update("test_update")
return count

self.update_handler.side_effect = update
Expand Down
6 changes: 4 additions & 2 deletions tests/storage/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,10 @@ def test_update_one_4cols(self):
def test_delete_one(self):
self.mock_txn.rowcount = 1

yield self.datastore.db_pool.simple_delete_one(
table="tablename", keyvalues={"keycol": "Go away"}
yield defer.ensureDeferred(
self.datastore.db_pool.simple_delete_one(
table="tablename", keyvalues={"keycol": "Go away"}
)
)

self.mock_txn.execute.assert_called_with(
Expand Down