From d0d4ecde858f4ab56b1365acb71e6e4643adca2f Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 13:33:51 +0100 Subject: [PATCH 01/73] Add authentication state and test --- .../crypto/symmetric_encryption.py | 14 ++++ .../authentication/authentication_state.py | 52 ++++++++++++++ .../authentication_state_test.py | 65 +++++++++++++++++ .../authentication/in_memory_auth_state.py | 66 ++++++++++++++++++ .../state/authentication/sqlite_auth_state.py | 69 +++++++++++++++++++ src/py/flwr/server/state/sqlite_state.py | 24 ++++++- 6 files changed, 289 insertions(+), 1 deletion(-) create mode 100644 src/py/flwr/server/state/authentication/authentication_state.py create mode 100644 src/py/flwr/server/state/authentication/authentication_state_test.py create mode 100644 src/py/flwr/server/state/authentication/in_memory_auth_state.py create mode 100644 src/py/flwr/server/state/authentication/sqlite_auth_state.py diff --git a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py index 844a93f3bde9..7b22565e2803 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py +++ b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py @@ -98,3 +98,17 @@ def decrypt(key: bytes, ciphertext: bytes) -> bytes: # The input key must be url safe fernet = Fernet(key) return fernet.decrypt(ciphertext) + +def compute_hmac(key: bytes, message: bytes) -> bytes: + computed_hmac = hmac.HMAC(key, hashes.SHA256()) + computed_hmac.update(message) + return computed_hmac.finalize() + +def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool: + computed_hmac = hmac.HMAC(key, hashes.SHA256()) + computed_hmac.update(message) + try: + computed_hmac.verify(hmac_value) + return True + except: + return False diff --git a/src/py/flwr/server/state/authentication/authentication_state.py b/src/py/flwr/server/state/authentication/authentication_state.py new file mode 100644 index 000000000000..4a6fc9a6ab57 --- /dev/null +++ b/src/py/flwr/server/state/authentication/authentication_state.py @@ -0,0 +1,52 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Abstract base class AuthenticationState.""" + +import abc +from typing import Set + +class AuthenticationState(abc.ABC): + """Abstract State.""" + @abc.abstractmethod + def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: + """Store `node_id` and `public_key` as key-value pair in state.""" + + @abc.abstractmethod + def get_public_key_from_node_id(self, node_id: int) -> bytes: + """Get client's public key in urlsafe bytes for `node_id`.""" + + @abc.abstractmethod + def store_server_public_private_key(self, public_key: bytes, private_key: bytes) -> None: + """Store server's `public_key` and `private_key` in state.""" + + @abc.abstractmethod + def get_server_private_key(self) -> bytes: + """Get server private key in urlsafe bytes.""" + + @abc.abstractmethod + def get_server_public_key(self) -> bytes: + """Get server public key in urlsafe bytes.""" + + @abc.abstractmethod + def store_client_public_keys(self, public_keys: Set[bytes]) -> None: + """Store a set of client public keys in state.""" + + @abc.abstractmethod + def store_client_public_key(self, public_key: bytes) -> None: + """Retrieve a client public key in state.""" + + @abc.abstractmethod + def get_client_public_keys(self) -> Set[bytes]: + """Retrieve all currently stored client public keys as a set.""" diff --git a/src/py/flwr/server/state/authentication/authentication_state_test.py b/src/py/flwr/server/state/authentication/authentication_state_test.py new file mode 100644 index 000000000000..473f8e37eb36 --- /dev/null +++ b/src/py/flwr/server/state/authentication/authentication_state_test.py @@ -0,0 +1,65 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for authentication state.""" + +import os +from in_memory_auth_state import InMemoryAuthState +from sqlite_auth_state import SqliteAuthState +from common.secure_aggregation.crypto.symmetric_encryption import ( + generate_key_pairs, + public_key_to_bytes, + generate_shared_key, + verify_hmac, + compute_hmac +) + +def test_client_public_keys() -> None: + key_pairs = [generate_key_pairs() for _ in range(3)] + public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} + + in_memory_auth_state = InMemoryAuthState() + in_memory_auth_state.store_client_public_keys(public_keys) + + assert in_memory_auth_state.get_client_public_keys == public_keys + +def test_node_id_public_key_pair() -> None: + node_id = int.from_bytes(os.urandom(8), "little", signed=True) + public_key = public_key_to_bytes(generate_key_pairs()[1]) + + in_memory_auth_state = InMemoryAuthState() + in_memory_auth_state.store_node_id_public_key_pair(node_id, public_key) + + assert in_memory_auth_state.get_public_key_from_node_id(node_id) == public_key + +def test_generate_shared_key() -> None: + client_keys = generate_key_pairs() + server_keys = generate_key_pairs() + + client_shared_secret = generate_shared_key(client_keys[0], server_keys[1]) + server_shared_secret = generate_shared_key(server_keys[0], client_keys[1]) + + assert client_shared_secret == server_shared_secret + +def test_hmac() -> None: + client_keys = generate_key_pairs() + server_keys = generate_key_pairs() + client_shared_secret = generate_shared_key(client_keys[0], server_keys[1]) + server_shared_secret = generate_shared_key(server_keys[0], client_keys[1]) + message = b"Flower is the future of AI" + + client_compute_hmac = compute_hmac(client_shared_secret, message) + + assert verify_hmac(server_shared_secret, message, client_compute_hmac) + \ No newline at end of file diff --git a/src/py/flwr/server/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/state/authentication/in_memory_auth_state.py new file mode 100644 index 000000000000..d4b387881962 --- /dev/null +++ b/src/py/flwr/server/state/authentication/in_memory_auth_state.py @@ -0,0 +1,66 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""In-memory Authentication State implementation.""" + +from in_memory_state import InMemoryState +from authentication_state import AuthenticationState +from typing import Dict, Set + +class InMemoryAuthState(AuthenticationState, InMemoryState): + def __init__(self) -> None: + super().__init__() + self.node_id_public_key_dict: Dict[int, bytes] = {} + self.client_public_keys: Set[bytes] = set() + self.server_public_key: bytes = bytes() + self.server_private_key: bytes = bytes() + + def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: + """Store `node_id` and `public_key` as key-value pair in state.""" + if node_id not in self.node_ids: + raise ValueError(f"Node {node_id} not found") + if node_id in self.node_id_public_key_dict: + raise ValueError(f"Node {node_id} has already assigned a public key") + self.node_id_public_key_dict[node_id] = public_key + + def get_public_key_from_node_id(self, node_id: int) -> bytes: + """Get client's public key in urlsafe bytes for `node_id`.""" + if node_id in self.node_id_public_key_dict: + return self.node_id_public_key_dict[node_id] + return bytes() + + def store_server_public_private_key(self, public_key: bytes, private_key: bytes) -> None: + """Store server's `public_key` and `private_key` in state.""" + self.server_private_key = private_key + self.server_public_key = public_key + + def get_server_private_key(self) -> bytes: + """Get server private key in urlsafe bytes.""" + return self.server_private_key + + def get_server_public_key(self) -> bytes: + """Get server public key in urlsafe bytes.""" + return self.server_public_key + + def store_client_public_keys(self, public_keys: Set[bytes]) -> None: + """Store a set of client public keys in state.""" + self.client_public_keys = public_keys + + def store_client_public_key(self, public_key: bytes) -> None: + """Retrieve a client public key in state.""" + self.client_public_keys.add(public_key) + + def get_client_public_keys(self) -> Set[bytes]: + """Retrieve all currently stored client public keys as a set.""" + return self.client_public_keys diff --git a/src/py/flwr/server/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/state/authentication/sqlite_auth_state.py new file mode 100644 index 000000000000..852629896778 --- /dev/null +++ b/src/py/flwr/server/state/authentication/sqlite_auth_state.py @@ -0,0 +1,69 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SQLite based implementation of server authentication state.""" + +from sqlite_state import SqliteState +from authentication_state import AuthenticationState +from typing import Set + +class SqliteAuthState(AuthenticationState, SqliteState): + def __init__(self) -> None: + super().__init__() + + def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: + """Store `node_id` and `public_key` as key-value pair in state.""" + query = "INSERT OR REPLACE INTO node_key (node_id, public_key) VALUES (:node_id, :public_key)" + self.query(query, {"node_id": node_id, "public_key": public_key}) + + def get_public_key_from_node_id(self, node_id: int) -> bytes: + """Get client's public key in urlsafe bytes for `node_id`.""" + query = "SELECT public_key FROM node_key WHERE node_id = :node_id" + rows = self.query(query, {"node_id": node_id}) + return rows[0]["public_key"] + + def store_server_public_private_key(self, public_key: bytes, private_key: bytes) -> None: + """Store server's `public_key` and `private_key` in state.""" + query = "INSERT OR REPLACE INTO credential (public_key, private_key) VALUES (:public_key, :private_key)" + self.query(query, {"public_key": public_key, "private_key": private_key}) + + def get_server_private_key(self) -> bytes: + """Get server private key in urlsafe bytes.""" + query = "SELECT private_key FROM credential" + rows = self.query(query) + return rows[0]["private_key"] + + def get_server_public_key(self) -> bytes: + """Get server public key in urlsafe bytes.""" + query = "SELECT public_key FROM credential" + rows = self.query(query) + return rows[0]["public_key"] + + def store_client_public_keys(self, public_keys: Set[bytes]) -> None: + """Store a set of client public keys in state.""" + query = "INSERT INTO public_key (public_key) VALUES (:public_key)" + for public_key in public_keys: + self.query(query, {"public_key": public_key}) + + def store_client_public_key(self, public_key: bytes) -> None: + """Retrieve a client public key in state.""" + query = "INSERT INTO public_key (public_key) VALUES (:public_key)" + self.query(query, {"public_key": public_key}) + + def get_client_public_keys(self) -> Set[bytes]: + """Retrieve all currently stored client public keys as a set.""" + query = "SELECT public_key FROM public_key" + rows = self.query(query) + result: Set[bytes] = {row["public_key"] for row in rows} + return result diff --git a/src/py/flwr/server/state/sqlite_state.py b/src/py/flwr/server/state/sqlite_state.py index 224c16cdf013..f89df6301334 100644 --- a/src/py/flwr/server/state/sqlite_state.py +++ b/src/py/flwr/server/state/sqlite_state.py @@ -37,6 +37,26 @@ ); """ +SQL_CREATE_TABLE_NODE_KEY = """ +CREATE TABLE IF NOT EXISTS node_key( + node_id INTEGER PRIMARY KEY, + public_key BLOB +); +""" + +SQL_CREATE_TABLE_CREDENTIAL = """ +CREATE TABLE IF NOT EXISTS credential( + public_key BLOB PRIMARY KEY, + private_key BLOB +); +""" + +SQL_CREATE_TABLE_PUBLIC_KEY = """ +CREATE TABLE IF NOT EXISTS public_key( + public_key BLOB UNIQUE +); +""" + SQL_CREATE_TABLE_RUN = """ CREATE TABLE IF NOT EXISTS run( run_id INTEGER UNIQUE @@ -123,6 +143,9 @@ def initialize(self, log_queries: bool = False) -> List[Tuple[str]]: cur.execute(SQL_CREATE_TABLE_TASK_INS) cur.execute(SQL_CREATE_TABLE_TASK_RES) cur.execute(SQL_CREATE_TABLE_NODE) + cur.execute(SQL_CREATE_TABLE_CREDENTIAL) + cur.execute(SQL_CREATE_TABLE_NODE_KEY) + cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY) res = cur.execute("SELECT name FROM sqlite_schema;") return res.fetchall() @@ -519,7 +542,6 @@ def create_run(self) -> int: log(ERROR, "Unexpected run creation failure.") return 0 - def dict_factory( cursor: sqlite3.Cursor, row: sqlite3.Row, From 91a2f18d523cb25f773e3366e2d80b3e5196b6ed Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 13:46:44 +0100 Subject: [PATCH 02/73] Fix isort --- .../secure_aggregation/crypto/symmetric_encryption.py | 2 +- .../state/authentication/authentication_state_test.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py index 7b22565e2803..67b49d85cc53 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py +++ b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py @@ -19,7 +19,7 @@ from typing import Tuple, cast from cryptography.fernet import Fernet -from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives import hashes, hmac, serialization from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.kdf.hkdf import HKDF diff --git a/src/py/flwr/server/state/authentication/authentication_state_test.py b/src/py/flwr/server/state/authentication/authentication_state_test.py index 473f8e37eb36..4f7328894152 100644 --- a/src/py/flwr/server/state/authentication/authentication_state_test.py +++ b/src/py/flwr/server/state/authentication/authentication_state_test.py @@ -15,15 +15,15 @@ """Test for authentication state.""" import os -from in_memory_auth_state import InMemoryAuthState -from sqlite_auth_state import SqliteAuthState from common.secure_aggregation.crypto.symmetric_encryption import ( - generate_key_pairs, - public_key_to_bytes, + compute_hmac, + generate_key_pairs, generate_shared_key, + public_key_to_bytes, verify_hmac, - compute_hmac ) +from in_memory_auth_state import InMemoryAuthState + def test_client_public_keys() -> None: key_pairs = [generate_key_pairs() for _ in range(3)] From db16c1003e0504ecfbd4fd7a209ae4e88248fb60 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 13:50:39 +0100 Subject: [PATCH 03/73] Fix isort --- .../server/state/authentication/authentication_state.py | 1 + .../state/authentication/authentication_state_test.py | 1 + .../server/state/authentication/in_memory_auth_state.py | 6 ++++-- .../flwr/server/state/authentication/sqlite_auth_state.py | 6 ++++-- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/server/state/authentication/authentication_state.py b/src/py/flwr/server/state/authentication/authentication_state.py index 4a6fc9a6ab57..1a9831cf2a12 100644 --- a/src/py/flwr/server/state/authentication/authentication_state.py +++ b/src/py/flwr/server/state/authentication/authentication_state.py @@ -17,6 +17,7 @@ import abc from typing import Set + class AuthenticationState(abc.ABC): """Abstract State.""" @abc.abstractmethod diff --git a/src/py/flwr/server/state/authentication/authentication_state_test.py b/src/py/flwr/server/state/authentication/authentication_state_test.py index 4f7328894152..db4a18b27512 100644 --- a/src/py/flwr/server/state/authentication/authentication_state_test.py +++ b/src/py/flwr/server/state/authentication/authentication_state_test.py @@ -15,6 +15,7 @@ """Test for authentication state.""" import os + from common.secure_aggregation.crypto.symmetric_encryption import ( compute_hmac, generate_key_pairs, diff --git a/src/py/flwr/server/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/state/authentication/in_memory_auth_state.py index d4b387881962..ba5db177f99c 100644 --- a/src/py/flwr/server/state/authentication/in_memory_auth_state.py +++ b/src/py/flwr/server/state/authentication/in_memory_auth_state.py @@ -14,10 +14,12 @@ # ============================================================================== """In-memory Authentication State implementation.""" -from in_memory_state import InMemoryState -from authentication_state import AuthenticationState from typing import Dict, Set +from authentication_state import AuthenticationState +from in_memory_state import InMemoryState + + class InMemoryAuthState(AuthenticationState, InMemoryState): def __init__(self) -> None: super().__init__() diff --git a/src/py/flwr/server/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/state/authentication/sqlite_auth_state.py index 852629896778..542fab26ca08 100644 --- a/src/py/flwr/server/state/authentication/sqlite_auth_state.py +++ b/src/py/flwr/server/state/authentication/sqlite_auth_state.py @@ -14,10 +14,12 @@ # ============================================================================== """SQLite based implementation of server authentication state.""" -from sqlite_state import SqliteState -from authentication_state import AuthenticationState from typing import Set +from authentication_state import AuthenticationState +from sqlite_state import SqliteState + + class SqliteAuthState(AuthenticationState, SqliteState): def __init__(self) -> None: super().__init__() From 28876bc415ee577bf8e48ca87007bb27dcff5015 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 13:56:06 +0100 Subject: [PATCH 04/73] Run format.sh --- .../crypto/symmetric_encryption.py | 2 ++ .../authentication/authentication_state.py | 11 +++++++---- .../authentication_state_test.py | 4 +++- .../authentication/in_memory_auth_state.py | 18 ++++++++++-------- .../state/authentication/sqlite_auth_state.py | 4 +++- src/py/flwr/server/state/sqlite_state.py | 1 + 6 files changed, 26 insertions(+), 14 deletions(-) diff --git a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py index 67b49d85cc53..0ad3bef18045 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py +++ b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py @@ -99,11 +99,13 @@ def decrypt(key: bytes, ciphertext: bytes) -> bytes: fernet = Fernet(key) return fernet.decrypt(ciphertext) + def compute_hmac(key: bytes, message: bytes) -> bytes: computed_hmac = hmac.HMAC(key, hashes.SHA256()) computed_hmac.update(message) return computed_hmac.finalize() + def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool: computed_hmac = hmac.HMAC(key, hashes.SHA256()) computed_hmac.update(message) diff --git a/src/py/flwr/server/state/authentication/authentication_state.py b/src/py/flwr/server/state/authentication/authentication_state.py index 1a9831cf2a12..a886f9b6510d 100644 --- a/src/py/flwr/server/state/authentication/authentication_state.py +++ b/src/py/flwr/server/state/authentication/authentication_state.py @@ -20,6 +20,7 @@ class AuthenticationState(abc.ABC): """Abstract State.""" + @abc.abstractmethod def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: """Store `node_id` and `public_key` as key-value pair in state.""" @@ -29,7 +30,9 @@ def get_public_key_from_node_id(self, node_id: int) -> bytes: """Get client's public key in urlsafe bytes for `node_id`.""" @abc.abstractmethod - def store_server_public_private_key(self, public_key: bytes, private_key: bytes) -> None: + def store_server_public_private_key( + self, public_key: bytes, private_key: bytes + ) -> None: """Store server's `public_key` and `private_key` in state.""" @abc.abstractmethod @@ -42,12 +45,12 @@ def get_server_public_key(self) -> bytes: @abc.abstractmethod def store_client_public_keys(self, public_keys: Set[bytes]) -> None: - """Store a set of client public keys in state.""" + """Store a set of client public keys in state.""" @abc.abstractmethod def store_client_public_key(self, public_key: bytes) -> None: - """Retrieve a client public key in state.""" + """Retrieve a client public key in state.""" @abc.abstractmethod def get_client_public_keys(self) -> Set[bytes]: - """Retrieve all currently stored client public keys as a set.""" + """Retrieve all currently stored client public keys as a set.""" diff --git a/src/py/flwr/server/state/authentication/authentication_state_test.py b/src/py/flwr/server/state/authentication/authentication_state_test.py index db4a18b27512..32c3363ab5d3 100644 --- a/src/py/flwr/server/state/authentication/authentication_state_test.py +++ b/src/py/flwr/server/state/authentication/authentication_state_test.py @@ -35,6 +35,7 @@ def test_client_public_keys() -> None: assert in_memory_auth_state.get_client_public_keys == public_keys + def test_node_id_public_key_pair() -> None: node_id = int.from_bytes(os.urandom(8), "little", signed=True) public_key = public_key_to_bytes(generate_key_pairs()[1]) @@ -44,6 +45,7 @@ def test_node_id_public_key_pair() -> None: assert in_memory_auth_state.get_public_key_from_node_id(node_id) == public_key + def test_generate_shared_key() -> None: client_keys = generate_key_pairs() server_keys = generate_key_pairs() @@ -53,6 +55,7 @@ def test_generate_shared_key() -> None: assert client_shared_secret == server_shared_secret + def test_hmac() -> None: client_keys = generate_key_pairs() server_keys = generate_key_pairs() @@ -63,4 +66,3 @@ def test_hmac() -> None: client_compute_hmac = compute_hmac(client_shared_secret, message) assert verify_hmac(server_shared_secret, message, client_compute_hmac) - \ No newline at end of file diff --git a/src/py/flwr/server/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/state/authentication/in_memory_auth_state.py index ba5db177f99c..c6dbcf4a5e1b 100644 --- a/src/py/flwr/server/state/authentication/in_memory_auth_state.py +++ b/src/py/flwr/server/state/authentication/in_memory_auth_state.py @@ -27,7 +27,7 @@ def __init__(self) -> None: self.client_public_keys: Set[bytes] = set() self.server_public_key: bytes = bytes() self.server_private_key: bytes = bytes() - + def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: """Store `node_id` and `public_key` as key-value pair in state.""" if node_id not in self.node_ids: @@ -42,7 +42,9 @@ def get_public_key_from_node_id(self, node_id: int) -> bytes: return self.node_id_public_key_dict[node_id] return bytes() - def store_server_public_private_key(self, public_key: bytes, private_key: bytes) -> None: + def store_server_public_private_key( + self, public_key: bytes, private_key: bytes + ) -> None: """Store server's `public_key` and `private_key` in state.""" self.server_private_key = private_key self.server_public_key = public_key @@ -56,13 +58,13 @@ def get_server_public_key(self) -> bytes: return self.server_public_key def store_client_public_keys(self, public_keys: Set[bytes]) -> None: - """Store a set of client public keys in state.""" - self.client_public_keys = public_keys + """Store a set of client public keys in state.""" + self.client_public_keys = public_keys def store_client_public_key(self, public_key: bytes) -> None: - """Retrieve a client public key in state.""" - self.client_public_keys.add(public_key) + """Retrieve a client public key in state.""" + self.client_public_keys.add(public_key) def get_client_public_keys(self) -> Set[bytes]: - """Retrieve all currently stored client public keys as a set.""" - return self.client_public_keys + """Retrieve all currently stored client public keys as a set.""" + return self.client_public_keys diff --git a/src/py/flwr/server/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/state/authentication/sqlite_auth_state.py index 542fab26ca08..381240df5d1a 100644 --- a/src/py/flwr/server/state/authentication/sqlite_auth_state.py +++ b/src/py/flwr/server/state/authentication/sqlite_auth_state.py @@ -35,7 +35,9 @@ def get_public_key_from_node_id(self, node_id: int) -> bytes: rows = self.query(query, {"node_id": node_id}) return rows[0]["public_key"] - def store_server_public_private_key(self, public_key: bytes, private_key: bytes) -> None: + def store_server_public_private_key( + self, public_key: bytes, private_key: bytes + ) -> None: """Store server's `public_key` and `private_key` in state.""" query = "INSERT OR REPLACE INTO credential (public_key, private_key) VALUES (:public_key, :private_key)" self.query(query, {"public_key": public_key, "private_key": private_key}) diff --git a/src/py/flwr/server/state/sqlite_state.py b/src/py/flwr/server/state/sqlite_state.py index f89df6301334..e91d8553863c 100644 --- a/src/py/flwr/server/state/sqlite_state.py +++ b/src/py/flwr/server/state/sqlite_state.py @@ -542,6 +542,7 @@ def create_run(self) -> int: log(ERROR, "Unexpected run creation failure.") return 0 + def dict_factory( cursor: sqlite3.Cursor, row: sqlite3.Row, From 42a7d386274b56d53d8268cafc698420b655911f Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 14:00:44 +0100 Subject: [PATCH 05/73] Add init.py --- .../flwr/server/state/authentication/__init__.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 src/py/flwr/server/state/authentication/__init__.py diff --git a/src/py/flwr/server/state/authentication/__init__.py b/src/py/flwr/server/state/authentication/__init__.py new file mode 100644 index 000000000000..3203b3230b5c --- /dev/null +++ b/src/py/flwr/server/state/authentication/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower server authentication state.""" From 8ec63c96203a8c19545958c47ed7661c81c53c32 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 14:17:30 +0100 Subject: [PATCH 06/73] Fix line too long --- .../secure_aggregation/crypto/symmetric_encryption.py | 4 +++- .../state/authentication/authentication_state_test.py | 4 ++++ .../server/state/authentication/in_memory_auth_state.py | 7 ++++--- .../flwr/server/state/authentication/sqlite_auth_state.py | 7 +++++-- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py index 0ad3bef18045..e38bdb6d7859 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py +++ b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py @@ -101,16 +101,18 @@ def decrypt(key: bytes, ciphertext: bytes) -> bytes: def compute_hmac(key: bytes, message: bytes) -> bytes: + """Compute hmac of a message using key as hash.""" computed_hmac = hmac.HMAC(key, hashes.SHA256()) computed_hmac.update(message) return computed_hmac.finalize() def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool: + """Verify hmac of a message using key as hash.""" computed_hmac = hmac.HMAC(key, hashes.SHA256()) computed_hmac.update(message) try: computed_hmac.verify(hmac_value) return True - except: + except Exception: return False diff --git a/src/py/flwr/server/state/authentication/authentication_state_test.py b/src/py/flwr/server/state/authentication/authentication_state_test.py index 32c3363ab5d3..35d36e7c8782 100644 --- a/src/py/flwr/server/state/authentication/authentication_state_test.py +++ b/src/py/flwr/server/state/authentication/authentication_state_test.py @@ -27,6 +27,7 @@ def test_client_public_keys() -> None: + """Test client public keys store and get from state.""" key_pairs = [generate_key_pairs() for _ in range(3)] public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} @@ -37,6 +38,7 @@ def test_client_public_keys() -> None: def test_node_id_public_key_pair() -> None: + """Test store and get node_id public_key pair.""" node_id = int.from_bytes(os.urandom(8), "little", signed=True) public_key = public_key_to_bytes(generate_key_pairs()[1]) @@ -47,6 +49,7 @@ def test_node_id_public_key_pair() -> None: def test_generate_shared_key() -> None: + """Test util function generate_shared_key.""" client_keys = generate_key_pairs() server_keys = generate_key_pairs() @@ -57,6 +60,7 @@ def test_generate_shared_key() -> None: def test_hmac() -> None: + """Test util function compute and verify hmac.""" client_keys = generate_key_pairs() server_keys = generate_key_pairs() client_shared_secret = generate_shared_key(client_keys[0], server_keys[1]) diff --git a/src/py/flwr/server/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/state/authentication/in_memory_auth_state.py index c6dbcf4a5e1b..2d51494e5158 100644 --- a/src/py/flwr/server/state/authentication/in_memory_auth_state.py +++ b/src/py/flwr/server/state/authentication/in_memory_auth_state.py @@ -22,11 +22,12 @@ class InMemoryAuthState(AuthenticationState, InMemoryState): def __init__(self) -> None: + """Init InMemoryAuthState.""" super().__init__() self.node_id_public_key_dict: Dict[int, bytes] = {} self.client_public_keys: Set[bytes] = set() - self.server_public_key: bytes = bytes() - self.server_private_key: bytes = bytes() + self.server_public_key: bytes = b"" + self.server_private_key: bytes = b"" def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: """Store `node_id` and `public_key` as key-value pair in state.""" @@ -40,7 +41,7 @@ def get_public_key_from_node_id(self, node_id: int) -> bytes: """Get client's public key in urlsafe bytes for `node_id`.""" if node_id in self.node_id_public_key_dict: return self.node_id_public_key_dict[node_id] - return bytes() + return b"" def store_server_public_private_key( self, public_key: bytes, private_key: bytes diff --git a/src/py/flwr/server/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/state/authentication/sqlite_auth_state.py index 381240df5d1a..5513814348f5 100644 --- a/src/py/flwr/server/state/authentication/sqlite_auth_state.py +++ b/src/py/flwr/server/state/authentication/sqlite_auth_state.py @@ -22,11 +22,13 @@ class SqliteAuthState(AuthenticationState, SqliteState): def __init__(self) -> None: + """Init SqliteAuthState.""" super().__init__() def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: """Store `node_id` and `public_key` as key-value pair in state.""" - query = "INSERT OR REPLACE INTO node_key (node_id, public_key) VALUES (:node_id, :public_key)" + query = "INSERT OR REPLACE INTO node_key (node_id, public_key) "\ + "VALUES (:node_id, :public_key)" self.query(query, {"node_id": node_id, "public_key": public_key}) def get_public_key_from_node_id(self, node_id: int) -> bytes: @@ -39,7 +41,8 @@ def store_server_public_private_key( self, public_key: bytes, private_key: bytes ) -> None: """Store server's `public_key` and `private_key` in state.""" - query = "INSERT OR REPLACE INTO credential (public_key, private_key) VALUES (:public_key, :private_key)" + query = "INSERT OR REPLACE INTO credential (public_key, private_key) "\ + "VALUES (:public_key, :private_key)" self.query(query, {"public_key": public_key, "private_key": private_key}) def get_server_private_key(self) -> bytes: From 8f04e25ca1977ee437776e7fe2ef41eb714c036e Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 14:21:32 +0100 Subject: [PATCH 07/73] Fix line too long --- .../state/authentication/in_memory_auth_state.py | 1 + .../state/authentication/sqlite_auth_state.py | 14 ++++++++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/server/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/state/authentication/in_memory_auth_state.py index 2d51494e5158..67cdeb1347ef 100644 --- a/src/py/flwr/server/state/authentication/in_memory_auth_state.py +++ b/src/py/flwr/server/state/authentication/in_memory_auth_state.py @@ -21,6 +21,7 @@ class InMemoryAuthState(AuthenticationState, InMemoryState): + """In-memory-based authentication state implementation.""" def __init__(self) -> None: """Init InMemoryAuthState.""" super().__init__() diff --git a/src/py/flwr/server/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/state/authentication/sqlite_auth_state.py index 5513814348f5..287e75a3c4fc 100644 --- a/src/py/flwr/server/state/authentication/sqlite_auth_state.py +++ b/src/py/flwr/server/state/authentication/sqlite_auth_state.py @@ -21,14 +21,18 @@ class SqliteAuthState(AuthenticationState, SqliteState): + """SQLite-based authentication state implementation.""" + def __init__(self) -> None: """Init SqliteAuthState.""" super().__init__() def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: """Store `node_id` and `public_key` as key-value pair in state.""" - query = "INSERT OR REPLACE INTO node_key (node_id, public_key) "\ - "VALUES (:node_id, :public_key)" + query = ( + "INSERT OR REPLACE INTO node_key (node_id, public_key) " + "VALUES (:node_id, :public_key)" + ) self.query(query, {"node_id": node_id, "public_key": public_key}) def get_public_key_from_node_id(self, node_id: int) -> bytes: @@ -41,8 +45,10 @@ def store_server_public_private_key( self, public_key: bytes, private_key: bytes ) -> None: """Store server's `public_key` and `private_key` in state.""" - query = "INSERT OR REPLACE INTO credential (public_key, private_key) "\ - "VALUES (:public_key, :private_key)" + query = ( + "INSERT OR REPLACE INTO credential (public_key, private_key) " + "VALUES (:public_key, :private_key)" + ) self.query(query, {"public_key": public_key, "private_key": private_key}) def get_server_private_key(self) -> bytes: From e8813fcf729293cf239bf79c43b2a64f23da1d4c Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 14:24:18 +0100 Subject: [PATCH 08/73] Fix line too long --- src/py/flwr/server/state/authentication/in_memory_auth_state.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/py/flwr/server/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/state/authentication/in_memory_auth_state.py index 67cdeb1347ef..9bd68ee1b7db 100644 --- a/src/py/flwr/server/state/authentication/in_memory_auth_state.py +++ b/src/py/flwr/server/state/authentication/in_memory_auth_state.py @@ -22,6 +22,7 @@ class InMemoryAuthState(AuthenticationState, InMemoryState): """In-memory-based authentication state implementation.""" + def __init__(self) -> None: """Init InMemoryAuthState.""" super().__init__() From d9f3fb04b565696eb888316afc4103b12f0865f0 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 14:35:36 +0100 Subject: [PATCH 09/73] Fix subclassing --- src/py/flwr/server/state/authentication/__init__.py | 10 ++++++++++ .../server/state/authentication/sqlite_auth_state.py | 9 ++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/server/state/authentication/__init__.py b/src/py/flwr/server/state/authentication/__init__.py index 3203b3230b5c..95f3e3fbbd57 100644 --- a/src/py/flwr/server/state/authentication/__init__.py +++ b/src/py/flwr/server/state/authentication/__init__.py @@ -13,3 +13,13 @@ # limitations under the License. # ============================================================================== """Flower server authentication state.""" + +from .authentication_state import AuthenticationState as AuthenticationState +from .in_memory_auth_state import InMemoryAuthState as InMemoryAuthState +from .sqlite_auth_state import SqliteAuthState as SqliteAuthState + +__all__ = [ + "AuthenticationState", + "InMemoryAuthState", + "SqliteAuthState", +] \ No newline at end of file diff --git a/src/py/flwr/server/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/state/authentication/sqlite_auth_state.py index 287e75a3c4fc..c93a148d9956 100644 --- a/src/py/flwr/server/state/authentication/sqlite_auth_state.py +++ b/src/py/flwr/server/state/authentication/sqlite_auth_state.py @@ -39,7 +39,8 @@ def get_public_key_from_node_id(self, node_id: int) -> bytes: """Get client's public key in urlsafe bytes for `node_id`.""" query = "SELECT public_key FROM node_key WHERE node_id = :node_id" rows = self.query(query, {"node_id": node_id}) - return rows[0]["public_key"] + public_key: bytes = rows[0]["public_key"] + return public_key def store_server_public_private_key( self, public_key: bytes, private_key: bytes @@ -55,13 +56,15 @@ def get_server_private_key(self) -> bytes: """Get server private key in urlsafe bytes.""" query = "SELECT private_key FROM credential" rows = self.query(query) - return rows[0]["private_key"] + private_key: bytes = rows[0]["private_key"] + return private_key def get_server_public_key(self) -> bytes: """Get server public key in urlsafe bytes.""" query = "SELECT public_key FROM credential" rows = self.query(query) - return rows[0]["public_key"] + public_key: bytes = rows[0]["public_key"] + return public_key def store_client_public_keys(self, public_keys: Set[bytes]) -> None: """Store a set of client public keys in state.""" From caf6695d7e9d87f19db2b8fd7b4c935607c1ffa4 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 14:39:29 +0100 Subject: [PATCH 10/73] Fix subclassing --- src/py/flwr/server/state/authentication/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/server/state/authentication/__init__.py b/src/py/flwr/server/state/authentication/__init__.py index 95f3e3fbbd57..8f5c0a97ab1f 100644 --- a/src/py/flwr/server/state/authentication/__init__.py +++ b/src/py/flwr/server/state/authentication/__init__.py @@ -22,4 +22,4 @@ "AuthenticationState", "InMemoryAuthState", "SqliteAuthState", -] \ No newline at end of file +] From fa217ae7f16430c48babf84d160717c792a0d6a8 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 15:03:56 +0100 Subject: [PATCH 11/73] Fix subclassing --- .../flwr/server/state/authentication/authentication_state.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/server/state/authentication/authentication_state.py b/src/py/flwr/server/state/authentication/authentication_state.py index a886f9b6510d..c881af432f4b 100644 --- a/src/py/flwr/server/state/authentication/authentication_state.py +++ b/src/py/flwr/server/state/authentication/authentication_state.py @@ -17,8 +17,10 @@ import abc from typing import Set +from state import State -class AuthenticationState(abc.ABC): + +class AuthenticationState(State, abc.ABC): """Abstract State.""" @abc.abstractmethod From 6edddd631dee4bba954c68074ebbe683fcf6b888 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 15:08:53 +0100 Subject: [PATCH 12/73] Fix subclassing --- .../flwr/server/state/authentication/authentication_state.py | 2 +- .../flwr/server/state/authentication/in_memory_auth_state.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/server/state/authentication/authentication_state.py b/src/py/flwr/server/state/authentication/authentication_state.py index c881af432f4b..fb538038dcbb 100644 --- a/src/py/flwr/server/state/authentication/authentication_state.py +++ b/src/py/flwr/server/state/authentication/authentication_state.py @@ -21,7 +21,7 @@ class AuthenticationState(State, abc.ABC): - """Abstract State.""" + """Abstract Authentication State.""" @abc.abstractmethod def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: diff --git a/src/py/flwr/server/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/state/authentication/in_memory_auth_state.py index 9bd68ee1b7db..9ddc958c18d3 100644 --- a/src/py/flwr/server/state/authentication/in_memory_auth_state.py +++ b/src/py/flwr/server/state/authentication/in_memory_auth_state.py @@ -16,8 +16,8 @@ from typing import Dict, Set -from authentication_state import AuthenticationState -from in_memory_state import InMemoryState +from .authentication_state import AuthenticationState +from flwr.server.state.in_memory_state import InMemoryState class InMemoryAuthState(AuthenticationState, InMemoryState): From 8bb15a569927b8c16f214a286309ee05ccb5a3b6 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 15:11:49 +0100 Subject: [PATCH 13/73] Fix subclassing --- src/py/flwr/server/state/authentication/in_memory_auth_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/server/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/state/authentication/in_memory_auth_state.py index 9ddc958c18d3..fe10c1301b11 100644 --- a/src/py/flwr/server/state/authentication/in_memory_auth_state.py +++ b/src/py/flwr/server/state/authentication/in_memory_auth_state.py @@ -16,7 +16,7 @@ from typing import Dict, Set -from .authentication_state import AuthenticationState +from flwr.server.state.authentication.authentication_state import AuthenticationState from flwr.server.state.in_memory_state import InMemoryState From c5bac4f2ba5807190c501a40440b048de3c2c9f5 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Sun, 11 Feb 2024 09:20:04 +0000 Subject: [PATCH 14/73] fixes --- .../server/state/authentication/authentication_state.py | 2 +- .../state/authentication/authentication_state_test.py | 7 ++++--- .../flwr/server/state/authentication/sqlite_auth_state.py | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/server/state/authentication/authentication_state.py b/src/py/flwr/server/state/authentication/authentication_state.py index fb538038dcbb..3adb450dc215 100644 --- a/src/py/flwr/server/state/authentication/authentication_state.py +++ b/src/py/flwr/server/state/authentication/authentication_state.py @@ -17,7 +17,7 @@ import abc from typing import Set -from state import State +from flwr.server.state import State class AuthenticationState(State, abc.ABC): diff --git a/src/py/flwr/server/state/authentication/authentication_state_test.py b/src/py/flwr/server/state/authentication/authentication_state_test.py index 35d36e7c8782..2aaf736a8d68 100644 --- a/src/py/flwr/server/state/authentication/authentication_state_test.py +++ b/src/py/flwr/server/state/authentication/authentication_state_test.py @@ -16,14 +16,15 @@ import os -from common.secure_aggregation.crypto.symmetric_encryption import ( +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( compute_hmac, generate_key_pairs, generate_shared_key, public_key_to_bytes, verify_hmac, ) -from in_memory_auth_state import InMemoryAuthState + +from .in_memory_auth_state import InMemoryAuthState def test_client_public_keys() -> None: @@ -34,7 +35,7 @@ def test_client_public_keys() -> None: in_memory_auth_state = InMemoryAuthState() in_memory_auth_state.store_client_public_keys(public_keys) - assert in_memory_auth_state.get_client_public_keys == public_keys + assert in_memory_auth_state.get_client_public_keys() == public_keys def test_node_id_public_key_pair() -> None: diff --git a/src/py/flwr/server/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/state/authentication/sqlite_auth_state.py index c93a148d9956..0e0436f0bae1 100644 --- a/src/py/flwr/server/state/authentication/sqlite_auth_state.py +++ b/src/py/flwr/server/state/authentication/sqlite_auth_state.py @@ -16,8 +16,8 @@ from typing import Set -from authentication_state import AuthenticationState -from sqlite_state import SqliteState +from flwr.server.state.authentication.authentication_state import AuthenticationState +from flwr.server.state.sqlite_state import SqliteState class SqliteAuthState(AuthenticationState, SqliteState): From c856b7c441a79bc60c82ec3359291dc30287a469 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sun, 11 Feb 2024 11:13:13 +0100 Subject: [PATCH 15/73] Fix state tests --- .../server/state/authentication/authentication_state_test.py | 5 ++--- src/py/flwr/server/state/authentication/sqlite_auth_state.py | 4 ---- src/py/flwr/server/state/state_test.py | 4 ++-- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/py/flwr/server/state/authentication/authentication_state_test.py b/src/py/flwr/server/state/authentication/authentication_state_test.py index 2aaf736a8d68..1495fdf5084e 100644 --- a/src/py/flwr/server/state/authentication/authentication_state_test.py +++ b/src/py/flwr/server/state/authentication/authentication_state_test.py @@ -14,7 +14,6 @@ # ============================================================================== """Test for authentication state.""" -import os from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( compute_hmac, @@ -40,10 +39,10 @@ def test_client_public_keys() -> None: def test_node_id_public_key_pair() -> None: """Test store and get node_id public_key pair.""" - node_id = int.from_bytes(os.urandom(8), "little", signed=True) + in_memory_auth_state = InMemoryAuthState() + node_id = in_memory_auth_state.create_node() public_key = public_key_to_bytes(generate_key_pairs()[1]) - in_memory_auth_state = InMemoryAuthState() in_memory_auth_state.store_node_id_public_key_pair(node_id, public_key) assert in_memory_auth_state.get_public_key_from_node_id(node_id) == public_key diff --git a/src/py/flwr/server/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/state/authentication/sqlite_auth_state.py index 0e0436f0bae1..55e4bc73a63b 100644 --- a/src/py/flwr/server/state/authentication/sqlite_auth_state.py +++ b/src/py/flwr/server/state/authentication/sqlite_auth_state.py @@ -23,10 +23,6 @@ class SqliteAuthState(AuthenticationState, SqliteState): """SQLite-based authentication state implementation.""" - def __init__(self) -> None: - """Init SqliteAuthState.""" - super().__init__() - def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: """Store `node_id` and `public_key` as key-value pair in state.""" query = ( diff --git a/src/py/flwr/server/state/state_test.py b/src/py/flwr/server/state/state_test.py index 95d764792ff3..9395083e2648 100644 --- a/src/py/flwr/server/state/state_test.py +++ b/src/py/flwr/server/state/state_test.py @@ -477,7 +477,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 8 + assert len(result) == 13 class SqliteFileBasedTest(StateTest, unittest.TestCase): @@ -502,7 +502,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 8 + assert len(result) == 13 if __name__ == "__main__": From 475850733ef370442203dea3ccc28bb281463b79 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sun, 11 Feb 2024 11:17:10 +0100 Subject: [PATCH 16/73] Fix too broad exception --- .../common/secure_aggregation/crypto/symmetric_encryption.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py index e38bdb6d7859..1d004a398ea8 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py +++ b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py @@ -18,6 +18,7 @@ import base64 from typing import Tuple, cast +from cryptography.exceptions import InvalidSignature from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes, hmac, serialization from cryptography.hazmat.primitives.asymmetric import ec @@ -114,5 +115,5 @@ def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool: try: computed_hmac.verify(hmac_value) return True - except Exception: + except InvalidSignature: return False From e666da57e1c62d1d0f6c9a50f8d86c073a707e0f Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sun, 11 Feb 2024 14:16:50 +0100 Subject: [PATCH 17/73] Add sqlite auth state test --- .../authentication_state_test.py | 29 +++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/state/authentication/authentication_state_test.py b/src/py/flwr/server/state/authentication/authentication_state_test.py index 1495fdf5084e..f18c428d3044 100644 --- a/src/py/flwr/server/state/authentication/authentication_state_test.py +++ b/src/py/flwr/server/state/authentication/authentication_state_test.py @@ -24,9 +24,10 @@ ) from .in_memory_auth_state import InMemoryAuthState +from .sqlite_auth_state import SqliteAuthState -def test_client_public_keys() -> None: +def test_in_memory_client_public_keys() -> None: """Test client public keys store and get from state.""" key_pairs = [generate_key_pairs() for _ in range(3)] public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} @@ -37,7 +38,19 @@ def test_client_public_keys() -> None: assert in_memory_auth_state.get_client_public_keys() == public_keys -def test_node_id_public_key_pair() -> None: +def test_sqlite_client_public_keys() -> None: + """Test client public keys store and get from state.""" + key_pairs = [generate_key_pairs() for _ in range(3)] + public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} + + sqlite_auth_state = SqliteAuthState(":memory:") + sqlite_auth_state.initialize() + sqlite_auth_state.store_client_public_keys(public_keys) + + assert sqlite_auth_state.get_client_public_keys() == public_keys + + +def test_in_memory_node_id_public_key_pair() -> None: """Test store and get node_id public_key pair.""" in_memory_auth_state = InMemoryAuthState() node_id = in_memory_auth_state.create_node() @@ -48,6 +61,18 @@ def test_node_id_public_key_pair() -> None: assert in_memory_auth_state.get_public_key_from_node_id(node_id) == public_key +def test_sqlite_node_id_public_key_pair() -> None: + """Test store and get node_id public_key pair.""" + sqlite_auth_state = SqliteAuthState(":memory:") + sqlite_auth_state.initialize() + node_id = sqlite_auth_state.create_node() + public_key = public_key_to_bytes(generate_key_pairs()[1]) + + sqlite_auth_state.store_node_id_public_key_pair(node_id, public_key) + + assert sqlite_auth_state.get_public_key_from_node_id(node_id) == public_key + + def test_generate_shared_key() -> None: """Test util function generate_shared_key.""" client_keys = generate_key_pairs() From 5c49a55cd43d3c04f7d9814fd7d2363e7f7be555 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Mon, 12 Feb 2024 16:41:18 +0100 Subject: [PATCH 18/73] Add server interceptor --- src/py/flwr/server/server_interceptor.py | 131 ++++++++++++++++++ src/py/flwr/server/server_interceptor_test.py | 43 ++++++ 2 files changed, 174 insertions(+) create mode 100644 src/py/flwr/server/server_interceptor.py create mode 100644 src/py/flwr/server/server_interceptor_test.py diff --git a/src/py/flwr/server/server_interceptor.py b/src/py/flwr/server/server_interceptor.py new file mode 100644 index 000000000000..15a0d1d4dfee --- /dev/null +++ b/src/py/flwr/server/server_interceptor.py @@ -0,0 +1,131 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower server interceptor.""" + +import grpc +from cryptography.hazmat.primitives.asymmetric import ec +from typing import Callable, Sequence, Tuple, Union +from flwr.server.state.authentication import AuthenticationState +from flwr.common.secure_aggregation.crypto.symmetric_encryption import generate_shared_key, bytes_to_public_key, public_key_to_bytes, verify_hmac +from flwr.proto.fleet_pb2 import ( + CreateNodeRequest, + CreateNodeResponse, +) +from flwr.server.fleet.message_handler import message_handler +from flwr.server.state import StateFactory, State + +_PUBLIC_KEY_HEADER = "public-key" +_AUTH_TOKEN_HEADER = "auth-token" + +def _unary_unary_rpc_terminator(): + + def terminate(_, context): + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") + + return grpc.unary_unary_rpc_method_handler(terminate) + +def _create_node_with_public_key(state: State, server_public_key: bytes): + + def send_public_key(request: CreateNodeRequest, context: grpc.ServicerContext) -> CreateNodeResponse: + context.set_trailing_metadata( + ( + (_PUBLIC_KEY_HEADER, server_public_key), + ) + ) + return message_handler.create_node(request, state) + + return grpc.unary_unary_rpc_method_handler(send_public_key) + +def _create_node_with_public_key(state: State, server_public_key: bytes): + + def send_public_key(request: CreateNodeRequest, context: grpc.ServicerContext) -> CreateNodeResponse: + context.set_trailing_metadata( + ( + (_PUBLIC_KEY_HEADER, server_public_key), + ) + ) + return message_handler.create_node(request, state) + + return grpc.unary_unary_rpc_method_handler(send_public_key) + +def _handle_authentication(public_key, private_key): + return generate_shared_key(public_key, private_key) + +def _is_public_key_known(state: AuthenticationState, public_key: bytes) -> bool: + return public_key in state.get_client_public_keys() + +def _get_value_from_tuples(key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]]) -> Union[str, bytes]: + return next((value[::-1] for key, value in tuples if key == key_string), "") + +class AuthenticateServerInterceptor(grpc.ServerInterceptor): + + def __init__(self, state_factory: StateFactory, private_key: ec.EllipticCurvePrivateKey, public_key: ec.EllipticCurvePublicKey): + self._private_key = private_key + self._public_key = public_key + self._state_factory = state_factory + self._terminator = _unary_unary_rpc_terminator() + self._create_node_handler = _create_node_with_public_key() + + def intercept_service(self, continuation: Callable, handler_call_details: grpc.HandlerCallDetails): + method_name = handler_call_details.method.split("/")[-1] + client_public_key_bytes = _get_value_from_tuples(_PUBLIC_KEY_HEADER, handler_call_details.invocation_metadata) + client_public_key = bytes_to_public_key(client_public_key_bytes) + + if _is_public_key_known(self._state_factory.state, client_public_key_bytes): + if method_name == 'CreateNode': + return _create_node_with_public_key(self._state_factory.state, self._public_key) + elif method_name in {'DeleteNode', 'PullTaskIns', 'PushTaskRes'}: + state: AuthenticationState = self._state_factory.state + shared_secret = generate_shared_key(self._private_key, client_public_key) + hmac = _get_value_from_tuples(_AUTH_TOKEN_HEADER, handler_call_details.invocation_metadata) + if verify_hmac(shared_secret, ) + state.get_client_public_keys() + expected_metadata = (_AUTH_TOKEN_HEADER, generate_shared_key()) + + + if (self._header, self._value) in handler_call_details.invocation_metadata: + grpc.unary_unary_rpc_method_handler + return continuation(handler_call_details) + else: + return self._terminator + + def intercept_service(self, continuation: Callable, handler_call_details: grpc.HandlerCallDetails): + client_public_key_bytes = _get_value_from_tuples(_PUBLIC_KEY_HEADER, handler_call_details.invocation_metadata) + if _is_public_key_known(self._state_factory.state, client_public_key_bytes): + message_handler: grpc.RpcMethodHandler = continuation(handler_call_details) + message_handler. + return grpc.unary_unary_rpc_method_handler(message_handler.unary_unary, request_deserializer=message_handler.request_deserializer, response_serializer=message_handler.response_serializer) + if message_handler is None: + return + else: + return self._terminator + + handler_factory, next_handler_method = _get_factory_and_method(next_handler) + + + def invoke_intercept_method(request_or_iterator, context): + method_name = handler_call_details.method + return self.intercept( + next_handler_method, + request_or_iterator, + context, + method_name, + ) + + return handler_factory( + invoke_intercept_method, + request_deserializer=next_handler.request_deserializer, + response_serializer=next_handler.response_serializer, + ) diff --git a/src/py/flwr/server/server_interceptor_test.py b/src/py/flwr/server/server_interceptor_test.py new file mode 100644 index 000000000000..d3a528addddd --- /dev/null +++ b/src/py/flwr/server/server_interceptor_test.py @@ -0,0 +1,43 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower server interceptor tests.""" + +import unittest +import grpc + +from app import _run_fleet_api_grpc_rere, ADDRESS_FLEET_API_GRPC_RERE +from flwr.common import GRPC_MAX_MESSAGE_LENGTH +from common.constant import TRANSPORT_TYPE_GRPC_RERE +from client.app import _init_connection +from state.state_factory import StateFactory + +class TestServerInterceptor(unittest.TestCase): + def setUp(self): + self._state_factory = StateFactory(":flwr-in-memory-state:") + self._server: grpc.Server = _run_fleet_api_grpc_rere(ADDRESS_FLEET_API_GRPC_RERE, self._state_factory) + self._connection, self._address = _init_connection(TRANSPORT_TYPE_GRPC_RERE, ADDRESS_FLEET_API_GRPC_RERE) + with self._connection( + self._address, + True, + GRPC_MAX_MESSAGE_LENGTH, + ) as conn: + self._receive, self._send, self._create_node, self._delete_node = conn + + def tearDown(self): + self._server.stop(None) + + def test_successful_create_node_with_metadata(self) -> None: + self._create_node() + From 986961e393c33c57230647d6875afcbc22ffd441 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 14 Feb 2024 18:18:32 +0100 Subject: [PATCH 19/73] Move state to superlink --- .../state/authentication/__init__.py | 25 +++++ .../authentication/authentication_state.py | 58 +++++++++++ .../authentication_state_test.py | 97 +++++++++++++++++++ .../authentication/in_memory_auth_state.py | 73 ++++++++++++++ .../state/authentication/sqlite_auth_state.py | 81 ++++++++++++++++ 5 files changed, 334 insertions(+) create mode 100644 src/py/flwr/server/superlink/state/authentication/__init__.py create mode 100644 src/py/flwr/server/superlink/state/authentication/authentication_state.py create mode 100644 src/py/flwr/server/superlink/state/authentication/authentication_state_test.py create mode 100644 src/py/flwr/server/superlink/state/authentication/in_memory_auth_state.py create mode 100644 src/py/flwr/server/superlink/state/authentication/sqlite_auth_state.py diff --git a/src/py/flwr/server/superlink/state/authentication/__init__.py b/src/py/flwr/server/superlink/state/authentication/__init__.py new file mode 100644 index 000000000000..8f5c0a97ab1f --- /dev/null +++ b/src/py/flwr/server/superlink/state/authentication/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower server authentication state.""" + +from .authentication_state import AuthenticationState as AuthenticationState +from .in_memory_auth_state import InMemoryAuthState as InMemoryAuthState +from .sqlite_auth_state import SqliteAuthState as SqliteAuthState + +__all__ = [ + "AuthenticationState", + "InMemoryAuthState", + "SqliteAuthState", +] diff --git a/src/py/flwr/server/superlink/state/authentication/authentication_state.py b/src/py/flwr/server/superlink/state/authentication/authentication_state.py new file mode 100644 index 000000000000..3adb450dc215 --- /dev/null +++ b/src/py/flwr/server/superlink/state/authentication/authentication_state.py @@ -0,0 +1,58 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Abstract base class AuthenticationState.""" + +import abc +from typing import Set + +from flwr.server.state import State + + +class AuthenticationState(State, abc.ABC): + """Abstract Authentication State.""" + + @abc.abstractmethod + def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: + """Store `node_id` and `public_key` as key-value pair in state.""" + + @abc.abstractmethod + def get_public_key_from_node_id(self, node_id: int) -> bytes: + """Get client's public key in urlsafe bytes for `node_id`.""" + + @abc.abstractmethod + def store_server_public_private_key( + self, public_key: bytes, private_key: bytes + ) -> None: + """Store server's `public_key` and `private_key` in state.""" + + @abc.abstractmethod + def get_server_private_key(self) -> bytes: + """Get server private key in urlsafe bytes.""" + + @abc.abstractmethod + def get_server_public_key(self) -> bytes: + """Get server public key in urlsafe bytes.""" + + @abc.abstractmethod + def store_client_public_keys(self, public_keys: Set[bytes]) -> None: + """Store a set of client public keys in state.""" + + @abc.abstractmethod + def store_client_public_key(self, public_key: bytes) -> None: + """Retrieve a client public key in state.""" + + @abc.abstractmethod + def get_client_public_keys(self) -> Set[bytes]: + """Retrieve all currently stored client public keys as a set.""" diff --git a/src/py/flwr/server/superlink/state/authentication/authentication_state_test.py b/src/py/flwr/server/superlink/state/authentication/authentication_state_test.py new file mode 100644 index 000000000000..f18c428d3044 --- /dev/null +++ b/src/py/flwr/server/superlink/state/authentication/authentication_state_test.py @@ -0,0 +1,97 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for authentication state.""" + + +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + compute_hmac, + generate_key_pairs, + generate_shared_key, + public_key_to_bytes, + verify_hmac, +) + +from .in_memory_auth_state import InMemoryAuthState +from .sqlite_auth_state import SqliteAuthState + + +def test_in_memory_client_public_keys() -> None: + """Test client public keys store and get from state.""" + key_pairs = [generate_key_pairs() for _ in range(3)] + public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} + + in_memory_auth_state = InMemoryAuthState() + in_memory_auth_state.store_client_public_keys(public_keys) + + assert in_memory_auth_state.get_client_public_keys() == public_keys + + +def test_sqlite_client_public_keys() -> None: + """Test client public keys store and get from state.""" + key_pairs = [generate_key_pairs() for _ in range(3)] + public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} + + sqlite_auth_state = SqliteAuthState(":memory:") + sqlite_auth_state.initialize() + sqlite_auth_state.store_client_public_keys(public_keys) + + assert sqlite_auth_state.get_client_public_keys() == public_keys + + +def test_in_memory_node_id_public_key_pair() -> None: + """Test store and get node_id public_key pair.""" + in_memory_auth_state = InMemoryAuthState() + node_id = in_memory_auth_state.create_node() + public_key = public_key_to_bytes(generate_key_pairs()[1]) + + in_memory_auth_state.store_node_id_public_key_pair(node_id, public_key) + + assert in_memory_auth_state.get_public_key_from_node_id(node_id) == public_key + + +def test_sqlite_node_id_public_key_pair() -> None: + """Test store and get node_id public_key pair.""" + sqlite_auth_state = SqliteAuthState(":memory:") + sqlite_auth_state.initialize() + node_id = sqlite_auth_state.create_node() + public_key = public_key_to_bytes(generate_key_pairs()[1]) + + sqlite_auth_state.store_node_id_public_key_pair(node_id, public_key) + + assert sqlite_auth_state.get_public_key_from_node_id(node_id) == public_key + + +def test_generate_shared_key() -> None: + """Test util function generate_shared_key.""" + client_keys = generate_key_pairs() + server_keys = generate_key_pairs() + + client_shared_secret = generate_shared_key(client_keys[0], server_keys[1]) + server_shared_secret = generate_shared_key(server_keys[0], client_keys[1]) + + assert client_shared_secret == server_shared_secret + + +def test_hmac() -> None: + """Test util function compute and verify hmac.""" + client_keys = generate_key_pairs() + server_keys = generate_key_pairs() + client_shared_secret = generate_shared_key(client_keys[0], server_keys[1]) + server_shared_secret = generate_shared_key(server_keys[0], client_keys[1]) + message = b"Flower is the future of AI" + + client_compute_hmac = compute_hmac(client_shared_secret, message) + + assert verify_hmac(server_shared_secret, message, client_compute_hmac) diff --git a/src/py/flwr/server/superlink/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/superlink/state/authentication/in_memory_auth_state.py new file mode 100644 index 000000000000..fe10c1301b11 --- /dev/null +++ b/src/py/flwr/server/superlink/state/authentication/in_memory_auth_state.py @@ -0,0 +1,73 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""In-memory Authentication State implementation.""" + +from typing import Dict, Set + +from flwr.server.state.authentication.authentication_state import AuthenticationState +from flwr.server.state.in_memory_state import InMemoryState + + +class InMemoryAuthState(AuthenticationState, InMemoryState): + """In-memory-based authentication state implementation.""" + + def __init__(self) -> None: + """Init InMemoryAuthState.""" + super().__init__() + self.node_id_public_key_dict: Dict[int, bytes] = {} + self.client_public_keys: Set[bytes] = set() + self.server_public_key: bytes = b"" + self.server_private_key: bytes = b"" + + def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: + """Store `node_id` and `public_key` as key-value pair in state.""" + if node_id not in self.node_ids: + raise ValueError(f"Node {node_id} not found") + if node_id in self.node_id_public_key_dict: + raise ValueError(f"Node {node_id} has already assigned a public key") + self.node_id_public_key_dict[node_id] = public_key + + def get_public_key_from_node_id(self, node_id: int) -> bytes: + """Get client's public key in urlsafe bytes for `node_id`.""" + if node_id in self.node_id_public_key_dict: + return self.node_id_public_key_dict[node_id] + return b"" + + def store_server_public_private_key( + self, public_key: bytes, private_key: bytes + ) -> None: + """Store server's `public_key` and `private_key` in state.""" + self.server_private_key = private_key + self.server_public_key = public_key + + def get_server_private_key(self) -> bytes: + """Get server private key in urlsafe bytes.""" + return self.server_private_key + + def get_server_public_key(self) -> bytes: + """Get server public key in urlsafe bytes.""" + return self.server_public_key + + def store_client_public_keys(self, public_keys: Set[bytes]) -> None: + """Store a set of client public keys in state.""" + self.client_public_keys = public_keys + + def store_client_public_key(self, public_key: bytes) -> None: + """Retrieve a client public key in state.""" + self.client_public_keys.add(public_key) + + def get_client_public_keys(self) -> Set[bytes]: + """Retrieve all currently stored client public keys as a set.""" + return self.client_public_keys diff --git a/src/py/flwr/server/superlink/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/superlink/state/authentication/sqlite_auth_state.py new file mode 100644 index 000000000000..55e4bc73a63b --- /dev/null +++ b/src/py/flwr/server/superlink/state/authentication/sqlite_auth_state.py @@ -0,0 +1,81 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SQLite based implementation of server authentication state.""" + +from typing import Set + +from flwr.server.state.authentication.authentication_state import AuthenticationState +from flwr.server.state.sqlite_state import SqliteState + + +class SqliteAuthState(AuthenticationState, SqliteState): + """SQLite-based authentication state implementation.""" + + def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: + """Store `node_id` and `public_key` as key-value pair in state.""" + query = ( + "INSERT OR REPLACE INTO node_key (node_id, public_key) " + "VALUES (:node_id, :public_key)" + ) + self.query(query, {"node_id": node_id, "public_key": public_key}) + + def get_public_key_from_node_id(self, node_id: int) -> bytes: + """Get client's public key in urlsafe bytes for `node_id`.""" + query = "SELECT public_key FROM node_key WHERE node_id = :node_id" + rows = self.query(query, {"node_id": node_id}) + public_key: bytes = rows[0]["public_key"] + return public_key + + def store_server_public_private_key( + self, public_key: bytes, private_key: bytes + ) -> None: + """Store server's `public_key` and `private_key` in state.""" + query = ( + "INSERT OR REPLACE INTO credential (public_key, private_key) " + "VALUES (:public_key, :private_key)" + ) + self.query(query, {"public_key": public_key, "private_key": private_key}) + + def get_server_private_key(self) -> bytes: + """Get server private key in urlsafe bytes.""" + query = "SELECT private_key FROM credential" + rows = self.query(query) + private_key: bytes = rows[0]["private_key"] + return private_key + + def get_server_public_key(self) -> bytes: + """Get server public key in urlsafe bytes.""" + query = "SELECT public_key FROM credential" + rows = self.query(query) + public_key: bytes = rows[0]["public_key"] + return public_key + + def store_client_public_keys(self, public_keys: Set[bytes]) -> None: + """Store a set of client public keys in state.""" + query = "INSERT INTO public_key (public_key) VALUES (:public_key)" + for public_key in public_keys: + self.query(query, {"public_key": public_key}) + + def store_client_public_key(self, public_key: bytes) -> None: + """Retrieve a client public key in state.""" + query = "INSERT INTO public_key (public_key) VALUES (:public_key)" + self.query(query, {"public_key": public_key}) + + def get_client_public_keys(self) -> Set[bytes]: + """Retrieve all currently stored client public keys as a set.""" + query = "SELECT public_key FROM public_key" + rows = self.query(query) + result: Set[bytes] = {row["public_key"] for row in rows} + return result From dddbbc96a4ee3c2c43b84b01e83553a09e35f837 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 14 Feb 2024 18:20:49 +0100 Subject: [PATCH 20/73] Move state to superlink --- .../server/state/authentication/__init__.py | 25 ----- .../authentication/authentication_state.py | 58 ----------- .../authentication_state_test.py | 97 ------------------- .../authentication/in_memory_auth_state.py | 73 -------------- .../state/authentication/sqlite_auth_state.py | 81 ---------------- 5 files changed, 334 deletions(-) delete mode 100644 src/py/flwr/server/state/authentication/__init__.py delete mode 100644 src/py/flwr/server/state/authentication/authentication_state.py delete mode 100644 src/py/flwr/server/state/authentication/authentication_state_test.py delete mode 100644 src/py/flwr/server/state/authentication/in_memory_auth_state.py delete mode 100644 src/py/flwr/server/state/authentication/sqlite_auth_state.py diff --git a/src/py/flwr/server/state/authentication/__init__.py b/src/py/flwr/server/state/authentication/__init__.py deleted file mode 100644 index 8f5c0a97ab1f..000000000000 --- a/src/py/flwr/server/state/authentication/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Flower server authentication state.""" - -from .authentication_state import AuthenticationState as AuthenticationState -from .in_memory_auth_state import InMemoryAuthState as InMemoryAuthState -from .sqlite_auth_state import SqliteAuthState as SqliteAuthState - -__all__ = [ - "AuthenticationState", - "InMemoryAuthState", - "SqliteAuthState", -] diff --git a/src/py/flwr/server/state/authentication/authentication_state.py b/src/py/flwr/server/state/authentication/authentication_state.py deleted file mode 100644 index 3adb450dc215..000000000000 --- a/src/py/flwr/server/state/authentication/authentication_state.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Abstract base class AuthenticationState.""" - -import abc -from typing import Set - -from flwr.server.state import State - - -class AuthenticationState(State, abc.ABC): - """Abstract Authentication State.""" - - @abc.abstractmethod - def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: - """Store `node_id` and `public_key` as key-value pair in state.""" - - @abc.abstractmethod - def get_public_key_from_node_id(self, node_id: int) -> bytes: - """Get client's public key in urlsafe bytes for `node_id`.""" - - @abc.abstractmethod - def store_server_public_private_key( - self, public_key: bytes, private_key: bytes - ) -> None: - """Store server's `public_key` and `private_key` in state.""" - - @abc.abstractmethod - def get_server_private_key(self) -> bytes: - """Get server private key in urlsafe bytes.""" - - @abc.abstractmethod - def get_server_public_key(self) -> bytes: - """Get server public key in urlsafe bytes.""" - - @abc.abstractmethod - def store_client_public_keys(self, public_keys: Set[bytes]) -> None: - """Store a set of client public keys in state.""" - - @abc.abstractmethod - def store_client_public_key(self, public_key: bytes) -> None: - """Retrieve a client public key in state.""" - - @abc.abstractmethod - def get_client_public_keys(self) -> Set[bytes]: - """Retrieve all currently stored client public keys as a set.""" diff --git a/src/py/flwr/server/state/authentication/authentication_state_test.py b/src/py/flwr/server/state/authentication/authentication_state_test.py deleted file mode 100644 index f18c428d3044..000000000000 --- a/src/py/flwr/server/state/authentication/authentication_state_test.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Test for authentication state.""" - - -from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - compute_hmac, - generate_key_pairs, - generate_shared_key, - public_key_to_bytes, - verify_hmac, -) - -from .in_memory_auth_state import InMemoryAuthState -from .sqlite_auth_state import SqliteAuthState - - -def test_in_memory_client_public_keys() -> None: - """Test client public keys store and get from state.""" - key_pairs = [generate_key_pairs() for _ in range(3)] - public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} - - in_memory_auth_state = InMemoryAuthState() - in_memory_auth_state.store_client_public_keys(public_keys) - - assert in_memory_auth_state.get_client_public_keys() == public_keys - - -def test_sqlite_client_public_keys() -> None: - """Test client public keys store and get from state.""" - key_pairs = [generate_key_pairs() for _ in range(3)] - public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} - - sqlite_auth_state = SqliteAuthState(":memory:") - sqlite_auth_state.initialize() - sqlite_auth_state.store_client_public_keys(public_keys) - - assert sqlite_auth_state.get_client_public_keys() == public_keys - - -def test_in_memory_node_id_public_key_pair() -> None: - """Test store and get node_id public_key pair.""" - in_memory_auth_state = InMemoryAuthState() - node_id = in_memory_auth_state.create_node() - public_key = public_key_to_bytes(generate_key_pairs()[1]) - - in_memory_auth_state.store_node_id_public_key_pair(node_id, public_key) - - assert in_memory_auth_state.get_public_key_from_node_id(node_id) == public_key - - -def test_sqlite_node_id_public_key_pair() -> None: - """Test store and get node_id public_key pair.""" - sqlite_auth_state = SqliteAuthState(":memory:") - sqlite_auth_state.initialize() - node_id = sqlite_auth_state.create_node() - public_key = public_key_to_bytes(generate_key_pairs()[1]) - - sqlite_auth_state.store_node_id_public_key_pair(node_id, public_key) - - assert sqlite_auth_state.get_public_key_from_node_id(node_id) == public_key - - -def test_generate_shared_key() -> None: - """Test util function generate_shared_key.""" - client_keys = generate_key_pairs() - server_keys = generate_key_pairs() - - client_shared_secret = generate_shared_key(client_keys[0], server_keys[1]) - server_shared_secret = generate_shared_key(server_keys[0], client_keys[1]) - - assert client_shared_secret == server_shared_secret - - -def test_hmac() -> None: - """Test util function compute and verify hmac.""" - client_keys = generate_key_pairs() - server_keys = generate_key_pairs() - client_shared_secret = generate_shared_key(client_keys[0], server_keys[1]) - server_shared_secret = generate_shared_key(server_keys[0], client_keys[1]) - message = b"Flower is the future of AI" - - client_compute_hmac = compute_hmac(client_shared_secret, message) - - assert verify_hmac(server_shared_secret, message, client_compute_hmac) diff --git a/src/py/flwr/server/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/state/authentication/in_memory_auth_state.py deleted file mode 100644 index fe10c1301b11..000000000000 --- a/src/py/flwr/server/state/authentication/in_memory_auth_state.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""In-memory Authentication State implementation.""" - -from typing import Dict, Set - -from flwr.server.state.authentication.authentication_state import AuthenticationState -from flwr.server.state.in_memory_state import InMemoryState - - -class InMemoryAuthState(AuthenticationState, InMemoryState): - """In-memory-based authentication state implementation.""" - - def __init__(self) -> None: - """Init InMemoryAuthState.""" - super().__init__() - self.node_id_public_key_dict: Dict[int, bytes] = {} - self.client_public_keys: Set[bytes] = set() - self.server_public_key: bytes = b"" - self.server_private_key: bytes = b"" - - def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: - """Store `node_id` and `public_key` as key-value pair in state.""" - if node_id not in self.node_ids: - raise ValueError(f"Node {node_id} not found") - if node_id in self.node_id_public_key_dict: - raise ValueError(f"Node {node_id} has already assigned a public key") - self.node_id_public_key_dict[node_id] = public_key - - def get_public_key_from_node_id(self, node_id: int) -> bytes: - """Get client's public key in urlsafe bytes for `node_id`.""" - if node_id in self.node_id_public_key_dict: - return self.node_id_public_key_dict[node_id] - return b"" - - def store_server_public_private_key( - self, public_key: bytes, private_key: bytes - ) -> None: - """Store server's `public_key` and `private_key` in state.""" - self.server_private_key = private_key - self.server_public_key = public_key - - def get_server_private_key(self) -> bytes: - """Get server private key in urlsafe bytes.""" - return self.server_private_key - - def get_server_public_key(self) -> bytes: - """Get server public key in urlsafe bytes.""" - return self.server_public_key - - def store_client_public_keys(self, public_keys: Set[bytes]) -> None: - """Store a set of client public keys in state.""" - self.client_public_keys = public_keys - - def store_client_public_key(self, public_key: bytes) -> None: - """Retrieve a client public key in state.""" - self.client_public_keys.add(public_key) - - def get_client_public_keys(self) -> Set[bytes]: - """Retrieve all currently stored client public keys as a set.""" - return self.client_public_keys diff --git a/src/py/flwr/server/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/state/authentication/sqlite_auth_state.py deleted file mode 100644 index 55e4bc73a63b..000000000000 --- a/src/py/flwr/server/state/authentication/sqlite_auth_state.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SQLite based implementation of server authentication state.""" - -from typing import Set - -from flwr.server.state.authentication.authentication_state import AuthenticationState -from flwr.server.state.sqlite_state import SqliteState - - -class SqliteAuthState(AuthenticationState, SqliteState): - """SQLite-based authentication state implementation.""" - - def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: - """Store `node_id` and `public_key` as key-value pair in state.""" - query = ( - "INSERT OR REPLACE INTO node_key (node_id, public_key) " - "VALUES (:node_id, :public_key)" - ) - self.query(query, {"node_id": node_id, "public_key": public_key}) - - def get_public_key_from_node_id(self, node_id: int) -> bytes: - """Get client's public key in urlsafe bytes for `node_id`.""" - query = "SELECT public_key FROM node_key WHERE node_id = :node_id" - rows = self.query(query, {"node_id": node_id}) - public_key: bytes = rows[0]["public_key"] - return public_key - - def store_server_public_private_key( - self, public_key: bytes, private_key: bytes - ) -> None: - """Store server's `public_key` and `private_key` in state.""" - query = ( - "INSERT OR REPLACE INTO credential (public_key, private_key) " - "VALUES (:public_key, :private_key)" - ) - self.query(query, {"public_key": public_key, "private_key": private_key}) - - def get_server_private_key(self) -> bytes: - """Get server private key in urlsafe bytes.""" - query = "SELECT private_key FROM credential" - rows = self.query(query) - private_key: bytes = rows[0]["private_key"] - return private_key - - def get_server_public_key(self) -> bytes: - """Get server public key in urlsafe bytes.""" - query = "SELECT public_key FROM credential" - rows = self.query(query) - public_key: bytes = rows[0]["public_key"] - return public_key - - def store_client_public_keys(self, public_keys: Set[bytes]) -> None: - """Store a set of client public keys in state.""" - query = "INSERT INTO public_key (public_key) VALUES (:public_key)" - for public_key in public_keys: - self.query(query, {"public_key": public_key}) - - def store_client_public_key(self, public_key: bytes) -> None: - """Retrieve a client public key in state.""" - query = "INSERT INTO public_key (public_key) VALUES (:public_key)" - self.query(query, {"public_key": public_key}) - - def get_client_public_keys(self) -> Set[bytes]: - """Retrieve all currently stored client public keys as a set.""" - query = "SELECT public_key FROM public_key" - rows = self.query(query) - result: Set[bytes] = {row["public_key"] for row in rows} - return result From fbbcb2abc67dc1bde55c3e07c6ef503c676bb8a9 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 15 Feb 2024 12:49:28 +0100 Subject: [PATCH 21/73] Fix server interceptor --- src/py/flwr/server/server_interceptor.py | 207 +++++++++--------- src/py/flwr/server/server_interceptor_test.py | 24 +- 2 files changed, 125 insertions(+), 106 deletions(-) diff --git a/src/py/flwr/server/server_interceptor.py b/src/py/flwr/server/server_interceptor.py index 15a0d1d4dfee..77815ee1a113 100644 --- a/src/py/flwr/server/server_interceptor.py +++ b/src/py/flwr/server/server_interceptor.py @@ -14,118 +14,127 @@ # ============================================================================== """Flower server interceptor.""" +import threading +from typing import Callable, Sequence, Tuple, Union + import grpc from cryptography.hazmat.primitives.asymmetric import ec -from typing import Callable, Sequence, Tuple, Union -from flwr.server.state.authentication import AuthenticationState -from flwr.common.secure_aggregation.crypto.symmetric_encryption import generate_shared_key, bytes_to_public_key, public_key_to_bytes, verify_hmac + +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + bytes_to_public_key, + generate_shared_key, + public_key_to_bytes, + verify_hmac, +) from flwr.proto.fleet_pb2 import ( CreateNodeRequest, CreateNodeResponse, + DeleteNodeRequest, + DeleteNodeResponse, + PullTaskInsRequest, + PullTaskInsResponse, + PushTaskResRequest, + PushTaskResResponse, ) -from flwr.server.fleet.message_handler import message_handler -from flwr.server.state import StateFactory, State +from flwr.server.superlink.state import StateFactory _PUBLIC_KEY_HEADER = "public-key" _AUTH_TOKEN_HEADER = "auth-token" -def _unary_unary_rpc_terminator(): - - def terminate(_, context): - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") - return grpc.unary_unary_rpc_method_handler(terminate) +def _get_value_from_tuples( + key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]] +) -> bytes: + value = next((value[::-1] for key, value in tuples if key == key_string), "") + if isinstance(value, str): + return value.encode() -def _create_node_with_public_key(state: State, server_public_key: bytes): - - def send_public_key(request: CreateNodeRequest, context: grpc.ServicerContext) -> CreateNodeResponse: - context.set_trailing_metadata( - ( - (_PUBLIC_KEY_HEADER, server_public_key), - ) - ) - return message_handler.create_node(request, state) - - return grpc.unary_unary_rpc_method_handler(send_public_key) - -def _create_node_with_public_key(state: State, server_public_key: bytes): - - def send_public_key(request: CreateNodeRequest, context: grpc.ServicerContext) -> CreateNodeResponse: - context.set_trailing_metadata( - ( - (_PUBLIC_KEY_HEADER, server_public_key), - ) - ) - return message_handler.create_node(request, state) + return value - return grpc.unary_unary_rpc_method_handler(send_public_key) - -def _handle_authentication(public_key, private_key): - return generate_shared_key(public_key, private_key) - -def _is_public_key_known(state: AuthenticationState, public_key: bytes) -> bool: - return public_key in state.get_client_public_keys() - -def _get_value_from_tuples(key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]]) -> Union[str, bytes]: - return next((value[::-1] for key, value in tuples if key == key_string), "") class AuthenticateServerInterceptor(grpc.ServerInterceptor): - - def __init__(self, state_factory: StateFactory, private_key: ec.EllipticCurvePrivateKey, public_key: ec.EllipticCurvePublicKey): - self._private_key = private_key - self._public_key = public_key - self._state_factory = state_factory - self._terminator = _unary_unary_rpc_terminator() - self._create_node_handler = _create_node_with_public_key() - - def intercept_service(self, continuation: Callable, handler_call_details: grpc.HandlerCallDetails): - method_name = handler_call_details.method.split("/")[-1] - client_public_key_bytes = _get_value_from_tuples(_PUBLIC_KEY_HEADER, handler_call_details.invocation_metadata) - client_public_key = bytes_to_public_key(client_public_key_bytes) - - if _is_public_key_known(self._state_factory.state, client_public_key_bytes): - if method_name == 'CreateNode': - return _create_node_with_public_key(self._state_factory.state, self._public_key) - elif method_name in {'DeleteNode', 'PullTaskIns', 'PushTaskRes'}: - state: AuthenticationState = self._state_factory.state - shared_secret = generate_shared_key(self._private_key, client_public_key) - hmac = _get_value_from_tuples(_AUTH_TOKEN_HEADER, handler_call_details.invocation_metadata) - if verify_hmac(shared_secret, ) - state.get_client_public_keys() - expected_metadata = (_AUTH_TOKEN_HEADER, generate_shared_key()) - - - if (self._header, self._value) in handler_call_details.invocation_metadata: - grpc.unary_unary_rpc_method_handler - return continuation(handler_call_details) - else: - return self._terminator - - def intercept_service(self, continuation: Callable, handler_call_details: grpc.HandlerCallDetails): - client_public_key_bytes = _get_value_from_tuples(_PUBLIC_KEY_HEADER, handler_call_details.invocation_metadata) - if _is_public_key_known(self._state_factory.state, client_public_key_bytes): - message_handler: grpc.RpcMethodHandler = continuation(handler_call_details) - message_handler. - return grpc.unary_unary_rpc_method_handler(message_handler.unary_unary, request_deserializer=message_handler.request_deserializer, response_serializer=message_handler.response_serializer) - if message_handler is None: - return - else: - return self._terminator - - handler_factory, next_handler_method = _get_factory_and_method(next_handler) - - - def invoke_intercept_method(request_or_iterator, context): - method_name = handler_call_details.method - return self.intercept( - next_handler_method, - request_or_iterator, - context, - method_name, - ) - - return handler_factory( - invoke_intercept_method, - request_deserializer=next_handler.request_deserializer, - response_serializer=next_handler.response_serializer, + """Server interceptor for client authentication.""" + + def __init__( + self, + state_factory: StateFactory, + private_key: ec.EllipticCurvePrivateKey, + public_key: ec.EllipticCurvePublicKey, + ): + self._lock = threading.Lock + self.server_private_key = private_key + self.server_public_key = public_key + self.state_factory = state_factory + + def intercept_service( + self, continuation: Callable, handler_call_details: grpc.HandlerCallDetails + ) -> grpc.RpcMethodHandler: + """Flower server interceptor authentication logic.""" + message_handler: grpc.RpcMethodHandler = continuation(handler_call_details) + return self._generic_auth_unary_method_handler(message_handler) + + def _generic_auth_unary_method_handler( + self, existing_handler: grpc.RpcMethodHandler + ) -> grpc.RpcMethodHandler: + def _generic_method_handler( + request: Union[ + CreateNodeRequest, + DeleteNodeRequest, + PullTaskInsRequest, + PushTaskResRequest, + ], + context: grpc.ServicerContext, + ) -> Union[ + CreateNodeResponse, + DeleteNodeResponse, + PullTaskInsResponse, + PushTaskResResponse, + ]: + with self._lock: + if isinstance(request, CreateNodeRequest): + client_public_key_bytes = _get_value_from_tuples( + _PUBLIC_KEY_HEADER, context.invocation_metadata() + ) + is_public_key_known = ( + client_public_key_bytes in self.state.get_client_public_keys() + ) + if is_public_key_known: + context.set_trailing_metadata( + ( + ( + _PUBLIC_KEY_HEADER, + public_key_to_bytes(self.server_public_key), + ), + ) + ) + else: + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") + elif isinstance( + request, (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest) + ): + hmac_value = _get_value_from_tuples( + _AUTH_TOKEN_HEADER, context.invocation_metadata() + ) + node_id: int = ( + -1 if request.node.anonymous else request.node.node_id + ) + client_public_key_bytes = state.get_public_key_from_node_id(node_id) + shared_secret = generate_shared_key( + self.server_private_key, + bytes_to_public_key(client_public_key_bytes), + ) + verify = verify_hmac( + shared_secret, request.SerializeToString(True), hmac_value + ) + if not verify: + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") + else: + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") + + return existing_handler.unary_unary + + return grpc.unary_unary_rpc_method_handler( + _generic_method_handler, + request_deserializer=existing_handler.request_deserializer, + response_serializer=existing_handler.response_serializer, ) diff --git a/src/py/flwr/server/server_interceptor_test.py b/src/py/flwr/server/server_interceptor_test.py index d3a528addddd..bcd9ab5cc7e0 100644 --- a/src/py/flwr/server/server_interceptor_test.py +++ b/src/py/flwr/server/server_interceptor_test.py @@ -15,19 +15,28 @@ """Flower server interceptor tests.""" import unittest -import grpc -from app import _run_fleet_api_grpc_rere, ADDRESS_FLEET_API_GRPC_RERE -from flwr.common import GRPC_MAX_MESSAGE_LENGTH -from common.constant import TRANSPORT_TYPE_GRPC_RERE +import grpc +from app import ADDRESS_FLEET_API_GRPC_RERE, _run_fleet_api_grpc_rere from client.app import _init_connection +from common.constant import TRANSPORT_TYPE_GRPC_RERE from state.state_factory import StateFactory +from flwr.common import GRPC_MAX_MESSAGE_LENGTH + + class TestServerInterceptor(unittest.TestCase): + """Server interceptor tests.""" + def setUp(self): + """Initialize mock stub and server interceptor.""" self._state_factory = StateFactory(":flwr-in-memory-state:") - self._server: grpc.Server = _run_fleet_api_grpc_rere(ADDRESS_FLEET_API_GRPC_RERE, self._state_factory) - self._connection, self._address = _init_connection(TRANSPORT_TYPE_GRPC_RERE, ADDRESS_FLEET_API_GRPC_RERE) + self._server: grpc.Server = _run_fleet_api_grpc_rere( + ADDRESS_FLEET_API_GRPC_RERE, self._state_factory + ) + self._connection, self._address = _init_connection( + TRANSPORT_TYPE_GRPC_RERE, ADDRESS_FLEET_API_GRPC_RERE + ) with self._connection( self._address, True, @@ -36,8 +45,9 @@ def setUp(self): self._receive, self._send, self._create_node, self._delete_node = conn def tearDown(self): + """Clean up grpc server.""" self._server.stop(None) def test_successful_create_node_with_metadata(self) -> None: + """Test server interceptor for create node.""" self._create_node() - From 77e5c3cb161d776544dbb2c97bb7667c47ad3e45 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 15 Feb 2024 12:59:08 +0100 Subject: [PATCH 22/73] Fix authentication state --- .../authentication/authentication_state.py | 58 ------------- .../authentication/in_memory_auth_state.py | 73 ----------------- .../state/authentication/sqlite_auth_state.py | 81 ------------------- .../server/superlink/state/in_memory_state.py | 45 +++++++++++ .../server/superlink/state/sqlite_state.py | 57 +++++++++++++ src/py/flwr/server/superlink/state/state.py | 34 ++++++++ 6 files changed, 136 insertions(+), 212 deletions(-) delete mode 100644 src/py/flwr/server/superlink/state/authentication/authentication_state.py delete mode 100644 src/py/flwr/server/superlink/state/authentication/in_memory_auth_state.py delete mode 100644 src/py/flwr/server/superlink/state/authentication/sqlite_auth_state.py diff --git a/src/py/flwr/server/superlink/state/authentication/authentication_state.py b/src/py/flwr/server/superlink/state/authentication/authentication_state.py deleted file mode 100644 index 3adb450dc215..000000000000 --- a/src/py/flwr/server/superlink/state/authentication/authentication_state.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Abstract base class AuthenticationState.""" - -import abc -from typing import Set - -from flwr.server.state import State - - -class AuthenticationState(State, abc.ABC): - """Abstract Authentication State.""" - - @abc.abstractmethod - def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: - """Store `node_id` and `public_key` as key-value pair in state.""" - - @abc.abstractmethod - def get_public_key_from_node_id(self, node_id: int) -> bytes: - """Get client's public key in urlsafe bytes for `node_id`.""" - - @abc.abstractmethod - def store_server_public_private_key( - self, public_key: bytes, private_key: bytes - ) -> None: - """Store server's `public_key` and `private_key` in state.""" - - @abc.abstractmethod - def get_server_private_key(self) -> bytes: - """Get server private key in urlsafe bytes.""" - - @abc.abstractmethod - def get_server_public_key(self) -> bytes: - """Get server public key in urlsafe bytes.""" - - @abc.abstractmethod - def store_client_public_keys(self, public_keys: Set[bytes]) -> None: - """Store a set of client public keys in state.""" - - @abc.abstractmethod - def store_client_public_key(self, public_key: bytes) -> None: - """Retrieve a client public key in state.""" - - @abc.abstractmethod - def get_client_public_keys(self) -> Set[bytes]: - """Retrieve all currently stored client public keys as a set.""" diff --git a/src/py/flwr/server/superlink/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/superlink/state/authentication/in_memory_auth_state.py deleted file mode 100644 index fe10c1301b11..000000000000 --- a/src/py/flwr/server/superlink/state/authentication/in_memory_auth_state.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""In-memory Authentication State implementation.""" - -from typing import Dict, Set - -from flwr.server.state.authentication.authentication_state import AuthenticationState -from flwr.server.state.in_memory_state import InMemoryState - - -class InMemoryAuthState(AuthenticationState, InMemoryState): - """In-memory-based authentication state implementation.""" - - def __init__(self) -> None: - """Init InMemoryAuthState.""" - super().__init__() - self.node_id_public_key_dict: Dict[int, bytes] = {} - self.client_public_keys: Set[bytes] = set() - self.server_public_key: bytes = b"" - self.server_private_key: bytes = b"" - - def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: - """Store `node_id` and `public_key` as key-value pair in state.""" - if node_id not in self.node_ids: - raise ValueError(f"Node {node_id} not found") - if node_id in self.node_id_public_key_dict: - raise ValueError(f"Node {node_id} has already assigned a public key") - self.node_id_public_key_dict[node_id] = public_key - - def get_public_key_from_node_id(self, node_id: int) -> bytes: - """Get client's public key in urlsafe bytes for `node_id`.""" - if node_id in self.node_id_public_key_dict: - return self.node_id_public_key_dict[node_id] - return b"" - - def store_server_public_private_key( - self, public_key: bytes, private_key: bytes - ) -> None: - """Store server's `public_key` and `private_key` in state.""" - self.server_private_key = private_key - self.server_public_key = public_key - - def get_server_private_key(self) -> bytes: - """Get server private key in urlsafe bytes.""" - return self.server_private_key - - def get_server_public_key(self) -> bytes: - """Get server public key in urlsafe bytes.""" - return self.server_public_key - - def store_client_public_keys(self, public_keys: Set[bytes]) -> None: - """Store a set of client public keys in state.""" - self.client_public_keys = public_keys - - def store_client_public_key(self, public_key: bytes) -> None: - """Retrieve a client public key in state.""" - self.client_public_keys.add(public_key) - - def get_client_public_keys(self) -> Set[bytes]: - """Retrieve all currently stored client public keys as a set.""" - return self.client_public_keys diff --git a/src/py/flwr/server/superlink/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/superlink/state/authentication/sqlite_auth_state.py deleted file mode 100644 index 55e4bc73a63b..000000000000 --- a/src/py/flwr/server/superlink/state/authentication/sqlite_auth_state.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SQLite based implementation of server authentication state.""" - -from typing import Set - -from flwr.server.state.authentication.authentication_state import AuthenticationState -from flwr.server.state.sqlite_state import SqliteState - - -class SqliteAuthState(AuthenticationState, SqliteState): - """SQLite-based authentication state implementation.""" - - def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: - """Store `node_id` and `public_key` as key-value pair in state.""" - query = ( - "INSERT OR REPLACE INTO node_key (node_id, public_key) " - "VALUES (:node_id, :public_key)" - ) - self.query(query, {"node_id": node_id, "public_key": public_key}) - - def get_public_key_from_node_id(self, node_id: int) -> bytes: - """Get client's public key in urlsafe bytes for `node_id`.""" - query = "SELECT public_key FROM node_key WHERE node_id = :node_id" - rows = self.query(query, {"node_id": node_id}) - public_key: bytes = rows[0]["public_key"] - return public_key - - def store_server_public_private_key( - self, public_key: bytes, private_key: bytes - ) -> None: - """Store server's `public_key` and `private_key` in state.""" - query = ( - "INSERT OR REPLACE INTO credential (public_key, private_key) " - "VALUES (:public_key, :private_key)" - ) - self.query(query, {"public_key": public_key, "private_key": private_key}) - - def get_server_private_key(self) -> bytes: - """Get server private key in urlsafe bytes.""" - query = "SELECT private_key FROM credential" - rows = self.query(query) - private_key: bytes = rows[0]["private_key"] - return private_key - - def get_server_public_key(self) -> bytes: - """Get server public key in urlsafe bytes.""" - query = "SELECT public_key FROM credential" - rows = self.query(query) - public_key: bytes = rows[0]["public_key"] - return public_key - - def store_client_public_keys(self, public_keys: Set[bytes]) -> None: - """Store a set of client public keys in state.""" - query = "INSERT INTO public_key (public_key) VALUES (:public_key)" - for public_key in public_keys: - self.query(query, {"public_key": public_key}) - - def store_client_public_key(self, public_key: bytes) -> None: - """Retrieve a client public key in state.""" - query = "INSERT INTO public_key (public_key) VALUES (:public_key)" - self.query(query, {"public_key": public_key}) - - def get_client_public_keys(self) -> Set[bytes]: - """Retrieve all currently stored client public keys as a set.""" - query = "SELECT public_key FROM public_key" - rows = self.query(query) - result: Set[bytes] = {row["public_key"] for row in rows} - return result diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index ecb39f18300a..93ae1cab97a8 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -35,6 +35,10 @@ def __init__(self) -> None: self.run_ids: Set[int] = set() self.task_ins_store: Dict[UUID, TaskIns] = {} self.task_res_store: Dict[UUID, TaskRes] = {} + self.node_id_public_key_dict: Dict[int, bytes] = {} + self.client_public_keys: Set[bytes] = set() + self.server_public_key: bytes = b"" + self.server_private_key: bytes = b"" def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: """Store one TaskIns.""" @@ -221,3 +225,44 @@ def create_run(self) -> int: return run_id log(ERROR, "Unexpected run creation failure.") return 0 + + def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: + """Store `node_id` and `public_key` as key-value pair in state.""" + if node_id not in self.node_ids: + raise ValueError(f"Node {node_id} not found") + if node_id in self.node_id_public_key_dict: + raise ValueError(f"Node {node_id} has already assigned a public key") + self.node_id_public_key_dict[node_id] = public_key + + def get_public_key_from_node_id(self, node_id: int) -> bytes: + """Get client's public key in urlsafe bytes for `node_id`.""" + if node_id in self.node_id_public_key_dict: + return self.node_id_public_key_dict[node_id] + return b"" + + def store_server_public_private_key( + self, public_key: bytes, private_key: bytes + ) -> None: + """Store server's `public_key` and `private_key` in state.""" + self.server_private_key = private_key + self.server_public_key = public_key + + def get_server_private_key(self) -> bytes: + """Get server private key in urlsafe bytes.""" + return self.server_private_key + + def get_server_public_key(self) -> bytes: + """Get server public key in urlsafe bytes.""" + return self.server_public_key + + def store_client_public_keys(self, public_keys: Set[bytes]) -> None: + """Store a set of client public keys in state.""" + self.client_public_keys = public_keys + + def store_client_public_key(self, public_key: bytes) -> None: + """Retrieve a client public key in state.""" + self.client_public_keys.add(public_key) + + def get_client_public_keys(self) -> Set[bytes]: + """Retrieve all currently stored client public keys as a set.""" + return self.client_public_keys diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index e91d8553863c..f671ff12a1e2 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -542,6 +542,63 @@ def create_run(self) -> int: log(ERROR, "Unexpected run creation failure.") return 0 + def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: + """Store `node_id` and `public_key` as key-value pair in state.""" + query = ( + "INSERT OR REPLACE INTO node_key (node_id, public_key) " + "VALUES (:node_id, :public_key)" + ) + self.query(query, {"node_id": node_id, "public_key": public_key}) + + def get_public_key_from_node_id(self, node_id: int) -> bytes: + """Get client's public key in urlsafe bytes for `node_id`.""" + query = "SELECT public_key FROM node_key WHERE node_id = :node_id" + rows = self.query(query, {"node_id": node_id}) + public_key: bytes = rows[0]["public_key"] + return public_key + + def store_server_public_private_key( + self, public_key: bytes, private_key: bytes + ) -> None: + """Store server's `public_key` and `private_key` in state.""" + query = ( + "INSERT OR REPLACE INTO credential (public_key, private_key) " + "VALUES (:public_key, :private_key)" + ) + self.query(query, {"public_key": public_key, "private_key": private_key}) + + def get_server_private_key(self) -> bytes: + """Get server private key in urlsafe bytes.""" + query = "SELECT private_key FROM credential" + rows = self.query(query) + private_key: bytes = rows[0]["private_key"] + return private_key + + def get_server_public_key(self) -> bytes: + """Get server public key in urlsafe bytes.""" + query = "SELECT public_key FROM credential" + rows = self.query(query) + public_key: bytes = rows[0]["public_key"] + return public_key + + def store_client_public_keys(self, public_keys: Set[bytes]) -> None: + """Store a set of client public keys in state.""" + query = "INSERT INTO public_key (public_key) VALUES (:public_key)" + for public_key in public_keys: + self.query(query, {"public_key": public_key}) + + def store_client_public_key(self, public_key: bytes) -> None: + """Retrieve a client public key in state.""" + query = "INSERT INTO public_key (public_key) VALUES (:public_key)" + self.query(query, {"public_key": public_key}) + + def get_client_public_keys(self) -> Set[bytes]: + """Retrieve all currently stored client public keys as a set.""" + query = "SELECT public_key FROM public_key" + rows = self.query(query) + result: Set[bytes] = {row["public_key"] for row in rows} + return result + def dict_factory( cursor: sqlite3.Cursor, diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index 9337ae6d8624..645b3ec1b290 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -152,3 +152,37 @@ def get_nodes(self, run_id: int) -> Set[int]: @abc.abstractmethod def create_run(self) -> int: """Create one run.""" + + @abc.abstractmethod + def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: + """Store `node_id` and `public_key` as key-value pair in state.""" + + @abc.abstractmethod + def get_public_key_from_node_id(self, node_id: int) -> bytes: + """Get client's public key in urlsafe bytes for `node_id`.""" + + @abc.abstractmethod + def store_server_public_private_key( + self, public_key: bytes, private_key: bytes + ) -> None: + """Store server's `public_key` and `private_key` in state.""" + + @abc.abstractmethod + def get_server_private_key(self) -> bytes: + """Get server private key in urlsafe bytes.""" + + @abc.abstractmethod + def get_server_public_key(self) -> bytes: + """Get server public key in urlsafe bytes.""" + + @abc.abstractmethod + def store_client_public_keys(self, public_keys: Set[bytes]) -> None: + """Store a set of client public keys in state.""" + + @abc.abstractmethod + def store_client_public_key(self, public_key: bytes) -> None: + """Retrieve a client public key in state.""" + + @abc.abstractmethod + def get_client_public_keys(self) -> Set[bytes]: + """Retrieve all currently stored client public keys as a set.""" From 21e590a3a4326eac6d453524fa2e47cad6e7b1e0 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 15 Feb 2024 13:07:40 +0100 Subject: [PATCH 23/73] Add symmetric encryption test --- .../crypto/symmetric_encryption_test.py | 40 ++++++++ .../state/authentication/__init__.py | 25 ----- .../authentication_state_test.py | 97 ------------------- .../flwr/server/superlink/state/state_test.py | 28 ++++++ 4 files changed, 68 insertions(+), 122 deletions(-) create mode 100644 src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption_test.py delete mode 100644 src/py/flwr/server/superlink/state/authentication/__init__.py delete mode 100644 src/py/flwr/server/superlink/state/authentication/authentication_state_test.py diff --git a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption_test.py b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption_test.py new file mode 100644 index 000000000000..523b5410f1d0 --- /dev/null +++ b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption_test.py @@ -0,0 +1,40 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Symmetric encryption tests.""" + +from .symmetric_encryption import generate_shared_key, generate_key_pairs, compute_hmac, verify_hmac + +def test_generate_shared_key() -> None: + """Test util function generate_shared_key.""" + client_keys = generate_key_pairs() + server_keys = generate_key_pairs() + + client_shared_secret = generate_shared_key(client_keys[0], server_keys[1]) + server_shared_secret = generate_shared_key(server_keys[0], client_keys[1]) + + assert client_shared_secret == server_shared_secret + + +def test_hmac() -> None: + """Test util function compute and verify hmac.""" + client_keys = generate_key_pairs() + server_keys = generate_key_pairs() + client_shared_secret = generate_shared_key(client_keys[0], server_keys[1]) + server_shared_secret = generate_shared_key(server_keys[0], client_keys[1]) + message = b"Flower is the future of AI" + + client_compute_hmac = compute_hmac(client_shared_secret, message) + + assert verify_hmac(server_shared_secret, message, client_compute_hmac) \ No newline at end of file diff --git a/src/py/flwr/server/superlink/state/authentication/__init__.py b/src/py/flwr/server/superlink/state/authentication/__init__.py deleted file mode 100644 index 8f5c0a97ab1f..000000000000 --- a/src/py/flwr/server/superlink/state/authentication/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Flower server authentication state.""" - -from .authentication_state import AuthenticationState as AuthenticationState -from .in_memory_auth_state import InMemoryAuthState as InMemoryAuthState -from .sqlite_auth_state import SqliteAuthState as SqliteAuthState - -__all__ = [ - "AuthenticationState", - "InMemoryAuthState", - "SqliteAuthState", -] diff --git a/src/py/flwr/server/superlink/state/authentication/authentication_state_test.py b/src/py/flwr/server/superlink/state/authentication/authentication_state_test.py deleted file mode 100644 index f18c428d3044..000000000000 --- a/src/py/flwr/server/superlink/state/authentication/authentication_state_test.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Test for authentication state.""" - - -from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - compute_hmac, - generate_key_pairs, - generate_shared_key, - public_key_to_bytes, - verify_hmac, -) - -from .in_memory_auth_state import InMemoryAuthState -from .sqlite_auth_state import SqliteAuthState - - -def test_in_memory_client_public_keys() -> None: - """Test client public keys store and get from state.""" - key_pairs = [generate_key_pairs() for _ in range(3)] - public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} - - in_memory_auth_state = InMemoryAuthState() - in_memory_auth_state.store_client_public_keys(public_keys) - - assert in_memory_auth_state.get_client_public_keys() == public_keys - - -def test_sqlite_client_public_keys() -> None: - """Test client public keys store and get from state.""" - key_pairs = [generate_key_pairs() for _ in range(3)] - public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} - - sqlite_auth_state = SqliteAuthState(":memory:") - sqlite_auth_state.initialize() - sqlite_auth_state.store_client_public_keys(public_keys) - - assert sqlite_auth_state.get_client_public_keys() == public_keys - - -def test_in_memory_node_id_public_key_pair() -> None: - """Test store and get node_id public_key pair.""" - in_memory_auth_state = InMemoryAuthState() - node_id = in_memory_auth_state.create_node() - public_key = public_key_to_bytes(generate_key_pairs()[1]) - - in_memory_auth_state.store_node_id_public_key_pair(node_id, public_key) - - assert in_memory_auth_state.get_public_key_from_node_id(node_id) == public_key - - -def test_sqlite_node_id_public_key_pair() -> None: - """Test store and get node_id public_key pair.""" - sqlite_auth_state = SqliteAuthState(":memory:") - sqlite_auth_state.initialize() - node_id = sqlite_auth_state.create_node() - public_key = public_key_to_bytes(generate_key_pairs()[1]) - - sqlite_auth_state.store_node_id_public_key_pair(node_id, public_key) - - assert sqlite_auth_state.get_public_key_from_node_id(node_id) == public_key - - -def test_generate_shared_key() -> None: - """Test util function generate_shared_key.""" - client_keys = generate_key_pairs() - server_keys = generate_key_pairs() - - client_shared_secret = generate_shared_key(client_keys[0], server_keys[1]) - server_shared_secret = generate_shared_key(server_keys[0], client_keys[1]) - - assert client_shared_secret == server_shared_secret - - -def test_hmac() -> None: - """Test util function compute and verify hmac.""" - client_keys = generate_key_pairs() - server_keys = generate_key_pairs() - client_shared_secret = generate_shared_key(client_keys[0], server_keys[1]) - server_shared_secret = generate_shared_key(server_keys[0], client_keys[1]) - message = b"Flower is the future of AI" - - client_compute_hmac = compute_hmac(client_shared_secret, message) - - assert verify_hmac(server_shared_secret, message, client_compute_hmac) diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 803702bb97bb..4ecac8b17c3c 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -28,6 +28,15 @@ from flwr.server.superlink.state import InMemoryState, SqliteState, State +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + compute_hmac, + generate_key_pairs, + generate_shared_key, + public_key_to_bytes, + verify_hmac, +) + + class StateTest(unittest.TestCase): """Test all state implementations.""" @@ -398,6 +407,25 @@ def test_num_task_res(self) -> None: # Assert assert num == 2 + def test_client_public_keys(self) -> None: + """Test client public keys store and get from state.""" + state: State = self.state_factory() + key_pairs = [generate_key_pairs() for _ in range(3)] + public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} + + state.store_client_public_keys(public_keys) + + assert state.get_client_public_keys() == public_keys + + def test_in_memory_node_id_public_key_pair() -> None: + """Test store and get node_id public_key pair.""" + state: State = self.state_factory() + node_id = state.create_node() + public_key = public_key_to_bytes(generate_key_pairs()[1]) + + state.store_node_id_public_key_pair(node_id, public_key) + + assert state.get_public_key_from_node_id(node_id) == public_key def create_task_ins( consumer_node_id: int, From 6823c8310373227e95f20a13fe7c401bf37c9b3f Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 15 Feb 2024 13:08:06 +0100 Subject: [PATCH 24/73] Add symmetric encryption test --- src/py/flwr/server/superlink/state/state_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 4ecac8b17c3c..ca24901d0bcd 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -417,7 +417,7 @@ def test_client_public_keys(self) -> None: assert state.get_client_public_keys() == public_keys - def test_in_memory_node_id_public_key_pair() -> None: + def test_in_memory_node_id_public_key_pair(self) -> None: """Test store and get node_id public_key pair.""" state: State = self.state_factory() node_id = state.create_node() From 9df829c15cdcb305656afdf4581ac0a3bed98300 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 15 Feb 2024 13:32:12 +0100 Subject: [PATCH 25/73] Format code --- .../crypto/symmetric_encryption_test.py | 10 ++++-- src/py/flwr/server/server_interceptor.py | 31 ++++++++++--------- src/py/flwr/server/server_interceptor_test.py | 8 ++--- .../flwr/server/superlink/state/state_test.py | 14 +++------ 4 files changed, 34 insertions(+), 29 deletions(-) diff --git a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption_test.py b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption_test.py index 523b5410f1d0..c10b04998892 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption_test.py +++ b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption_test.py @@ -14,7 +14,13 @@ # ============================================================================== """Symmetric encryption tests.""" -from .symmetric_encryption import generate_shared_key, generate_key_pairs, compute_hmac, verify_hmac +from .symmetric_encryption import ( + compute_hmac, + generate_key_pairs, + generate_shared_key, + verify_hmac, +) + def test_generate_shared_key() -> None: """Test util function generate_shared_key.""" @@ -37,4 +43,4 @@ def test_hmac() -> None: client_compute_hmac = compute_hmac(client_shared_secret, message) - assert verify_hmac(server_shared_secret, message, client_compute_hmac) \ No newline at end of file + assert verify_hmac(server_shared_secret, message, client_compute_hmac) diff --git a/src/py/flwr/server/server_interceptor.py b/src/py/flwr/server/server_interceptor.py index 77815ee1a113..6b8e749774a2 100644 --- a/src/py/flwr/server/server_interceptor.py +++ b/src/py/flwr/server/server_interceptor.py @@ -41,6 +41,17 @@ _PUBLIC_KEY_HEADER = "public-key" _AUTH_TOKEN_HEADER = "auth-token" +Request = Union[ + CreateNodeRequest, DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest +] + +Response = Union[ + CreateNodeResponse, + DeleteNodeResponse, + PullTaskInsResponse, + PushTaskResResponse, +] + def _get_value_from_tuples( key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]] @@ -64,7 +75,7 @@ def __init__( self._lock = threading.Lock self.server_private_key = private_key self.server_public_key = public_key - self.state_factory = state_factory + self.state = state_factory.state def intercept_service( self, continuation: Callable, handler_call_details: grpc.HandlerCallDetails @@ -77,19 +88,9 @@ def _generic_auth_unary_method_handler( self, existing_handler: grpc.RpcMethodHandler ) -> grpc.RpcMethodHandler: def _generic_method_handler( - request: Union[ - CreateNodeRequest, - DeleteNodeRequest, - PullTaskInsRequest, - PushTaskResRequest, - ], + request: Request, context: grpc.ServicerContext, - ) -> Union[ - CreateNodeResponse, - DeleteNodeResponse, - PullTaskInsResponse, - PushTaskResResponse, - ]: + ) -> Response: with self._lock: if isinstance(request, CreateNodeRequest): client_public_key_bytes = _get_value_from_tuples( @@ -118,7 +119,9 @@ def _generic_method_handler( node_id: int = ( -1 if request.node.anonymous else request.node.node_id ) - client_public_key_bytes = state.get_public_key_from_node_id(node_id) + client_public_key_bytes = self.state.get_public_key_from_node_id( + node_id + ) shared_secret = generate_shared_key( self.server_private_key, bytes_to_public_key(client_public_key_bytes), diff --git a/src/py/flwr/server/server_interceptor_test.py b/src/py/flwr/server/server_interceptor_test.py index bcd9ab5cc7e0..4dd48c9d735b 100644 --- a/src/py/flwr/server/server_interceptor_test.py +++ b/src/py/flwr/server/server_interceptor_test.py @@ -17,10 +17,10 @@ import unittest import grpc -from app import ADDRESS_FLEET_API_GRPC_RERE, _run_fleet_api_grpc_rere -from client.app import _init_connection -from common.constant import TRANSPORT_TYPE_GRPC_RERE -from state.state_factory import StateFactory +from .app import ADDRESS_FLEET_API_GRPC_RERE, _run_fleet_api_grpc_rere +from flwr.client.app import _init_connection +from flwr.common.constant import TRANSPORT_TYPE_GRPC_RERE +from .superlink.state.state_factory import StateFactory from flwr.common import GRPC_MAX_MESSAGE_LENGTH diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index ca24901d0bcd..da0a1881156d 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -22,19 +22,14 @@ from typing import List from uuid import uuid4 -from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 -from flwr.server.superlink.state import InMemoryState, SqliteState, State - - from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - compute_hmac, generate_key_pairs, - generate_shared_key, public_key_to_bytes, - verify_hmac, ) +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 +from flwr.server.superlink.state import InMemoryState, SqliteState, State class StateTest(unittest.TestCase): @@ -427,6 +422,7 @@ def test_in_memory_node_id_public_key_pair(self) -> None: assert state.get_public_key_from_node_id(node_id) == public_key + def create_task_ins( consumer_node_id: int, anonymous: bool, From d2509453da3f36575742a85c01af226426b11ac5 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 15 Feb 2024 17:42:14 +0100 Subject: [PATCH 26/73] Make tests pass --- src/py/flwr/server/app.py | 4 +- src/py/flwr/server/server_interceptor.py | 90 ++++++++------- src/py/flwr/server/server_interceptor_test.py | 109 +++++++++++++++--- .../superlink/fleet/grpc_bidi/grpc_server.py | 4 +- .../server/superlink/state/in_memory_state.py | 15 --- .../server/superlink/state/sqlite_state.py | 23 ---- src/py/flwr/server/superlink/state/state.py | 8 -- .../flwr/server/superlink/state/state_test.py | 14 +-- 8 files changed, 147 insertions(+), 120 deletions(-) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 66adcbdb6b85..3c66bd1e7572 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -24,7 +24,7 @@ from pathlib import Path from signal import SIGINT, SIGTERM, signal from types import FrameType -from typing import List, Optional, Tuple +from typing import List, Optional, Sequence, Tuple import grpc @@ -584,6 +584,7 @@ def _run_fleet_api_grpc_rere( address: str, state_factory: StateFactory, certificates: Optional[Tuple[bytes, bytes, bytes]], + interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None, ) -> grpc.Server: """Run Fleet API (gRPC, request-response).""" # Create Fleet API gRPC server @@ -596,6 +597,7 @@ def _run_fleet_api_grpc_rere( server_address=address, max_message_length=GRPC_MAX_MESSAGE_LENGTH, certificates=certificates, + interceptors=interceptors, ) log(INFO, "Flower ECE: Starting Fleet API (gRPC-rere) on %s", address) diff --git a/src/py/flwr/server/server_interceptor.py b/src/py/flwr/server/server_interceptor.py index 6b8e749774a2..12aaae5bd2ba 100644 --- a/src/py/flwr/server/server_interceptor.py +++ b/src/py/flwr/server/server_interceptor.py @@ -14,19 +14,22 @@ # ============================================================================== """Flower server interceptor.""" +import base64 import threading -from typing import Callable, Sequence, Tuple, Union +from logging import INFO +from typing import Any, Callable, Sequence, Tuple, Union import grpc from cryptography.hazmat.primitives.asymmetric import ec +from flwr.common.logger import log from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( bytes_to_public_key, generate_shared_key, public_key_to_bytes, verify_hmac, ) -from flwr.proto.fleet_pb2 import ( +from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, DeleteNodeRequest, @@ -63,7 +66,7 @@ def _get_value_from_tuples( return value -class AuthenticateServerInterceptor(grpc.ServerInterceptor): +class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore """Server interceptor for client authentication.""" def __init__( @@ -72,72 +75,75 @@ def __init__( private_key: ec.EllipticCurvePrivateKey, public_key: ec.EllipticCurvePublicKey, ): - self._lock = threading.Lock + self._lock = threading.Lock() self.server_private_key = private_key self.server_public_key = public_key - self.state = state_factory.state + self.state = state_factory.state() def intercept_service( - self, continuation: Callable, handler_call_details: grpc.HandlerCallDetails + self, + continuation: Callable[[Any], Any], + handler_call_details: grpc.HandlerCallDetails, ) -> grpc.RpcMethodHandler: """Flower server interceptor authentication logic.""" message_handler: grpc.RpcMethodHandler = continuation(handler_call_details) return self._generic_auth_unary_method_handler(message_handler) def _generic_auth_unary_method_handler( - self, existing_handler: grpc.RpcMethodHandler + self, message_handler: grpc.RpcMethodHandler ) -> grpc.RpcMethodHandler: def _generic_method_handler( request: Request, context: grpc.ServicerContext, - ) -> Response: + ) -> Any: with self._lock: - if isinstance(request, CreateNodeRequest): - client_public_key_bytes = _get_value_from_tuples( - _PUBLIC_KEY_HEADER, context.invocation_metadata() - ) - is_public_key_known = ( - client_public_key_bytes in self.state.get_client_public_keys() - ) - if is_public_key_known: - context.set_trailing_metadata( + encoded_bytes = _get_value_from_tuples( + _PUBLIC_KEY_HEADER, context.invocation_metadata() + )[::-1] + log(INFO, "Client public key bytes: %s", encoded_bytes) + client_public_key_bytes = base64.urlsafe_b64decode(encoded_bytes) + is_public_key_known = ( + client_public_key_bytes in self.state.get_client_public_keys() + ) + if is_public_key_known: + if isinstance(request, CreateNodeRequest): + context.send_initial_metadata( ( ( _PUBLIC_KEY_HEADER, - public_key_to_bytes(self.server_public_key), + base64.urlsafe_b64encode( + public_key_to_bytes(self.server_public_key) + ), ), ) ) + elif isinstance(request, (DeleteNodeRequest, PullTaskInsRequest)): + encoded_bytes = _get_value_from_tuples( + _AUTH_TOKEN_HEADER, context.invocation_metadata() + )[::-1] + hmac_value = base64.urlsafe_b64decode(encoded_bytes) + log(INFO, "Client public key bytes: %s", encoded_bytes) + client_public_key = bytes_to_public_key(client_public_key_bytes) + shared_secret = generate_shared_key( + self.server_private_key, + client_public_key, + ) + verify = verify_hmac( + shared_secret, request.SerializeToString(True), hmac_value + ) + if not verify: + context.abort( + grpc.StatusCode.UNAUTHENTICATED, "Access denied!" + ) else: context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") - elif isinstance( - request, (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest) - ): - hmac_value = _get_value_from_tuples( - _AUTH_TOKEN_HEADER, context.invocation_metadata() - ) - node_id: int = ( - -1 if request.node.anonymous else request.node.node_id - ) - client_public_key_bytes = self.state.get_public_key_from_node_id( - node_id - ) - shared_secret = generate_shared_key( - self.server_private_key, - bytes_to_public_key(client_public_key_bytes), - ) - verify = verify_hmac( - shared_secret, request.SerializeToString(True), hmac_value - ) - if not verify: - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") else: context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") - return existing_handler.unary_unary + return message_handler.unary_unary(request, context) return grpc.unary_unary_rpc_method_handler( _generic_method_handler, - request_deserializer=existing_handler.request_deserializer, - response_serializer=existing_handler.response_serializer, + request_deserializer=message_handler.request_deserializer, + response_serializer=message_handler.response_serializer, ) diff --git a/src/py/flwr/server/server_interceptor_test.py b/src/py/flwr/server/server_interceptor_test.py index 4dd48c9d735b..09d524c3fc7c 100644 --- a/src/py/flwr/server/server_interceptor_test.py +++ b/src/py/flwr/server/server_interceptor_test.py @@ -14,40 +14,113 @@ # ============================================================================== """Flower server interceptor tests.""" +import base64 import unittest +from logging import INFO import grpc + +from flwr.common.logger import log +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + compute_hmac, + generate_key_pairs, + generate_shared_key, + public_key_to_bytes, +) +from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 + CreateNodeRequest, + CreateNodeResponse, + DeleteNodeRequest, + DeleteNodeResponse, +) + from .app import ADDRESS_FLEET_API_GRPC_RERE, _run_fleet_api_grpc_rere -from flwr.client.app import _init_connection -from flwr.common.constant import TRANSPORT_TYPE_GRPC_RERE +from .server_interceptor import ( + _AUTH_TOKEN_HEADER, + _PUBLIC_KEY_HEADER, + AuthenticateServerInterceptor, +) from .superlink.state.state_factory import StateFactory -from flwr.common import GRPC_MAX_MESSAGE_LENGTH - -class TestServerInterceptor(unittest.TestCase): +class TestServerInterceptor(unittest.TestCase): # pylint: disable=R0902 """Server interceptor tests.""" - def setUp(self): + def setUp(self) -> None: """Initialize mock stub and server interceptor.""" - self._state_factory = StateFactory(":flwr-in-memory-state:") + self._client_private_key, self._client_public_key = generate_key_pairs() + self._server_private_key, self._server_public_key = generate_key_pairs() + + state_factory = StateFactory(":flwr-in-memory-state:") + state_factory.state().store_client_public_key( + public_key_to_bytes(self._client_public_key) + ) + self._server_interceptor = AuthenticateServerInterceptor( + state_factory, self._server_private_key, self._server_public_key + ) self._server: grpc.Server = _run_fleet_api_grpc_rere( - ADDRESS_FLEET_API_GRPC_RERE, self._state_factory + ADDRESS_FLEET_API_GRPC_RERE, state_factory, None, [self._server_interceptor] ) - self._connection, self._address = _init_connection( - TRANSPORT_TYPE_GRPC_RERE, ADDRESS_FLEET_API_GRPC_RERE + + self._client_node_id: int = -1 + + self._channel = grpc.insecure_channel("localhost:9092") + self._create_node = self._channel.unary_unary( + "/flwr.proto.Fleet/CreateNode", + request_serializer=CreateNodeRequest.SerializeToString, + response_deserializer=CreateNodeResponse.FromString, + ) + self._delete_node = self._channel.unary_unary( + "/flwr.proto.Fleet/DeleteNode", + request_serializer=DeleteNodeRequest.SerializeToString, + response_deserializer=DeleteNodeResponse.FromString, ) - with self._connection( - self._address, - True, - GRPC_MAX_MESSAGE_LENGTH, - ) as conn: - self._receive, self._send, self._create_node, self._delete_node = conn - def tearDown(self): + def tearDown(self) -> None: """Clean up grpc server.""" self._server.stop(None) def test_successful_create_node_with_metadata(self) -> None: """Test server interceptor for create node.""" - self._create_node() + public_key_bytes = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ) + log(INFO, "Client public key bytes: %s", public_key_bytes) + response, call = self._create_node.with_call( + request=CreateNodeRequest(), + metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),), + ) + + expected_metadata = ( + _PUBLIC_KEY_HEADER, + base64.urlsafe_b64encode( + public_key_to_bytes(self._server_public_key) + ).decode(), + ) + + assert call.initial_metadata()[0] == expected_metadata + assert isinstance(response, CreateNodeResponse) + self._client_node_id = response.node.node_id + + def test_successful_delete_node_with_metadata(self) -> None: + """Test server interceptor for create node.""" + request = DeleteNodeRequest() + shared_secret = generate_shared_key( + self._client_private_key, self._server_public_key + ) + hmac_value = base64.urlsafe_b64encode( + compute_hmac(shared_secret, request.SerializeToString(True)) + ) + public_key_bytes = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ) + log(INFO, "Client public key bytes: %s", public_key_bytes) + response, call = self._delete_node.with_call( + request=request, + metadata=( + (_PUBLIC_KEY_HEADER, public_key_bytes), + (_AUTH_TOKEN_HEADER, hmac_value), + ), + ) + assert isinstance(response, DeleteNodeResponse) + assert grpc.StatusCode.OK == call.code() diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py index 82f049844bd6..6aeaa7ef413f 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py @@ -18,7 +18,7 @@ import concurrent.futures import sys from logging import ERROR -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Optional, Sequence, Tuple, Union import grpc @@ -162,6 +162,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, keepalive_time_ms: int = 210000, certificates: Optional[Tuple[bytes, bytes, bytes]] = None, + interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None, ) -> grpc.Server: """Create a gRPC server with a single servicer. @@ -249,6 +250,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments # returning RESOURCE_EXHAUSTED status, or None to indicate no limit. maximum_concurrent_rpcs=max_concurrent_workers, options=options, + interceptors=interceptors, ) add_servicer_to_server_fn(servicer, server) diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index 93ae1cab97a8..45459b7e77b2 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -35,7 +35,6 @@ def __init__(self) -> None: self.run_ids: Set[int] = set() self.task_ins_store: Dict[UUID, TaskIns] = {} self.task_res_store: Dict[UUID, TaskRes] = {} - self.node_id_public_key_dict: Dict[int, bytes] = {} self.client_public_keys: Set[bytes] = set() self.server_public_key: bytes = b"" self.server_private_key: bytes = b"" @@ -226,20 +225,6 @@ def create_run(self) -> int: log(ERROR, "Unexpected run creation failure.") return 0 - def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: - """Store `node_id` and `public_key` as key-value pair in state.""" - if node_id not in self.node_ids: - raise ValueError(f"Node {node_id} not found") - if node_id in self.node_id_public_key_dict: - raise ValueError(f"Node {node_id} has already assigned a public key") - self.node_id_public_key_dict[node_id] = public_key - - def get_public_key_from_node_id(self, node_id: int) -> bytes: - """Get client's public key in urlsafe bytes for `node_id`.""" - if node_id in self.node_id_public_key_dict: - return self.node_id_public_key_dict[node_id] - return b"" - def store_server_public_private_key( self, public_key: bytes, private_key: bytes ) -> None: diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index f671ff12a1e2..bf65828ac10e 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -37,13 +37,6 @@ ); """ -SQL_CREATE_TABLE_NODE_KEY = """ -CREATE TABLE IF NOT EXISTS node_key( - node_id INTEGER PRIMARY KEY, - public_key BLOB -); -""" - SQL_CREATE_TABLE_CREDENTIAL = """ CREATE TABLE IF NOT EXISTS credential( public_key BLOB PRIMARY KEY, @@ -144,7 +137,6 @@ def initialize(self, log_queries: bool = False) -> List[Tuple[str]]: cur.execute(SQL_CREATE_TABLE_TASK_RES) cur.execute(SQL_CREATE_TABLE_NODE) cur.execute(SQL_CREATE_TABLE_CREDENTIAL) - cur.execute(SQL_CREATE_TABLE_NODE_KEY) cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY) res = cur.execute("SELECT name FROM sqlite_schema;") @@ -542,21 +534,6 @@ def create_run(self) -> int: log(ERROR, "Unexpected run creation failure.") return 0 - def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: - """Store `node_id` and `public_key` as key-value pair in state.""" - query = ( - "INSERT OR REPLACE INTO node_key (node_id, public_key) " - "VALUES (:node_id, :public_key)" - ) - self.query(query, {"node_id": node_id, "public_key": public_key}) - - def get_public_key_from_node_id(self, node_id: int) -> bytes: - """Get client's public key in urlsafe bytes for `node_id`.""" - query = "SELECT public_key FROM node_key WHERE node_id = :node_id" - rows = self.query(query, {"node_id": node_id}) - public_key: bytes = rows[0]["public_key"] - return public_key - def store_server_public_private_key( self, public_key: bytes, private_key: bytes ) -> None: diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index 645b3ec1b290..b4a26db0afd8 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -153,14 +153,6 @@ def get_nodes(self, run_id: int) -> Set[int]: def create_run(self) -> int: """Create one run.""" - @abc.abstractmethod - def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: - """Store `node_id` and `public_key` as key-value pair in state.""" - - @abc.abstractmethod - def get_public_key_from_node_id(self, node_id: int) -> bytes: - """Get client's public key in urlsafe bytes for `node_id`.""" - @abc.abstractmethod def store_server_public_private_key( self, public_key: bytes, private_key: bytes diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index da0a1881156d..f23855a00e45 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -412,16 +412,6 @@ def test_client_public_keys(self) -> None: assert state.get_client_public_keys() == public_keys - def test_in_memory_node_id_public_key_pair(self) -> None: - """Test store and get node_id public_key pair.""" - state: State = self.state_factory() - node_id = state.create_node() - public_key = public_key_to_bytes(generate_key_pairs()[1]) - - state.store_node_id_public_key_pair(node_id, public_key) - - assert state.get_public_key_from_node_id(node_id) == public_key - def create_task_ins( consumer_node_id: int, @@ -501,7 +491,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 13 + assert len(result) == 12 class SqliteFileBasedTest(StateTest, unittest.TestCase): @@ -526,7 +516,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 13 + assert len(result) == 12 if __name__ == "__main__": From 2f61623d80316f22aaeaa8febf4b00daf93e4ee0 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 15 Feb 2024 17:58:05 +0100 Subject: [PATCH 27/73] Revert commit to only include auth state --- src/py/flwr/server/app.py | 4 +- src/py/flwr/server/server_interceptor.py | 149 ------------------ src/py/flwr/server/server_interceptor_test.py | 126 --------------- .../superlink/fleet/grpc_bidi/grpc_server.py | 4 +- 4 files changed, 2 insertions(+), 281 deletions(-) delete mode 100644 src/py/flwr/server/server_interceptor.py delete mode 100644 src/py/flwr/server/server_interceptor_test.py diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 3c66bd1e7572..66adcbdb6b85 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -24,7 +24,7 @@ from pathlib import Path from signal import SIGINT, SIGTERM, signal from types import FrameType -from typing import List, Optional, Sequence, Tuple +from typing import List, Optional, Tuple import grpc @@ -584,7 +584,6 @@ def _run_fleet_api_grpc_rere( address: str, state_factory: StateFactory, certificates: Optional[Tuple[bytes, bytes, bytes]], - interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None, ) -> grpc.Server: """Run Fleet API (gRPC, request-response).""" # Create Fleet API gRPC server @@ -597,7 +596,6 @@ def _run_fleet_api_grpc_rere( server_address=address, max_message_length=GRPC_MAX_MESSAGE_LENGTH, certificates=certificates, - interceptors=interceptors, ) log(INFO, "Flower ECE: Starting Fleet API (gRPC-rere) on %s", address) diff --git a/src/py/flwr/server/server_interceptor.py b/src/py/flwr/server/server_interceptor.py deleted file mode 100644 index 12aaae5bd2ba..000000000000 --- a/src/py/flwr/server/server_interceptor.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Flower server interceptor.""" - -import base64 -import threading -from logging import INFO -from typing import Any, Callable, Sequence, Tuple, Union - -import grpc -from cryptography.hazmat.primitives.asymmetric import ec - -from flwr.common.logger import log -from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - bytes_to_public_key, - generate_shared_key, - public_key_to_bytes, - verify_hmac, -) -from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 - CreateNodeRequest, - CreateNodeResponse, - DeleteNodeRequest, - DeleteNodeResponse, - PullTaskInsRequest, - PullTaskInsResponse, - PushTaskResRequest, - PushTaskResResponse, -) -from flwr.server.superlink.state import StateFactory - -_PUBLIC_KEY_HEADER = "public-key" -_AUTH_TOKEN_HEADER = "auth-token" - -Request = Union[ - CreateNodeRequest, DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest -] - -Response = Union[ - CreateNodeResponse, - DeleteNodeResponse, - PullTaskInsResponse, - PushTaskResResponse, -] - - -def _get_value_from_tuples( - key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]] -) -> bytes: - value = next((value[::-1] for key, value in tuples if key == key_string), "") - if isinstance(value, str): - return value.encode() - - return value - - -class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore - """Server interceptor for client authentication.""" - - def __init__( - self, - state_factory: StateFactory, - private_key: ec.EllipticCurvePrivateKey, - public_key: ec.EllipticCurvePublicKey, - ): - self._lock = threading.Lock() - self.server_private_key = private_key - self.server_public_key = public_key - self.state = state_factory.state() - - def intercept_service( - self, - continuation: Callable[[Any], Any], - handler_call_details: grpc.HandlerCallDetails, - ) -> grpc.RpcMethodHandler: - """Flower server interceptor authentication logic.""" - message_handler: grpc.RpcMethodHandler = continuation(handler_call_details) - return self._generic_auth_unary_method_handler(message_handler) - - def _generic_auth_unary_method_handler( - self, message_handler: grpc.RpcMethodHandler - ) -> grpc.RpcMethodHandler: - def _generic_method_handler( - request: Request, - context: grpc.ServicerContext, - ) -> Any: - with self._lock: - encoded_bytes = _get_value_from_tuples( - _PUBLIC_KEY_HEADER, context.invocation_metadata() - )[::-1] - log(INFO, "Client public key bytes: %s", encoded_bytes) - client_public_key_bytes = base64.urlsafe_b64decode(encoded_bytes) - is_public_key_known = ( - client_public_key_bytes in self.state.get_client_public_keys() - ) - if is_public_key_known: - if isinstance(request, CreateNodeRequest): - context.send_initial_metadata( - ( - ( - _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode( - public_key_to_bytes(self.server_public_key) - ), - ), - ) - ) - elif isinstance(request, (DeleteNodeRequest, PullTaskInsRequest)): - encoded_bytes = _get_value_from_tuples( - _AUTH_TOKEN_HEADER, context.invocation_metadata() - )[::-1] - hmac_value = base64.urlsafe_b64decode(encoded_bytes) - log(INFO, "Client public key bytes: %s", encoded_bytes) - client_public_key = bytes_to_public_key(client_public_key_bytes) - shared_secret = generate_shared_key( - self.server_private_key, - client_public_key, - ) - verify = verify_hmac( - shared_secret, request.SerializeToString(True), hmac_value - ) - if not verify: - context.abort( - grpc.StatusCode.UNAUTHENTICATED, "Access denied!" - ) - else: - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") - else: - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") - - return message_handler.unary_unary(request, context) - - return grpc.unary_unary_rpc_method_handler( - _generic_method_handler, - request_deserializer=message_handler.request_deserializer, - response_serializer=message_handler.response_serializer, - ) diff --git a/src/py/flwr/server/server_interceptor_test.py b/src/py/flwr/server/server_interceptor_test.py deleted file mode 100644 index 09d524c3fc7c..000000000000 --- a/src/py/flwr/server/server_interceptor_test.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Flower server interceptor tests.""" - -import base64 -import unittest -from logging import INFO - -import grpc - -from flwr.common.logger import log -from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - compute_hmac, - generate_key_pairs, - generate_shared_key, - public_key_to_bytes, -) -from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 - CreateNodeRequest, - CreateNodeResponse, - DeleteNodeRequest, - DeleteNodeResponse, -) - -from .app import ADDRESS_FLEET_API_GRPC_RERE, _run_fleet_api_grpc_rere -from .server_interceptor import ( - _AUTH_TOKEN_HEADER, - _PUBLIC_KEY_HEADER, - AuthenticateServerInterceptor, -) -from .superlink.state.state_factory import StateFactory - - -class TestServerInterceptor(unittest.TestCase): # pylint: disable=R0902 - """Server interceptor tests.""" - - def setUp(self) -> None: - """Initialize mock stub and server interceptor.""" - self._client_private_key, self._client_public_key = generate_key_pairs() - self._server_private_key, self._server_public_key = generate_key_pairs() - - state_factory = StateFactory(":flwr-in-memory-state:") - state_factory.state().store_client_public_key( - public_key_to_bytes(self._client_public_key) - ) - self._server_interceptor = AuthenticateServerInterceptor( - state_factory, self._server_private_key, self._server_public_key - ) - self._server: grpc.Server = _run_fleet_api_grpc_rere( - ADDRESS_FLEET_API_GRPC_RERE, state_factory, None, [self._server_interceptor] - ) - - self._client_node_id: int = -1 - - self._channel = grpc.insecure_channel("localhost:9092") - self._create_node = self._channel.unary_unary( - "/flwr.proto.Fleet/CreateNode", - request_serializer=CreateNodeRequest.SerializeToString, - response_deserializer=CreateNodeResponse.FromString, - ) - self._delete_node = self._channel.unary_unary( - "/flwr.proto.Fleet/DeleteNode", - request_serializer=DeleteNodeRequest.SerializeToString, - response_deserializer=DeleteNodeResponse.FromString, - ) - - def tearDown(self) -> None: - """Clean up grpc server.""" - self._server.stop(None) - - def test_successful_create_node_with_metadata(self) -> None: - """Test server interceptor for create node.""" - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._client_public_key) - ) - log(INFO, "Client public key bytes: %s", public_key_bytes) - response, call = self._create_node.with_call( - request=CreateNodeRequest(), - metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),), - ) - - expected_metadata = ( - _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode( - public_key_to_bytes(self._server_public_key) - ).decode(), - ) - - assert call.initial_metadata()[0] == expected_metadata - assert isinstance(response, CreateNodeResponse) - self._client_node_id = response.node.node_id - - def test_successful_delete_node_with_metadata(self) -> None: - """Test server interceptor for create node.""" - request = DeleteNodeRequest() - shared_secret = generate_shared_key( - self._client_private_key, self._server_public_key - ) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._client_public_key) - ) - log(INFO, "Client public key bytes: %s", public_key_bytes) - response, call = self._delete_node.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) - assert isinstance(response, DeleteNodeResponse) - assert grpc.StatusCode.OK == call.code() diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py index 6aeaa7ef413f..82f049844bd6 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py @@ -18,7 +18,7 @@ import concurrent.futures import sys from logging import ERROR -from typing import Any, Callable, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union import grpc @@ -162,7 +162,6 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, keepalive_time_ms: int = 210000, certificates: Optional[Tuple[bytes, bytes, bytes]] = None, - interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None, ) -> grpc.Server: """Create a gRPC server with a single servicer. @@ -250,7 +249,6 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments # returning RESOURCE_EXHAUSTED status, or None to indicate no limit. maximum_concurrent_rpcs=max_concurrent_workers, options=options, - interceptors=interceptors, ) add_servicer_to_server_fn(servicer, server) From 2f7aa4855eabc0e4d9ceb6f51bce3faea356c78c Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Fri, 16 Feb 2024 09:28:23 +0100 Subject: [PATCH 28/73] Remove logging messages --- src/py/flwr/server/server_interceptor.py | 24 +++++++++---------- src/py/flwr/server/server_interceptor_test.py | 4 ---- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/src/py/flwr/server/server_interceptor.py b/src/py/flwr/server/server_interceptor.py index 12aaae5bd2ba..d4a4f4df157e 100644 --- a/src/py/flwr/server/server_interceptor.py +++ b/src/py/flwr/server/server_interceptor.py @@ -16,13 +16,11 @@ import base64 import threading -from logging import INFO from typing import Any, Callable, Sequence, Tuple, Union import grpc from cryptography.hazmat.primitives.asymmetric import ec -from flwr.common.logger import log from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( bytes_to_public_key, generate_shared_key, @@ -59,7 +57,7 @@ def _get_value_from_tuples( key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]] ) -> bytes: - value = next((value[::-1] for key, value in tuples if key == key_string), "") + value = next((value for key, value in tuples if key == key_string), "") if isinstance(value, str): return value.encode() @@ -97,11 +95,11 @@ def _generic_method_handler( context: grpc.ServicerContext, ) -> Any: with self._lock: - encoded_bytes = _get_value_from_tuples( - _PUBLIC_KEY_HEADER, context.invocation_metadata() - )[::-1] - log(INFO, "Client public key bytes: %s", encoded_bytes) - client_public_key_bytes = base64.urlsafe_b64decode(encoded_bytes) + client_public_key_bytes = base64.urlsafe_b64decode( + _get_value_from_tuples( + _PUBLIC_KEY_HEADER, context.invocation_metadata() + ) + ) is_public_key_known = ( client_public_key_bytes in self.state.get_client_public_keys() ) @@ -118,11 +116,11 @@ def _generic_method_handler( ) ) elif isinstance(request, (DeleteNodeRequest, PullTaskInsRequest)): - encoded_bytes = _get_value_from_tuples( - _AUTH_TOKEN_HEADER, context.invocation_metadata() - )[::-1] - hmac_value = base64.urlsafe_b64decode(encoded_bytes) - log(INFO, "Client public key bytes: %s", encoded_bytes) + hmac_value = base64.urlsafe_b64decode( + _get_value_from_tuples( + _AUTH_TOKEN_HEADER, context.invocation_metadata() + ) + ) client_public_key = bytes_to_public_key(client_public_key_bytes) shared_secret = generate_shared_key( self.server_private_key, diff --git a/src/py/flwr/server/server_interceptor_test.py b/src/py/flwr/server/server_interceptor_test.py index 09d524c3fc7c..9284a30838f0 100644 --- a/src/py/flwr/server/server_interceptor_test.py +++ b/src/py/flwr/server/server_interceptor_test.py @@ -16,11 +16,9 @@ import base64 import unittest -from logging import INFO import grpc -from flwr.common.logger import log from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( compute_hmac, generate_key_pairs, @@ -85,7 +83,6 @@ def test_successful_create_node_with_metadata(self) -> None: public_key_bytes = base64.urlsafe_b64encode( public_key_to_bytes(self._client_public_key) ) - log(INFO, "Client public key bytes: %s", public_key_bytes) response, call = self._create_node.with_call( request=CreateNodeRequest(), metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),), @@ -114,7 +111,6 @@ def test_successful_delete_node_with_metadata(self) -> None: public_key_bytes = base64.urlsafe_b64encode( public_key_to_bytes(self._client_public_key) ) - log(INFO, "Client public key bytes: %s", public_key_bytes) response, call = self._delete_node.with_call( request=request, metadata=( From 6ee8a61b4a3e8289b3624b3f3046ae9b6fd5a93b Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 22 Feb 2024 10:25:58 +0100 Subject: [PATCH 29/73] Update server interceptor --- src/py/flwr/server/server_interceptor.py | 16 ++++++++++++++-- src/py/flwr/server/server_interceptor_test.py | 9 +++++---- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/server/server_interceptor.py b/src/py/flwr/server/server_interceptor.py index d4a4f4df157e..6a0b13d679dc 100644 --- a/src/py/flwr/server/server_interceptor.py +++ b/src/py/flwr/server/server_interceptor.py @@ -16,11 +16,13 @@ import base64 import threading -from typing import Any, Callable, Sequence, Tuple, Union +from logging import INFO +from typing import Any, Callable, Sequence, Set, Tuple, Union import grpc from cryptography.hazmat.primitives.asymmetric import ec +from flwr.common.logger import log from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( bytes_to_public_key, generate_shared_key, @@ -70,6 +72,7 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore def __init__( self, state_factory: StateFactory, + client_public_keys: Set[bytes], private_key: ec.EllipticCurvePrivateKey, public_key: ec.EllipticCurvePublicKey, ): @@ -77,6 +80,12 @@ def __init__( self.server_private_key = private_key self.server_public_key = public_key self.state = state_factory.state() + self.state.store_client_public_keys(client_public_keys) + log( + INFO, + "Client authentication enabled with %d known public keys", + len(client_public_keys), + ) def intercept_service( self, @@ -115,7 +124,10 @@ def _generic_method_handler( ), ) ) - elif isinstance(request, (DeleteNodeRequest, PullTaskInsRequest)): + elif isinstance( + request, + (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest), + ): hmac_value = base64.urlsafe_b64decode( _get_value_from_tuples( _AUTH_TOKEN_HEADER, context.invocation_metadata() diff --git a/src/py/flwr/server/server_interceptor_test.py b/src/py/flwr/server/server_interceptor_test.py index 9284a30838f0..2a886bc07cd8 100644 --- a/src/py/flwr/server/server_interceptor_test.py +++ b/src/py/flwr/server/server_interceptor_test.py @@ -50,11 +50,12 @@ def setUp(self) -> None: self._server_private_key, self._server_public_key = generate_key_pairs() state_factory = StateFactory(":flwr-in-memory-state:") - state_factory.state().store_client_public_key( - public_key_to_bytes(self._client_public_key) - ) + self._server_interceptor = AuthenticateServerInterceptor( - state_factory, self._server_private_key, self._server_public_key + state_factory, + {public_key_to_bytes(self._client_public_key)}, + self._server_private_key, + self._server_public_key, ) self._server: grpc.Server = _run_fleet_api_grpc_rere( ADDRESS_FLEET_API_GRPC_RERE, state_factory, None, [self._server_interceptor] From 041482e8ef26f11b633a1d3b50b7fc3a0230a842 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 22 Feb 2024 10:36:33 +0100 Subject: [PATCH 30/73] Merge from add-auth-cli --- src/py/flwr/server/app.py | 92 ++++++++++++++++++++++++++++- src/py/flwr/server/server_test.py | 96 +++++++++++++++++++++++++++++++ 2 files changed, 187 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index a51244b218f5..b36de3298db8 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -16,6 +16,7 @@ import argparse +import csv import importlib.util import sys import threading @@ -24,9 +25,14 @@ from pathlib import Path from signal import SIGINT, SIGTERM, signal from types import FrameType -from typing import List, Optional, Sequence, Tuple +from typing import List, Optional, Sequence, Set, Tuple import grpc +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.serialization import ( + load_ssh_private_key, + load_ssh_public_key, +) from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event from flwr.common.address import parse_address @@ -36,6 +42,9 @@ TRANSPORT_TYPE_REST, ) from flwr.common.logger import log +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + public_key_to_bytes, +) from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611 add_DriverServicer_to_server, ) @@ -47,6 +56,7 @@ from .history import History from .server import Server from .server_config import ServerConfig +from .server_interceptor import AuthenticateServerInterceptor from .strategy import FedAvg, Strategy from .superlink.driver.driver_servicer import DriverServicer from .superlink.fleet.grpc_bidi.grpc_server import ( @@ -395,10 +405,29 @@ def run_superlink() -> None: sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.") host, port, is_v6 = parsed_address address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" + + data = _try_setup_client_authentication(args) + interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None + if data is not None: + ( + client_public_keys, + server_public_key, + server_private_key, + ) = data + interceptors = [ + AuthenticateServerInterceptor( + state_factory, + client_public_keys, + server_private_key, + server_public_key, + ) + ] + fleet_server = _run_fleet_api_grpc_rere( address=address, state_factory=state_factory, certificates=certificates, + interceptors=interceptors, ) grpc_servers.append(fleet_server) else: @@ -420,6 +449,59 @@ def run_superlink() -> None: driver_server.wait_for_termination(timeout=1) +def _try_setup_client_authentication( + args: argparse.Namespace, +) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePublicKey, ec.EllipticCurvePrivateKey]]: + if args.require_client_authentication: + client_keys_file_path = Path(args.require_client_authentication[0]) + if client_keys_file_path.exists(): + client_public_keys: Set[bytes] = set() + public_key = load_ssh_public_key( + Path(args.require_client_authentication[1]).read_bytes() + ) + private_key = load_ssh_private_key( + Path(args.require_client_authentication[2]).read_bytes(), + None, + ) + log(INFO, type(public_key)) + log(INFO, type(private_key)) + if not isinstance(public_key, ec.EllipticCurvePublicKey): + sys.exit( + "An eliptic curve public and private key pair is required for " + "client authentication. Please provide the file path containing " + "valid public and private key to '--require-client-authentication'." + ) + server_public_key = public_key + if not isinstance(private_key, ec.EllipticCurvePrivateKey): + sys.exit( + "An eliptic curve public and private key pair is required for " + "client authentication. Please provide the file path containing " + "valid public and private key to '--require-client-authentication'." + ) + server_private_key = private_key + + with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile: + reader = csv.reader(csvfile) + for row in reader: + for element in row: + public_key = load_ssh_public_key(element.encode()) + if isinstance(public_key, ec.EllipticCurvePublicKey): + client_public_keys.add(public_key_to_bytes(public_key)) + return ( + client_public_keys, + server_public_key, + server_private_key, + ) + else: + sys.exit( + "Client public keys csv file are required for client authentication. " + "Please provide the csv file path containing known client public keys " + "to '--require-client-authentication'." + ) + else: + return None + + def _try_obtain_certificates( args: argparse.Namespace, ) -> Optional[Tuple[bytes, bytes, bytes]]: @@ -687,6 +769,14 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None: "Flower will just create a state in memory.", default=DATABASE, ) + parser.add_argument( + "--require-client-authentication", + nargs=3, + metavar=("CLIENT_KEYS", "SERVER_PUBLIC_KEY", "SERVER_PRIVATE_KEY"), + type=str, + help="Paths to .csv file containing list of known client public keys for " + "authentication, server public key, and server private key, in that order.", + ) def _add_args_driver_api(parser: argparse.ArgumentParser) -> None: diff --git a/src/py/flwr/server/server_test.py b/src/py/flwr/server/server_test.py index 9b5c03aeeaf9..133a46bffdf9 100644 --- a/src/py/flwr/server/server_test.py +++ b/src/py/flwr/server/server_test.py @@ -15,9 +15,20 @@ """Flower server tests.""" +import argparse +import csv +import tempfile +from pathlib import Path from typing import List, Optional import numpy as np +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.serialization import ( + Encoding, + PublicFormat, + load_ssh_private_key, + load_ssh_public_key, +) from flwr.common import ( Code, @@ -35,8 +46,14 @@ Status, ndarray_to_bytes, ) +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + generate_key_pairs, + private_key_to_bytes, + public_key_to_bytes, +) from flwr.server.client_manager import SimpleClientManager +from .app import _try_setup_client_authentication from .client_proxy import ClientProxy from .server import Server, evaluate_clients, fit_clients @@ -169,3 +186,82 @@ def test_set_max_workers() -> None: # Assert assert server.max_workers == 42 + + +def test_setup_client_auth() -> None: + """Test setup client authentication.""" + # Generate keys + _, first_public_key = generate_key_pairs() + server_public_key = ( + b"ecdsa-sha2-nistp384 " + b"AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABhBIqtP/EvrgBYukcjRJT9zVLXE" + b"fykvVvT/QcHXuxCNu83SyCwedk3nNZxy5rZ1f8KoU+OSGmum5I9BxnWcLeBC+TGqpifTUSNwa/riV" + b"oJGcN/SxF3euqQg58YePORhos/Ug==" + ) + server_private_key = b"""-----BEGIN OPENSSH PRIVATE KEY----- + b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAiAAAABNlY2RzYS + 1zaGEyLW5pc3RwMzg0AAAACG5pc3RwMzg0AAAAYQSKrT/xL64AWLpHI0SU/c1S1xH8pL1b + 0/0HB17sQjbvN0sgsHnZN5zWccua2dX/CqFPjkhprpuSPQcZ1nC3gQvkxqqYn01EjcGv64 + laCRnDf0sRd3rqkIOfGHjzkYaLP1IAAADw/rbMO/62zDsAAAATZWNkc2Etc2hhMi1uaXN0 + cDM4NAAAAAhuaXN0cDM4NAAAAGEEiq0/8S+uAFi6RyNElP3NUtcR/KS9W9P9Bwde7EI27z + dLILB52Tec1nHLmtnV/wqhT45Iaa6bkj0HGdZwt4EL5MaqmJ9NRI3Br+uJWgkZw39LEXd6 + 6pCDnxh485GGiz9SAAAAMQDQmvP7JeFNBDvo1VXciQF0Wv3/DCcj9x0kUABuX1gxb42Iw3 + v7FOEco/enMaS4URwAAAAnZGFuaWVsbnVncmFoYUBEYW5pZWxzLU1hY0Jvb2stUHJvLmxv + Y2Fs + -----END OPENSSH PRIVATE KEY-----""" + _, second_public_key = generate_key_pairs() + + with tempfile.TemporaryDirectory() as temp_dir: + # Initialize temporary files + client_keys_file_path = Path(temp_dir) / "client_keys.csv" + server_public_key_path = Path(temp_dir) / "server_public_key" + server_private_key_path = Path(temp_dir) / "server_private_key" + + # Fill the files with relevant keys + with open(client_keys_file_path, "w", newline="", encoding="utf-8") as csvfile: + writer = csv.writer(csvfile) + writer.writerow( + [ + first_public_key.public_bytes( + encoding=Encoding.OpenSSH, format=PublicFormat.OpenSSH + ).decode(), + second_public_key.public_bytes( + encoding=Encoding.OpenSSH, format=PublicFormat.OpenSSH + ).decode(), + ] + ) + server_public_key_path.write_bytes(server_public_key) + server_private_key_path.write_bytes(server_private_key) + + # Mock argparse with `require-client-authentication`` flag + mock_args = argparse.Namespace( + require_client_authentication=[ + str(client_keys_file_path), + str(server_public_key_path), + str(server_private_key_path), + ] + ) + + # Run _try_setup_client_authentication + result = _try_setup_client_authentication(mock_args) + + expected_private_key = load_ssh_private_key(server_private_key, None) + expected_public_key = load_ssh_public_key(server_public_key) + + if isinstance(expected_private_key, ec.EllipticCurvePrivateKey) and isinstance( + expected_public_key, ec.EllipticCurvePublicKey + ): + # Assert result with expected values + assert result is not None + assert result[0] == { + public_key_to_bytes(first_public_key), + public_key_to_bytes(second_public_key), + } + assert public_key_to_bytes(result[1]) == public_key_to_bytes( + expected_public_key + ) + assert private_key_to_bytes(result[2]) == private_key_to_bytes( + expected_private_key + ) + else: + raise AssertionError() From 781796e7fc0a3fdc81c69e42f2f114287634d9aa Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 28 Feb 2024 16:29:46 +0100 Subject: [PATCH 31/73] Docstring changes --- .../crypto/symmetric_encryption_test.py | 3 ++- .../flwr/server/superlink/state/in_memory_state.py | 12 ++++++------ src/py/flwr/server/superlink/state/sqlite_state.py | 12 ++++++------ src/py/flwr/server/superlink/state/state.py | 12 ++++++------ 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption_test.py b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption_test.py index c10b04998892..7755016eb7af 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption_test.py +++ b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ # ============================================================================== """Symmetric encryption tests.""" + from .symmetric_encryption import ( compute_hmac, generate_key_pairs, diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index 45459b7e77b2..32782cd56eb8 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -228,26 +228,26 @@ def create_run(self) -> int: def store_server_public_private_key( self, public_key: bytes, private_key: bytes ) -> None: - """Store server's `public_key` and `private_key` in state.""" + """Store `server_public_key` and `server_private_key` in state.""" self.server_private_key = private_key self.server_public_key = public_key def get_server_private_key(self) -> bytes: - """Get server private key in urlsafe bytes.""" + """Retrieve `server_private_key` in urlsafe bytes.""" return self.server_private_key def get_server_public_key(self) -> bytes: - """Get server public key in urlsafe bytes.""" + """Retrieve `server_public_key` in urlsafe bytes.""" return self.server_public_key def store_client_public_keys(self, public_keys: Set[bytes]) -> None: - """Store a set of client public keys in state.""" + """Store a set of `client_public_keys` in state.""" self.client_public_keys = public_keys def store_client_public_key(self, public_key: bytes) -> None: - """Retrieve a client public key in state.""" + """Store a `client_public_key` in state.""" self.client_public_keys.add(public_key) def get_client_public_keys(self) -> Set[bytes]: - """Retrieve all currently stored client public keys as a set.""" + """Retrieve all currently stored `client_public_keys` as a set.""" return self.client_public_keys diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index bf65828ac10e..bb0b07642a79 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -537,7 +537,7 @@ def create_run(self) -> int: def store_server_public_private_key( self, public_key: bytes, private_key: bytes ) -> None: - """Store server's `public_key` and `private_key` in state.""" + """Store `server_public_key` and `server_private_key` in state.""" query = ( "INSERT OR REPLACE INTO credential (public_key, private_key) " "VALUES (:public_key, :private_key)" @@ -545,32 +545,32 @@ def store_server_public_private_key( self.query(query, {"public_key": public_key, "private_key": private_key}) def get_server_private_key(self) -> bytes: - """Get server private key in urlsafe bytes.""" + """Retrieve `server_private_key` in urlsafe bytes.""" query = "SELECT private_key FROM credential" rows = self.query(query) private_key: bytes = rows[0]["private_key"] return private_key def get_server_public_key(self) -> bytes: - """Get server public key in urlsafe bytes.""" + """Retrieve `server_public_key` in urlsafe bytes.""" query = "SELECT public_key FROM credential" rows = self.query(query) public_key: bytes = rows[0]["public_key"] return public_key def store_client_public_keys(self, public_keys: Set[bytes]) -> None: - """Store a set of client public keys in state.""" + """Store a set of `client_public_keys` in state.""" query = "INSERT INTO public_key (public_key) VALUES (:public_key)" for public_key in public_keys: self.query(query, {"public_key": public_key}) def store_client_public_key(self, public_key: bytes) -> None: - """Retrieve a client public key in state.""" + """Store a `client_public_key` in state.""" query = "INSERT INTO public_key (public_key) VALUES (:public_key)" self.query(query, {"public_key": public_key}) def get_client_public_keys(self) -> Set[bytes]: - """Retrieve all currently stored client public keys as a set.""" + """Retrieve all currently stored `client_public_keys` as a set.""" query = "SELECT public_key FROM public_key" rows = self.query(query) result: Set[bytes] = {row["public_key"] for row in rows} diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index b4a26db0afd8..909b3a37027a 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -157,24 +157,24 @@ def create_run(self) -> int: def store_server_public_private_key( self, public_key: bytes, private_key: bytes ) -> None: - """Store server's `public_key` and `private_key` in state.""" + """Store `server_public_key` and `server_private_key` in state.""" @abc.abstractmethod def get_server_private_key(self) -> bytes: - """Get server private key in urlsafe bytes.""" + """Retrieve `server_private_key` in urlsafe bytes.""" @abc.abstractmethod def get_server_public_key(self) -> bytes: - """Get server public key in urlsafe bytes.""" + """Retrieve `server_public_key` in urlsafe bytes.""" @abc.abstractmethod def store_client_public_keys(self, public_keys: Set[bytes]) -> None: - """Store a set of client public keys in state.""" + """Store a set of `client_public_keys` in state.""" @abc.abstractmethod def store_client_public_key(self, public_key: bytes) -> None: - """Retrieve a client public key in state.""" + """Store a `client_public_key` in state.""" @abc.abstractmethod def get_client_public_keys(self) -> Set[bytes]: - """Retrieve all currently stored client public keys as a set.""" + """Retrieve all currently stored `client_public_keys` as a set.""" From ab5317f5487a09be5eae5b7e3c198becf0cccb87 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 28 Feb 2024 18:23:13 +0100 Subject: [PATCH 32/73] Fix merge conflict interceptors gone --- src/py/flwr/server/app.py | 11 ++++++++++- src/py/flwr/server/server_test.py | 2 +- .../server/superlink/fleet/grpc_bidi/grpc_server.py | 4 +++- src/py/flwr/server/superlink/state/in_memory_state.py | 2 +- 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index cd65c1d241f5..174fb396d7b5 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -366,7 +366,7 @@ def run_superlink() -> None: host, port, is_v6 = parsed_address address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" - data = _try_setup_client_authentication(args) + data = _try_setup_client_authentication(args, certificates) interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None if data is not None: ( @@ -420,8 +420,15 @@ def run_superlink() -> None: def _try_setup_client_authentication( args: argparse.Namespace, + certificates: Optional[Tuple[bytes, bytes, bytes]], ) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePublicKey, ec.EllipticCurvePrivateKey]]: if args.require_client_authentication: + if certificates is None: + sys.exit( + "Certificates are required to enable client authentication. " + "Please provide certificate paths with '--certificates' before " + "enabling '--require-client-authentication'." + ) client_keys_file_path = Path(args.require_client_authentication[0]) if client_keys_file_path.exists(): client_public_keys: Set[bytes] = set() @@ -522,6 +529,7 @@ def _run_fleet_api_grpc_rere( address: str, state_factory: StateFactory, certificates: Optional[Tuple[bytes, bytes, bytes]], + interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None, ) -> grpc.Server: """Run Fleet API (gRPC, request-response).""" # Create Fleet API gRPC server @@ -534,6 +542,7 @@ def _run_fleet_api_grpc_rere( server_address=address, max_message_length=GRPC_MAX_MESSAGE_LENGTH, certificates=certificates, + interceptors=interceptors, ) log(INFO, "Flower ECE: Starting Fleet API (gRPC-rere) on %s", address) diff --git a/src/py/flwr/server/server_test.py b/src/py/flwr/server/server_test.py index 06614db509c4..288eab4b23f9 100644 --- a/src/py/flwr/server/server_test.py +++ b/src/py/flwr/server/server_test.py @@ -256,7 +256,7 @@ def test_setup_client_auth() -> None: ) # Run _try_setup_client_authentication - result = _try_setup_client_authentication(mock_args) + result = _try_setup_client_authentication(mock_args, (b"", b"", b"")) expected_private_key = load_ssh_private_key(server_private_key, None) expected_public_key = load_ssh_public_key(server_public_key) diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py index 82f049844bd6..6aeaa7ef413f 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py @@ -18,7 +18,7 @@ import concurrent.futures import sys from logging import ERROR -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Optional, Sequence, Tuple, Union import grpc @@ -162,6 +162,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, keepalive_time_ms: int = 210000, certificates: Optional[Tuple[bytes, bytes, bytes]] = None, + interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None, ) -> grpc.Server: """Create a gRPC server with a single servicer. @@ -249,6 +250,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments # returning RESOURCE_EXHAUSTED status, or None to indicate no limit. maximum_concurrent_rpcs=max_concurrent_workers, options=options, + interceptors=interceptors, ) add_servicer_to_server_fn(servicer, server) diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index 96e2af3e626d..ff289dd2bee9 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -28,7 +28,7 @@ from flwr.server.utils import validate_task_ins_or_res -class InMemoryState(State): +class InMemoryState(State): # pylint: disable=R0902 """In-memory State implementation.""" def __init__(self) -> None: From 5075ab753cf2f5394f3f80d5bd785b6d4de91c80 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 28 Feb 2024 18:27:41 +0100 Subject: [PATCH 33/73] Fix too many instances --- src/py/flwr/server/superlink/state/in_memory_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index 96e2af3e626d..ff289dd2bee9 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -28,7 +28,7 @@ from flwr.server.utils import validate_task_ins_or_res -class InMemoryState(State): +class InMemoryState(State): # pylint: disable=R0902 """In-memory State implementation.""" def __init__(self) -> None: From d6238eecd646878ec8ad17a2a8a671d5c2d76482 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 29 Feb 2024 09:24:53 +0100 Subject: [PATCH 34/73] Fix imports merge conflict --- src/py/flwr/server/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 8d3bfc5fa1a3..59222e45daf3 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -15,8 +15,8 @@ """Flower server app.""" import argparse -import csv import asyncio +import csv import importlib.util import sys import threading From 2de5fd5ff6b253869efa6f8e095941ba5b9ba601 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 29 Feb 2024 19:22:06 +0100 Subject: [PATCH 35/73] Add docstring to interceptor --- src/py/flwr/server/server_interceptor.py | 8 +++++++- src/py/flwr/server/server_interceptor_test.py | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/server/server_interceptor.py b/src/py/flwr/server/server_interceptor.py index 6a0b13d679dc..a63fb8b64c65 100644 --- a/src/py/flwr/server/server_interceptor.py +++ b/src/py/flwr/server/server_interceptor.py @@ -14,6 +14,7 @@ # ============================================================================== """Flower server interceptor.""" + import base64 import threading from logging import INFO @@ -92,7 +93,12 @@ def intercept_service( continuation: Callable[[Any], Any], handler_call_details: grpc.HandlerCallDetails, ) -> grpc.RpcMethodHandler: - """Flower server interceptor authentication logic.""" + """Flower server interceptor authentication logic. + + Intercept unary call from client and do authentication process by validating + metadata sent from client. Continue RPC call if client is authenticated, else, + terminate RPC call by setting context to abort. + """ message_handler: grpc.RpcMethodHandler = continuation(handler_call_details) return self._generic_auth_unary_method_handler(message_handler) diff --git a/src/py/flwr/server/server_interceptor_test.py b/src/py/flwr/server/server_interceptor_test.py index 2a886bc07cd8..11fbc49329d2 100644 --- a/src/py/flwr/server/server_interceptor_test.py +++ b/src/py/flwr/server/server_interceptor_test.py @@ -14,6 +14,7 @@ # ============================================================================== """Flower server interceptor tests.""" + import base64 import unittest From 9b62c4f483e92e8ab2a8d622068d6e4054f59768 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Fri, 1 Mar 2024 07:42:12 +0100 Subject: [PATCH 36/73] Format --- src/py/flwr/server/server_interceptor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/server/server_interceptor.py b/src/py/flwr/server/server_interceptor.py index a63fb8b64c65..617e1decd603 100644 --- a/src/py/flwr/server/server_interceptor.py +++ b/src/py/flwr/server/server_interceptor.py @@ -93,7 +93,7 @@ def intercept_service( continuation: Callable[[Any], Any], handler_call_details: grpc.HandlerCallDetails, ) -> grpc.RpcMethodHandler: - """Flower server interceptor authentication logic. + """Flower server interceptor authentication logic. Intercept unary call from client and do authentication process by validating metadata sent from client. Continue RPC call if client is authenticated, else, From e2ad1ef2cb3163310e7b115f0108e3c829af1bc9 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Fri, 1 Mar 2024 16:10:00 +0100 Subject: [PATCH 37/73] Implement feedback --- src/py/flwr/server/app.py | 71 +++++++++++------------- src/py/flwr/server/server_interceptor.py | 58 ++++++++++--------- 2 files changed, 60 insertions(+), 69 deletions(-) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index c2eb0682e2bd..516443870b99 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -429,50 +429,43 @@ def _try_setup_client_authentication( "enabling '--require-client-authentication'." ) client_keys_file_path = Path(args.require_client_authentication[0]) - if client_keys_file_path.exists(): - client_public_keys: Set[bytes] = set() - public_key = load_ssh_public_key( - Path(args.require_client_authentication[1]).read_bytes() - ) - private_key = load_ssh_private_key( - Path(args.require_client_authentication[2]).read_bytes(), - None, - ) - log(INFO, type(public_key)) - log(INFO, type(private_key)) - if not isinstance(public_key, ec.EllipticCurvePublicKey): - sys.exit( - "An eliptic curve public and private key pair is required for " - "client authentication. Please provide the file path containing " - "valid public and private key to '--require-client-authentication'." - ) - server_public_key = public_key - if not isinstance(private_key, ec.EllipticCurvePrivateKey): - sys.exit( - "An eliptic curve public and private key pair is required for " - "client authentication. Please provide the file path containing " - "valid public and private key to '--require-client-authentication'." - ) - server_private_key = private_key - - with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile: - reader = csv.reader(csvfile) - for row in reader: - for element in row: - public_key = load_ssh_public_key(element.encode()) - if isinstance(public_key, ec.EllipticCurvePublicKey): - client_public_keys.add(public_key_to_bytes(public_key)) - return ( - client_public_keys, - server_public_key, - server_private_key, - ) - else: + if not client_keys_file_path.exists(): sys.exit( "Client public keys csv file are required for client authentication. " "Please provide the csv file path containing known client public keys " "to '--require-client-authentication'." ) + client_public_keys: Set[bytes] = set() + public_key = load_ssh_public_key( + Path(args.require_client_authentication[1]).read_bytes() + ) + private_key = load_ssh_private_key( + Path(args.require_client_authentication[2]).read_bytes(), + None, + ) + if not isinstance(public_key, ec.EllipticCurvePublicKey) or not isinstance( + private_key, ec.EllipticCurvePrivateKey + ): + sys.exit( + "An eliptic curve public and private key pair is required for " + "client authentication. Please provide the file path containing " + "valid public and private key to '--require-client-authentication'." + ) + server_public_key = public_key + server_private_key = private_key + + with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile: + reader = csv.reader(csvfile) + for row in reader: + for element in row: + public_key = load_ssh_public_key(element.encode()) + if isinstance(public_key, ec.EllipticCurvePublicKey): + client_public_keys.add(public_key_to_bytes(public_key)) + return ( + client_public_keys, + server_public_key, + server_private_key, + ) else: return None diff --git a/src/py/flwr/server/server_interceptor.py b/src/py/flwr/server/server_interceptor.py index 617e1decd603..e947ac35a697 100644 --- a/src/py/flwr/server/server_interceptor.py +++ b/src/py/flwr/server/server_interceptor.py @@ -118,40 +118,38 @@ def _generic_method_handler( is_public_key_known = ( client_public_key_bytes in self.state.get_client_public_keys() ) - if is_public_key_known: - if isinstance(request, CreateNodeRequest): - context.send_initial_metadata( + if not is_public_key_known: + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") + + if isinstance(request, CreateNodeRequest): + context.send_initial_metadata( + ( ( - ( - _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode( - public_key_to_bytes(self.server_public_key) - ), + _PUBLIC_KEY_HEADER, + base64.urlsafe_b64encode( + public_key_to_bytes(self.server_public_key) ), - ) - ) - elif isinstance( - request, - (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest), - ): - hmac_value = base64.urlsafe_b64decode( - _get_value_from_tuples( - _AUTH_TOKEN_HEADER, context.invocation_metadata() - ) - ) - client_public_key = bytes_to_public_key(client_public_key_bytes) - shared_secret = generate_shared_key( - self.server_private_key, - client_public_key, + ), ) - verify = verify_hmac( - shared_secret, request.SerializeToString(True), hmac_value + ) + elif isinstance( + request, + (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest), + ): + hmac_value = base64.urlsafe_b64decode( + _get_value_from_tuples( + _AUTH_TOKEN_HEADER, context.invocation_metadata() ) - if not verify: - context.abort( - grpc.StatusCode.UNAUTHENTICATED, "Access denied!" - ) - else: + ) + client_public_key = bytes_to_public_key(client_public_key_bytes) + shared_secret = generate_shared_key( + self.server_private_key, + client_public_key, + ) + verify = verify_hmac( + shared_secret, request.SerializeToString(True), hmac_value + ) + if not verify: context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") else: context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") From 28afce571cc9786f880795000b2437b280263dbf Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 4 Apr 2024 07:54:46 +0200 Subject: [PATCH 38/73] Format --- src/py/flwr/server/superlink/state/in_memory_state.py | 2 +- src/py/flwr/server/superlink/state/sqlite_state.py | 2 +- src/py/flwr/server/superlink/state/state.py | 1 + src/py/flwr/server/superlink/state/state_test.py | 6 +++--- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index 110b756552df..b6bd6a7c1098 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -279,7 +279,7 @@ def store_client_public_key(self, public_key: bytes) -> None: def get_client_public_keys(self) -> Set[bytes]: """Retrieve all currently stored `client_public_keys` as a set.""" return self.client_public_keys - + def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: """Acknowledge a ping received from a node, serving as a heartbeat.""" with self.lock: diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 77ec85ad325f..911604d5a94e 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -628,7 +628,7 @@ def get_client_public_keys(self) -> Set[bytes]: rows = self.query(query) result: Set[bytes] = {row["public_key"] for row in rows} return result - + def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: """Acknowledge a ping received from a node, serving as a heartbeat.""" # Update `online_until` and `ping_interval` for the given `node_id` diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index b9f15baa0352..3090e00a76ba 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -178,6 +178,7 @@ def store_client_public_key(self, public_key: bytes) -> None: @abc.abstractmethod def get_client_public_keys(self) -> Set[bytes]: """Retrieve all currently stored `client_public_keys` as a set.""" + def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: """Acknowledge a ping received from a node, serving as a heartbeat. diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 635eceea948a..a33f41b1fa17 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -24,12 +24,12 @@ from unittest.mock import patch from uuid import uuid4 +from flwr.common import DEFAULT_TTL +from flwr.common.constant import ErrorCode from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( generate_key_pairs, public_key_to_bytes, ) -from flwr.common import DEFAULT_TTL -from flwr.common.constant import ErrorCode from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 @@ -408,7 +408,7 @@ def test_client_public_keys(self) -> None: state.store_client_public_keys(public_keys) assert state.get_client_public_keys() == public_keys - + def test_acknowledge_ping(self) -> None: """Test if acknowledge_ping works and if get_nodes return online nodes.""" # Prepare From 7c098d23a73b86004484e0561275521284c02892 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 4 Apr 2024 08:25:47 +0200 Subject: [PATCH 39/73] Fix error --- src/py/flwr/server/superlink/state/state.py | 1 + src/py/flwr/server/superlink/state/state_test.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index 3090e00a76ba..02209f27957c 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -179,6 +179,7 @@ def store_client_public_key(self, public_key: bytes) -> None: def get_client_public_keys(self) -> Set[bytes]: """Retrieve all currently stored `client_public_keys` as a set.""" + @abc.abstractmethod def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: """Acknowledge a ping received from a node, serving as a heartbeat. diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index a33f41b1fa17..2bb04c941494 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -555,7 +555,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 12 + assert len(result) == 13 class SqliteFileBasedTest(StateTest, unittest.TestCase): @@ -580,7 +580,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 12 + assert len(result) == 13 if __name__ == "__main__": From 100eadb69c7d7262fb8bf730289dd2fc66bd4edb Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 4 Apr 2024 08:26:52 +0200 Subject: [PATCH 40/73] Fix error --- src/py/flwr/server/superlink/state/state.py | 1 + src/py/flwr/server/superlink/state/state_test.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index 3090e00a76ba..02209f27957c 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -179,6 +179,7 @@ def store_client_public_key(self, public_key: bytes) -> None: def get_client_public_keys(self) -> Set[bytes]: """Retrieve all currently stored `client_public_keys` as a set.""" + @abc.abstractmethod def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: """Acknowledge a ping received from a node, serving as a heartbeat. diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index a33f41b1fa17..2bb04c941494 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -555,7 +555,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 12 + assert len(result) == 13 class SqliteFileBasedTest(StateTest, unittest.TestCase): @@ -580,7 +580,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 12 + assert len(result) == 13 if __name__ == "__main__": From 94ee2fe1f2228f0feeddadf2566499fa35cfcb99 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 18 Apr 2024 08:56:31 +0200 Subject: [PATCH 41/73] Add lock to write operations --- src/py/flwr/server/superlink/state/in_memory_state.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index b6bd6a7c1098..95069927658b 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -257,8 +257,9 @@ def store_server_public_private_key( self, public_key: bytes, private_key: bytes ) -> None: """Store `server_public_key` and `server_private_key` in state.""" - self.server_private_key = private_key - self.server_public_key = public_key + with self.lock: + self.server_private_key = private_key + self.server_public_key = public_key def get_server_private_key(self) -> bytes: """Retrieve `server_private_key` in urlsafe bytes.""" @@ -270,11 +271,13 @@ def get_server_public_key(self) -> bytes: def store_client_public_keys(self, public_keys: Set[bytes]) -> None: """Store a set of `client_public_keys` in state.""" - self.client_public_keys = public_keys + with self.lock: + self.client_public_keys = public_keys def store_client_public_key(self, public_key: bytes) -> None: """Store a `client_public_key` in state.""" - self.client_public_keys.add(public_key) + with self.lock: + self.client_public_keys.add(public_key) def get_client_public_keys(self) -> Set[bytes]: """Retrieve all currently stored `client_public_keys` as a set.""" From 0fe27441fc9c204f62f089d32a39c4c88733afa0 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sun, 21 Apr 2024 21:35:41 +0200 Subject: [PATCH 42/73] Fix docstring --- src/py/flwr/server/server_interceptor_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/server_interceptor_test.py b/src/py/flwr/server/server_interceptor_test.py index 11fbc49329d2..20f019e3e283 100644 --- a/src/py/flwr/server/server_interceptor_test.py +++ b/src/py/flwr/server/server_interceptor_test.py @@ -81,7 +81,7 @@ def tearDown(self) -> None: self._server.stop(None) def test_successful_create_node_with_metadata(self) -> None: - """Test server interceptor for create node.""" + """Test server interceptor for creating node.""" public_key_bytes = base64.urlsafe_b64encode( public_key_to_bytes(self._client_public_key) ) @@ -102,7 +102,7 @@ def test_successful_create_node_with_metadata(self) -> None: self._client_node_id = response.node.node_id def test_successful_delete_node_with_metadata(self) -> None: - """Test server interceptor for create node.""" + """Test server interceptor for deleting node.""" request = DeleteNodeRequest() shared_secret = generate_shared_key( self._client_private_key, self._server_public_key From 54b7afa55acb09cae24efb6334b983c62ac7524e Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 12:45:29 +0200 Subject: [PATCH 43/73] Format --- src/py/flwr/server/superlink/state/in_memory_state.py | 2 +- src/py/flwr/server/superlink/state/sqlite_state.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index fad38dca1b40..6496b41a2c98 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -283,7 +283,7 @@ def store_client_public_key(self, public_key: bytes) -> None: def get_client_public_keys(self) -> Set[bytes]: """Retrieve all currently stored `client_public_keys` as a set.""" return self.client_public_keys - + def get_run(self, run_id: int) -> Tuple[int, str, str]: """Retrieve information about the run with the specified `run_id`.""" with self.lock: diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index d2e22212586c..54d574219778 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -630,7 +630,7 @@ def get_client_public_keys(self) -> Set[bytes]: rows = self.query(query) result: Set[bytes] = {row["public_key"] for row in rows} return result - + def get_run(self, run_id: int) -> Tuple[int, str, str]: """Retrieve information about the run with the specified `run_id`.""" query = "SELECT * FROM run WHERE run_id = ?;" From 4ab971bd0298ed2c919becaea67cba92671e5916 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 12:51:22 +0200 Subject: [PATCH 44/73] Format --- src/py/flwr/server/superlink/state/sqlite_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 54d574219778..a835a4097a4a 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -108,7 +108,7 @@ DictOrTuple = Union[Tuple[Any, ...], Dict[str, Any]] -class SqliteState(State): +class SqliteState(State): # pylint: disable=R0904 """SQLite-based state implementation.""" def __init__( From 054bd044f17aecaa6b09b1a0c77c5275643e2bb6 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 14:50:00 +0200 Subject: [PATCH 45/73] Add more tests --- src/py/flwr/server/server_interceptor_test.py | 81 ++++++++++++++++++- 1 file changed, 78 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/server/server_interceptor_test.py b/src/py/flwr/server/server_interceptor_test.py index 20f019e3e283..84ba5d62a4c0 100644 --- a/src/py/flwr/server/server_interceptor_test.py +++ b/src/py/flwr/server/server_interceptor_test.py @@ -31,7 +31,12 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, + PullTaskInsRequest, + PullTaskInsResponse, + PushTaskResRequest, + PushTaskResResponse, ) +from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611 from .app import ADDRESS_FLEET_API_GRPC_RERE, _run_fleet_api_grpc_rere from .server_interceptor import ( @@ -62,8 +67,6 @@ def setUp(self) -> None: ADDRESS_FLEET_API_GRPC_RERE, state_factory, None, [self._server_interceptor] ) - self._client_node_id: int = -1 - self._channel = grpc.insecure_channel("localhost:9092") self._create_node = self._channel.unary_unary( "/flwr.proto.Fleet/CreateNode", @@ -75,6 +78,16 @@ def setUp(self) -> None: request_serializer=DeleteNodeRequest.SerializeToString, response_deserializer=DeleteNodeResponse.FromString, ) + self._pull_task_ins = self._channel.unary_unary( + "/flwr.proto.Fleet/PullTaskIns", + request_serializer=PullTaskInsRequest.SerializeToString, + response_deserializer=PullTaskInsResponse.FromString, + ) + self._push_task_res = self._channel.unary_unary( + "/flwr.proto.Fleet/PushTaskRes", + request_serializer=PushTaskResRequest.SerializeToString, + response_deserializer=PushTaskResResponse.FromString, + ) def tearDown(self) -> None: """Clean up grpc server.""" @@ -82,9 +95,12 @@ def tearDown(self) -> None: def test_successful_create_node_with_metadata(self) -> None: """Test server interceptor for creating node.""" + # Prepare public_key_bytes = base64.urlsafe_b64encode( public_key_to_bytes(self._client_public_key) ) + + # Execute response, call = self._create_node.with_call( request=CreateNodeRequest(), metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),), @@ -97,12 +113,13 @@ def test_successful_create_node_with_metadata(self) -> None: ).decode(), ) + # Assert assert call.initial_metadata()[0] == expected_metadata assert isinstance(response, CreateNodeResponse) - self._client_node_id = response.node.node_id def test_successful_delete_node_with_metadata(self) -> None: """Test server interceptor for deleting node.""" + # Prepare request = DeleteNodeRequest() shared_secret = generate_shared_key( self._client_private_key, self._server_public_key @@ -113,6 +130,8 @@ def test_successful_delete_node_with_metadata(self) -> None: public_key_bytes = base64.urlsafe_b64encode( public_key_to_bytes(self._client_public_key) ) + + # Execute response, call = self._delete_node.with_call( request=request, metadata=( @@ -120,5 +139,61 @@ def test_successful_delete_node_with_metadata(self) -> None: (_AUTH_TOKEN_HEADER, hmac_value), ), ) + + # Assert assert isinstance(response, DeleteNodeResponse) assert grpc.StatusCode.OK == call.code() + + def test_successful_pull_task_ins_with_metadata(self) -> None: + """Test server interceptor for deleting node.""" + # Prepare + request = PullTaskInsRequest() + shared_secret = generate_shared_key( + self._client_private_key, self._server_public_key + ) + hmac_value = base64.urlsafe_b64encode( + compute_hmac(shared_secret, request.SerializeToString(True)) + ) + public_key_bytes = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ) + + # Execute + response, call = self._pull_task_ins.with_call( + request=request, + metadata=( + (_PUBLIC_KEY_HEADER, public_key_bytes), + (_AUTH_TOKEN_HEADER, hmac_value), + ), + ) + + # Assert + assert isinstance(response, PullTaskInsResponse) + assert grpc.StatusCode.OK == call.code() + + def test_successful_push_task_res_with_metadata(self) -> None: + """Test server interceptor for deleting node.""" + # Prepare + request = PushTaskResRequest(task_res_list=[TaskRes()]) + shared_secret = generate_shared_key( + self._client_private_key, self._server_public_key + ) + hmac_value = base64.urlsafe_b64encode( + compute_hmac(shared_secret, request.SerializeToString(True)) + ) + public_key_bytes = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ) + + # Execute + response, call = self._push_task_res.with_call( + request=request, + metadata=( + (_PUBLIC_KEY_HEADER, public_key_bytes), + (_AUTH_TOKEN_HEADER, hmac_value), + ), + ) + + # Assert + assert isinstance(response, PushTaskResResponse) + assert grpc.StatusCode.OK == call.code() From aac42ab6df891b1f887b7b5bacede102802e2f84 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 15:11:29 +0200 Subject: [PATCH 46/73] Add failure tests --- src/py/flwr/server/server_interceptor_test.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/py/flwr/server/server_interceptor_test.py b/src/py/flwr/server/server_interceptor_test.py index 84ba5d62a4c0..18c305d1db90 100644 --- a/src/py/flwr/server/server_interceptor_test.py +++ b/src/py/flwr/server/server_interceptor_test.py @@ -143,6 +143,38 @@ def test_successful_delete_node_with_metadata(self) -> None: # Assert assert isinstance(response, DeleteNodeResponse) assert grpc.StatusCode.OK == call.code() + + def test_unsuccessful_delete_node_with_metadata(self) -> None: + """Test server interceptor for deleting node.""" + # Prepare + request = DeleteNodeRequest() + client_private_key, _ = generate_key_pairs() + shared_secret = generate_shared_key( + client_private_key, self._server_public_key + ) + hmac_value = base64.urlsafe_b64encode( + compute_hmac(shared_secret, request.SerializeToString(True)) + ) + public_key_bytes = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ) + + # Execute + try: + response, call = self._delete_node.with_call( + request=request, + metadata=( + (_PUBLIC_KEY_HEADER, public_key_bytes), + (_AUTH_TOKEN_HEADER, hmac_value), + ), + ) + except grpc.RpcError as e: + # Assert + + print(e) + + # Assert + assert False def test_successful_pull_task_ins_with_metadata(self) -> None: """Test server interceptor for deleting node.""" From c5b3e46b96d678114ca49f71ce7d041e712ecd86 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 15:26:24 +0200 Subject: [PATCH 47/73] Add failure tests --- src/py/flwr/server/server_interceptor_test.py | 80 ++++++++++++++++--- 1 file changed, 71 insertions(+), 9 deletions(-) diff --git a/src/py/flwr/server/server_interceptor_test.py b/src/py/flwr/server/server_interceptor_test.py index 18c305d1db90..bc0141edd48c 100644 --- a/src/py/flwr/server/server_interceptor_test.py +++ b/src/py/flwr/server/server_interceptor_test.py @@ -116,6 +116,22 @@ def test_successful_create_node_with_metadata(self) -> None: # Assert assert call.initial_metadata()[0] == expected_metadata assert isinstance(response, CreateNodeResponse) + + def test_unsuccessful_create_node_with_metadata(self) -> None: + """Test server interceptor for creating node.""" + # Prepare + _, client_public_key = generate_key_pairs() + public_key_bytes = base64.urlsafe_b64encode( + public_key_to_bytes(client_public_key) + ) + + # Execute + # Assert + with self.assertRaises(grpc.RpcError): + self._create_node.with_call( + request=CreateNodeRequest(), + metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),), + ) def test_successful_delete_node_with_metadata(self) -> None: """Test server interceptor for deleting node.""" @@ -160,21 +176,15 @@ def test_unsuccessful_delete_node_with_metadata(self) -> None: ) # Execute - try: - response, call = self._delete_node.with_call( + # Assert + with self.assertRaises(grpc.RpcError): + self._delete_node.with_call( request=request, metadata=( (_PUBLIC_KEY_HEADER, public_key_bytes), (_AUTH_TOKEN_HEADER, hmac_value), ), ) - except grpc.RpcError as e: - # Assert - - print(e) - - # Assert - assert False def test_successful_pull_task_ins_with_metadata(self) -> None: """Test server interceptor for deleting node.""" @@ -203,6 +213,32 @@ def test_successful_pull_task_ins_with_metadata(self) -> None: assert isinstance(response, PullTaskInsResponse) assert grpc.StatusCode.OK == call.code() + def test_unsuccessful_pull_task_ins_with_metadata(self) -> None: + """Test server interceptor for deleting node.""" + # Prepare + request = PullTaskInsRequest() + client_private_key, _ = generate_key_pairs() + shared_secret = generate_shared_key( + client_private_key, self._server_public_key + ) + hmac_value = base64.urlsafe_b64encode( + compute_hmac(shared_secret, request.SerializeToString(True)) + ) + public_key_bytes = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ) + + # Execute + # Assert + with self.assertRaises(grpc.RpcError): + self._pull_task_ins.with_call( + request=request, + metadata=( + (_PUBLIC_KEY_HEADER, public_key_bytes), + (_AUTH_TOKEN_HEADER, hmac_value), + ), + ) + def test_successful_push_task_res_with_metadata(self) -> None: """Test server interceptor for deleting node.""" # Prepare @@ -229,3 +265,29 @@ def test_successful_push_task_res_with_metadata(self) -> None: # Assert assert isinstance(response, PushTaskResResponse) assert grpc.StatusCode.OK == call.code() + + def test_successful_push_task_res_with_metadata(self) -> None: + """Test server interceptor for deleting node.""" + # Prepare + request = PushTaskResRequest(task_res_list=[TaskRes()]) + client_private_key, _ = generate_key_pairs() + shared_secret = generate_shared_key( + client_private_key, self._server_public_key + ) + hmac_value = base64.urlsafe_b64encode( + compute_hmac(shared_secret, request.SerializeToString(True)) + ) + public_key_bytes = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ) + + # Execute + # Assert + with self.assertRaises(grpc.RpcError): + self._push_task_res.with_call( + request=request, + metadata=( + (_PUBLIC_KEY_HEADER, public_key_bytes), + (_AUTH_TOKEN_HEADER, hmac_value), + ), + ) From e76b9378a48fb79ccd1ec7606ad841c7fffc8ed5 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 15:33:45 +0200 Subject: [PATCH 48/73] Format --- src/py/flwr/server/server_interceptor_test.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/py/flwr/server/server_interceptor_test.py b/src/py/flwr/server/server_interceptor_test.py index bc0141edd48c..27d584935cdc 100644 --- a/src/py/flwr/server/server_interceptor_test.py +++ b/src/py/flwr/server/server_interceptor_test.py @@ -116,7 +116,7 @@ def test_successful_create_node_with_metadata(self) -> None: # Assert assert call.initial_metadata()[0] == expected_metadata assert isinstance(response, CreateNodeResponse) - + def test_unsuccessful_create_node_with_metadata(self) -> None: """Test server interceptor for creating node.""" # Prepare @@ -159,15 +159,13 @@ def test_successful_delete_node_with_metadata(self) -> None: # Assert assert isinstance(response, DeleteNodeResponse) assert grpc.StatusCode.OK == call.code() - + def test_unsuccessful_delete_node_with_metadata(self) -> None: """Test server interceptor for deleting node.""" # Prepare request = DeleteNodeRequest() client_private_key, _ = generate_key_pairs() - shared_secret = generate_shared_key( - client_private_key, self._server_public_key - ) + shared_secret = generate_shared_key(client_private_key, self._server_public_key) hmac_value = base64.urlsafe_b64encode( compute_hmac(shared_secret, request.SerializeToString(True)) ) @@ -218,9 +216,7 @@ def test_unsuccessful_pull_task_ins_with_metadata(self) -> None: # Prepare request = PullTaskInsRequest() client_private_key, _ = generate_key_pairs() - shared_secret = generate_shared_key( - client_private_key, self._server_public_key - ) + shared_secret = generate_shared_key(client_private_key, self._server_public_key) hmac_value = base64.urlsafe_b64encode( compute_hmac(shared_secret, request.SerializeToString(True)) ) @@ -266,14 +262,12 @@ def test_successful_push_task_res_with_metadata(self) -> None: assert isinstance(response, PushTaskResResponse) assert grpc.StatusCode.OK == call.code() - def test_successful_push_task_res_with_metadata(self) -> None: + def test_unsuccessful_push_task_res_with_metadata(self) -> None: """Test server interceptor for deleting node.""" # Prepare request = PushTaskResRequest(task_res_list=[TaskRes()]) client_private_key, _ = generate_key_pairs() - shared_secret = generate_shared_key( - client_private_key, self._server_public_key - ) + shared_secret = generate_shared_key(client_private_key, self._server_public_key) hmac_value = base64.urlsafe_b64encode( compute_hmac(shared_secret, request.SerializeToString(True)) ) From 846373f312b7eb6df742f8fb4965f49c9ae53d75 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 15:35:00 +0200 Subject: [PATCH 49/73] Fix docstring --- src/py/flwr/server/server_interceptor_test.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/server/server_interceptor_test.py b/src/py/flwr/server/server_interceptor_test.py index 27d584935cdc..5b7fac6fede2 100644 --- a/src/py/flwr/server/server_interceptor_test.py +++ b/src/py/flwr/server/server_interceptor_test.py @@ -118,7 +118,7 @@ def test_successful_create_node_with_metadata(self) -> None: assert isinstance(response, CreateNodeResponse) def test_unsuccessful_create_node_with_metadata(self) -> None: - """Test server interceptor for creating node.""" + """Test server interceptor for creating node unsuccessfully.""" # Prepare _, client_public_key = generate_key_pairs() public_key_bytes = base64.urlsafe_b64encode( @@ -161,7 +161,7 @@ def test_successful_delete_node_with_metadata(self) -> None: assert grpc.StatusCode.OK == call.code() def test_unsuccessful_delete_node_with_metadata(self) -> None: - """Test server interceptor for deleting node.""" + """Test server interceptor for deleting node unsuccessfully.""" # Prepare request = DeleteNodeRequest() client_private_key, _ = generate_key_pairs() @@ -185,7 +185,7 @@ def test_unsuccessful_delete_node_with_metadata(self) -> None: ) def test_successful_pull_task_ins_with_metadata(self) -> None: - """Test server interceptor for deleting node.""" + """Test server interceptor for pull task ins.""" # Prepare request = PullTaskInsRequest() shared_secret = generate_shared_key( @@ -212,7 +212,7 @@ def test_successful_pull_task_ins_with_metadata(self) -> None: assert grpc.StatusCode.OK == call.code() def test_unsuccessful_pull_task_ins_with_metadata(self) -> None: - """Test server interceptor for deleting node.""" + """Test server interceptor for pull task ins unsuccessfully.""" # Prepare request = PullTaskInsRequest() client_private_key, _ = generate_key_pairs() @@ -236,7 +236,7 @@ def test_unsuccessful_pull_task_ins_with_metadata(self) -> None: ) def test_successful_push_task_res_with_metadata(self) -> None: - """Test server interceptor for deleting node.""" + """Test server interceptor for push task res.""" # Prepare request = PushTaskResRequest(task_res_list=[TaskRes()]) shared_secret = generate_shared_key( @@ -263,7 +263,7 @@ def test_successful_push_task_res_with_metadata(self) -> None: assert grpc.StatusCode.OK == call.code() def test_unsuccessful_push_task_res_with_metadata(self) -> None: - """Test server interceptor for deleting node.""" + """Test server interceptor for push task res unsuccessfully.""" # Prepare request = PushTaskResRequest(task_res_list=[TaskRes()]) client_private_key, _ = generate_key_pairs() From a61892e7acdae1cc5b5acf540cf315b27e121444 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 15:41:02 +0200 Subject: [PATCH 50/73] Format prepare, execute & assert --- src/py/flwr/server/server_interceptor_test.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/py/flwr/server/server_interceptor_test.py b/src/py/flwr/server/server_interceptor_test.py index 5b7fac6fede2..5e0ed0be2313 100644 --- a/src/py/flwr/server/server_interceptor_test.py +++ b/src/py/flwr/server/server_interceptor_test.py @@ -125,8 +125,7 @@ def test_unsuccessful_create_node_with_metadata(self) -> None: public_key_to_bytes(client_public_key) ) - # Execute - # Assert + # Execute & Assert with self.assertRaises(grpc.RpcError): self._create_node.with_call( request=CreateNodeRequest(), @@ -173,8 +172,7 @@ def test_unsuccessful_delete_node_with_metadata(self) -> None: public_key_to_bytes(self._client_public_key) ) - # Execute - # Assert + # Execute & Assert with self.assertRaises(grpc.RpcError): self._delete_node.with_call( request=request, @@ -224,8 +222,7 @@ def test_unsuccessful_pull_task_ins_with_metadata(self) -> None: public_key_to_bytes(self._client_public_key) ) - # Execute - # Assert + # Execute & Assert with self.assertRaises(grpc.RpcError): self._pull_task_ins.with_call( request=request, @@ -275,8 +272,7 @@ def test_unsuccessful_push_task_res_with_metadata(self) -> None: public_key_to_bytes(self._client_public_key) ) - # Execute - # Assert + # Execute & Assert with self.assertRaises(grpc.RpcError): self._push_task_res.with_call( request=request, From 8b58c5e631dddb6645f7da7a936f92528340a5b6 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 17:04:16 +0200 Subject: [PATCH 51/73] Add get run --- src/py/flwr/server/server_interceptor.py | 10 +++- src/py/flwr/server/server_interceptor_test.py | 57 +++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/server/server_interceptor.py b/src/py/flwr/server/server_interceptor.py index e947ac35a697..fa9d3820cc82 100644 --- a/src/py/flwr/server/server_interceptor.py +++ b/src/py/flwr/server/server_interceptor.py @@ -35,6 +35,8 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, + GetRunRequest, + GetRunResponse, PullTaskInsRequest, PullTaskInsResponse, PushTaskResRequest, @@ -54,6 +56,7 @@ DeleteNodeResponse, PullTaskInsResponse, PushTaskResResponse, + GetRunResponse, ] @@ -134,7 +137,12 @@ def _generic_method_handler( ) elif isinstance( request, - (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest), + ( + DeleteNodeRequest, + PullTaskInsRequest, + PushTaskResRequest, + GetRunRequest, + ), ): hmac_value = base64.urlsafe_b64decode( _get_value_from_tuples( diff --git a/src/py/flwr/server/server_interceptor_test.py b/src/py/flwr/server/server_interceptor_test.py index 5e0ed0be2313..ef5f1f6bd816 100644 --- a/src/py/flwr/server/server_interceptor_test.py +++ b/src/py/flwr/server/server_interceptor_test.py @@ -31,6 +31,8 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, + GetRunRequest, + GetRunResponse, PullTaskInsRequest, PullTaskInsResponse, PushTaskResRequest, @@ -88,6 +90,11 @@ def setUp(self) -> None: request_serializer=PushTaskResRequest.SerializeToString, response_deserializer=PushTaskResResponse.FromString, ) + self._get_run = self._channel.unary_unary( + "/flwr.proto.Fleet/GetRun", + request_serializer=GetRunRequest.SerializeToString, + response_deserializer=GetRunResponse.FromString, + ) def tearDown(self) -> None: """Clean up grpc server.""" @@ -281,3 +288,53 @@ def test_unsuccessful_push_task_res_with_metadata(self) -> None: (_AUTH_TOKEN_HEADER, hmac_value), ), ) + + def test_successful_get_run_with_metadata(self) -> None: + """Test server interceptor for pull task ins.""" + # Prepare + request = GetRunRequest(run_id=0) + shared_secret = generate_shared_key( + self._client_private_key, self._server_public_key + ) + hmac_value = base64.urlsafe_b64encode( + compute_hmac(shared_secret, request.SerializeToString(True)) + ) + public_key_bytes = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ) + + # Execute + response, call = self._get_run.with_call( + request=request, + metadata=( + (_PUBLIC_KEY_HEADER, public_key_bytes), + (_AUTH_TOKEN_HEADER, hmac_value), + ), + ) + + # Assert + assert isinstance(response, GetRunResponse) + assert grpc.StatusCode.OK == call.code() + + def test_unsuccessful_get_run_with_metadata(self) -> None: + """Test server interceptor for pull task ins unsuccessfully.""" + # Prepare + request = GetRunRequest(run_id=0) + client_private_key, _ = generate_key_pairs() + shared_secret = generate_shared_key(client_private_key, self._server_public_key) + hmac_value = base64.urlsafe_b64encode( + compute_hmac(shared_secret, request.SerializeToString(True)) + ) + public_key_bytes = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ) + + # Execute & Assert + with self.assertRaises(grpc.RpcError): + self._get_run.with_call( + request=request, + metadata=( + (_PUBLIC_KEY_HEADER, public_key_bytes), + (_AUTH_TOKEN_HEADER, hmac_value), + ), + ) From bd6163bfc4949aaf19e9264a61a15ac80aa6c657 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 17:27:15 +0200 Subject: [PATCH 52/73] Dynamically generate ssh key --- src/py/flwr/server/server_test.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/src/py/flwr/server/server_test.py b/src/py/flwr/server/server_test.py index 288eab4b23f9..be303e52d13c 100644 --- a/src/py/flwr/server/server_test.py +++ b/src/py/flwr/server/server_test.py @@ -25,6 +25,8 @@ from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.serialization import ( Encoding, + NoEncryption, + PrivateFormat, PublicFormat, load_ssh_private_key, load_ssh_public_key, @@ -205,23 +207,14 @@ def test_setup_client_auth() -> None: """Test setup client authentication.""" # Generate keys _, first_public_key = generate_key_pairs() - server_public_key = ( - b"ecdsa-sha2-nistp384 " - b"AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABhBIqtP/EvrgBYukcjRJT9zVLXE" - b"fykvVvT/QcHXuxCNu83SyCwedk3nNZxy5rZ1f8KoU+OSGmum5I9BxnWcLeBC+TGqpifTUSNwa/riV" - b"oJGcN/SxF3euqQg58YePORhos/Ug==" + private_key, public_key = generate_key_pairs() + + server_public_key = public_key.public_bytes( + encoding=Encoding.OpenSSH, format=PublicFormat.OpenSSH + ) + server_private_key = private_key.private_bytes( + Encoding.PEM, PrivateFormat.OpenSSH, NoEncryption() ) - server_private_key = b"""-----BEGIN OPENSSH PRIVATE KEY----- - b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAiAAAABNlY2RzYS - 1zaGEyLW5pc3RwMzg0AAAACG5pc3RwMzg0AAAAYQSKrT/xL64AWLpHI0SU/c1S1xH8pL1b - 0/0HB17sQjbvN0sgsHnZN5zWccua2dX/CqFPjkhprpuSPQcZ1nC3gQvkxqqYn01EjcGv64 - laCRnDf0sRd3rqkIOfGHjzkYaLP1IAAADw/rbMO/62zDsAAAATZWNkc2Etc2hhMi1uaXN0 - cDM4NAAAAAhuaXN0cDM4NAAAAGEEiq0/8S+uAFi6RyNElP3NUtcR/KS9W9P9Bwde7EI27z - dLILB52Tec1nHLmtnV/wqhT45Iaa6bkj0HGdZwt4EL5MaqmJ9NRI3Br+uJWgkZw39LEXd6 - 6pCDnxh485GGiz9SAAAAMQDQmvP7JeFNBDvo1VXciQF0Wv3/DCcj9x0kUABuX1gxb42Iw3 - v7FOEco/enMaS4URwAAAAnZGFuaWVsbnVncmFoYUBEYW5pZWxzLU1hY0Jvb2stUHJvLmxv - Y2Fs - -----END OPENSSH PRIVATE KEY-----""" _, second_public_key = generate_key_pairs() with tempfile.TemporaryDirectory() as temp_dir: From 6a02ba9cd31cdcb046f644cce0331b4bee3aa3af Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 17:35:25 +0200 Subject: [PATCH 53/73] Encode only once --- src/py/flwr/server/server_interceptor.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/server/server_interceptor.py b/src/py/flwr/server/server_interceptor.py index fa9d3820cc82..c31421a2f3c4 100644 --- a/src/py/flwr/server/server_interceptor.py +++ b/src/py/flwr/server/server_interceptor.py @@ -85,6 +85,9 @@ def __init__( self.server_public_key = public_key self.state = state_factory.state() self.state.store_client_public_keys(client_public_keys) + self.encoded_server_public_key = base64.urlsafe_b64encode( + public_key_to_bytes(self.server_public_key) + ) log( INFO, "Client authentication enabled with %d known public keys", @@ -129,9 +132,7 @@ def _generic_method_handler( ( ( _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode( - public_key_to_bytes(self.server_public_key) - ), + self.encoded_server_public_key, ), ) ) From d7029242518915336683b8d3b2ffad939279e25c Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 17:37:11 +0200 Subject: [PATCH 54/73] Format --- src/py/flwr/server/server_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/server/server_test.py b/src/py/flwr/server/server_test.py index be303e52d13c..4fa964905f8a 100644 --- a/src/py/flwr/server/server_test.py +++ b/src/py/flwr/server/server_test.py @@ -203,7 +203,7 @@ def test_set_max_workers() -> None: assert server.max_workers == 42 -def test_setup_client_auth() -> None: +def test_setup_client_auth() -> None: # pylint: disable=R0914 """Test setup client authentication.""" # Generate keys _, first_public_key = generate_key_pairs() From ac54694e5be7edaf0abcbf5472ad809d1669daf1 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 17:56:52 +0200 Subject: [PATCH 55/73] Add get_run --- src/py/flwr/server/server_interceptor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/server/server_interceptor.py b/src/py/flwr/server/server_interceptor.py index c31421a2f3c4..abffc0c170be 100644 --- a/src/py/flwr/server/server_interceptor.py +++ b/src/py/flwr/server/server_interceptor.py @@ -48,7 +48,11 @@ _AUTH_TOKEN_HEADER = "auth-token" Request = Union[ - CreateNodeRequest, DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest + CreateNodeRequest, + DeleteNodeRequest, + PullTaskInsRequest, + PushTaskResRequest, + GetRunRequest, ] Response = Union[ From 5735549541421e68e086cdc0163b454a486089ec Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 18:30:24 +0200 Subject: [PATCH 56/73] Unindent function --- src/py/flwr/server/app.py | 92 ++++++++++++++++++++------------------- 1 file changed, 47 insertions(+), 45 deletions(-) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 7834fa91f138..68554b4233a2 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -423,53 +423,55 @@ def _try_setup_client_authentication( args: argparse.Namespace, certificates: Optional[Tuple[bytes, bytes, bytes]], ) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePublicKey, ec.EllipticCurvePrivateKey]]: - if args.require_client_authentication: - if certificates is None: - sys.exit( - "Certificates are required to enable client authentication. " - "Please provide certificate paths with '--certificates' before " - "enabling '--require-client-authentication'." - ) - client_keys_file_path = Path(args.require_client_authentication[0]) - if not client_keys_file_path.exists(): - sys.exit( - "Client public keys csv file are required for client authentication. " - "Please provide the csv file path containing known client public keys " - "to '--require-client-authentication'." - ) - client_public_keys: Set[bytes] = set() - public_key = load_ssh_public_key( - Path(args.require_client_authentication[1]).read_bytes() + if not args.require_client_authentication: + return None + + if certificates is None: + sys.exit( + "Certificates are required to enable client authentication. " + "Please provide certificate paths with '--certificates' before " + "enabling '--require-client-authentication'." ) - private_key = load_ssh_private_key( - Path(args.require_client_authentication[2]).read_bytes(), - None, + + client_keys_file_path = Path(args.require_client_authentication[0]) + if not client_keys_file_path.exists(): + sys.exit( + "Client public keys csv file are required for client authentication. " + "Please provide the csv file path containing known client public keys " + "to '--require-client-authentication'." + ) + + client_public_keys: Set[bytes] = set() + public_key = load_ssh_public_key( + Path(args.require_client_authentication[1]).read_bytes() + ) + private_key = load_ssh_private_key( + Path(args.require_client_authentication[2]).read_bytes(), + None, + ) + if not isinstance(public_key, ec.EllipticCurvePublicKey) or not isinstance( + private_key, ec.EllipticCurvePrivateKey + ): + sys.exit( + "An eliptic curve public and private key pair is required for " + "client authentication. Please provide the file path containing " + "valid public and private key to '--require-client-authentication'." + ) + server_public_key = public_key + server_private_key = private_key + + with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile: + reader = csv.reader(csvfile) + for row in reader: + for element in row: + public_key = load_ssh_public_key(element.encode()) + if isinstance(public_key, ec.EllipticCurvePublicKey): + client_public_keys.add(public_key_to_bytes(public_key)) + return ( + client_public_keys, + server_public_key, + server_private_key, ) - if not isinstance(public_key, ec.EllipticCurvePublicKey) or not isinstance( - private_key, ec.EllipticCurvePrivateKey - ): - sys.exit( - "An eliptic curve public and private key pair is required for " - "client authentication. Please provide the file path containing " - "valid public and private key to '--require-client-authentication'." - ) - server_public_key = public_key - server_private_key = private_key - - with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile: - reader = csv.reader(csvfile) - for row in reader: - for element in row: - public_key = load_ssh_public_key(element.encode()) - if isinstance(public_key, ec.EllipticCurvePublicKey): - client_public_keys.add(public_key_to_bytes(public_key)) - return ( - client_public_keys, - server_public_key, - server_private_key, - ) - else: - return None def _try_obtain_certificates( From 1ca16ba08cc0063e5e18a09e9b4b161fd155389b Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 18:31:27 +0200 Subject: [PATCH 57/73] Format --- src/py/flwr/server/app.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 68554b4233a2..953e00fbc52b 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -425,7 +425,7 @@ def _try_setup_client_authentication( ) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePublicKey, ec.EllipticCurvePrivateKey]]: if not args.require_client_authentication: return None - + if certificates is None: sys.exit( "Certificates are required to enable client authentication. " @@ -440,7 +440,7 @@ def _try_setup_client_authentication( "Please provide the csv file path containing known client public keys " "to '--require-client-authentication'." ) - + client_public_keys: Set[bytes] = set() public_key = load_ssh_public_key( Path(args.require_client_authentication[1]).read_bytes() From 64ec5e8141cbab0c79a7b7f92d2890499914dcb1 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 19:42:42 +0200 Subject: [PATCH 58/73] Update src/py/flwr/server/app.py Co-authored-by: Daniel J. Beutel --- src/py/flwr/server/app.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 953e00fbc52b..52d8ae0d2a49 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -428,8 +428,8 @@ def _try_setup_client_authentication( if certificates is None: sys.exit( - "Certificates are required to enable client authentication. " - "Please provide certificate paths with '--certificates' before " + "Client authentication only works over secure connections. " + "Please provide certificate paths using '--certificates' when " "enabling '--require-client-authentication'." ) From 9c03cd6554cda03612b159e2092329ef26fb5f34 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 19:43:01 +0200 Subject: [PATCH 59/73] Update src/py/flwr/server/app.py Co-authored-by: Daniel J. Beutel --- src/py/flwr/server/app.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 52d8ae0d2a49..78bf533d090c 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -436,8 +436,8 @@ def _try_setup_client_authentication( client_keys_file_path = Path(args.require_client_authentication[0]) if not client_keys_file_path.exists(): sys.exit( - "Client public keys csv file are required for client authentication. " - "Please provide the csv file path containing known client public keys " + "The provided path to the client public keys CSV file does not exist. " + "Please provide the CSV file path containing known client public keys " "to '--require-client-authentication'." ) From 4779dc09611a64e7ef3ecd56a4b45100823f2ac9 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 20:26:13 +0200 Subject: [PATCH 60/73] Implement review feedback --- .../crypto/symmetric_encryption.py | 13 +++++++ src/py/flwr/server/app.py | 39 ++++++++++++------- .../fleet/grpc_rere}/server_interceptor.py | 0 .../grpc_rere}/server_interceptor_test.py | 4 +- 4 files changed, 40 insertions(+), 16 deletions(-) rename src/py/flwr/server/{ => superlink/fleet/grpc_rere}/server_interceptor.py (100%) rename src/py/flwr/server/{ => superlink/fleet/grpc_rere}/server_interceptor_test.py (98%) diff --git a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py index 1d004a398ea8..76470fc0b33f 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py +++ b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py @@ -117,3 +117,16 @@ def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool: return True except InvalidSignature: return False + + +def ssh_types_to_elliptic_curve( + private_key: serialization.SSHPrivateKeyTypes, + public_key: serialization.SSHPublicKeyTypes, +) -> Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]: + """Cast SSH key types to elliptic curve.""" + if isinstance(private_key, ec.EllipticCurvePrivateKey) and isinstance( + public_key, ec.EllipticCurvePublicKey + ): + return (private_key, public_key) + + raise TypeError("The provided key is not an EllipticCurvePrivateKey") diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 78bf533d090c..12eabfee1182 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -44,6 +44,7 @@ from flwr.common.logger import log from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( public_key_to_bytes, + ssh_types_to_elliptic_curve, ) from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611 add_FleetServicer_to_server, @@ -53,7 +54,6 @@ from .history import History from .server import Server, init_defaults, run_fl from .server_config import ServerConfig -from .server_interceptor import AuthenticateServerInterceptor from .strategy import Strategy from .superlink.driver.driver_grpc import run_driver_api_grpc from .superlink.fleet.grpc_bidi.grpc_server import ( @@ -61,6 +61,7 @@ start_grpc_server, ) from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer +from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor from .superlink.fleet.vce import start_vce from .superlink.state import StateFactory @@ -436,29 +437,32 @@ def _try_setup_client_authentication( client_keys_file_path = Path(args.require_client_authentication[0]) if not client_keys_file_path.exists(): sys.exit( - "The provided path to the client public keys CSV file does not exist. " + "The provided path to the client public keys CSV file does not exist: " + f"{client_keys_file_path}. " "Please provide the CSV file path containing known client public keys " "to '--require-client-authentication'." ) client_public_keys: Set[bytes] = set() - public_key = load_ssh_public_key( + ssh_public_key = load_ssh_public_key( Path(args.require_client_authentication[1]).read_bytes() ) - private_key = load_ssh_private_key( + ssh_private_key = load_ssh_private_key( Path(args.require_client_authentication[2]).read_bytes(), None, ) - if not isinstance(public_key, ec.EllipticCurvePublicKey) or not isinstance( - private_key, ec.EllipticCurvePrivateKey - ): + + try: + server_private_key, server_public_key = ssh_types_to_elliptic_curve( + ssh_private_key, ssh_public_key + ) + except TypeError: sys.exit( - "An eliptic curve public and private key pair is required for " - "client authentication. Please provide the file path containing " - "valid public and private key to '--require-client-authentication'." + "The file paths provided do not contain a vaild public and private " + "key. Client authentication requires an elliptic curve public and " + "private key pair. Please provide the file path containing elliptic " + "curve public and private key to '--require-client-authentication'." ) - server_public_key = public_key - server_private_key = private_key with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile: reader = csv.reader(csvfile) @@ -467,6 +471,12 @@ def _try_setup_client_authentication( public_key = load_ssh_public_key(element.encode()) if isinstance(public_key, ec.EllipticCurvePublicKey): client_public_keys.add(public_key_to_bytes(public_key)) + else: + sys.exit( + "Error: Unable to parse the public keys in the .csv " + "file. Please ensure that the .csv file contains valid " + "SSH public keys and try again." + ) return ( client_public_keys, server_public_key, @@ -697,8 +707,9 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None: nargs=3, metavar=("CLIENT_KEYS", "SERVER_PUBLIC_KEY", "SERVER_PRIVATE_KEY"), type=str, - help="Paths to .csv file containing list of known client public keys for " - "authentication, server public key, and server private key, in that order.", + help="Provide three file paths: (1) a .csv file containing a list of " + "known client public keys for authentication, (2) the server's public " + "key file, and (3) the server's private key file.", ) diff --git a/src/py/flwr/server/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py similarity index 100% rename from src/py/flwr/server/server_interceptor.py rename to src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py diff --git a/src/py/flwr/server/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py similarity index 98% rename from src/py/flwr/server/server_interceptor_test.py rename to src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index ef5f1f6bd816..b3924391a0f4 100644 --- a/src/py/flwr/server/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -39,14 +39,14 @@ PushTaskResResponse, ) from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611 +from flwr.server.app import ADDRESS_FLEET_API_GRPC_RERE, _run_fleet_api_grpc_rere +from flwr.server.superlink.state.state_factory import StateFactory -from .app import ADDRESS_FLEET_API_GRPC_RERE, _run_fleet_api_grpc_rere from .server_interceptor import ( _AUTH_TOKEN_HEADER, _PUBLIC_KEY_HEADER, AuthenticateServerInterceptor, ) -from .superlink.state.state_factory import StateFactory class TestServerInterceptor(unittest.TestCase): # pylint: disable=R0902 From af03db14098feaa6e8a815710f3010d3225002e3 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 25 Apr 2024 07:31:09 +0200 Subject: [PATCH 61/73] Adapt error string --- .../common/secure_aggregation/crypto/symmetric_encryption.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py index 76470fc0b33f..9856b8b706f9 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py +++ b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py @@ -129,4 +129,6 @@ def ssh_types_to_elliptic_curve( ): return (private_key, public_key) - raise TypeError("The provided key is not an EllipticCurvePrivateKey") + raise TypeError( + "The provided key is not an EllipticCurvePrivateKey or EllipticCurvePublicKey" + ) From 24b7182fa88dd3be3ce935da0cf2cd20dc419b6a Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Fri, 26 Apr 2024 15:41:35 +0200 Subject: [PATCH 62/73] Update src/py/flwr/server/app.py Co-authored-by: Daniel J. Beutel --- src/py/flwr/server/app.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 12eabfee1182..54f058d1b25d 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -458,10 +458,10 @@ def _try_setup_client_authentication( ) except TypeError: sys.exit( - "The file paths provided do not contain a vaild public and private " - "key. Client authentication requires an elliptic curve public and " - "private key pair. Please provide the file path containing elliptic " - "curve public and private key to '--require-client-authentication'." + "The file paths provided could not be read as a public and private " + "key pair. Client authentication requires an elliptic curve public and " + "private key pair. Please provide the file paths containing elliptic " + "curve public and private keys to '--require-client-authentication'." ) with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile: From 16e20d4b54367ee1e1a137626da339493f1c9a26 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Fri, 26 Apr 2024 15:43:30 +0200 Subject: [PATCH 63/73] Change data to maybe_keys --- src/py/flwr/server/app.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 54f058d1b25d..7475c9aac884 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -366,14 +366,14 @@ def run_superlink() -> None: host, port, is_v6 = parsed_address address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}" - data = _try_setup_client_authentication(args, certificates) + maybe_keys = _try_setup_client_authentication(args, certificates) interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None - if data is not None: + if maybe_keys is not None: ( client_public_keys, server_public_key, server_private_key, - ) = data + ) = maybe_keys interceptors = [ AuthenticateServerInterceptor( state_factory, From a742d40419d38b1e570590e7dfd2109eb33167fc Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Fri, 26 Apr 2024 15:55:10 +0200 Subject: [PATCH 64/73] Implement feedback --- src/py/flwr/server/app.py | 2 +- .../superlink/fleet/grpc_rere/server_interceptor.py | 12 ++++++------ .../fleet/grpc_rere/server_interceptor_test.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 7475c9aac884..74a1131a9e29 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -376,7 +376,7 @@ def run_superlink() -> None: ) = maybe_keys interceptors = [ AuthenticateServerInterceptor( - state_factory, + state_factory.state(), client_public_keys, server_private_key, server_public_key, diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index abffc0c170be..b1995f5084f5 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -42,7 +42,7 @@ PushTaskResRequest, PushTaskResResponse, ) -from flwr.server.superlink.state import StateFactory +from flwr.server.superlink.state import State _PUBLIC_KEY_HEADER = "public-key" _AUTH_TOKEN_HEADER = "auth-token" @@ -79,7 +79,7 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore def __init__( self, - state_factory: StateFactory, + state: State, client_public_keys: Set[bytes], private_key: ec.EllipticCurvePrivateKey, public_key: ec.EllipticCurvePublicKey, @@ -87,7 +87,7 @@ def __init__( self._lock = threading.Lock() self.server_private_key = private_key self.server_public_key = public_key - self.state = state_factory.state() + self.state = state self.state.store_client_public_keys(client_public_keys) self.encoded_server_public_key = base64.urlsafe_b64encode( public_key_to_bytes(self.server_public_key) @@ -105,9 +105,9 @@ def intercept_service( ) -> grpc.RpcMethodHandler: """Flower server interceptor authentication logic. - Intercept unary call from client and do authentication process by validating - metadata sent from client. Continue RPC call if client is authenticated, else, - terminate RPC call by setting context to abort. + Intercept all unary calls from clients and authenticate clients by validating + auth metadata sent by the client. Continue RPC call if client is authenticated, + else, terminate RPC call by setting context to abort. """ message_handler: grpc.RpcMethodHandler = continuation(handler_call_details) return self._generic_auth_unary_method_handler(message_handler) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index b3924391a0f4..0877238fbf30 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -60,7 +60,7 @@ def setUp(self) -> None: state_factory = StateFactory(":flwr-in-memory-state:") self._server_interceptor = AuthenticateServerInterceptor( - state_factory, + state_factory.state(), {public_key_to_bytes(self._client_public_key)}, self._server_private_key, self._server_public_key, From 461802a652d709b67069df0b189e22b7b4320335 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Fri, 26 Apr 2024 16:12:02 +0200 Subject: [PATCH 65/73] Private key first then public --- src/py/flwr/server/app.py | 24 ++++++++++++------------ src/py/flwr/server/server_test.py | 12 ++++++------ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 74a1131a9e29..87d1d1fd0c2e 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -371,8 +371,8 @@ def run_superlink() -> None: if maybe_keys is not None: ( client_public_keys, - server_public_key, server_private_key, + server_public_key, ) = maybe_keys interceptors = [ AuthenticateServerInterceptor( @@ -423,7 +423,7 @@ def run_superlink() -> None: def _try_setup_client_authentication( args: argparse.Namespace, certificates: Optional[Tuple[bytes, bytes, bytes]], -) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePublicKey, ec.EllipticCurvePrivateKey]]: +) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]: if not args.require_client_authentication: return None @@ -444,13 +444,13 @@ def _try_setup_client_authentication( ) client_public_keys: Set[bytes] = set() - ssh_public_key = load_ssh_public_key( - Path(args.require_client_authentication[1]).read_bytes() - ) ssh_private_key = load_ssh_private_key( - Path(args.require_client_authentication[2]).read_bytes(), + Path(args.require_client_authentication[1]).read_bytes(), None, ) + ssh_public_key = load_ssh_public_key( + Path(args.require_client_authentication[2]).read_bytes() + ) try: server_private_key, server_public_key = ssh_types_to_elliptic_curve( @@ -458,10 +458,10 @@ def _try_setup_client_authentication( ) except TypeError: sys.exit( - "The file paths provided could not be read as a public and private " + "The file paths provided could not be read as a private and public " "key pair. Client authentication requires an elliptic curve public and " "private key pair. Please provide the file paths containing elliptic " - "curve public and private keys to '--require-client-authentication'." + "curve private and public keys to '--require-client-authentication'." ) with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile: @@ -479,8 +479,8 @@ def _try_setup_client_authentication( ) return ( client_public_keys, - server_public_key, server_private_key, + server_public_key, ) @@ -705,11 +705,11 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--require-client-authentication", nargs=3, - metavar=("CLIENT_KEYS", "SERVER_PUBLIC_KEY", "SERVER_PRIVATE_KEY"), + metavar=("CLIENT_KEYS", "SERVER_PRIVATE_KEY", "SERVER_PUBLIC_KEY"), type=str, help="Provide three file paths: (1) a .csv file containing a list of " - "known client public keys for authentication, (2) the server's public " - "key file, and (3) the server's private key file.", + "known client public keys for authentication, (2) the server's private " + "key file, and (3) the server's public key file.", ) diff --git a/src/py/flwr/server/server_test.py b/src/py/flwr/server/server_test.py index 4fa964905f8a..429a610af211 100644 --- a/src/py/flwr/server/server_test.py +++ b/src/py/flwr/server/server_test.py @@ -220,8 +220,8 @@ def test_setup_client_auth() -> None: # pylint: disable=R0914 with tempfile.TemporaryDirectory() as temp_dir: # Initialize temporary files client_keys_file_path = Path(temp_dir) / "client_keys.csv" - server_public_key_path = Path(temp_dir) / "server_public_key" server_private_key_path = Path(temp_dir) / "server_private_key" + server_public_key_path = Path(temp_dir) / "server_public_key" # Fill the files with relevant keys with open(client_keys_file_path, "w", newline="", encoding="utf-8") as csvfile: @@ -243,8 +243,8 @@ def test_setup_client_auth() -> None: # pylint: disable=R0914 mock_args = argparse.Namespace( require_client_authentication=[ str(client_keys_file_path), - str(server_public_key_path), str(server_private_key_path), + str(server_public_key_path), ] ) @@ -263,11 +263,11 @@ def test_setup_client_auth() -> None: # pylint: disable=R0914 public_key_to_bytes(first_public_key), public_key_to_bytes(second_public_key), } - assert public_key_to_bytes(result[1]) == public_key_to_bytes( - expected_public_key - ) - assert private_key_to_bytes(result[2]) == private_key_to_bytes( + assert private_key_to_bytes(result[1]) == private_key_to_bytes( expected_private_key ) + assert public_key_to_bytes(result[2]) == public_key_to_bytes( + expected_public_key + ) else: raise AssertionError() From cc61c1cf2def08df040b4d48ef1b4149bd69dfa7 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Mon, 29 Apr 2024 13:09:31 +0200 Subject: [PATCH 66/73] Add context to message handler as comment --- .../flwr/server/superlink/fleet/grpc_rere/server_interceptor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index b1995f5084f5..d2e74c732a44 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -109,6 +109,8 @@ def intercept_service( auth metadata sent by the client. Continue RPC call if client is authenticated, else, terminate RPC call by setting context to abort. """ + + # The default message handler in flwr.server.superlink.fleet.message_handler message_handler: grpc.RpcMethodHandler = continuation(handler_call_details) return self._generic_auth_unary_method_handler(message_handler) From 948d8a79d9e1f4ef89d88311e2350164b6857e3b Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Mon, 29 Apr 2024 13:25:35 +0200 Subject: [PATCH 67/73] Remove state usage --- src/py/flwr/server/app.py | 1 - .../fleet/grpc_rere/server_interceptor.py | 92 +++++++++---------- .../grpc_rere/server_interceptor_test.py | 1 - .../server/superlink/state/in_memory_state.py | 33 ------- .../server/superlink/state/sqlite_state.py | 68 -------------- src/py/flwr/server/superlink/state/state.py | 26 ------ .../flwr/server/superlink/state/state_test.py | 88 +----------------- 7 files changed, 43 insertions(+), 266 deletions(-) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 87d1d1fd0c2e..7e06062311da 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -376,7 +376,6 @@ def run_superlink() -> None: ) = maybe_keys interceptors = [ AuthenticateServerInterceptor( - state_factory.state(), client_public_keys, server_private_key, server_public_key, diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index d2e74c732a44..3a358e6ab88a 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -16,7 +16,6 @@ import base64 -import threading from logging import INFO from typing import Any, Callable, Sequence, Set, Tuple, Union @@ -42,7 +41,6 @@ PushTaskResRequest, PushTaskResResponse, ) -from flwr.server.superlink.state import State _PUBLIC_KEY_HEADER = "public-key" _AUTH_TOKEN_HEADER = "auth-token" @@ -79,18 +77,14 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore def __init__( self, - state: State, client_public_keys: Set[bytes], private_key: ec.EllipticCurvePrivateKey, public_key: ec.EllipticCurvePublicKey, ): - self._lock = threading.Lock() self.server_private_key = private_key - self.server_public_key = public_key - self.state = state - self.state.store_client_public_keys(client_public_keys) + self.client_public_keys = client_public_keys self.encoded_server_public_key = base64.urlsafe_b64encode( - public_key_to_bytes(self.server_public_key) + public_key_to_bytes(public_key) ) log( INFO, @@ -109,7 +103,6 @@ def intercept_service( auth metadata sent by the client. Continue RPC call if client is authenticated, else, terminate RPC call by setting context to abort. """ - # The default message handler in flwr.server.superlink.fleet.message_handler message_handler: grpc.RpcMethodHandler = continuation(handler_call_details) return self._generic_auth_unary_method_handler(message_handler) @@ -121,55 +114,52 @@ def _generic_method_handler( request: Request, context: grpc.ServicerContext, ) -> Any: - with self._lock: - client_public_key_bytes = base64.urlsafe_b64decode( - _get_value_from_tuples( - _PUBLIC_KEY_HEADER, context.invocation_metadata() - ) + client_public_key_bytes = base64.urlsafe_b64decode( + _get_value_from_tuples( + _PUBLIC_KEY_HEADER, context.invocation_metadata() ) - is_public_key_known = ( - client_public_key_bytes in self.state.get_client_public_keys() - ) - if not is_public_key_known: - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") + ) + is_public_key_known = client_public_key_bytes in self.client_public_keys + if not is_public_key_known: + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") - if isinstance(request, CreateNodeRequest): - context.send_initial_metadata( - ( - ( - _PUBLIC_KEY_HEADER, - self.encoded_server_public_key, - ), - ) - ) - elif isinstance( - request, + if isinstance(request, CreateNodeRequest): + context.send_initial_metadata( ( - DeleteNodeRequest, - PullTaskInsRequest, - PushTaskResRequest, - GetRunRequest, - ), - ): - hmac_value = base64.urlsafe_b64decode( - _get_value_from_tuples( - _AUTH_TOKEN_HEADER, context.invocation_metadata() - ) - ) - client_public_key = bytes_to_public_key(client_public_key_bytes) - shared_secret = generate_shared_key( - self.server_private_key, - client_public_key, + ( + _PUBLIC_KEY_HEADER, + self.encoded_server_public_key, + ), ) - verify = verify_hmac( - shared_secret, request.SerializeToString(True), hmac_value + ) + elif isinstance( + request, + ( + DeleteNodeRequest, + PullTaskInsRequest, + PushTaskResRequest, + GetRunRequest, + ), + ): + hmac_value = base64.urlsafe_b64decode( + _get_value_from_tuples( + _AUTH_TOKEN_HEADER, context.invocation_metadata() ) - if not verify: - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") - else: + ) + client_public_key = bytes_to_public_key(client_public_key_bytes) + shared_secret = generate_shared_key( + self.server_private_key, + client_public_key, + ) + verify = verify_hmac( + shared_secret, request.SerializeToString(True), hmac_value + ) + if not verify: context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") + else: + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") - return message_handler.unary_unary(request, context) + return message_handler.unary_unary(request, context) return grpc.unary_unary_rpc_method_handler( _generic_method_handler, diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index 0877238fbf30..b68d41f304a4 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -60,7 +60,6 @@ def setUp(self) -> None: state_factory = StateFactory(":flwr-in-memory-state:") self._server_interceptor = AuthenticateServerInterceptor( - state_factory.state(), {public_key_to_bytes(self._client_public_key)}, self._server_private_key, self._server_public_key, diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index ebccac3509f0..9c05daa9aa93 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -254,39 +254,6 @@ def create_run(self, fab_id: str, fab_version: str) -> int: log(ERROR, "Unexpected run creation failure.") return 0 - def store_server_public_private_key( - self, public_key: bytes, private_key: bytes - ) -> None: - """Store `server_public_key` and `server_private_key` in state.""" - with self.lock: - if self.server_private_key is None and self.server_public_key is None: - self.server_private_key = private_key - self.server_public_key = public_key - else: - raise RuntimeError("Server public and private key already set") - - def get_server_private_key(self) -> Optional[bytes]: - """Retrieve `server_private_key` in urlsafe bytes.""" - return self.server_private_key - - def get_server_public_key(self) -> Optional[bytes]: - """Retrieve `server_public_key` in urlsafe bytes.""" - return self.server_public_key - - def store_client_public_keys(self, public_keys: Set[bytes]) -> None: - """Store a set of `client_public_keys` in state.""" - with self.lock: - self.client_public_keys = public_keys - - def store_client_public_key(self, public_key: bytes) -> None: - """Store a `client_public_key` in state.""" - with self.lock: - self.client_public_keys.add(public_key) - - def get_client_public_keys(self) -> Set[bytes]: - """Retrieve all currently stored `client_public_keys` as a set.""" - return self.client_public_keys - def get_run(self, run_id: int) -> Tuple[int, str, str]: """Retrieve information about the run with the specified `run_id`.""" with self.lock: diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 39ed92637902..d64fc53fbb11 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -40,19 +40,6 @@ ); """ -SQL_CREATE_TABLE_CREDENTIAL = """ -CREATE TABLE IF NOT EXISTS credential( - public_key BLOB PRIMARY KEY, - private_key BLOB -); -""" - -SQL_CREATE_TABLE_PUBLIC_KEY = """ -CREATE TABLE IF NOT EXISTS public_key( - public_key BLOB UNIQUE -); -""" - SQL_CREATE_INDEX_ONLINE_UNTIL = """ CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until); """ @@ -146,8 +133,6 @@ def initialize(self, log_queries: bool = False) -> List[Tuple[str]]: cur.execute(SQL_CREATE_TABLE_TASK_INS) cur.execute(SQL_CREATE_TABLE_TASK_RES) cur.execute(SQL_CREATE_TABLE_NODE) - cur.execute(SQL_CREATE_TABLE_CREDENTIAL) - cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY) cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL) res = cur.execute("SELECT name FROM sqlite_schema;") @@ -589,59 +574,6 @@ def create_run(self, fab_id: str, fab_version: str) -> int: log(ERROR, "Unexpected run creation failure.") return 0 - def store_server_public_private_key( - self, public_key: bytes, private_key: bytes - ) -> None: - """Store `server_public_key` and `server_private_key` in state.""" - query = "SELECT COUNT(*) FROM credential" - count = self.query(query)[0]["COUNT(*)"] - if count < 1: - query = ( - "INSERT OR REPLACE INTO credential (public_key, private_key) " - "VALUES (:public_key, :private_key)" - ) - self.query(query, {"public_key": public_key, "private_key": private_key}) - else: - raise RuntimeError("Server public and private key already set") - - def get_server_private_key(self) -> Optional[bytes]: - """Retrieve `server_private_key` in urlsafe bytes.""" - query = "SELECT private_key FROM credential" - rows = self.query(query) - try: - private_key: Optional[bytes] = rows[0]["private_key"] - except IndexError: - private_key = None - return private_key - - def get_server_public_key(self) -> Optional[bytes]: - """Retrieve `server_public_key` in urlsafe bytes.""" - query = "SELECT public_key FROM credential" - rows = self.query(query) - try: - public_key: Optional[bytes] = rows[0]["public_key"] - except IndexError: - public_key = None - return public_key - - def store_client_public_keys(self, public_keys: Set[bytes]) -> None: - """Store a set of `client_public_keys` in state.""" - query = "INSERT INTO public_key (public_key) VALUES (?)" - data = [(key,) for key in public_keys] - self.query(query, data) - - def store_client_public_key(self, public_key: bytes) -> None: - """Store a `client_public_key` in state.""" - query = "INSERT INTO public_key (public_key) VALUES (:public_key)" - self.query(query, {"public_key": public_key}) - - def get_client_public_keys(self) -> Set[bytes]: - """Retrieve all currently stored `client_public_keys` as a set.""" - query = "SELECT public_key FROM public_key" - rows = self.query(query) - result: Set[bytes] = {row["public_key"] for row in rows} - return result - def get_run(self, run_id: int) -> Tuple[int, str, str]: """Retrieve information about the run with the specified `run_id`.""" query = "SELECT * FROM run WHERE run_id = ?;" diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index 7992aa2345a1..8b087e3d644a 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -171,32 +171,6 @@ def get_run(self, run_id: int) -> Tuple[int, str, str]: - `fab_version`: The version of the FAB used in the specified run. """ - @abc.abstractmethod - def store_server_public_private_key( - self, public_key: bytes, private_key: bytes - ) -> None: - """Store `server_public_key` and `server_private_key` in state.""" - - @abc.abstractmethod - def get_server_private_key(self) -> Optional[bytes]: - """Retrieve `server_private_key` in urlsafe bytes.""" - - @abc.abstractmethod - def get_server_public_key(self) -> Optional[bytes]: - """Retrieve `server_public_key` in urlsafe bytes.""" - - @abc.abstractmethod - def store_client_public_keys(self, public_keys: Set[bytes]) -> None: - """Store a set of `client_public_keys` in state.""" - - @abc.abstractmethod - def store_client_public_key(self, public_key: bytes) -> None: - """Store a `client_public_key` in state.""" - - @abc.abstractmethod - def get_client_public_keys(self) -> Set[bytes]: - """Retrieve all currently stored `client_public_keys` as a set.""" - @abc.abstractmethod def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: """Acknowledge a ping received from a node, serving as a heartbeat. diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 0aeb7b064ad6..281707e16be0 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -26,11 +26,6 @@ from flwr.common import DEFAULT_TTL from flwr.common.constant import ErrorCode -from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - generate_key_pairs, - private_key_to_bytes, - public_key_to_bytes, -) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 @@ -414,85 +409,6 @@ def test_num_task_res(self) -> None: # Assert assert num == 2 - def test_server_public_private_key(self) -> None: - """Test get server public and private key after inserting.""" - # Prepare - state: State = self.state_factory() - private_key, public_key = generate_key_pairs() - private_key_bytes = private_key_to_bytes(private_key) - public_key_bytes = public_key_to_bytes(public_key) - - # Execute - state.store_server_public_private_key(public_key_bytes, private_key_bytes) - server_private_key = state.get_server_private_key() - server_public_key = state.get_server_public_key() - - # Assert - assert server_private_key == private_key_bytes - assert server_public_key == public_key_bytes - - def test_server_public_private_key_none(self) -> None: - """Test get server public and private key without inserting.""" - # Prepare - state: State = self.state_factory() - - # Execute - server_private_key = state.get_server_private_key() - server_public_key = state.get_server_public_key() - - # Assert - assert server_private_key is None - assert server_public_key is None - - def test_store_server_public_private_key_twice(self) -> None: - """Test inserting public and private key twice.""" - # Prepare - state: State = self.state_factory() - private_key, public_key = generate_key_pairs() - private_key_bytes = private_key_to_bytes(private_key) - public_key_bytes = public_key_to_bytes(public_key) - new_private_key, new_public_key = generate_key_pairs() - new_private_key_bytes = private_key_to_bytes(new_private_key) - new_public_key_bytes = public_key_to_bytes(new_public_key) - - # Execute - state.store_server_public_private_key(public_key_bytes, private_key_bytes) - - # Assert - with self.assertRaises(RuntimeError): - state.store_server_public_private_key( - new_public_key_bytes, new_private_key_bytes - ) - - def test_client_public_keys(self) -> None: - """Test store_client_public_keys and get_client_public_keys from state.""" - # Prepare - state: State = self.state_factory() - key_pairs = [generate_key_pairs() for _ in range(3)] - public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} - - # Execute - state.store_client_public_keys(public_keys) - client_public_keys = state.get_client_public_keys() - - # Assert - assert client_public_keys == public_keys - - def test_client_public_key(self) -> None: - """Test store_client_public_key and get_client_public_keys from state.""" - # Prepare - state: State = self.state_factory() - key_pairs = [generate_key_pairs() for _ in range(3)] - public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} - - # Execute - for public_key in public_keys: - state.store_client_public_key(public_key) - client_public_keys = state.get_client_public_keys() - - # Assert - assert client_public_keys == public_keys - def test_acknowledge_ping(self) -> None: """Test if acknowledge_ping works and if get_nodes return online nodes.""" # Prepare @@ -639,7 +555,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 13 + assert len(result) == 9 class SqliteFileBasedTest(StateTest, unittest.TestCase): @@ -664,7 +580,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 13 + assert len(result) == 9 if __name__ == "__main__": From 9cde4a8ccfb80b42a14801867ad71be32cfe8a08 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Mon, 29 Apr 2024 13:28:58 +0200 Subject: [PATCH 68/73] Fix server_test --- src/py/flwr/server/server_test.py | 34 +++++++++++++++---------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/src/py/flwr/server/server_test.py b/src/py/flwr/server/server_test.py index 429a610af211..51071c13f895 100644 --- a/src/py/flwr/server/server_test.py +++ b/src/py/flwr/server/server_test.py @@ -205,7 +205,7 @@ def test_set_max_workers() -> None: def test_setup_client_auth() -> None: # pylint: disable=R0914 """Test setup client authentication.""" - # Generate keys + # Prepare _, first_public_key = generate_key_pairs() private_key, public_key = generate_key_pairs() @@ -217,6 +217,7 @@ def test_setup_client_auth() -> None: # pylint: disable=R0914 ) _, second_public_key = generate_key_pairs() + # Execute with tempfile.TemporaryDirectory() as temp_dir: # Initialize temporary files client_keys_file_path = Path(temp_dir) / "client_keys.csv" @@ -254,20 +255,17 @@ def test_setup_client_auth() -> None: # pylint: disable=R0914 expected_private_key = load_ssh_private_key(server_private_key, None) expected_public_key = load_ssh_public_key(server_public_key) - if isinstance(expected_private_key, ec.EllipticCurvePrivateKey) and isinstance( - expected_public_key, ec.EllipticCurvePublicKey - ): - # Assert result with expected values - assert result is not None - assert result[0] == { - public_key_to_bytes(first_public_key), - public_key_to_bytes(second_public_key), - } - assert private_key_to_bytes(result[1]) == private_key_to_bytes( - expected_private_key - ) - assert public_key_to_bytes(result[2]) == public_key_to_bytes( - expected_public_key - ) - else: - raise AssertionError() + # Assert + assert isinstance(expected_private_key, ec.EllipticCurvePrivateKey) + assert isinstance(expected_public_key, ec.EllipticCurvePublicKey) + assert result is not None + assert result[0] == { + public_key_to_bytes(first_public_key), + public_key_to_bytes(second_public_key), + } + assert private_key_to_bytes(result[1]) == private_key_to_bytes( + expected_private_key + ) + assert public_key_to_bytes(result[2]) == public_key_to_bytes( + expected_public_key + ) From 4d2f17ceb673aa29a6a054d370ca4c154175c367 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Mon, 29 Apr 2024 13:43:35 +0200 Subject: [PATCH 69/73] Apply suggestions from code review --- .../server/superlink/fleet/grpc_rere/server_interceptor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index 3a358e6ab88a..f78ceeeabe81 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -121,7 +121,7 @@ def _generic_method_handler( ) is_public_key_known = client_public_key_bytes in self.client_public_keys if not is_public_key_known: - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied") if isinstance(request, CreateNodeRequest): context.send_initial_metadata( @@ -155,9 +155,9 @@ def _generic_method_handler( shared_secret, request.SerializeToString(True), hmac_value ) if not verify: - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied") else: - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied!") + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied") return message_handler.unary_unary(request, context) From b7357bc5bbe6919c279b464b6960264db4c2f37c Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Mon, 29 Apr 2024 14:20:59 +0200 Subject: [PATCH 70/73] Update src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py --- .../server/superlink/fleet/grpc_rere/server_interceptor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index f78ceeeabe81..0ae3a86be93c 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -103,8 +103,8 @@ def intercept_service( auth metadata sent by the client. Continue RPC call if client is authenticated, else, terminate RPC call by setting context to abort. """ - # The default message handler in flwr.server.superlink.fleet.message_handler - message_handler: grpc.RpcMethodHandler = continuation(handler_call_details) + # One of the method handlers in `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer` + method_handler: grpc.RpcMethodHandler = continuation(handler_call_details) return self._generic_auth_unary_method_handler(message_handler) def _generic_auth_unary_method_handler( From 499dbdc64e28b4e45b940e04ffba47bae453e514 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Mon, 29 Apr 2024 14:55:05 +0200 Subject: [PATCH 71/73] Revert delete state --- .../server/superlink/state/in_memory_state.py | 33 +++++++ .../server/superlink/state/sqlite_state.py | 68 ++++++++++++++ src/py/flwr/server/superlink/state/state.py | 26 ++++++ .../flwr/server/superlink/state/state_test.py | 88 ++++++++++++++++++- 4 files changed, 213 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index 9c05daa9aa93..ebccac3509f0 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -254,6 +254,39 @@ def create_run(self, fab_id: str, fab_version: str) -> int: log(ERROR, "Unexpected run creation failure.") return 0 + def store_server_public_private_key( + self, public_key: bytes, private_key: bytes + ) -> None: + """Store `server_public_key` and `server_private_key` in state.""" + with self.lock: + if self.server_private_key is None and self.server_public_key is None: + self.server_private_key = private_key + self.server_public_key = public_key + else: + raise RuntimeError("Server public and private key already set") + + def get_server_private_key(self) -> Optional[bytes]: + """Retrieve `server_private_key` in urlsafe bytes.""" + return self.server_private_key + + def get_server_public_key(self) -> Optional[bytes]: + """Retrieve `server_public_key` in urlsafe bytes.""" + return self.server_public_key + + def store_client_public_keys(self, public_keys: Set[bytes]) -> None: + """Store a set of `client_public_keys` in state.""" + with self.lock: + self.client_public_keys = public_keys + + def store_client_public_key(self, public_key: bytes) -> None: + """Store a `client_public_key` in state.""" + with self.lock: + self.client_public_keys.add(public_key) + + def get_client_public_keys(self) -> Set[bytes]: + """Retrieve all currently stored `client_public_keys` as a set.""" + return self.client_public_keys + def get_run(self, run_id: int) -> Tuple[int, str, str]: """Retrieve information about the run with the specified `run_id`.""" with self.lock: diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index d64fc53fbb11..39ed92637902 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -40,6 +40,19 @@ ); """ +SQL_CREATE_TABLE_CREDENTIAL = """ +CREATE TABLE IF NOT EXISTS credential( + public_key BLOB PRIMARY KEY, + private_key BLOB +); +""" + +SQL_CREATE_TABLE_PUBLIC_KEY = """ +CREATE TABLE IF NOT EXISTS public_key( + public_key BLOB UNIQUE +); +""" + SQL_CREATE_INDEX_ONLINE_UNTIL = """ CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until); """ @@ -133,6 +146,8 @@ def initialize(self, log_queries: bool = False) -> List[Tuple[str]]: cur.execute(SQL_CREATE_TABLE_TASK_INS) cur.execute(SQL_CREATE_TABLE_TASK_RES) cur.execute(SQL_CREATE_TABLE_NODE) + cur.execute(SQL_CREATE_TABLE_CREDENTIAL) + cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY) cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL) res = cur.execute("SELECT name FROM sqlite_schema;") @@ -574,6 +589,59 @@ def create_run(self, fab_id: str, fab_version: str) -> int: log(ERROR, "Unexpected run creation failure.") return 0 + def store_server_public_private_key( + self, public_key: bytes, private_key: bytes + ) -> None: + """Store `server_public_key` and `server_private_key` in state.""" + query = "SELECT COUNT(*) FROM credential" + count = self.query(query)[0]["COUNT(*)"] + if count < 1: + query = ( + "INSERT OR REPLACE INTO credential (public_key, private_key) " + "VALUES (:public_key, :private_key)" + ) + self.query(query, {"public_key": public_key, "private_key": private_key}) + else: + raise RuntimeError("Server public and private key already set") + + def get_server_private_key(self) -> Optional[bytes]: + """Retrieve `server_private_key` in urlsafe bytes.""" + query = "SELECT private_key FROM credential" + rows = self.query(query) + try: + private_key: Optional[bytes] = rows[0]["private_key"] + except IndexError: + private_key = None + return private_key + + def get_server_public_key(self) -> Optional[bytes]: + """Retrieve `server_public_key` in urlsafe bytes.""" + query = "SELECT public_key FROM credential" + rows = self.query(query) + try: + public_key: Optional[bytes] = rows[0]["public_key"] + except IndexError: + public_key = None + return public_key + + def store_client_public_keys(self, public_keys: Set[bytes]) -> None: + """Store a set of `client_public_keys` in state.""" + query = "INSERT INTO public_key (public_key) VALUES (?)" + data = [(key,) for key in public_keys] + self.query(query, data) + + def store_client_public_key(self, public_key: bytes) -> None: + """Store a `client_public_key` in state.""" + query = "INSERT INTO public_key (public_key) VALUES (:public_key)" + self.query(query, {"public_key": public_key}) + + def get_client_public_keys(self) -> Set[bytes]: + """Retrieve all currently stored `client_public_keys` as a set.""" + query = "SELECT public_key FROM public_key" + rows = self.query(query) + result: Set[bytes] = {row["public_key"] for row in rows} + return result + def get_run(self, run_id: int) -> Tuple[int, str, str]: """Retrieve information about the run with the specified `run_id`.""" query = "SELECT * FROM run WHERE run_id = ?;" diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index 8b087e3d644a..7992aa2345a1 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -171,6 +171,32 @@ def get_run(self, run_id: int) -> Tuple[int, str, str]: - `fab_version`: The version of the FAB used in the specified run. """ + @abc.abstractmethod + def store_server_public_private_key( + self, public_key: bytes, private_key: bytes + ) -> None: + """Store `server_public_key` and `server_private_key` in state.""" + + @abc.abstractmethod + def get_server_private_key(self) -> Optional[bytes]: + """Retrieve `server_private_key` in urlsafe bytes.""" + + @abc.abstractmethod + def get_server_public_key(self) -> Optional[bytes]: + """Retrieve `server_public_key` in urlsafe bytes.""" + + @abc.abstractmethod + def store_client_public_keys(self, public_keys: Set[bytes]) -> None: + """Store a set of `client_public_keys` in state.""" + + @abc.abstractmethod + def store_client_public_key(self, public_key: bytes) -> None: + """Store a `client_public_key` in state.""" + + @abc.abstractmethod + def get_client_public_keys(self) -> Set[bytes]: + """Retrieve all currently stored `client_public_keys` as a set.""" + @abc.abstractmethod def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: """Acknowledge a ping received from a node, serving as a heartbeat. diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 281707e16be0..0aeb7b064ad6 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -26,6 +26,11 @@ from flwr.common import DEFAULT_TTL from flwr.common.constant import ErrorCode +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + generate_key_pairs, + private_key_to_bytes, + public_key_to_bytes, +) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 @@ -409,6 +414,85 @@ def test_num_task_res(self) -> None: # Assert assert num == 2 + def test_server_public_private_key(self) -> None: + """Test get server public and private key after inserting.""" + # Prepare + state: State = self.state_factory() + private_key, public_key = generate_key_pairs() + private_key_bytes = private_key_to_bytes(private_key) + public_key_bytes = public_key_to_bytes(public_key) + + # Execute + state.store_server_public_private_key(public_key_bytes, private_key_bytes) + server_private_key = state.get_server_private_key() + server_public_key = state.get_server_public_key() + + # Assert + assert server_private_key == private_key_bytes + assert server_public_key == public_key_bytes + + def test_server_public_private_key_none(self) -> None: + """Test get server public and private key without inserting.""" + # Prepare + state: State = self.state_factory() + + # Execute + server_private_key = state.get_server_private_key() + server_public_key = state.get_server_public_key() + + # Assert + assert server_private_key is None + assert server_public_key is None + + def test_store_server_public_private_key_twice(self) -> None: + """Test inserting public and private key twice.""" + # Prepare + state: State = self.state_factory() + private_key, public_key = generate_key_pairs() + private_key_bytes = private_key_to_bytes(private_key) + public_key_bytes = public_key_to_bytes(public_key) + new_private_key, new_public_key = generate_key_pairs() + new_private_key_bytes = private_key_to_bytes(new_private_key) + new_public_key_bytes = public_key_to_bytes(new_public_key) + + # Execute + state.store_server_public_private_key(public_key_bytes, private_key_bytes) + + # Assert + with self.assertRaises(RuntimeError): + state.store_server_public_private_key( + new_public_key_bytes, new_private_key_bytes + ) + + def test_client_public_keys(self) -> None: + """Test store_client_public_keys and get_client_public_keys from state.""" + # Prepare + state: State = self.state_factory() + key_pairs = [generate_key_pairs() for _ in range(3)] + public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} + + # Execute + state.store_client_public_keys(public_keys) + client_public_keys = state.get_client_public_keys() + + # Assert + assert client_public_keys == public_keys + + def test_client_public_key(self) -> None: + """Test store_client_public_key and get_client_public_keys from state.""" + # Prepare + state: State = self.state_factory() + key_pairs = [generate_key_pairs() for _ in range(3)] + public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} + + # Execute + for public_key in public_keys: + state.store_client_public_key(public_key) + client_public_keys = state.get_client_public_keys() + + # Assert + assert client_public_keys == public_keys + def test_acknowledge_ping(self) -> None: """Test if acknowledge_ping works and if get_nodes return online nodes.""" # Prepare @@ -555,7 +639,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 9 + assert len(result) == 13 class SqliteFileBasedTest(StateTest, unittest.TestCase): @@ -580,7 +664,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 9 + assert len(result) == 13 if __name__ == "__main__": From cd4451cfc9ffefa7c21676b7e4e31cee6e102f8d Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Mon, 29 Apr 2024 14:58:46 +0200 Subject: [PATCH 72/73] Rename message handler to method handler --- .../superlink/fleet/grpc_rere/server_interceptor.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index 0ae3a86be93c..cc3993e45362 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -103,12 +103,13 @@ def intercept_service( auth metadata sent by the client. Continue RPC call if client is authenticated, else, terminate RPC call by setting context to abort. """ - # One of the method handlers in `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer` + # One of the method handlers in + # `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer` method_handler: grpc.RpcMethodHandler = continuation(handler_call_details) - return self._generic_auth_unary_method_handler(message_handler) + return self._generic_auth_unary_method_handler(method_handler) def _generic_auth_unary_method_handler( - self, message_handler: grpc.RpcMethodHandler + self, method_handler: grpc.RpcMethodHandler ) -> grpc.RpcMethodHandler: def _generic_method_handler( request: Request, @@ -159,10 +160,10 @@ def _generic_method_handler( else: context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied") - return message_handler.unary_unary(request, context) + return method_handler.unary_unary(request, context) return grpc.unary_unary_rpc_method_handler( _generic_method_handler, - request_deserializer=message_handler.request_deserializer, - response_serializer=message_handler.response_serializer, + request_deserializer=method_handler.request_deserializer, + response_serializer=method_handler.response_serializer, ) From 1329328474761de17d9a29a5cf396dd6710e7ccc Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Mon, 29 Apr 2024 15:06:40 +0200 Subject: [PATCH 73/73] Change Any to Response type --- .../server/superlink/fleet/grpc_rere/server_interceptor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index cc3993e45362..7532364336a7 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -114,7 +114,7 @@ def _generic_auth_unary_method_handler( def _generic_method_handler( request: Request, context: grpc.ServicerContext, - ) -> Any: + ) -> Response: client_public_key_bytes = base64.urlsafe_b64decode( _get_value_from_tuples( _PUBLIC_KEY_HEADER, context.invocation_metadata() @@ -160,7 +160,7 @@ def _generic_method_handler( else: context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied") - return method_handler.unary_unary(request, context) + return method_handler.unary_unary(request, context) # type: ignore return grpc.unary_unary_rpc_method_handler( _generic_method_handler,