diff --git a/changelog.d/8231.misc b/changelog.d/8231.misc new file mode 100644 index 000000000000..979c8b227bbc --- /dev/null +++ b/changelog.d/8231.misc @@ -0,0 +1 @@ +Refactor queries for device keys and cross-signatures. diff --git a/changelog.d/8233.misc b/changelog.d/8233.misc new file mode 100644 index 000000000000..979c8b227bbc --- /dev/null +++ b/changelog.d/8233.misc @@ -0,0 +1 @@ +Refactor queries for device keys and cross-signatures. diff --git a/changelog.d/8234.misc b/changelog.d/8234.misc new file mode 100644 index 000000000000..979c8b227bbc --- /dev/null +++ b/changelog.d/8234.misc @@ -0,0 +1 @@ +Refactor queries for device keys and cross-signatures. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 8bedcdbdff06..add4e3ea0ec0 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -255,9 +255,7 @@ async def _get_device_update_edus_by_remote( List of objects representing an device update EDU """ devices = ( - await self.db_pool.runInteraction( - "get_e2e_device_keys_and_signatures_txn", - self._get_e2e_device_keys_and_signatures_txn, + await self.get_e2e_device_keys_and_signatures( query_map.keys(), include_all_devices=True, include_deleted_devices=True, @@ -293,15 +291,9 @@ async def _get_device_update_edus_by_remote( prev_id = stream_id if device is not None: - key_json = device.key_json - if key_json: - result["keys"] = db_to_json(key_json) - - if device.signatures: - for sig_user_id, sigs in device.signatures.items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) + keys = device.keys + if keys: + result["keys"] = keys device_display_name = device.display_name if device_display_name: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 4059701cfda7..c420303fb0f9 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -25,6 +25,7 @@ from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import make_in_list_sql_clause +from synapse.storage.types import Cursor from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -36,17 +37,14 @@ @attr.s class DeviceKeyLookupResult: - """The type returned by _get_e2e_device_keys_and_signatures_txn""" + """The type returned by get_e2e_device_keys_and_signatures""" display_name = attr.ib(type=Optional[str]) # the key data from e2e_device_keys_json. Typically includes fields like # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing - # key) and "signatures" (a signature of the structure by the ed25519 key) - key_json = attr.ib(type=Optional[str]) - - # cross-signing sigs - signatures = attr.ib(type=Optional[Dict], default=None) + # key) and signatures" (a map from (user id) to (key id/device_id) to signature.) + keys = attr.ib(type=Optional[JsonDict]) class EndToEndKeyWorkerStore(SQLBaseStore): @@ -60,11 +58,7 @@ async def get_e2e_device_keys_for_federation_query( """ now_stream_id = self.get_device_stream_token() - devices = await self.db_pool.runInteraction( - "get_e2e_device_keys_and_signatures_txn", - self._get_e2e_device_keys_and_signatures_txn, - [(user_id, None)], - ) + devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)]) if devices: user_devices = devices[user_id] @@ -72,15 +66,9 @@ async def get_e2e_device_keys_for_federation_query( for device_id, device in user_devices.items(): result = {"device_id": device_id} - key_json = device.key_json - if key_json: - result["keys"] = db_to_json(key_json) - - if device.signatures: - for sig_user_id, sigs in device.signatures.items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) + keys = device.keys + if keys: + result["keys"] = keys device_display_name = device.display_name if device_display_name: @@ -108,11 +96,7 @@ async def get_e2e_device_keys_for_cs_api( if not query_list: return {} - results = await self.db_pool.runInteraction( - "get_e2e_device_keys_and_signatures_txn", - self._get_e2e_device_keys_and_signatures_txn, - query_list, - ) + results = await self.get_e2e_device_keys_and_signatures(query_list) # Build the result structure, un-jsonify the results, and add the # "unsigned" section @@ -120,31 +104,99 @@ async def get_e2e_device_keys_for_cs_api( for user_id, device_keys in results.items(): rv[user_id] = {} for device_id, device_info in device_keys.items(): - r = db_to_json(device_info.key_json) + r = device_info.keys r["unsigned"] = {} display_name = device_info.display_name if display_name is not None: r["unsigned"]["device_display_name"] = display_name - if device_info.signatures: - for sig_user_id, sigs in device_info.signatures.items(): - r.setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) rv[user_id][device_id] = r return rv @trace - def _get_e2e_device_keys_and_signatures_txn( - self, txn, query_list, include_all_devices=False, include_deleted_devices=False + async def get_e2e_device_keys_and_signatures( + self, + query_list: List[Tuple[str, Optional[str]]], + include_all_devices: bool = False, + include_deleted_devices: bool = False, ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: + """Fetch a list of device keys, together with their cross-signatures. + + The cross-signatures are added to the `signatures` field within the `keys` + object in the response. + + Args: + query_list: List of pairs of user_ids and device_ids. Device id can be None + to indicate "all devices for this user" + + include_all_devices: whether to return devices without device keys + + include_deleted_devices:whether to include null entries for + devices which no longer exist (but were in the query_list). + This option only takes effect if include_all_devices is true. + + Returns: + Dict mapping from user-id to dict mapping from device_id to + key data. + """ set_tag("include_all_devices", include_all_devices) set_tag("include_deleted_devices", include_deleted_devices) + result = await self.db_pool.runInteraction( + "get_e2e_device_keys", + self._get_e2e_device_keys_txn, + query_list, + include_all_devices, + include_deleted_devices, + ) + + # get a the (user_id, device_id) tuples to look up cross-signatures for + signature_query = ( + [user_id, device_id] + for user_id, dev in result.items() + for device_id, d in dev.items() + if d is not None and d.keys is not None + ) + + for batch in batch_iter(signature_query, 50): + cross_sigs_result = await self.db_pool.runInteraction( + "get_e2e_cross_signing_signatures", + self._get_e2e_cross_signing_signatures_txn, + batch, + ) + + # add each cross-signing signature to the correct device in the result dict. + for row in cross_sigs_result: + signing_user_id = row["user_id"] + signing_key_id = row["key_id"] + target_user_id = row["target_user_id"] + target_device_id = row["target_device_id"] + signature = row["signature"] + + target_device_result = result[target_user_id][target_device_id] + target_device_signatures = target_device_result.keys.setdefault( + "signatures", {} + ) + + signing_user_signatures = target_device_signatures.setdefault( + signing_user_id, {} + ) + signing_user_signatures[signing_key_id] = signature + + log_kv(result) + return result + + def _get_e2e_device_keys_txn( + self, txn, query_list, include_all_devices=False, include_deleted_devices=False + ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: + """Get information on devices from the database + + The results include the device's keys and self-signatures, but *not* any + cross-signing signatures which have been added subsequently (for which, see + get_e2e_device_keys_and_signatures) + """ query_clauses = [] query_params = [] - signature_query_clauses = [] - signature_query_params = [] if include_all_devices is False: include_deleted_devices = False @@ -155,20 +207,12 @@ def _get_e2e_device_keys_and_signatures_txn( for (user_id, device_id) in query_list: query_clause = "user_id = ?" query_params.append(user_id) - signature_query_clause = "target_user_id = ?" - signature_query_params.append(user_id) if device_id is not None: query_clause += " AND device_id = ?" query_params.append(device_id) - signature_query_clause += " AND target_device_id = ?" - signature_query_params.append(device_id) - - signature_query_clause += " AND user_id = ?" - signature_query_params.append(user_id) query_clauses.append(query_clause) - signature_query_clauses.append(signature_query_clause) sql = ( "SELECT user_id, device_id, " @@ -189,49 +233,36 @@ def _get_e2e_device_keys_and_signatures_txn( if include_deleted_devices: deleted_devices.remove((user_id, device_id)) result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult( - display_name, key_json + display_name, db_to_json(key_json) if key_json else None ) if include_deleted_devices: for user_id, device_id in deleted_devices: result.setdefault(user_id, {})[device_id] = None - # get signatures on the device - signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % ( - " OR ".join("(" + q + ")" for q in signature_query_clauses) - ) - - txn.execute(signature_sql, signature_query_params) - rows = self.db_pool.cursor_to_dict(txn) - - # add each cross-signing signature to the correct device in the result dict. - for row in rows: - signing_user_id = row["user_id"] - signing_key_id = row["key_id"] - target_user_id = row["target_user_id"] - target_device_id = row["target_device_id"] - signature = row["signature"] - - target_user_result = result.get(target_user_id) - if not target_user_result: - continue - - target_device_result = target_user_result.get(target_device_id) - if not target_device_result: - # note that target_device_result will be None for deleted devices. - continue + return result - target_device_signatures = target_device_result.signatures - if target_device_signatures is None: - target_device_signatures = target_device_result.signatures = {} + def _get_e2e_cross_signing_signatures_txn( + self, txn: Cursor, device_query: Iterable[Tuple[str, str]] + ) -> List[Dict]: + """Get cross-signing signatures for a given list of devices""" + signature_query_clauses = [] + signature_query_params = [] - signing_user_signatures = target_device_signatures.setdefault( - signing_user_id, {} + for (user_id, device_id) in device_query: + # XXX: I don't really know why we limit this by user_id, but there's + # probably a good reason. + signature_query_clauses.append( + "target_user_id = ? AND target_device_id = ? AND user_id = ?" ) - signing_user_signatures[signing_key_id] = signature + signature_query_params.extend([user_id, device_id, user_id]) - log_kv(result) - return result + signature_sql = "SELECT * FROM e2e_cross_signing_signatures WHERE %s" % ( + " OR ".join("(" + q + ")" for q in signature_query_clauses) + ) + + txn.execute(signature_sql, signature_query_params) + return self.db_pool.cursor_to_dict(txn) async def get_e2e_one_time_keys( self, user_id: str, device_id: str, key_ids: List[str]