Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit d65862c

Browse files
authored
Refactor _get_e2e_device_keys_txn to split large queries (#13956)
Instead of running a single large query, run a single query for user-only lookups and additional queries for batches of user device lookups. Resolves #13580. Signed-off-by: Sean Quah <seanq@matrix.org>
1 parent 061739d commit d65862c

File tree

3 files changed

+115
-29
lines changed

3 files changed

+115
-29
lines changed

changelog.d/13956.bugfix

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix a long-standing bug where `POST /_matrix/client/v3/keys/query` requests could result in excessively large SQL queries.

synapse/storage/database.py

+60
Original file line numberDiff line numberDiff line change
@@ -2461,6 +2461,66 @@ def make_in_list_sql_clause(
24612461
return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
24622462

24632463

2464+
# These overloads ensure that `columns` and `iterable` values have the same length.
2465+
# Suppress "Single overload definition, multiple required" complaint.
2466+
@overload # type: ignore[misc]
2467+
def make_tuple_in_list_sql_clause(
2468+
database_engine: BaseDatabaseEngine,
2469+
columns: Tuple[str, str],
2470+
iterable: Collection[Tuple[Any, Any]],
2471+
) -> Tuple[str, list]:
2472+
...
2473+
2474+
2475+
def make_tuple_in_list_sql_clause(
2476+
database_engine: BaseDatabaseEngine,
2477+
columns: Tuple[str, ...],
2478+
iterable: Collection[Tuple[Any, ...]],
2479+
) -> Tuple[str, list]:
2480+
"""Returns an SQL clause that checks the given tuple of columns is in the iterable.
2481+
2482+
Args:
2483+
database_engine
2484+
columns: Names of the columns in the tuple.
2485+
iterable: The tuples to check the columns against.
2486+
2487+
Returns:
2488+
A tuple of SQL query and the args
2489+
"""
2490+
if len(columns) == 0:
2491+
# Should be unreachable due to mypy, as long as the overloads are set up right.
2492+
if () in iterable:
2493+
return "TRUE", []
2494+
else:
2495+
return "FALSE", []
2496+
2497+
if len(columns) == 1:
2498+
# Use `= ANY(?)` on postgres.
2499+
return make_in_list_sql_clause(
2500+
database_engine, next(iter(columns)), [values[0] for values in iterable]
2501+
)
2502+
2503+
# There are multiple columns. Avoid using an `= ANY(?)` clause on postgres, as
2504+
# indices are not used when there are multiple columns. Instead, use an `IN`
2505+
# expression.
2506+
#
2507+
# `IN ((?, ...), ...)` with tuples is supported by postgres only, whereas
2508+
# `IN (VALUES (?, ...), ...)` is supported by both sqlite and postgres.
2509+
# Thus, the latter is chosen.
2510+
2511+
if len(iterable) == 0:
2512+
# A 0-length `VALUES` list is not allowed in sqlite or postgres.
2513+
# Also note that a 0-length `IN (...)` clause (not using `VALUES`) is not
2514+
# allowed in postgres.
2515+
return "FALSE", []
2516+
2517+
tuple_sql = "(%s)" % (",".join("?" for _ in columns),)
2518+
return "(%s) IN (VALUES %s)" % (
2519+
",".join(column for column in columns),
2520+
",".join(tuple_sql for _ in iterable),
2521+
), [value for values in iterable for value in values]
2522+
2523+
24642524
KV = TypeVar("KV")
24652525

24662526

synapse/storage/databases/main/end_to_end_keys.py

+54-29
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
LoggingDatabaseConnection,
4444
LoggingTransaction,
4545
make_in_list_sql_clause,
46+
make_tuple_in_list_sql_clause,
4647
)
4748
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
4849
from synapse.storage.engines import PostgresEngine
@@ -278,7 +279,7 @@ async def get_e2e_device_keys_and_signatures(
278279
def _get_e2e_device_keys_txn(
279280
self,
280281
txn: LoggingTransaction,
281-
query_list: Collection[Tuple[str, str]],
282+
query_list: Collection[Tuple[str, Optional[str]]],
282283
include_all_devices: bool = False,
283284
include_deleted_devices: bool = False,
284285
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
@@ -288,49 +289,73 @@ def _get_e2e_device_keys_txn(
288289
cross-signing signatures which have been added subsequently (for which, see
289290
get_e2e_device_keys_and_signatures)
290291
"""
291-
query_clauses = []
292-
query_params = []
292+
query_clauses: List[str] = []
293+
query_params_list: List[List[object]] = []
293294

294295
if include_all_devices is False:
295296
include_deleted_devices = False
296297

297298
if include_deleted_devices:
298299
deleted_devices = set(query_list)
299300

301+
# Split the query list into queries for users and queries for particular
302+
# devices.
303+
user_list = []
304+
user_device_list = []
300305
for (user_id, device_id) in query_list:
301-
query_clause = "user_id = ?"
302-
query_params.append(user_id)
303-
304-
if device_id is not None:
305-
query_clause += " AND device_id = ?"
306-
query_params.append(device_id)
307-
308-
query_clauses.append(query_clause)
309-
310-
sql = (
311-
"SELECT user_id, device_id, "
312-
" d.display_name, "
313-
" k.key_json"
314-
" FROM devices d"
315-
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
316-
" WHERE %s AND NOT d.hidden"
317-
) % (
318-
"LEFT" if include_all_devices else "INNER",
319-
" OR ".join("(" + q + ")" for q in query_clauses),
320-
)
306+
if device_id is None:
307+
user_list.append(user_id)
308+
else:
309+
user_device_list.append((user_id, device_id))
321310

322-
txn.execute(sql, query_params)
311+
if user_list:
312+
user_id_in_list_clause, user_args = make_in_list_sql_clause(
313+
txn.database_engine, "user_id", user_list
314+
)
315+
query_clauses.append(user_id_in_list_clause)
316+
query_params_list.append(user_args)
317+
318+
if user_device_list:
319+
# Divide the device queries into batches, to avoid excessively large
320+
# queries.
321+
for user_device_batch in batch_iter(user_device_list, 1024):
322+
(
323+
user_device_id_in_list_clause,
324+
user_device_args,
325+
) = make_tuple_in_list_sql_clause(
326+
txn.database_engine, ("user_id", "device_id"), user_device_batch
327+
)
328+
query_clauses.append(user_device_id_in_list_clause)
329+
query_params_list.append(user_device_args)
323330

324331
result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
325-
for (user_id, device_id, display_name, key_json) in txn:
326-
if include_deleted_devices:
327-
deleted_devices.remove((user_id, device_id))
328-
result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
329-
display_name, db_to_json(key_json) if key_json else None
332+
for query_clause, query_params in zip(query_clauses, query_params_list):
333+
sql = (
334+
"SELECT user_id, device_id, "
335+
" d.display_name, "
336+
" k.key_json"
337+
" FROM devices d"
338+
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
339+
" WHERE %s AND NOT d.hidden"
340+
) % (
341+
"LEFT" if include_all_devices else "INNER",
342+
query_clause,
330343
)
331344

345+
txn.execute(sql, query_params)
346+
347+
for (user_id, device_id, display_name, key_json) in txn:
348+
assert device_id is not None
349+
if include_deleted_devices:
350+
deleted_devices.remove((user_id, device_id))
351+
result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
352+
display_name, db_to_json(key_json) if key_json else None
353+
)
354+
332355
if include_deleted_devices:
333356
for user_id, device_id in deleted_devices:
357+
if device_id is None:
358+
continue
334359
result.setdefault(user_id, {})[device_id] = None
335360

336361
return result

0 commit comments

Comments
 (0)