diff --git a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py index 1d004a398ea..9856b8b706f 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py +++ b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py @@ -117,3 +117,18 @@ def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool: return True except InvalidSignature: return False + + +def ssh_types_to_elliptic_curve( + private_key: serialization.SSHPrivateKeyTypes, + public_key: serialization.SSHPublicKeyTypes, +) -> Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]: + """Cast SSH key types to elliptic curve.""" + if isinstance(private_key, ec.EllipticCurvePrivateKey) and isinstance( + public_key, ec.EllipticCurvePublicKey + ): + return (private_key, public_key) + + raise TypeError( + "The provided key is not an EllipticCurvePrivateKey or EllipticCurvePublicKey" + ) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 70e53da765d..7e06062311d 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -16,15 +16,21 @@ import argparse import asyncio +import csv import importlib.util import sys import threading from logging import ERROR, INFO, WARN from os.path import isfile from pathlib import Path -from typing import List, Optional, Tuple +from typing import List, Optional, Sequence, Set, Tuple import grpc +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.serialization import ( + load_ssh_private_key, + load_ssh_public_key, +) from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event from flwr.common.address import parse_address @@ -36,6 +42,10 @@ ) from flwr.common.exit_handlers import register_exit_handlers from flwr.common.logger import log +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + public_key_to_bytes, + ssh_types_to_elliptic_curve, +) from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611 add_FleetServicer_to_server, ) @@ -51,6 +61,7 @@ start_grpc_server, ) from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer +from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor from .superlink.fleet.vce import start_vce from .superlink.state import StateFactory @@ -354,10 +365,28 @@ def run_superlink() -> None: sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.") host, port, is_v6 = parsed_address address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" + + maybe_keys = _try_setup_client_authentication(args, certificates) + interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None + if maybe_keys is not None: + ( + client_public_keys, + server_private_key, + server_public_key, + ) = maybe_keys + interceptors = [ + AuthenticateServerInterceptor( + client_public_keys, + server_private_key, + server_public_key, + ) + ] + fleet_server = _run_fleet_api_grpc_rere( address=address, state_factory=state_factory, certificates=certificates, + interceptors=interceptors, ) grpc_servers.append(fleet_server) elif args.fleet_api_type == TRANSPORT_TYPE_VCE: @@ -390,6 +419,70 @@ def run_superlink() -> None: driver_server.wait_for_termination(timeout=1) +def _try_setup_client_authentication( + args: argparse.Namespace, + certificates: Optional[Tuple[bytes, bytes, bytes]], +) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]: + if not args.require_client_authentication: + return None + + if certificates is None: + sys.exit( + "Client authentication only works over secure connections. " + "Please provide certificate paths using '--certificates' when " + "enabling '--require-client-authentication'." + ) + + client_keys_file_path = Path(args.require_client_authentication[0]) + if not client_keys_file_path.exists(): + sys.exit( + "The provided path to the client public keys CSV file does not exist: " + f"{client_keys_file_path}. " + "Please provide the CSV file path containing known client public keys " + "to '--require-client-authentication'." + ) + + client_public_keys: Set[bytes] = set() + ssh_private_key = load_ssh_private_key( + Path(args.require_client_authentication[1]).read_bytes(), + None, + ) + ssh_public_key = load_ssh_public_key( + Path(args.require_client_authentication[2]).read_bytes() + ) + + try: + server_private_key, server_public_key = ssh_types_to_elliptic_curve( + ssh_private_key, ssh_public_key + ) + except TypeError: + sys.exit( + "The file paths provided could not be read as a private and public " + "key pair. Client authentication requires an elliptic curve public and " + "private key pair. Please provide the file paths containing elliptic " + "curve private and public keys to '--require-client-authentication'." + ) + + with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile: + reader = csv.reader(csvfile) + for row in reader: + for element in row: + public_key = load_ssh_public_key(element.encode()) + if isinstance(public_key, ec.EllipticCurvePublicKey): + client_public_keys.add(public_key_to_bytes(public_key)) + else: + sys.exit( + "Error: Unable to parse the public keys in the .csv " + "file. Please ensure that the .csv file contains valid " + "SSH public keys and try again." + ) + return ( + client_public_keys, + server_private_key, + server_public_key, + ) + + def _try_obtain_certificates( args: argparse.Namespace, ) -> Optional[Tuple[bytes, bytes, bytes]]: @@ -417,6 +510,7 @@ def _run_fleet_api_grpc_rere( address: str, state_factory: StateFactory, certificates: Optional[Tuple[bytes, bytes, bytes]], + interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None, ) -> grpc.Server: """Run Fleet API (gRPC, request-response).""" # Create Fleet API gRPC server @@ -429,6 +523,7 @@ def _run_fleet_api_grpc_rere( server_address=address, max_message_length=GRPC_MAX_MESSAGE_LENGTH, certificates=certificates, + interceptors=interceptors, ) log(INFO, "Flower ECE: Starting Fleet API (gRPC-rere) on %s", address) @@ -606,6 +701,15 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None: "Flower will just create a state in memory.", default=DATABASE, ) + parser.add_argument( + "--require-client-authentication", + nargs=3, + metavar=("CLIENT_KEYS", "SERVER_PRIVATE_KEY", "SERVER_PUBLIC_KEY"), + type=str, + help="Provide three file paths: (1) a .csv file containing a list of " + "known client public keys for authentication, (2) the server's private " + "key file, and (3) the server's public key file.", + ) def _add_args_driver_api(parser: argparse.ArgumentParser) -> None: diff --git a/src/py/flwr/server/server_test.py b/src/py/flwr/server/server_test.py index 274e5289fee..51071c13f89 100644 --- a/src/py/flwr/server/server_test.py +++ b/src/py/flwr/server/server_test.py @@ -15,9 +15,22 @@ """Flower server tests.""" +import argparse +import csv +import tempfile +from pathlib import Path from typing import List, Optional import numpy as np +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.serialization import ( + Encoding, + NoEncryption, + PrivateFormat, + PublicFormat, + load_ssh_private_key, + load_ssh_public_key, +) from flwr.common import ( Code, @@ -35,8 +48,14 @@ Status, ndarray_to_bytes, ) +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + generate_key_pairs, + private_key_to_bytes, + public_key_to_bytes, +) from flwr.server.client_manager import SimpleClientManager +from .app import _try_setup_client_authentication from .client_proxy import ClientProxy from .server import Server, evaluate_clients, fit_clients @@ -182,3 +201,71 @@ def test_set_max_workers() -> None: # Assert assert server.max_workers == 42 + + +def test_setup_client_auth() -> None: # pylint: disable=R0914 + """Test setup client authentication.""" + # Prepare + _, first_public_key = generate_key_pairs() + private_key, public_key = generate_key_pairs() + + server_public_key = public_key.public_bytes( + encoding=Encoding.OpenSSH, format=PublicFormat.OpenSSH + ) + server_private_key = private_key.private_bytes( + Encoding.PEM, PrivateFormat.OpenSSH, NoEncryption() + ) + _, second_public_key = generate_key_pairs() + + # Execute + with tempfile.TemporaryDirectory() as temp_dir: + # Initialize temporary files + client_keys_file_path = Path(temp_dir) / "client_keys.csv" + server_private_key_path = Path(temp_dir) / "server_private_key" + server_public_key_path = Path(temp_dir) / "server_public_key" + + # Fill the files with relevant keys + with open(client_keys_file_path, "w", newline="", encoding="utf-8") as csvfile: + writer = csv.writer(csvfile) + writer.writerow( + [ + first_public_key.public_bytes( + encoding=Encoding.OpenSSH, format=PublicFormat.OpenSSH + ).decode(), + second_public_key.public_bytes( + encoding=Encoding.OpenSSH, format=PublicFormat.OpenSSH + ).decode(), + ] + ) + server_public_key_path.write_bytes(server_public_key) + server_private_key_path.write_bytes(server_private_key) + + # Mock argparse with `require-client-authentication`` flag + mock_args = argparse.Namespace( + require_client_authentication=[ + str(client_keys_file_path), + str(server_private_key_path), + str(server_public_key_path), + ] + ) + + # Run _try_setup_client_authentication + result = _try_setup_client_authentication(mock_args, (b"", b"", b"")) + + expected_private_key = load_ssh_private_key(server_private_key, None) + expected_public_key = load_ssh_public_key(server_public_key) + + # Assert + assert isinstance(expected_private_key, ec.EllipticCurvePrivateKey) + assert isinstance(expected_public_key, ec.EllipticCurvePublicKey) + assert result is not None + assert result[0] == { + public_key_to_bytes(first_public_key), + public_key_to_bytes(second_public_key), + } + assert private_key_to_bytes(result[1]) == private_key_to_bytes( + expected_private_key + ) + assert public_key_to_bytes(result[2]) == public_key_to_bytes( + expected_public_key + ) diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py index 82f049844bd..6aeaa7ef413 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py @@ -18,7 +18,7 @@ import concurrent.futures import sys from logging import ERROR -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Optional, Sequence, Tuple, Union import grpc @@ -162,6 +162,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, keepalive_time_ms: int = 210000, certificates: Optional[Tuple[bytes, bytes, bytes]] = None, + interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None, ) -> grpc.Server: """Create a gRPC server with a single servicer. @@ -249,6 +250,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments # returning RESOURCE_EXHAUSTED status, or None to indicate no limit. maximum_concurrent_rpcs=max_concurrent_workers, options=options, + interceptors=interceptors, ) add_servicer_to_server_fn(servicer, server) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py new file mode 100644 index 00000000000..7532364336a --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -0,0 +1,169 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower server interceptor.""" + + +import base64 +from logging import INFO +from typing import Any, Callable, Sequence, Set, Tuple, Union + +import grpc +from cryptography.hazmat.primitives.asymmetric import ec + +from flwr.common.logger import log +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + bytes_to_public_key, + generate_shared_key, + public_key_to_bytes, + verify_hmac, +) +from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 + CreateNodeRequest, + CreateNodeResponse, + DeleteNodeRequest, + DeleteNodeResponse, + GetRunRequest, + GetRunResponse, + PullTaskInsRequest, + PullTaskInsResponse, + PushTaskResRequest, + PushTaskResResponse, +) + +_PUBLIC_KEY_HEADER = "public-key" +_AUTH_TOKEN_HEADER = "auth-token" + +Request = Union[ + CreateNodeRequest, + DeleteNodeRequest, + PullTaskInsRequest, + PushTaskResRequest, + GetRunRequest, +] + +Response = Union[ + CreateNodeResponse, + DeleteNodeResponse, + PullTaskInsResponse, + PushTaskResResponse, + GetRunResponse, +] + + +def _get_value_from_tuples( + key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]] +) -> bytes: + value = next((value for key, value in tuples if key == key_string), "") + if isinstance(value, str): + return value.encode() + + return value + + +class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore + """Server interceptor for client authentication.""" + + def __init__( + self, + client_public_keys: Set[bytes], + private_key: ec.EllipticCurvePrivateKey, + public_key: ec.EllipticCurvePublicKey, + ): + self.server_private_key = private_key + self.client_public_keys = client_public_keys + self.encoded_server_public_key = base64.urlsafe_b64encode( + public_key_to_bytes(public_key) + ) + log( + INFO, + "Client authentication enabled with %d known public keys", + len(client_public_keys), + ) + + def intercept_service( + self, + continuation: Callable[[Any], Any], + handler_call_details: grpc.HandlerCallDetails, + ) -> grpc.RpcMethodHandler: + """Flower server interceptor authentication logic. + + Intercept all unary calls from clients and authenticate clients by validating + auth metadata sent by the client. Continue RPC call if client is authenticated, + else, terminate RPC call by setting context to abort. + """ + # One of the method handlers in + # `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer` + method_handler: grpc.RpcMethodHandler = continuation(handler_call_details) + return self._generic_auth_unary_method_handler(method_handler) + + def _generic_auth_unary_method_handler( + self, method_handler: grpc.RpcMethodHandler + ) -> grpc.RpcMethodHandler: + def _generic_method_handler( + request: Request, + context: grpc.ServicerContext, + ) -> Response: + client_public_key_bytes = base64.urlsafe_b64decode( + _get_value_from_tuples( + _PUBLIC_KEY_HEADER, context.invocation_metadata() + ) + ) + is_public_key_known = client_public_key_bytes in self.client_public_keys + if not is_public_key_known: + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied") + + if isinstance(request, CreateNodeRequest): + context.send_initial_metadata( + ( + ( + _PUBLIC_KEY_HEADER, + self.encoded_server_public_key, + ), + ) + ) + elif isinstance( + request, + ( + DeleteNodeRequest, + PullTaskInsRequest, + PushTaskResRequest, + GetRunRequest, + ), + ): + hmac_value = base64.urlsafe_b64decode( + _get_value_from_tuples( + _AUTH_TOKEN_HEADER, context.invocation_metadata() + ) + ) + client_public_key = bytes_to_public_key(client_public_key_bytes) + shared_secret = generate_shared_key( + self.server_private_key, + client_public_key, + ) + verify = verify_hmac( + shared_secret, request.SerializeToString(True), hmac_value + ) + if not verify: + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied") + else: + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied") + + return method_handler.unary_unary(request, context) # type: ignore + + return grpc.unary_unary_rpc_method_handler( + _generic_method_handler, + request_deserializer=method_handler.request_deserializer, + response_serializer=method_handler.response_serializer, + ) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py new file mode 100644 index 00000000000..b68d41f304a --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -0,0 +1,339 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower server interceptor tests.""" + + +import base64 +import unittest + +import grpc + +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + compute_hmac, + generate_key_pairs, + generate_shared_key, + public_key_to_bytes, +) +from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 + CreateNodeRequest, + CreateNodeResponse, + DeleteNodeRequest, + DeleteNodeResponse, + GetRunRequest, + GetRunResponse, + PullTaskInsRequest, + PullTaskInsResponse, + PushTaskResRequest, + PushTaskResResponse, +) +from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611 +from flwr.server.app import ADDRESS_FLEET_API_GRPC_RERE, _run_fleet_api_grpc_rere +from flwr.server.superlink.state.state_factory import StateFactory + +from .server_interceptor import ( + _AUTH_TOKEN_HEADER, + _PUBLIC_KEY_HEADER, + AuthenticateServerInterceptor, +) + + +class TestServerInterceptor(unittest.TestCase): # pylint: disable=R0902 + """Server interceptor tests.""" + + def setUp(self) -> None: + """Initialize mock stub and server interceptor.""" + self._client_private_key, self._client_public_key = generate_key_pairs() + self._server_private_key, self._server_public_key = generate_key_pairs() + + state_factory = StateFactory(":flwr-in-memory-state:") + + self._server_interceptor = AuthenticateServerInterceptor( + {public_key_to_bytes(self._client_public_key)}, + self._server_private_key, + self._server_public_key, + ) + self._server: grpc.Server = _run_fleet_api_grpc_rere( + ADDRESS_FLEET_API_GRPC_RERE, state_factory, None, [self._server_interceptor] + ) + + self._channel = grpc.insecure_channel("localhost:9092") + self._create_node = self._channel.unary_unary( + "/flwr.proto.Fleet/CreateNode", + request_serializer=CreateNodeRequest.SerializeToString, + response_deserializer=CreateNodeResponse.FromString, + ) + self._delete_node = self._channel.unary_unary( + "/flwr.proto.Fleet/DeleteNode", + request_serializer=DeleteNodeRequest.SerializeToString, + response_deserializer=DeleteNodeResponse.FromString, + ) + self._pull_task_ins = self._channel.unary_unary( + "/flwr.proto.Fleet/PullTaskIns", + request_serializer=PullTaskInsRequest.SerializeToString, + response_deserializer=PullTaskInsResponse.FromString, + ) + self._push_task_res = self._channel.unary_unary( + "/flwr.proto.Fleet/PushTaskRes", + request_serializer=PushTaskResRequest.SerializeToString, + response_deserializer=PushTaskResResponse.FromString, + ) + self._get_run = self._channel.unary_unary( + "/flwr.proto.Fleet/GetRun", + request_serializer=GetRunRequest.SerializeToString, + response_deserializer=GetRunResponse.FromString, + ) + + def tearDown(self) -> None: + """Clean up grpc server.""" + self._server.stop(None) + + def test_successful_create_node_with_metadata(self) -> None: + """Test server interceptor for creating node.""" + # Prepare + public_key_bytes = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ) + + # Execute + response, call = self._create_node.with_call( + request=CreateNodeRequest(), + metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),), + ) + + expected_metadata = ( + _PUBLIC_KEY_HEADER, + base64.urlsafe_b64encode( + public_key_to_bytes(self._server_public_key) + ).decode(), + ) + + # Assert + assert call.initial_metadata()[0] == expected_metadata + assert isinstance(response, CreateNodeResponse) + + def test_unsuccessful_create_node_with_metadata(self) -> None: + """Test server interceptor for creating node unsuccessfully.""" + # Prepare + _, client_public_key = generate_key_pairs() + public_key_bytes = base64.urlsafe_b64encode( + public_key_to_bytes(client_public_key) + ) + + # Execute & Assert + with self.assertRaises(grpc.RpcError): + self._create_node.with_call( + request=CreateNodeRequest(), + metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),), + ) + + def test_successful_delete_node_with_metadata(self) -> None: + """Test server interceptor for deleting node.""" + # Prepare + request = DeleteNodeRequest() + shared_secret = generate_shared_key( + self._client_private_key, self._server_public_key + ) + hmac_value = base64.urlsafe_b64encode( + compute_hmac(shared_secret, request.SerializeToString(True)) + ) + public_key_bytes = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ) + + # Execute + response, call = self._delete_node.with_call( + request=request, + metadata=( + (_PUBLIC_KEY_HEADER, public_key_bytes), + (_AUTH_TOKEN_HEADER, hmac_value), + ), + ) + + # Assert + assert isinstance(response, DeleteNodeResponse) + assert grpc.StatusCode.OK == call.code() + + def test_unsuccessful_delete_node_with_metadata(self) -> None: + """Test server interceptor for deleting node unsuccessfully.""" + # Prepare + request = DeleteNodeRequest() + client_private_key, _ = generate_key_pairs() + shared_secret = generate_shared_key(client_private_key, self._server_public_key) + hmac_value = base64.urlsafe_b64encode( + compute_hmac(shared_secret, request.SerializeToString(True)) + ) + public_key_bytes = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ) + + # Execute & Assert + with self.assertRaises(grpc.RpcError): + self._delete_node.with_call( + request=request, + metadata=( + (_PUBLIC_KEY_HEADER, public_key_bytes), + (_AUTH_TOKEN_HEADER, hmac_value), + ), + ) + + def test_successful_pull_task_ins_with_metadata(self) -> None: + """Test server interceptor for pull task ins.""" + # Prepare + request = PullTaskInsRequest() + shared_secret = generate_shared_key( + self._client_private_key, self._server_public_key + ) + hmac_value = base64.urlsafe_b64encode( + compute_hmac(shared_secret, request.SerializeToString(True)) + ) + public_key_bytes = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ) + + # Execute + response, call = self._pull_task_ins.with_call( + request=request, + metadata=( + (_PUBLIC_KEY_HEADER, public_key_bytes), + (_AUTH_TOKEN_HEADER, hmac_value), + ), + ) + + # Assert + assert isinstance(response, PullTaskInsResponse) + assert grpc.StatusCode.OK == call.code() + + def test_unsuccessful_pull_task_ins_with_metadata(self) -> None: + """Test server interceptor for pull task ins unsuccessfully.""" + # Prepare + request = PullTaskInsRequest() + client_private_key, _ = generate_key_pairs() + shared_secret = generate_shared_key(client_private_key, self._server_public_key) + hmac_value = base64.urlsafe_b64encode( + compute_hmac(shared_secret, request.SerializeToString(True)) + ) + public_key_bytes = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ) + + # Execute & Assert + with self.assertRaises(grpc.RpcError): + self._pull_task_ins.with_call( + request=request, + metadata=( + (_PUBLIC_KEY_HEADER, public_key_bytes), + (_AUTH_TOKEN_HEADER, hmac_value), + ), + ) + + def test_successful_push_task_res_with_metadata(self) -> None: + """Test server interceptor for push task res.""" + # Prepare + request = PushTaskResRequest(task_res_list=[TaskRes()]) + shared_secret = generate_shared_key( + self._client_private_key, self._server_public_key + ) + hmac_value = base64.urlsafe_b64encode( + compute_hmac(shared_secret, request.SerializeToString(True)) + ) + public_key_bytes = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ) + + # Execute + response, call = self._push_task_res.with_call( + request=request, + metadata=( + (_PUBLIC_KEY_HEADER, public_key_bytes), + (_AUTH_TOKEN_HEADER, hmac_value), + ), + ) + + # Assert + assert isinstance(response, PushTaskResResponse) + assert grpc.StatusCode.OK == call.code() + + def test_unsuccessful_push_task_res_with_metadata(self) -> None: + """Test server interceptor for push task res unsuccessfully.""" + # Prepare + request = PushTaskResRequest(task_res_list=[TaskRes()]) + client_private_key, _ = generate_key_pairs() + shared_secret = generate_shared_key(client_private_key, self._server_public_key) + hmac_value = base64.urlsafe_b64encode( + compute_hmac(shared_secret, request.SerializeToString(True)) + ) + public_key_bytes = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ) + + # Execute & Assert + with self.assertRaises(grpc.RpcError): + self._push_task_res.with_call( + request=request, + metadata=( + (_PUBLIC_KEY_HEADER, public_key_bytes), + (_AUTH_TOKEN_HEADER, hmac_value), + ), + ) + + def test_successful_get_run_with_metadata(self) -> None: + """Test server interceptor for pull task ins.""" + # Prepare + request = GetRunRequest(run_id=0) + shared_secret = generate_shared_key( + self._client_private_key, self._server_public_key + ) + hmac_value = base64.urlsafe_b64encode( + compute_hmac(shared_secret, request.SerializeToString(True)) + ) + public_key_bytes = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ) + + # Execute + response, call = self._get_run.with_call( + request=request, + metadata=( + (_PUBLIC_KEY_HEADER, public_key_bytes), + (_AUTH_TOKEN_HEADER, hmac_value), + ), + ) + + # Assert + assert isinstance(response, GetRunResponse) + assert grpc.StatusCode.OK == call.code() + + def test_unsuccessful_get_run_with_metadata(self) -> None: + """Test server interceptor for pull task ins unsuccessfully.""" + # Prepare + request = GetRunRequest(run_id=0) + client_private_key, _ = generate_key_pairs() + shared_secret = generate_shared_key(client_private_key, self._server_public_key) + hmac_value = base64.urlsafe_b64encode( + compute_hmac(shared_secret, request.SerializeToString(True)) + ) + public_key_bytes = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ) + + # Execute & Assert + with self.assertRaises(grpc.RpcError): + self._get_run.with_call( + request=request, + metadata=( + (_PUBLIC_KEY_HEADER, public_key_bytes), + (_AUTH_TOKEN_HEADER, hmac_value), + ), + )