From 2655d1b8c9b5003dc4782f83cf12de60e66dbfe2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 28 Aug 2020 17:11:57 +0100 Subject: [PATCH 1/4] wrap `_get_e2e_device_keys_and_signatures_txn` in a non-txn method We have three things which all call `_get_e2e_device_keys_and_signatures_txn` with their own `runInteraction`. Factor out the common code. --- changelog.d/8231.misc | 1 + synapse/storage/databases/main/devices.py | 4 +- .../storage/databases/main/end_to_end_keys.py | 52 ++++++++++++++----- 3 files changed, 40 insertions(+), 17 deletions(-) create mode 100644 changelog.d/8231.misc diff --git a/changelog.d/8231.misc b/changelog.d/8231.misc new file mode 100644 index 000000000000..979c8b227bbc --- /dev/null +++ b/changelog.d/8231.misc @@ -0,0 +1 @@ +Refactor queries for device keys and cross-signatures. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 8bedcdbdff06..f8fe948122a0 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -255,9 +255,7 @@ async def _get_device_update_edus_by_remote( List of objects representing an device update EDU """ devices = ( - await self.db_pool.runInteraction( - "get_e2e_device_keys_and_signatures_txn", - self._get_e2e_device_keys_and_signatures_txn, + await self.get_e2e_device_keys_and_signatures( query_map.keys(), include_all_devices=True, include_deleted_devices=True, diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 4059701cfda7..1c0367e51c10 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -36,7 +36,7 @@ @attr.s class DeviceKeyLookupResult: - """The type returned by _get_e2e_device_keys_and_signatures_txn""" + """The type returned by get_e2e_device_keys_and_signatures""" display_name = attr.ib(type=Optional[str]) @@ -60,11 +60,7 @@ async def get_e2e_device_keys_for_federation_query( """ now_stream_id = self.get_device_stream_token() - devices = await self.db_pool.runInteraction( - "get_e2e_device_keys_and_signatures_txn", - self._get_e2e_device_keys_and_signatures_txn, - [(user_id, None)], - ) + devices = await self._get_e2e_device_keys_and_signatures([(user_id, None)]) if devices: user_devices = devices[user_id] @@ -108,11 +104,7 @@ async def get_e2e_device_keys_for_cs_api( if not query_list: return {} - results = await self.db_pool.runInteraction( - "get_e2e_device_keys_and_signatures_txn", - self._get_e2e_device_keys_and_signatures_txn, - query_list, - ) + results = await self.get_e2e_device_keys_and_signatures(query_list) # Build the result structure, un-jsonify the results, and add the # "unsigned" section @@ -135,12 +127,45 @@ async def get_e2e_device_keys_for_cs_api( return rv @trace - def _get_e2e_device_keys_and_signatures_txn( - self, txn, query_list, include_all_devices=False, include_deleted_devices=False + async def get_e2e_device_keys_and_signatures( + self, + query_list: List[Tuple[str, Optional[str]]], + include_all_devices: bool = False, + include_deleted_devices: bool = False, ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: + """Fetch a list of device keys, together with their cross-signatures. + + Args: + query_list: List of pairs of user_ids and device_ids. Device id can be None + to indicate "all devices for this user" + + include_all_devices: whether to return devices without device keys + + include_deleted_devices:whether to include null entries for + devices which no longer exist (but were in the query_list). + This option only takes effect if include_all_devices is true. + + Returns: + Dict mapping from user-id to dict mapping from device_id to + key data. + """ set_tag("include_all_devices", include_all_devices) set_tag("include_deleted_devices", include_deleted_devices) + result = await self.db_pool.runInteraction( + "get_e2e_device_keys", + self._get_e2e_device_keys_and_signatures_txn, + query_list, + include_all_devices, + include_deleted_devices, + ) + + log_kv(result) + return result + + def _get_e2e_device_keys_and_signatures_txn( + self, txn, query_list, include_all_devices=False, include_deleted_devices=False + ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: query_clauses = [] query_params = [] signature_query_clauses = [] @@ -230,7 +255,6 @@ def _get_e2e_device_keys_and_signatures_txn( ) signing_user_signatures[signing_key_id] = signature - log_kv(result) return result async def get_e2e_one_time_keys( From 1ea500a441bf3f0d1915432afdc5699ed0a8aab4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 2 Sep 2020 18:23:23 +0100 Subject: [PATCH 2/4] fix typo --- synapse/storage/databases/main/end_to_end_keys.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 1c0367e51c10..edf9d5e59abe 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -60,7 +60,7 @@ async def get_e2e_device_keys_for_federation_query( """ now_stream_id = self.get_device_stream_token() - devices = await self._get_e2e_device_keys_and_signatures([(user_id, None)]) + devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)]) if devices: user_devices = devices[user_id] From 35667c6d70ff0d10ff9ca74d59d2d9ac152165ed Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 28 Aug 2020 16:41:25 +0100 Subject: [PATCH 3/4] Split fetching device keys and signatures into two transactions I think this is simpler (and moves stuff out of the db threads) --- changelog.d/8233.misc | 1 + .../storage/databases/main/end_to_end_keys.py | 107 ++++++++++-------- 2 files changed, 63 insertions(+), 45 deletions(-) create mode 100644 changelog.d/8233.misc diff --git a/changelog.d/8233.misc b/changelog.d/8233.misc new file mode 100644 index 000000000000..979c8b227bbc --- /dev/null +++ b/changelog.d/8233.misc @@ -0,0 +1 @@ +Refactor queries for device keys and cross-signatures. diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index edf9d5e59abe..acff57269352 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -25,6 +25,7 @@ from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import make_in_list_sql_clause +from synapse.storage.types import Cursor from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -45,8 +46,9 @@ class DeviceKeyLookupResult: # key) and "signatures" (a signature of the structure by the ed25519 key) key_json = attr.ib(type=Optional[str]) - # cross-signing sigs - signatures = attr.ib(type=Optional[Dict], default=None) + # cross-signing sigs on this device. + # dict from (signing user_id)->(signing device_id)->sig + signatures = attr.ib(type=Optional[Dict[str, Dict[str, str]]], factory=dict) class EndToEndKeyWorkerStore(SQLBaseStore): @@ -154,22 +156,57 @@ async def get_e2e_device_keys_and_signatures( result = await self.db_pool.runInteraction( "get_e2e_device_keys", - self._get_e2e_device_keys_and_signatures_txn, + self._get_e2e_device_keys_txn, query_list, include_all_devices, include_deleted_devices, ) + # get a the (user_id, device_id) tuples to look up cross-signatures for + signature_query = ( + [user_id, device_id] + for user_id, dev in result.items() + for device_id, d in dev.items() + if d is not None + ) + + for batch in batch_iter(signature_query, 50): + cross_sigs_result = await self.db_pool.runInteraction( + "get_e2e_cross_signing_signatures", + self._get_e2e_cross_signing_signatures_txn, + batch, + ) + + # add each cross-signing signature to the correct device in the result dict. + for row in cross_sigs_result: + signing_user_id = row["user_id"] + signing_key_id = row["key_id"] + target_user_id = row["target_user_id"] + target_device_id = row["target_device_id"] + signature = row["signature"] + + target_device_result = result[target_user_id][target_device_id] + target_device_signatures = target_device_result.signatures + + signing_user_signatures = target_device_signatures.setdefault( + signing_user_id, {} + ) + signing_user_signatures[signing_key_id] = signature + log_kv(result) return result - def _get_e2e_device_keys_and_signatures_txn( + def _get_e2e_device_keys_txn( self, txn, query_list, include_all_devices=False, include_deleted_devices=False ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: + """Get information on devices from the database + + The results include the device's keys and self-signatures, but *not* any + cross-signing signatures which have been added subsequently (for which, see + get_e2e_device_keys_and_signatures) + """ query_clauses = [] query_params = [] - signature_query_clauses = [] - signature_query_params = [] if include_all_devices is False: include_deleted_devices = False @@ -180,20 +217,12 @@ def _get_e2e_device_keys_and_signatures_txn( for (user_id, device_id) in query_list: query_clause = "user_id = ?" query_params.append(user_id) - signature_query_clause = "target_user_id = ?" - signature_query_params.append(user_id) if device_id is not None: query_clause += " AND device_id = ?" query_params.append(device_id) - signature_query_clause += " AND target_device_id = ?" - signature_query_params.append(device_id) - - signature_query_clause += " AND user_id = ?" - signature_query_params.append(user_id) query_clauses.append(query_clause) - signature_query_clauses.append(signature_query_clause) sql = ( "SELECT user_id, device_id, " @@ -221,41 +250,29 @@ def _get_e2e_device_keys_and_signatures_txn( for user_id, device_id in deleted_devices: result.setdefault(user_id, {})[device_id] = None - # get signatures on the device - signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % ( - " OR ".join("(" + q + ")" for q in signature_query_clauses) - ) - - txn.execute(signature_sql, signature_query_params) - rows = self.db_pool.cursor_to_dict(txn) - - # add each cross-signing signature to the correct device in the result dict. - for row in rows: - signing_user_id = row["user_id"] - signing_key_id = row["key_id"] - target_user_id = row["target_user_id"] - target_device_id = row["target_device_id"] - signature = row["signature"] - - target_user_result = result.get(target_user_id) - if not target_user_result: - continue - - target_device_result = target_user_result.get(target_device_id) - if not target_device_result: - # note that target_device_result will be None for deleted devices. - continue + return result - target_device_signatures = target_device_result.signatures - if target_device_signatures is None: - target_device_signatures = target_device_result.signatures = {} + def _get_e2e_cross_signing_signatures_txn( + self, txn: Cursor, device_query: Iterable[Tuple[str, str]] + ) -> List[Dict]: + """Get cross-signing signatures for a given list of devices""" + signature_query_clauses = [] + signature_query_params = [] - signing_user_signatures = target_device_signatures.setdefault( - signing_user_id, {} + for (user_id, device_id) in device_query: + # XXX: I don't really know why we limit this by user_id, but there's + # probably a good reason. + signature_query_clauses.append( + "target_user_id = ? AND target_device_id = ? AND user_id = ?" ) - signing_user_signatures[signing_key_id] = signature + signature_query_params.extend([user_id, device_id, user_id]) - return result + signature_sql = "SELECT * FROM e2e_cross_signing_signatures WHERE %s" % ( + " OR ".join("(" + q + ")" for q in signature_query_clauses) + ) + + txn.execute(signature_sql, signature_query_params) + return self.db_pool.cursor_to_dict(txn) async def get_e2e_one_time_keys( self, user_id: str, device_id: str, key_ids: List[str] From 69e655fbf58a6e93fe1fe5cb6e7f704baac27583 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 28 Aug 2020 17:31:09 +0100 Subject: [PATCH 4/4] Add cross-signing sigs to the `keys` object All the callers want this info in the same place, so let's reduce the duplication by doing it here. --- changelog.d/8234.misc | 1 + synapse/storage/databases/main/devices.py | 12 ++---- .../storage/databases/main/end_to_end_keys.py | 38 +++++++------------ 3 files changed, 18 insertions(+), 33 deletions(-) create mode 100644 changelog.d/8234.misc diff --git a/changelog.d/8234.misc b/changelog.d/8234.misc new file mode 100644 index 000000000000..979c8b227bbc --- /dev/null +++ b/changelog.d/8234.misc @@ -0,0 +1 @@ +Refactor queries for device keys and cross-signatures. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index f8fe948122a0..add4e3ea0ec0 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -291,15 +291,9 @@ async def _get_device_update_edus_by_remote( prev_id = stream_id if device is not None: - key_json = device.key_json - if key_json: - result["keys"] = db_to_json(key_json) - - if device.signatures: - for sig_user_id, sigs in device.signatures.items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) + keys = device.keys + if keys: + result["keys"] = keys device_display_name = device.display_name if device_display_name: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index acff57269352..c420303fb0f9 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -43,12 +43,8 @@ class DeviceKeyLookupResult: # the key data from e2e_device_keys_json. Typically includes fields like # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing - # key) and "signatures" (a signature of the structure by the ed25519 key) - key_json = attr.ib(type=Optional[str]) - - # cross-signing sigs on this device. - # dict from (signing user_id)->(signing device_id)->sig - signatures = attr.ib(type=Optional[Dict[str, Dict[str, str]]], factory=dict) + # key) and signatures" (a map from (user id) to (key id/device_id) to signature.) + keys = attr.ib(type=Optional[JsonDict]) class EndToEndKeyWorkerStore(SQLBaseStore): @@ -70,15 +66,9 @@ async def get_e2e_device_keys_for_federation_query( for device_id, device in user_devices.items(): result = {"device_id": device_id} - key_json = device.key_json - if key_json: - result["keys"] = db_to_json(key_json) - - if device.signatures: - for sig_user_id, sigs in device.signatures.items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) + keys = device.keys + if keys: + result["keys"] = keys device_display_name = device.display_name if device_display_name: @@ -114,16 +104,11 @@ async def get_e2e_device_keys_for_cs_api( for user_id, device_keys in results.items(): rv[user_id] = {} for device_id, device_info in device_keys.items(): - r = db_to_json(device_info.key_json) + r = device_info.keys r["unsigned"] = {} display_name = device_info.display_name if display_name is not None: r["unsigned"]["device_display_name"] = display_name - if device_info.signatures: - for sig_user_id, sigs in device_info.signatures.items(): - r.setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) rv[user_id][device_id] = r return rv @@ -137,6 +122,9 @@ async def get_e2e_device_keys_and_signatures( ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: """Fetch a list of device keys, together with their cross-signatures. + The cross-signatures are added to the `signatures` field within the `keys` + object in the response. + Args: query_list: List of pairs of user_ids and device_ids. Device id can be None to indicate "all devices for this user" @@ -167,7 +155,7 @@ async def get_e2e_device_keys_and_signatures( [user_id, device_id] for user_id, dev in result.items() for device_id, d in dev.items() - if d is not None + if d is not None and d.keys is not None ) for batch in batch_iter(signature_query, 50): @@ -186,7 +174,9 @@ async def get_e2e_device_keys_and_signatures( signature = row["signature"] target_device_result = result[target_user_id][target_device_id] - target_device_signatures = target_device_result.signatures + target_device_signatures = target_device_result.keys.setdefault( + "signatures", {} + ) signing_user_signatures = target_device_signatures.setdefault( signing_user_id, {} @@ -243,7 +233,7 @@ def _get_e2e_device_keys_txn( if include_deleted_devices: deleted_devices.remove((user_id, device_id)) result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult( - display_name, key_json + display_name, db_to_json(key_json) if key_json else None ) if include_deleted_devices: