Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

clean up device key/cross-signature handling. #8206

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8231.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.
1 change: 1 addition & 0 deletions changelog.d/8233.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.
1 change: 1 addition & 0 deletions changelog.d/8234.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.
16 changes: 4 additions & 12 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,7 @@ async def _get_device_update_edus_by_remote(
List of objects representing an device update EDU
"""
devices = (
await self.db_pool.runInteraction(
"get_e2e_device_keys_and_signatures_txn",
self._get_e2e_device_keys_and_signatures_txn,
await self.get_e2e_device_keys_and_signatures(
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
Expand Down Expand Up @@ -293,15 +291,9 @@ async def _get_device_update_edus_by_remote(
prev_id = stream_id

if device is not None:
key_json = device.key_json
if key_json:
result["keys"] = db_to_json(key_json)

if device.signatures:
for sig_user_id, sigs in device.signatures.items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
keys = device.keys
if keys:
result["keys"] = keys

device_display_name = device.display_name
if device_display_name:
Expand Down
183 changes: 107 additions & 76 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause
from synapse.storage.types import Cursor
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
Expand All @@ -36,17 +37,14 @@

@attr.s
class DeviceKeyLookupResult:
"""The type returned by _get_e2e_device_keys_and_signatures_txn"""
"""The type returned by get_e2e_device_keys_and_signatures"""

display_name = attr.ib(type=Optional[str])

# the key data from e2e_device_keys_json. Typically includes fields like
# "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
# key) and "signatures" (a signature of the structure by the ed25519 key)
key_json = attr.ib(type=Optional[str])

# cross-signing sigs
signatures = attr.ib(type=Optional[Dict], default=None)
# key) and signatures" (a map from (user id) to (key id/device_id) to signature.)
keys = attr.ib(type=Optional[JsonDict])


class EndToEndKeyWorkerStore(SQLBaseStore):
Expand All @@ -60,27 +58,17 @@ async def get_e2e_device_keys_for_federation_query(
"""
now_stream_id = self.get_device_stream_token()

devices = await self.db_pool.runInteraction(
"get_e2e_device_keys_and_signatures_txn",
self._get_e2e_device_keys_and_signatures_txn,
[(user_id, None)],
)
devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])

if devices:
user_devices = devices[user_id]
results = []
for device_id, device in user_devices.items():
result = {"device_id": device_id}

key_json = device.key_json
if key_json:
result["keys"] = db_to_json(key_json)

if device.signatures:
for sig_user_id, sigs in device.signatures.items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
keys = device.keys
if keys:
result["keys"] = keys

