From 5aff64e5eb5b5fbe1dc70b481032b5cc532835d9 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 3 Sep 2020 16:35:11 -0400 Subject: [PATCH 01/13] add support for fallback keys --- synapse/handlers/e2e_keys.py | 16 ++++ synapse/handlers/sync.py | 8 ++ synapse/rest/client/v2_alpha/sync.py | 1 + .../storage/databases/main/end_to_end_keys.py | 77 +++++++++++++++++++ .../main/schema/delta/58/11fallback.sql | 24 ++++++ 5 files changed, 126 insertions(+) create mode 100644 synapse/storage/databases/main/schema/delta/58/11fallback.sql diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index dd40fd129936..40f0787d18e6 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -496,6 +496,22 @@ async def upload_keys_for_user(self, user_id, device_id, keys): log_kv( {"message": "Did not update one_time_keys", "reason": "no keys given"} ) + fallback_keys = keys.get("fallback_keys", None) + if fallback_keys and isinstance(fallback_keys, dict): + log_kv( + { + "message": "Updating fallback_keys for device.", + "user_id": user_id, + "device_id": device_id, + } + ) + await self.store.set_e2e_fallback_keys( + user_id, device_id, fallback_keys + ) + else: + log_kv( + {"message": "Did not update fallback_keys", "reason": "no keys given"} + ) # the device should have been registered already, but it may have been # deleted due to a race with a DELETE request. Or we may be using an diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 9b3a4f638b13..06da3e20631e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -201,6 +201,8 @@ class SyncResult: device_lists: List of user_ids whose devices have changed device_one_time_keys_count: Dict of algorithm to count for one time keys for this device + device_unused_fallback_keys: List of key types that have an unused fallback + key groups: Group updates, if any """ @@ -213,6 +215,7 @@ class SyncResult: to_device = attr.ib(type=List[JsonDict]) device_lists = attr.ib(type=DeviceLists) device_one_time_keys_count = attr.ib(type=JsonDict) + device_unused_fallback_keys = attr.ib(type=List[str]) groups = attr.ib(type=Optional[GroupsSyncResult]) def __bool__(self) -> bool: @@ -1014,10 +1017,14 @@ async def generate_sync_result( logger.debug("Fetching OTK data") device_id = sync_config.device_id one_time_key_counts = {} # type: JsonDict + unused_fallback_keys = [] # type: list if device_id: one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id ) + unused_fallback_keys = await self.store.get_e2e_unused_fallback_keys( + user_id, device_id + ) logger.debug("Fetching group data") await self._generate_sync_entry_for_groups(sync_result_builder) @@ -1041,6 +1048,7 @@ async def generate_sync_result( device_lists=device_lists, groups=sync_result_builder.groups, device_one_time_keys_count=one_time_key_counts, + device_unused_fallback_keys=unused_fallback_keys, next_batch=sync_result_builder.now_token, ) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index a0b00135e1cb..9ed34761024c 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -236,6 +236,7 @@ async def encode_response(self, time_now, sync_result, access_token_id, filter): "leave": sync_result.groups.leave, }, "device_one_time_keys_count": sync_result.device_one_time_keys_count, + "device_unused_fallback_keys": sync_result.device_unused_fallback_keys, "next_batch": sync_result.next_batch.to_string(), } diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index c8df0bcb3fe5..5112849b1b58 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -367,6 +367,44 @@ def _count_e2e_one_time_keys(txn): "count_e2e_one_time_keys", _count_e2e_one_time_keys ) + async def set_e2e_fallback_keys( + self, user_id: str, device_id: str, fallback_keys: dict + ): + # fallback_keys will usually only have one item in it, so using a for + # loop (as opposed to calling simple_upsert_many_txn) won't be too bad + # FIXME: make sure that only one key per algorithm is uploaded + for key_id, fallback_key in fallback_keys.items(): + algorithm, key_id = key_id.split(":", 1) + await self.db_pool.simple_upsert( + "e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm + }, + values={ + "key_id": key_id, + "key_json": json_encoder.encode(fallback_key), + "used": 0 + }, + desc="set_e2e_fallback_key" + ) + + @cached(max_entries=10000) + async def get_e2e_unused_fallback_keys( + self, user_id: str, device_id: str + ): + return await self.db_pool.simple_select_onecol( + "e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "used": 0 + }, + retcol="algorithm", + desc="get_e2e_unused_fallback_keys" + ) + async def get_e2e_cross_signing_key( self, user_id: str, key_type: str, from_user_id: Optional[str] = None ) -> Optional[dict]: @@ -701,15 +739,29 @@ def _claim_e2e_one_time_keys(txn): " WHERE user_id = ? AND device_id = ? AND algorithm = ?" " LIMIT 1" ) + fallback_sql = ( + "SELECT key_id, key_json, used FROM e2e_fallback_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " LIMIT 1" + ) result = {} delete = [] + used_fallbacks = [] for user_id, device_id, algorithm in query_list: user_result = result.setdefault(user_id, {}) device_result = user_result.setdefault(device_id, {}) txn.execute(sql, (user_id, device_id, algorithm)) + found = False for key_id, key_json in txn: + found = True device_result[algorithm + ":" + key_id] = key_json delete.append((user_id, device_id, algorithm, key_id)) + if not found: + txn.execute(fallback_sql, (user_id, device_id, algorithm)) + for key_id, key_json, used in txn: + device_result[algorithm + ":" + key_id] = key_json + if used == 0: + used_fallbacks.append((user_id, device_id, algorithm, key_id)) sql = ( "DELETE FROM e2e_one_time_keys_json" " WHERE user_id = ? AND device_id = ? AND algorithm = ?" @@ -726,6 +778,23 @@ def _claim_e2e_one_time_keys(txn): self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + for user_id, device_id, algorithm, key_id in used_fallbacks: + self.db_pool.simple_update_txn( + txn, + "e2e_fallback_keys_json", + { + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + "key_id": key_id + }, + { + "used": 1 + } + ) + self._invalidate_cache_and_stream( + txn, self.get_e2e_unused_fallback_keys, (user_id, device_id) + ) return result return await self.db_pool.runInteraction( @@ -754,6 +823,14 @@ def delete_e2e_keys_by_device_txn(txn): self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + self.db_pool.simple_delete_txn( + txn, + table="e2e_fallback_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._invalidate_cache_and_stream( + txn, self.get_e2e_unused_fallback_keys, (user_id, device_id) + ) await self.db_pool.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn diff --git a/synapse/storage/databases/main/schema/delta/58/11fallback.sql b/synapse/storage/databases/main/schema/delta/58/11fallback.sql new file mode 100644 index 000000000000..272314a4a832 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/11fallback.sql @@ -0,0 +1,24 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json ( + user_id TEXT NOT NULL, -- The user this fallback key is for. + device_id TEXT NOT NULL, -- The device this fallback key is for. + algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for. + key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads. + key_json TEXT NOT NULL, -- The key as a JSON blob. + used SMALLINT NOT NULL DEFAULT 0, -- Whether the key has been used or not. + CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm) +); From c8ab46923e1d423d4640fdae2a4b215993bccdd5 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Mon, 14 Sep 2020 22:54:12 -0400 Subject: [PATCH 02/13] add changelog and run black --- changelog.d/8312.feature | 1 + synapse/handlers/e2e_keys.py | 4 +-- .../storage/databases/main/end_to_end_keys.py | 30 ++++++++----------- 3 files changed, 14 insertions(+), 21 deletions(-) create mode 100644 changelog.d/8312.feature diff --git a/changelog.d/8312.feature b/changelog.d/8312.feature new file mode 100644 index 000000000000..041ef9659967 --- /dev/null +++ b/changelog.d/8312.feature @@ -0,0 +1 @@ +Add support for olm fallback keys. ([MSC2732](https://github.com/matrix-org/matrix-doc/pull/2732)) \ No newline at end of file diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 40f0787d18e6..b1a861874968 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -505,9 +505,7 @@ async def upload_keys_for_user(self, user_id, device_id, keys): "device_id": device_id, } ) - await self.store.set_e2e_fallback_keys( - user_id, device_id, fallback_keys - ) + await self.store.set_e2e_fallback_keys(user_id, device_id, fallback_keys) else: log_kv( {"message": "Did not update fallback_keys", "reason": "no keys given"} diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 5112849b1b58..fe8bb7cdeddc 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -368,7 +368,7 @@ def _count_e2e_one_time_keys(txn): ) async def set_e2e_fallback_keys( - self, user_id: str, device_id: str, fallback_keys: dict + self, user_id: str, device_id: str, fallback_keys: dict ): # fallback_keys will usually only have one item in it, so using a for # loop (as opposed to calling simple_upsert_many_txn) won't be too bad @@ -380,29 +380,23 @@ async def set_e2e_fallback_keys( keyvalues={ "user_id": user_id, "device_id": device_id, - "algorithm": algorithm + "algorithm": algorithm, }, values={ "key_id": key_id, "key_json": json_encoder.encode(fallback_key), - "used": 0 + "used": 0, }, - desc="set_e2e_fallback_key" + desc="set_e2e_fallback_key", ) @cached(max_entries=10000) - async def get_e2e_unused_fallback_keys( - self, user_id: str, device_id: str - ): + async def get_e2e_unused_fallback_keys(self, user_id: str, device_id: str): return await self.db_pool.simple_select_onecol( "e2e_fallback_keys_json", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - "used": 0 - }, + keyvalues={"user_id": user_id, "device_id": device_id, "used": 0}, retcol="algorithm", - desc="get_e2e_unused_fallback_keys" + desc="get_e2e_unused_fallback_keys", ) async def get_e2e_cross_signing_key( @@ -761,7 +755,9 @@ def _claim_e2e_one_time_keys(txn): for key_id, key_json, used in txn: device_result[algorithm + ":" + key_id] = key_json if used == 0: - used_fallbacks.append((user_id, device_id, algorithm, key_id)) + used_fallbacks.append( + (user_id, device_id, algorithm, key_id) + ) sql = ( "DELETE FROM e2e_one_time_keys_json" " WHERE user_id = ? AND device_id = ? AND algorithm = ?" @@ -786,11 +782,9 @@ def _claim_e2e_one_time_keys(txn): "user_id": user_id, "device_id": device_id, "algorithm": algorithm, - "key_id": key_id + "key_id": key_id, }, - { - "used": 1 - } + {"used": 1}, ) self._invalidate_cache_and_stream( txn, self.get_e2e_unused_fallback_keys, (user_id, device_id) From 3188692c1eede927a6c79db4606f86938027f086 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Mon, 14 Sep 2020 23:01:04 -0400 Subject: [PATCH 03/13] fix news file --- changelog.d/8312.feature | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/changelog.d/8312.feature b/changelog.d/8312.feature index 041ef9659967..222a1b032a4d 100644 --- a/changelog.d/8312.feature +++ b/changelog.d/8312.feature @@ -1 +1 @@ -Add support for olm fallback keys. ([MSC2732](https://github.com/matrix-org/matrix-doc/pull/2732)) \ No newline at end of file +Add support for olm fallback keys ([MSC2732](https://github.com/matrix-org/matrix-doc/pull/2732)). \ No newline at end of file From aac48e0e90237abba94ab8982f52eed81f246fa8 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Tue, 15 Sep 2020 16:59:04 -0400 Subject: [PATCH 04/13] add docs, comments, and tests --- .../storage/databases/main/end_to_end_keys.py | 27 +++++++++++- tests/handlers/test_e2e_keys.py | 41 +++++++++++++++++++ 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index fe8bb7cdeddc..fa84e09b7560 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -369,7 +369,15 @@ def _count_e2e_one_time_keys(txn): async def set_e2e_fallback_keys( self, user_id: str, device_id: str, fallback_keys: dict - ): + ) -> None: + """Set the user's e2e fallback keys. + + Args: + user_id: the user whose keys are being set + device_id: the device whose keys are being set + fallback_keys: the keys to set. This is a map from key ID (which is + of the form "algorithm:id") to key data. + """ # fallback_keys will usually only have one item in it, so using a for # loop (as opposed to calling simple_upsert_many_txn) won't be too bad # FIXME: make sure that only one key per algorithm is uploaded @@ -391,7 +399,16 @@ async def set_e2e_fallback_keys( ) @cached(max_entries=10000) - async def get_e2e_unused_fallback_keys(self, user_id: str, device_id: str): + async def get_e2e_unused_fallback_keys(self, user_id: str, device_id: str) -> List[str]: + """Returns the fallback key types that have an unused key. + + Args: + user_id: the user whose keys are being queried + device_id: the device whose keys are being queried + + Returns: + a list of key types + """ return await self.db_pool.simple_select_onecol( "e2e_fallback_keys_json", keyvalues={"user_id": user_id, "device_id": device_id, "used": 0}, @@ -751,6 +768,8 @@ def _claim_e2e_one_time_keys(txn): device_result[algorithm + ":" + key_id] = key_json delete.append((user_id, device_id, algorithm, key_id)) if not found: + # no one-time key available, so see if there's a fallback + # key txn.execute(fallback_sql, (user_id, device_id, algorithm)) for key_id, key_json, used in txn: device_result[algorithm + ":" + key_id] = key_json @@ -758,6 +777,8 @@ def _claim_e2e_one_time_keys(txn): used_fallbacks.append( (user_id, device_id, algorithm, key_id) ) + + # drop any one-time keys that were claimed sql = ( "DELETE FROM e2e_one_time_keys_json" " WHERE user_id = ? AND device_id = ? AND algorithm = ?" @@ -774,6 +795,7 @@ def _claim_e2e_one_time_keys(txn): self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + # mark fallback keys as used for user_id, device_id, algorithm, key_id in used_fallbacks: self.db_pool.simple_update_txn( txn, @@ -789,6 +811,7 @@ def _claim_e2e_one_time_keys(txn): self._invalidate_cache_and_stream( txn, self.get_e2e_unused_fallback_keys, (user_id, device_id) ) + return result return await self.db_pool.runInteraction( diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 210ddcbb882f..def2feb6df66 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -171,6 +171,47 @@ def test_claim_one_time_key(self): }, ) + @defer.inlineCallbacks + def test_fallback_key(self): + local_user = "@boris:" + self.hs.hostname + device_id = "xyz" + keys = {"alg1:k1": "key1"} + + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"fallback_keys": keys} + ) + ) + + # claiming an OTK when no OTKs are available should return the fallback + # key + res2 = yield defer.ensureDeferred( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) + ) + self.assertEqual( + res2, + { + "failures": {}, + "one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}}, + }, + ) + + # claiming an OTK again should return the same fallback key + res3 = yield defer.ensureDeferred( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) + ) + self.assertEqual( + res3, + { + "failures": {}, + "one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}}, + }, + ) + @defer.inlineCallbacks def test_replace_master_key(self): """uploading a new signing key should make the old signing key unavailable""" From 411a92bc1f2db411a5a9c4bf01a00d8bb1983320 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Tue, 15 Sep 2020 17:09:37 -0400 Subject: [PATCH 05/13] black --- synapse/storage/databases/main/end_to_end_keys.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index fa84e09b7560..39877e921667 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -399,7 +399,9 @@ async def set_e2e_fallback_keys( ) @cached(max_entries=10000) - async def get_e2e_unused_fallback_keys(self, user_id: str, device_id: str) -> List[str]: + async def get_e2e_unused_fallback_keys( + self, user_id: str, device_id: str + ) -> List[str]: """Returns the fallback key types that have an unused key. Args: From 424989f6436d8c13c238a650702a3f547dfdb40f Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Tue, 15 Sep 2020 17:13:11 -0400 Subject: [PATCH 06/13] lint test --- tests/handlers/test_e2e_keys.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index def2feb6df66..695868fb4614 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -177,7 +177,7 @@ def test_fallback_key(self): device_id = "xyz" keys = {"alg1:k1": "key1"} - res = yield defer.ensureDeferred( + yield defer.ensureDeferred( self.handler.upload_keys_for_user( local_user, device_id, {"fallback_keys": keys} ) From ce0d898083a98525cb1af92491afd8348feee5d6 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 23 Sep 2020 17:31:56 -0400 Subject: [PATCH 07/13] apply some changes from review --- synapse/handlers/e2e_keys.py | 4 ++++ synapse/handlers/sync.py | 2 +- synapse/storage/databases/main/end_to_end_keys.py | 8 ++++---- .../storage/databases/main/schema/delta/58/11fallback.sql | 2 +- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index b1a861874968..568fd64d9414 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -506,6 +506,10 @@ async def upload_keys_for_user(self, user_id, device_id, keys): } ) await self.store.set_e2e_fallback_keys(user_id, device_id, fallback_keys) + elif fallback_keys: + log_kv( + {"message": "Did not update fallback_keys", "reason": "not a dict"} + ) else: log_kv( {"message": "Did not update fallback_keys", "reason": "no keys given"} diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 06da3e20631e..b191245f29b5 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1017,7 +1017,7 @@ async def generate_sync_result( logger.debug("Fetching OTK data") device_id = sync_config.device_id one_time_key_counts = {} # type: JsonDict - unused_fallback_keys = [] # type: list + unused_fallback_keys = [] # type: List[str] if device_id: one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 39877e921667..55e852845c8e 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -393,7 +393,7 @@ async def set_e2e_fallback_keys( values={ "key_id": key_id, "key_json": json_encoder.encode(fallback_key), - "used": 0, + "used": False, }, desc="set_e2e_fallback_key", ) @@ -413,7 +413,7 @@ async def get_e2e_unused_fallback_keys( """ return await self.db_pool.simple_select_onecol( "e2e_fallback_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id, "used": 0}, + keyvalues={"user_id": user_id, "device_id": device_id, "used": False}, retcol="algorithm", desc="get_e2e_unused_fallback_keys", ) @@ -775,7 +775,7 @@ def _claim_e2e_one_time_keys(txn): txn.execute(fallback_sql, (user_id, device_id, algorithm)) for key_id, key_json, used in txn: device_result[algorithm + ":" + key_id] = key_json - if used == 0: + if not used: used_fallbacks.append( (user_id, device_id, algorithm, key_id) ) @@ -808,7 +808,7 @@ def _claim_e2e_one_time_keys(txn): "algorithm": algorithm, "key_id": key_id, }, - {"used": 1}, + {"used": True}, ) self._invalidate_cache_and_stream( txn, self.get_e2e_unused_fallback_keys, (user_id, device_id) diff --git a/synapse/storage/databases/main/schema/delta/58/11fallback.sql b/synapse/storage/databases/main/schema/delta/58/11fallback.sql index 272314a4a832..4ed981dbf89e 100644 --- a/synapse/storage/databases/main/schema/delta/58/11fallback.sql +++ b/synapse/storage/databases/main/schema/delta/58/11fallback.sql @@ -19,6 +19,6 @@ CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json ( algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for. key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads. key_json TEXT NOT NULL, -- The key as a JSON blob. - used SMALLINT NOT NULL DEFAULT 0, -- Whether the key has been used or not. + used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not. CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm) ); From 22aae133ad47e8e9b7f55475e86a67a343429fc1 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 23 Sep 2020 17:59:08 -0400 Subject: [PATCH 08/13] fix format --- synapse/handlers/e2e_keys.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 568fd64d9414..df21893f86ef 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -507,9 +507,7 @@ async def upload_keys_for_user(self, user_id, device_id, keys): ) await self.store.set_e2e_fallback_keys(user_id, device_id, fallback_keys) elif fallback_keys: - log_kv( - {"message": "Did not update fallback_keys", "reason": "not a dict"} - ) + log_kv({"message": "Did not update fallback_keys", "reason": "not a dict"}) else: log_kv( {"message": "Did not update fallback_keys", "reason": "no keys given"} From d3262a6b80ab3e3c1e24c65b1690fdb173a649a3 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Mon, 28 Sep 2020 23:25:39 -0400 Subject: [PATCH 09/13] use unstable prefix --- synapse/handlers/e2e_keys.py | 2 +- synapse/rest/client/v2_alpha/sync.py | 2 +- tests/handlers/test_e2e_keys.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index df21893f86ef..611742ae72d5 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -496,7 +496,7 @@ async def upload_keys_for_user(self, user_id, device_id, keys): log_kv( {"message": "Did not update one_time_keys", "reason": "no keys given"} ) - fallback_keys = keys.get("fallback_keys", None) + fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None) if fallback_keys and isinstance(fallback_keys, dict): log_kv( { diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 9ed34761024c..1e41ccc1918b 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -236,7 +236,7 @@ async def encode_response(self, time_now, sync_result, access_token_id, filter): "leave": sync_result.groups.leave, }, "device_one_time_keys_count": sync_result.device_one_time_keys_count, - "device_unused_fallback_keys": sync_result.device_unused_fallback_keys, + "org.matrix.msc2732.device_unused_fallback_keys": sync_result.device_unused_fallback_keys, "next_batch": sync_result.next_batch.to_string(), } diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 695868fb4614..e6d9512e9121 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -179,7 +179,7 @@ def test_fallback_key(self): yield defer.ensureDeferred( self.handler.upload_keys_for_user( - local_user, device_id, {"fallback_keys": keys} + local_user, device_id, {"org.matrix.msc2732.fallback_keys": keys} ) ) From c8b52f609955c10d50f6af679c8d942c66f1afa0 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Fri, 2 Oct 2020 15:25:24 -0400 Subject: [PATCH 10/13] apply changes from review --- scripts/synapse_port_db | 1 + synapse/handlers/sync.py | 10 ++-- synapse/rest/client/v2_alpha/sync.py | 2 +- .../storage/databases/main/end_to_end_keys.py | 10 ++-- tests/handlers/test_e2e_keys.py | 52 ++++++++++++++----- 5 files changed, 50 insertions(+), 25 deletions(-) diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index a34bdf18302c..028db4a12755 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -89,6 +89,7 @@ BOOLEAN_COLUMNS = { "redactions": ["have_censored"], "room_stats_state": ["is_federatable"], "local_media_repository": ["safe_from_quarantine"], + "e2e_fallback_keys_json": ["used"], } diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b191245f29b5..bb27efa0cdd6 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -201,7 +201,7 @@ class SyncResult: device_lists: List of user_ids whose devices have changed device_one_time_keys_count: Dict of algorithm to count for one time keys for this device - device_unused_fallback_keys: List of key types that have an unused fallback + device_unused_fallback_key_types: List of key types that have an unused fallback key groups: Group updates, if any """ @@ -215,7 +215,7 @@ class SyncResult: to_device = attr.ib(type=List[JsonDict]) device_lists = attr.ib(type=DeviceLists) device_one_time_keys_count = attr.ib(type=JsonDict) - device_unused_fallback_keys = attr.ib(type=List[str]) + device_unused_fallback_key_types = attr.ib(type=List[str]) groups = attr.ib(type=Optional[GroupsSyncResult]) def __bool__(self) -> bool: @@ -1017,12 +1017,12 @@ async def generate_sync_result( logger.debug("Fetching OTK data") device_id = sync_config.device_id one_time_key_counts = {} # type: JsonDict - unused_fallback_keys = [] # type: List[str] + unused_fallback_key_types = [] # type: List[str] if device_id: one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id ) - unused_fallback_keys = await self.store.get_e2e_unused_fallback_keys( + unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types( user_id, device_id ) @@ -1048,7 +1048,7 @@ async def generate_sync_result( device_lists=device_lists, groups=sync_result_builder.groups, device_one_time_keys_count=one_time_key_counts, - device_unused_fallback_keys=unused_fallback_keys, + device_unused_fallback_key_types=unused_fallback_key_types, next_batch=sync_result_builder.now_token, ) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 1e41ccc1918b..ea4d9176795b 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -236,7 +236,7 @@ async def encode_response(self, time_now, sync_result, access_token_id, filter): "leave": sync_result.groups.leave, }, "device_one_time_keys_count": sync_result.device_one_time_keys_count, - "org.matrix.msc2732.device_unused_fallback_keys": sync_result.device_unused_fallback_keys, + "org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types, "next_batch": sync_result.next_batch.to_string(), } diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 55e852845c8e..219f3e95f14e 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -368,7 +368,7 @@ def _count_e2e_one_time_keys(txn): ) async def set_e2e_fallback_keys( - self, user_id: str, device_id: str, fallback_keys: dict + self, user_id: str, device_id: str, fallback_keys: JsonDict ) -> None: """Set the user's e2e fallback keys. @@ -399,7 +399,7 @@ async def set_e2e_fallback_keys( ) @cached(max_entries=10000) - async def get_e2e_unused_fallback_keys( + async def get_e2e_unused_fallback_key_types( self, user_id: str, device_id: str ) -> List[str]: """Returns the fallback key types that have an unused key. @@ -415,7 +415,7 @@ async def get_e2e_unused_fallback_keys( "e2e_fallback_keys_json", keyvalues={"user_id": user_id, "device_id": device_id, "used": False}, retcol="algorithm", - desc="get_e2e_unused_fallback_keys", + desc="get_e2e_unused_fallback_key_types", ) async def get_e2e_cross_signing_key( @@ -811,7 +811,7 @@ def _claim_e2e_one_time_keys(txn): {"used": True}, ) self._invalidate_cache_and_stream( - txn, self.get_e2e_unused_fallback_keys, (user_id, device_id) + txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) ) return result @@ -848,7 +848,7 @@ def delete_e2e_keys_by_device_txn(txn): keyvalues={"user_id": user_id, "device_id": device_id}, ) self._invalidate_cache_and_stream( - txn, self.get_e2e_unused_fallback_keys, (user_id, device_id) + txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) ) await self.db_pool.runInteraction( diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index e6d9512e9121..493ae052ea39 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -175,41 +175,65 @@ def test_claim_one_time_key(self): def test_fallback_key(self): local_user = "@boris:" + self.hs.hostname device_id = "xyz" - keys = {"alg1:k1": "key1"} + fallback_key = {"alg1:k1": "key1"} + otk = {"alg1:k2": "key2"} yield defer.ensureDeferred( self.handler.upload_keys_for_user( - local_user, device_id, {"org.matrix.msc2732.fallback_keys": keys} + local_user, + device_id, + {"org.matrix.msc2732.fallback_keys": fallback_key}, ) ) # claiming an OTK when no OTKs are available should return the fallback # key - res2 = yield defer.ensureDeferred( + res = yield defer.ensureDeferred( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) ) self.assertEqual( - res2, - { - "failures": {}, - "one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}}, - }, + res, + {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}},}, ) # claiming an OTK again should return the same fallback key - res3 = yield defer.ensureDeferred( + res = yield defer.ensureDeferred( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) ) self.assertEqual( - res3, - { - "failures": {}, - "one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}}, - }, + res, + {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}},}, + ) + + # if the user uploads a one-time key, the next claim should fetch the + # one-time key, and then go back to the fallback + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": otk} + ) + ) + + res = yield defer.ensureDeferred( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) + ) + self.assertEqual( + res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}},}, + ) + + res = yield defer.ensureDeferred( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) + ) + self.assertEqual( + res, + {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}},}, ) @defer.inlineCallbacks From f714b3f24776a9287a23c774f03d6b550aaf5276 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Fri, 2 Oct 2020 15:32:16 -0400 Subject: [PATCH 11/13] lint --- tests/handlers/test_e2e_keys.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 493ae052ea39..d496613b0c0a 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -195,7 +195,7 @@ def test_fallback_key(self): ) self.assertEqual( res, - {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}},}, + {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, ) # claiming an OTK again should return the same fallback key @@ -206,7 +206,7 @@ def test_fallback_key(self): ) self.assertEqual( res, - {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}},}, + {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, ) # if the user uploads a one-time key, the next claim should fetch the @@ -223,7 +223,7 @@ def test_fallback_key(self): ) ) self.assertEqual( - res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}},}, + res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}}, ) res = yield defer.ensureDeferred( @@ -233,7 +233,7 @@ def test_fallback_key(self): ) self.assertEqual( res, - {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}},}, + {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, ) @defer.inlineCallbacks From 88c09bbe42a636e6929ad484b86692417471e22b Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Fri, 2 Oct 2020 15:45:21 -0400 Subject: [PATCH 12/13] use txn.fetchone instead of looping --- synapse/storage/databases/main/end_to_end_keys.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 219f3e95f14e..fe1863864058 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -764,12 +764,8 @@ def _claim_e2e_one_time_keys(txn): user_result = result.setdefault(user_id, {}) device_result = user_result.setdefault(device_id, {}) txn.execute(sql, (user_id, device_id, algorithm)) - found = False - for key_id, key_json in txn: - found = True - device_result[algorithm + ":" + key_id] = key_json - delete.append((user_id, device_id, algorithm, key_id)) - if not found: + row = txn.fetchone() + if row is None: # no one-time key available, so see if there's a fallback # key txn.execute(fallback_sql, (user_id, device_id, algorithm)) @@ -779,6 +775,10 @@ def _claim_e2e_one_time_keys(txn): used_fallbacks.append( (user_id, device_id, algorithm, key_id) ) + else: + (key_id, key_json) = row + device_result[algorithm + ":" + key_id] = key_json + delete.append((user_id, device_id, algorithm, key_id)) # drop any one-time keys that were claimed sql = ( From 19df8b6428c52bbe12f90c7b6b2886b91f6a4245 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Mon, 5 Oct 2020 15:12:40 -0400 Subject: [PATCH 13/13] apply changes from review and reorder clauses so it makes more sense --- .../storage/databases/main/end_to_end_keys.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 3f9763cfc66f..8c97f2af5ce5 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -764,21 +764,23 @@ def _claim_e2e_one_time_keys(txn): user_result = result.setdefault(user_id, {}) device_result = user_result.setdefault(device_id, {}) txn.execute(sql, (user_id, device_id, algorithm)) - row = txn.fetchone() - if row is None: + otk_row = txn.fetchone() + if otk_row is not None: + key_id, key_json = otk_row + device_result[algorithm + ":" + key_id] = key_json + delete.append((user_id, device_id, algorithm, key_id)) + else: # no one-time key available, so see if there's a fallback # key txn.execute(fallback_sql, (user_id, device_id, algorithm)) - for key_id, key_json, used in txn: + fallback_row = txn.fetchone() + if fallback_row is not None: + key_id, key_json, used = fallback_row device_result[algorithm + ":" + key_id] = key_json if not used: used_fallbacks.append( (user_id, device_id, algorithm, key_id) ) - else: - (key_id, key_json) = row - device_result[algorithm + ":" + key_id] = key_json - delete.append((user_id, device_id, algorithm, key_id)) # drop any one-time keys that were claimed sql = (