diff --git a/changelog.d/8168.misc b/changelog.d/8168.misc new file mode 100644 index 000000000000..dfe4c03171d6 --- /dev/null +++ b/changelog.d/8168.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 56818f4df883..0db900fa0e0b 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -414,13 +414,14 @@ async def updater(progress, batch_size): self.register_background_update_handler(update_name, updater) - def _end_background_update(self, update_name): + async def _end_background_update(self, update_name: str) -> None: """Removes a completed background update task from the queue. Args: - update_name(str): The name of the completed task to remove + update_name:: The name of the completed task to remove + Returns: - A deferred that completes once the task is removed. + None, completes once the task is removed. """ if update_name != self._current_background_update: raise Exception( @@ -428,7 +429,7 @@ def _end_background_update(self, update_name): % update_name ) self._current_background_update = None - return self.db_pool.simple_delete_one( + await self.db_pool.simple_delete_one( "background_updates", keyvalues={"update_name": update_name} ) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 38010af60041..2f6f49a4bf84 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -605,7 +605,13 @@ def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]: results = [dict(zip(col_headers, row)) for row in cursor] return results - def execute(self, desc: str, decoder: Callable, query: str, *args: Any): + async def execute( + self, + desc: str, + decoder: Optional[Callable[[Cursor], R]], + query: str, + *args: Any + ) -> R: """Runs a single query for a result set. Args: @@ -614,7 +620,7 @@ def execute(self, desc: str, decoder: Callable, query: str, *args: Any): query - The query string to execute *args - Query args. Returns: - Deferred which results to the result of decoder(results) + The result of decoder(results) """ def interaction(txn): @@ -624,7 +630,7 @@ def interaction(txn): else: return txn.fetchall() - return self.runInteraction(desc, interaction) + return await self.runInteraction(desc, interaction) # "Simple" SQL API methods that operate on a single table with no JOINs, # no complex WHERE clauses, just a dict of values for columns. @@ -673,15 +679,30 @@ def simple_insert_txn( txn.execute(sql, vals) - def simple_insert_many( + async def simple_insert_many( self, table: str, values: List[Dict[str, Any]], desc: str - ) -> defer.Deferred: - return self.runInteraction(desc, self.simple_insert_many_txn, table, values) + ) -> None: + """Executes an INSERT query on the named table. + + Args: + table: string giving the table name + values: dict of new column names and values for them + desc: string giving a description of the transaction + """ + await self.runInteraction(desc, self.simple_insert_many_txn, table, values) @staticmethod def simple_insert_many_txn( txn: LoggingTransaction, table: str, values: List[Dict[str, Any]] ) -> None: + """Executes an INSERT query on the named table. + + Args: + txn: The transaction to use. + table: string giving the table name + values: dict of new column names and values for them + desc: string giving a description of the transaction + """ if not values: return @@ -1397,9 +1418,9 @@ def simple_select_one_txn( return dict(zip(retcols, row)) - def simple_delete_one( + async def simple_delete_one( self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one" - ) -> defer.Deferred: + ) -> None: """Executes a DELETE query on the named table, expecting to delete a single row. @@ -1407,7 +1428,7 @@ def simple_delete_one( table: string giving the table name keyvalues: dict of column names and values to select the row with """ - return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) + await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) @staticmethod def simple_delete_one_txn( @@ -1446,15 +1467,15 @@ def simple_delete_txn( txn.execute(sql, list(keyvalues.values())) return txn.rowcount - def simple_delete_many( + async def simple_delete_many( self, table: str, column: str, iterable: Iterable[Any], keyvalues: Dict[str, Any], desc: str, - ) -> defer.Deferred: - return self.runInteraction( + ) -> int: + return await self.runInteraction( desc, self.simple_delete_many_txn, table, column, iterable, keyvalues ) @@ -1537,52 +1558,6 @@ def get_cache_dict( return cache, min_val - def simple_select_list_paginate( - self, - table: str, - orderby: str, - start: int, - limit: int, - retcols: Iterable[str], - filters: Optional[Dict[str, Any]] = None, - keyvalues: Optional[Dict[str, Any]] = None, - order_direction: str = "ASC", - desc: str = "simple_select_list_paginate", - ) -> defer.Deferred: - """ - Executes a SELECT query on the named table with start and limit, - of row numbers, which may return zero or number of rows from start to limit, - returning the result as a list of dicts. - - Args: - table: the table name - orderby: Column to order the results by. - start: Index to begin the query at. - limit: Number of results to return. - retcols: the names of the columns to return - filters: - column names and values to filter the rows with, or None to not - apply a WHERE ? LIKE ? clause. - keyvalues: - column names and values to select the rows with, or None to not - apply a WHERE clause. - order_direction: Whether the results should be ordered "ASC" or "DESC". - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self.runInteraction( - desc, - self.simple_select_list_paginate_txn, - table, - orderby, - start, - limit, - retcols, - filters=filters, - keyvalues=keyvalues, - order_direction=order_direction, - ) - @classmethod def simple_select_list_paginate_txn( cls, @@ -1647,14 +1622,14 @@ def simple_select_list_paginate_txn( return cls.cursor_to_dict(txn) - def simple_search_list( + async def simple_search_list( self, table: str, term: Optional[str], col: str, retcols: Iterable[str], desc="simple_search_list", - ): + ) -> Optional[List[Dict[str, Any]]]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1665,10 +1640,10 @@ def simple_search_list( retcols: the names of the columns to return Returns: - defer.Deferred: resolves to list[dict[str, Any]] or None + A list of dictionaries or None. """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_search_list_txn, table, term, col, retcols ) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 8b9b6eb47241..70cf15dd7f46 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -18,7 +18,7 @@ import calendar import logging import time -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from synapse.api.constants import PresenceState from synapse.config.homeserver import HomeServerConfig @@ -559,17 +559,17 @@ def get_users_paginate_txn(txn): "get_users_paginate_txn", get_users_paginate_txn ) - def search_users(self, term): + async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]: """Function to search users list for one or more users with the matched term. Args: - term (str): search term - col (str): column to query term should be matched to + term: search term + Returns: - defer.Deferred: resolves to list[dict[str, Any]] + A list of dictionaries or None. """ - return self.db_pool.simple_search_list( + return await self.db_pool.simple_search_list( table="users", term=term, col="name", diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 385868bdab3f..af0b85e2c92c 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json @@ -27,6 +27,9 @@ from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter +if TYPE_CHECKING: + from synapse.handlers.e2e_keys import SignatureListItem + class EndToEndKeyWorkerStore(SQLBaseStore): @trace @@ -730,14 +733,16 @@ async def set_e2e_cross_signing_key(self, user_id, key_type, key): stream_id, ) - def store_e2e_cross_signing_signatures(self, user_id, signatures): + async def store_e2e_cross_signing_signatures( + self, user_id: str, signatures: "Iterable[SignatureListItem]" + ) -> None: """Stores cross-signing signatures. Args: - user_id (str): the user who made the signatures - signatures (iterable[SignatureListItem]): signatures to add + user_id: the user who made the signatures + signatures: signatures to add """ - return self.db_pool.simple_insert_many( + await self.db_pool.simple_insert_many( "e2e_cross_signing_signatures", [ { diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index fc223f5a2ab3..8361dd63d964 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -314,14 +314,14 @@ def store_remote_media_thumbnail( desc="store_remote_media_thumbnail", ) - def get_remote_media_before(self, before_ts): + async def get_remote_media_before(self, before_ts): sql = ( "SELECT media_origin, media_id, filesystem_id" " FROM remote_media_cache" " WHERE last_access_ts < ?" ) - return self.db_pool.execute( + return await self.db_pool.execute( "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts ) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 2efbc97c2e62..1a1c59256cfc 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -67,13 +67,12 @@ def update(progress, count): # second step: complete the update # we should now get run with a much bigger number of items to update - @defer.inlineCallbacks - def update(progress, count): + async def update(progress, count): self.assertEqual(progress, {"my_key": 2}) self.assertAlmostEqual( count, target_background_update_duration_ms / duration_ms, places=0, ) - yield self.updates._end_background_update("test_update") + await self.updates._end_background_update("test_update") return count self.update_handler.side_effect = update diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 64abe8cc49d1..40ba652248ce 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -197,8 +197,10 @@ def test_update_one_4cols(self): def test_delete_one(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_delete_one( - table="tablename", keyvalues={"keycol": "Go away"} + yield defer.ensureDeferred( + self.datastore.db_pool.simple_delete_one( + table="tablename", keyvalues={"keycol": "Go away"} + ) ) self.mock_txn.execute.assert_called_with(