Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add requesting user id parameter to key claim methods in TransportLayerClient #15663

Merged
merged 6 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/15663.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add requesting user id parameter to key claim methods in `TransportLayerClient`.
6 changes: 4 additions & 2 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,15 @@ async def query_user_devices(

async def claim_client_keys(
self,
user: UserID,
destination: str,
query: Dict[str, Dict[str, Dict[str, int]]],
timeout: Optional[int],
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.

Args:
user: The user id of the requesting user
destination: Domain name of the remote homeserver
content: The query content.

Expand Down Expand Up @@ -279,7 +281,7 @@ async def claim_client_keys(
if use_unstable:
try:
return await self.transport_layer.claim_client_keys_unstable(
destination, unstable_content, timeout
user, destination, unstable_content, timeout
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
Expand All @@ -295,7 +297,7 @@ async def claim_client_keys(
logger.debug("Skipping unstable claim client keys API")

return await self.transport_layer.claim_client_keys(
destination, content, timeout
user, destination, content, timeout
)

@trace
Expand Down
16 changes: 13 additions & 3 deletions synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from synapse.federation.units import Transaction
from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser
from synapse.http.types import QueryParams
from synapse.types import JsonDict
from synapse.types import JsonDict, UserID
from synapse.util import ExceptionBundle

if TYPE_CHECKING:
Expand Down Expand Up @@ -630,7 +630,11 @@ async def query_user_devices(
)

async def claim_client_keys(
self, destination: str, query_content: JsonDict, timeout: Optional[int]
self,
user: UserID,
destination: str,
query_content: JsonDict,
timeout: Optional[int],
) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server.

Expand All @@ -655,6 +659,7 @@ async def claim_client_keys(
}

Args:
user: the user_id of the requesting user
destination: The server to query.
query_content: The user ids to query.
Returns:
Expand All @@ -671,7 +676,11 @@ async def claim_client_keys(
)

async def claim_client_keys_unstable(
self, destination: str, query_content: JsonDict, timeout: Optional[int]
self,
user: UserID,
destination: str,
query_content: JsonDict,
timeout: Optional[int],
) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server.

Expand All @@ -696,6 +705,7 @@ async def claim_client_keys_unstable(
}

Args:
user: the user_id of the requesting user
destination: The server to query.
query_content: The user ids to query.
Returns:
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,7 @@ async def claim_local_one_time_keys(
async def claim_one_time_keys(
self,
query: Dict[str, Dict[str, Dict[str, int]]],
user: UserID,
timeout: Optional[int],
always_include_fallback_keys: bool,
) -> JsonDict:
Expand Down Expand Up @@ -703,7 +704,7 @@ async def claim_client_keys(destination: str) -> None:
device_keys = remote_queries[destination]
try:
remote_result = await self.federation.claim_client_keys(
destination, device_keys, timeout=timeout
user, destination, device_keys, timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
Expand Down
8 changes: 4 additions & 4 deletions synapse/rest/client/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def __init__(self, hs: "HomeServer"):
self.e2e_keys_handler = hs.get_e2e_keys_handler()

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
requester = await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)

Expand All @@ -298,7 +298,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
query.setdefault(user_id, {})[device_id] = {algorithm: 1}

result = await self.e2e_keys_handler.claim_one_time_keys(
query, timeout, always_include_fallback_keys=False
query, requester.user, timeout, always_include_fallback_keys=False
)
return 200, result

Expand Down Expand Up @@ -335,7 +335,7 @@ def __init__(self, hs: "HomeServer"):
self.e2e_keys_handler = hs.get_e2e_keys_handler()

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
requester = await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)

Expand All @@ -346,7 +346,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
query.setdefault(user_id, {})[device_id] = Counter(algorithms)

result = await self.e2e_keys_handler.claim_one_time_keys(
query, timeout, always_include_fallback_keys=True
query, requester.user, timeout, always_include_fallback_keys=True
)
return 200, result

Expand Down
16 changes: 15 additions & 1 deletion tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from synapse.handlers.device import DeviceHandler
from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex
from synapse.types import JsonDict
from synapse.types import JsonDict, UserID
from synapse.util import Clock

from tests import unittest
Expand All @@ -45,6 +45,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_e2e_keys_handler()
self.store = self.hs.get_datastores().main
self.requester = UserID.from_string(f"@test_requester:{self.hs.hostname}")

def test_query_local_devices_no_devices(self) -> None:
"""If the user has no devices, we expect an empty list."""
Expand Down Expand Up @@ -161,6 +162,7 @@ def test_claim_one_time_key(self) -> None:
res2 = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
Expand Down Expand Up @@ -206,6 +208,7 @@ def test_fallback_key(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
Expand All @@ -225,6 +228,7 @@ def test_fallback_key(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
Expand Down Expand Up @@ -274,6 +278,7 @@ def test_fallback_key(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
Expand All @@ -286,6 +291,7 @@ def test_fallback_key(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
Expand All @@ -307,6 +313,7 @@ def test_fallback_key(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
Expand Down Expand Up @@ -348,6 +355,7 @@ def test_fallback_key_always_returned(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=True,
)
Expand All @@ -370,6 +378,7 @@ def test_fallback_key_always_returned(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=True,
)
Expand Down Expand Up @@ -1080,6 +1089,7 @@ def test_query_appservice(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id_1: {"alg1": 1}, device_id_2: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
Expand Down Expand Up @@ -1125,6 +1135,7 @@ def test_query_appservice_with_fallback(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id_1: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=True,
)
Expand Down Expand Up @@ -1169,6 +1180,7 @@ def test_query_appservice_with_fallback(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id_1: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=True,
)
Expand Down Expand Up @@ -1202,6 +1214,7 @@ def test_query_appservice_with_fallback(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id_1: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=True,
)
Expand Down Expand Up @@ -1229,6 +1242,7 @@ def test_query_appservice_with_fallback(self) -> None:
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id_1: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=True,
)
Expand Down