Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add requesting user id parameter to key claim methods in TransportLayerClient #15663

Merged
merged 6 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/15663.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add requesting user id parameter to key claim methods in `TransportLayerClient`.
6 changes: 4 additions & 2 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,15 @@ async def claim_client_keys(
self,
destination: str,
query: Dict[str, Dict[str, Dict[str, int]]],
user: UserID,
timeout: Optional[int],
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.

Args:
destination: Domain name of the remote homeserver
content: The query content.
user: The user id of the requesting user
clokep marked this conversation as resolved.
Show resolved Hide resolved

H-Shay marked this conversation as resolved.
Show resolved Hide resolved
Returns:
The JSON object from the response
Expand Down Expand Up @@ -279,7 +281,7 @@ async def claim_client_keys(
if use_unstable:
try:
return await self.transport_layer.claim_client_keys_unstable(
destination, unstable_content, timeout
destination, unstable_content, user, timeout
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
Expand All @@ -295,7 +297,7 @@ async def claim_client_keys(
logger.debug("Skipping unstable claim client keys API")

return await self.transport_layer.claim_client_keys(
destination, content, timeout
destination, content, user, timeout
)

@trace
Expand Down
16 changes: 13 additions & 3 deletions synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from synapse.federation.units import Transaction
from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser
from synapse.http.types import QueryParams
from synapse.types import JsonDict
from synapse.types import JsonDict, UserID
from synapse.util import ExceptionBundle

if TYPE_CHECKING:
Expand Down Expand Up @@ -630,7 +630,11 @@ async def query_user_devices(
)

async def claim_client_keys(
self, destination: str, query_content: JsonDict, timeout: Optional[int]
self,
destination: str,
query_content: JsonDict,
user: UserID,
timeout: Optional[int],
) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server.

Expand All @@ -657,6 +661,7 @@ async def claim_client_keys(
Args:
destination: The server to query.
query_content: The user ids to query.
user: the user_id of the requesting user
Returns:
A dict containing the one-time keys.
"""
Expand All @@ -671,7 +676,11 @@ async def claim_client_keys(
)

async def claim_client_keys_unstable(
self, destination: str, query_content: JsonDict, timeout: Optional[int]
self,
destination: str,
query_content: JsonDict,
user: UserID,
timeout: Optional[int],
) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server.

Expand All @@ -698,6 +707,7 @@ async def claim_client_keys_unstable(
Args:
destination: The server to query.
query_content: The user ids to query.
user: the user_id of the requesting user
Returns:
A dict containing the one-time keys.
"""
Expand Down
7 changes: 4 additions & 3 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,7 @@ async def claim_local_one_time_keys(
async def claim_one_time_keys(
self,
query: Dict[str, Dict[str, Dict[str, int]]],
user: UserID,
timeout: Optional[int],
always_include_fallback_keys: bool,
) -> JsonDict:
Expand Down Expand Up @@ -698,12 +699,12 @@ async def claim_one_time_keys(
failures: Dict[str, JsonDict] = {}

@trace
async def claim_client_keys(destination: str) -> None:
async def claim_client_keys(destination: str, user: UserID) -> None:
clokep marked this conversation as resolved.
Show resolved Hide resolved
set_tag("destination", destination)
device_keys = remote_queries[destination]
try:
remote_result = await self.federation.claim_client_keys(
destination, device_keys, timeout=timeout
destination, device_keys, user, timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
Expand All @@ -718,7 +719,7 @@ async def claim_client_keys(destination: str) -> None:
await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(claim_client_keys, destination)
run_in_background(claim_client_keys, destination, user)
for destination in remote_queries
],
consumeErrors=True,
Expand Down
10 changes: 6 additions & 4 deletions synapse/rest/client/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ def __init__(self, hs: "HomeServer"):
self.e2e_keys_handler = hs.get_e2e_keys_handler()

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = requester.user
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)

Expand All @@ -298,7 +299,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
query.setdefault(user_id, {})[device_id] = {algorithm: 1}

result = await self.e2e_keys_handler.claim_one_time_keys(
query, timeout, always_include_fallback_keys=False
query, user, timeout, always_include_fallback_keys=False
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
)
return 200, result

Expand Down Expand Up @@ -335,7 +336,8 @@ def __init__(self, hs: "HomeServer"):
self.e2e_keys_handler = hs.get_e2e_keys_handler()

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = requester.user
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)

Expand All @@ -346,7 +348,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
query.setdefault(user_id, {})[device_id] = Counter(algorithms)

result = await self.e2e_keys_handler.claim_one_time_keys(
query, timeout, always_include_fallback_keys=True
query, user, timeout, always_include_fallback_keys=True
)
return 200, result

Expand Down
16 changes: 15 additions & 1 deletion tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from synapse.handlers.device import DeviceHandler
from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex
from synapse.types import JsonDict
from synapse.types import JsonDict, UserID
from synapse.util import Clock

from tests import unittest
Expand All @@ -45,6 +45,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_e2e_keys_handler()
self.store = self.hs.get_datastores().main
self.requester = UserID.from_string(f"@test_requester:{self.hs.hostname}")

def test_query_local_devices_no_devices(self) -> None:
"""If the user has no devices, we expect an empty list."""
Expand Down Expand Up @@ -161,6 +162,7 @@ def test_claim_one_time_key(self) -> None:
res2 = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
Expand Down Expand Up @@ -206,6 +208,7 @@ def test_fallback_key(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
Expand All @@ -225,6 +228,7 @@ def test_fallback_key(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
Expand Down Expand Up @@ -274,6 +278,7 @@ def test_fallback_key(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
Expand All @@ -286,6 +291,7 @@ def test_fallback_key(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
Expand All @@ -307,6 +313,7 @@ def test_fallback_key(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
Expand Down Expand Up @@ -348,6 +355,7 @@ def test_fallback_key_always_returned(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=True,
)
Expand All @@ -370,6 +378,7 @@ def test_fallback_key_always_returned(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=True,
)
Expand Down Expand Up @@ -1080,6 +1089,7 @@ def test_query_appservice(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id_1: {"alg1": 1}, device_id_2: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
Expand Down Expand Up @@ -1125,6 +1135,7 @@ def test_query_appservice_with_fallback(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id_1: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=True,
)
Expand Down Expand Up @@ -1169,6 +1180,7 @@ def test_query_appservice_with_fallback(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id_1: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=True,
)
Expand Down Expand Up @@ -1202,6 +1214,7 @@ def test_query_appservice_with_fallback(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id_1: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=True,
)
Expand Down Expand Up @@ -1229,6 +1242,7 @@ def test_query_appservice_with_fallback(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id_1: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=True,
)
Expand Down