Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue one time keys in upload order #17903

Merged
merged 5 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/17903.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug in Synapse which could cause one-time keys to be issued in the incorrect order, causing message decryption failures.
2 changes: 1 addition & 1 deletion synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ async def claim_local_one_time_keys(
3. Attempt to fetch fallback keys from the database.

Args:
local_query: An iterable of tuples of (user ID, device ID, algorithm).
local_query: An iterable of tuples of (user ID, device ID, algorithm, number of keys).
always_include_fallback_keys: True to always include fallback keys.

Returns:
Expand Down
25 changes: 23 additions & 2 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ def __init__(
unique=True,
)

self.db_pool.updates.register_background_index_update(
update_name="add_otk_ts_added_index",
index_name="e2e_one_time_keys_json_user_id_device_id_algorithm_ts_added_idx",
table="e2e_one_time_keys_json",
columns=("user_id", "device_id", "algorithm", "ts_added_ms"),
)


class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore):
def __init__(
Expand Down Expand Up @@ -1122,7 +1129,7 @@ async def claim_e2e_one_time_keys(
"""Take a list of one time keys out of the database.

Args:
query_list: An iterable of tuples of (user ID, device ID, algorithm).
query_list: An iterable of tuples of (user ID, device ID, algorithm, number of keys).

Returns:
A tuple (results, missing) of:
Expand Down Expand Up @@ -1310,9 +1317,14 @@ def _claim_e2e_one_time_key_simple(
OTK was found.
"""

# Return the oldest keys from this device (based on `ts_added_ms`).
# Doing so means that keys are issued in the same order they were uploaded,
# which reduces the chances of a client expiring its copy of a (private)
# key while the public key is still on the server, waiting to be issued.
sql = """
richvdh marked this conversation as resolved.
Show resolved Hide resolved
SELECT key_id, key_json FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
ORDER BY ts_added_ms
LIMIT ?
"""

Expand Down Expand Up @@ -1354,13 +1366,22 @@ def _claim_e2e_one_time_keys_bulk(
A list of tuples (user_id, device_id, algorithm, key_id, key_json)
for each OTK claimed.
"""
# Find, delete, and return the oldest keys from each device (based on
# `ts_added_ms`).
#
# Doing so means that keys are issued in the same order they were uploaded,
# which reduces the chances of a client expiring its copy of a (private)
# key while the public key is still on the server, waiting to be issued.
sql = """
WITH claims(user_id, device_id, algorithm, claim_count) AS (
VALUES ?
), ranked_keys AS (
SELECT
user_id, device_id, algorithm, key_id, claim_count,
ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
ROW_NUMBER() OVER (
PARTITION BY (user_id, device_id, algorithm)
ORDER BY ts_added_ms
) AS r
FROM e2e_one_time_keys_json
JOIN claims USING (user_id, device_id, algorithm)
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2024 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.


-- Add an index on (user_id, device_id, algorithm, ts_added_ms) on e2e_one_time_keys_json, so that OTKs can
-- efficiently be issued in the same order they were uploaded.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(8803, 'add_otk_ts_added_index', '{}');
78 changes: 73 additions & 5 deletions tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,18 +151,30 @@ def test_change_one_time_keys(self) -> None:
def test_claim_one_time_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
keys = {"alg1:k1": "key1"}

res = self.get_success(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}
local_user, device_id, {"one_time_keys": {"alg1:k1": "key1"}}
)
)
self.assertDictEqual(
res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}}
)

res2 = self.get_success(
# Keys should be returned in the order they were uploaded. To test, advance time
# a little, then upload a second key with an earlier key ID; it should get
# returned second.
self.reactor.advance(1)
res = self.get_success(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg1:k0": "key0"}}
)
)
self.assertDictEqual(
res, {"one_time_key_counts": {"alg1": 2, "signed_curve25519": 0}}
)

# now claim both keys back. They should be in the same order
res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
Expand All @@ -171,12 +183,27 @@ def test_claim_one_time_key(self) -> None:
)
)
self.assertEqual(
res2,
res,
{
"failures": {},
"one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
},
)
res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
)
self.assertEqual(
res,
{
"failures": {},
"one_time_keys": {local_user: {device_id: {"alg1:k0": "key0"}}},
},
)

def test_claim_one_time_key_bulk(self) -> None:
"""Like test_claim_one_time_key but claims multiple keys in one handler call."""
Expand Down Expand Up @@ -336,6 +363,47 @@ def assertAllOtks(user_id: str, device_id: str, *alg_key_pairs: str) -> None:
counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}"
)

def test_claim_one_time_key_bulk_ordering(self) -> None:
"""Keys returned by the bulk claim call should be returned in the correct order"""

# Alice has lots of keys, uploaded in a specific order
alice = f"@alice:{self.hs.hostname}"
alice_dev = "alice_dev_1"

self.get_success(
self.handler.upload_keys_for_user(
alice,
alice_dev,
{"one_time_keys": {"alg1:k20": 20, "alg1:k21": 21, "alg1:k22": 22}},
)
)
# Advance time by 1s, to ensure that there is a difference in upload time.
self.reactor.advance(1)
self.get_success(
self.handler.upload_keys_for_user(
alice,
alice_dev,
{"one_time_keys": {"alg1:k10": 10, "alg1:k11": 11, "alg1:k12": 12}},
)
)

# Now claim some, and check we get the right ones.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{alice: {alice_dev: {"alg1": 2}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
)
# We should get the first-uploaded keys, even though they have later key ids.
# We should get a random set of two of k20, k21, k22.
self.assertEqual(claim_res["failures"], {})
claimed_keys = claim_res["one_time_keys"]["@alice:test"]["alice_dev_1"]
self.assertEqual(len(claimed_keys), 2)
for key_id in claimed_keys.keys():
self.assertIn(key_id, ["alg1:k20", "alg1:k21", "alg1:k22"])

def test_fallback_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
Expand Down
Loading