device_display_name = device.display_name
if device_display_name:
Expand Down Expand Up @@ -108,43 +96,107 @@ async def get_e2e_device_keys_for_cs_api(
if not query_list:
return {}

results = await self.db_pool.runInteraction(
"get_e2e_device_keys_and_signatures_txn",
self._get_e2e_device_keys_and_signatures_txn,
query_list,
)
results = await self.get_e2e_device_keys_and_signatures(query_list)

# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
rv = {}
for user_id, device_keys in results.items():
rv[user_id] = {}
for device_id, device_info in device_keys.items():
r = db_to_json(device_info.key_json)
r = device_info.keys
r["unsigned"] = {}
display_name = device_info.display_name
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
if device_info.signatures:
for sig_user_id, sigs in device_info.signatures.items():
r.setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
rv[user_id][device_id] = r

return rv

@trace
def _get_e2e_device_keys_and_signatures_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
async def get_e2e_device_keys_and_signatures(
self,
query_list: List[Tuple[str, Optional[str]]],
include_all_devices: bool = False,
include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Fetch a list of device keys, together with their cross-signatures.

The cross-signatures are added to the `signatures` field within the `keys`
object in the response.

Args:
query_list: List of pairs of user_ids and device_ids. Device id can be None
to indicate "all devices for this user"

include_all_devices: whether to return devices without device keys

include_deleted_devices:whether to include null entries for
devices which no longer exist (but were in the query_list).
This option only takes effect if include_all_devices is true.

Returns:
Dict mapping from user-id to dict mapping from device_id to
key data.
"""
set_tag("include_all_devices", include_all_devices)
set_tag("include_deleted_devices", include_deleted_devices)

result = await self.db_pool.runInteraction(
"get_e2e_device_keys",
self._get_e2e_device_keys_txn,
query_list,
include_all_devices,
include_deleted_devices,
)

# get a the (user_id, device_id) tuples to look up cross-signatures for
signature_query = (
[user_id, device_id]
for user_id, dev in result.items()
for device_id, d in dev.items()
if d is not None and d.keys is not None
)

for batch in batch_iter(signature_query, 50):
cross_sigs_result = await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_txn,
batch,
)

# add each cross-signing signature to the correct device in the result dict.
for row in cross_sigs_result:
signing_user_id = row["user_id"]
signing_key_id = row["key_id"]
target_user_id = row["target_user_id"]
target_device_id = row["target_device_id"]
signature = row["signature"]

target_device_result = result[target_user_id][target_device_id]
target_device_signatures = target_device_result.keys.setdefault(
"signatures", {}
)

signing_user_signatures = target_device_signatures.setdefault(
signing_user_id, {}
)
signing_user_signatures[signing_key_id] = signature

log_kv(result)
return result

def _get_e2e_device_keys_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Get information on devices from the database

The results include the device's keys and self-signatures, but *not* any
cross-signing signatures which have been added subsequently (for which, see
get_e2e_device_keys_and_signatures)
"""
query_clauses = []
query_params = []
signature_query_clauses = []
signature_query_params = []

if include_all_devices is False:
include_deleted_devices = False
Expand All @@ -155,20 +207,12 @@ def _get_e2e_device_keys_and_signatures_txn(
for (user_id, device_id) in query_list:
query_clause = "user_id = ?"
query_params.append(user_id)
signature_query_clause = "target_user_id = ?"
signature_query_params.append(user_id)

if device_id is not None:
query_clause += " AND device_id = ?"
query_params.append(device_id)
signature_query_clause += " AND target_device_id = ?"
signature_query_params.append(device_id)

signature_query_clause += " AND user_id = ?"
signature_query_params.append(user_id)

query_clauses.append(query_clause)
signature_query_clauses.append(signature_query_clause)

sql = (
"SELECT user_id, device_id, "
Expand All @@ -189,49 +233,36 @@ def _get_e2e_device_keys_and_signatures_txn(
if include_deleted_devices:
deleted_devices.remove((user_id, device_id))
result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
display_name, key_json
display_name, db_to_json(key_json) if key_json else None
)

if include_deleted_devices:
for user_id, device_id in deleted_devices:
result.setdefault(user_id, {})[device_id] = None

# get signatures on the device
signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % (
" OR ".join("(" + q + ")" for q in signature_query_clauses)
)

txn.execute(signature_sql, signature_query_params)
rows = self.db_pool.cursor_to_dict(txn)

# add each cross-signing signature to the correct device in the result dict.
for row in rows:
signing_user_id = row["user_id"]
signing_key_id = row["key_id"]
target_user_id = row["target_user_id"]
target_device_id = row["target_device_id"]
signature = row["signature"]

target_user_result = result.get(target_user_id)
if not target_user_result:
continue

target_device_result = target_user_result.get(target_device_id)
if not target_device_result:
# note that target_device_result will be None for deleted devices.
continue
return result

target_device_signatures = target_device_result.signatures
if target_device_signatures is None:
target_device_signatures = target_device_result.signatures = {}
def _get_e2e_cross_signing_signatures_txn(
self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
) -> List[Dict]:
"""Get cross-signing signatures for a given list of devices"""
signature_query_clauses = []
signature_query_params = []

signing_user_signatures = target_device_signatures.setdefault(
signing_user_id, {}
for (user_id, device_id) in device_query:
# XXX: I don't really know why we limit this by user_id, but there's
# probably a good reason.
signature_query_clauses.append(
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
)
signing_user_signatures[signing_key_id] = signature
signature_query_params.extend([user_id, device_id, user_id])

log_kv(result)
return result
signature_sql = "SELECT * FROM e2e_cross_signing_signatures WHERE %s" % (
" OR ".join("(" + q + ")" for q in signature_query_clauses)
)

txn.execute(signature_sql, signature_query_params)
return self.db_pool.cursor_to_dict(txn)

async def get_e2e_one_time_keys(
self, user_id: str, device_id: str, key_ids: List[str]
Expand Down