Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add server auth interceptor #2948

Merged
merged 102 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
d0d4ecd
Add authentication state and test
danielnugraha Feb 10, 2024
91a2f18
Fix isort
danielnugraha Feb 10, 2024
db16c10
Fix isort
danielnugraha Feb 10, 2024
28876bc
Run format.sh
danielnugraha Feb 10, 2024
42a7d38
Add init.py
danielnugraha Feb 10, 2024
8ec63c9
Fix line too long
danielnugraha Feb 10, 2024
8f04e25
Fix line too long
danielnugraha Feb 10, 2024
e8813fc
Fix line too long
danielnugraha Feb 10, 2024
d9f3fb0
Fix subclassing
danielnugraha Feb 10, 2024
caf6695
Fix subclassing
danielnugraha Feb 10, 2024
fa217ae
Fix subclassing
danielnugraha Feb 10, 2024
6edddd6
Fix subclassing
danielnugraha Feb 10, 2024
8bb15a5
Fix subclassing
danielnugraha Feb 10, 2024
c5bac4f
fixes
jafermarq Feb 11, 2024
c856b7c
Fix state tests
danielnugraha Feb 11, 2024
4758507
Fix too broad exception
danielnugraha Feb 11, 2024
e666da5
Add sqlite auth state test
danielnugraha Feb 11, 2024
151a619
Merge remote-tracking branch 'origin' into add-auth-state
danielnugraha Feb 11, 2024
5c49a55
Add server interceptor
danielnugraha Feb 12, 2024
e443bf9
Merge remote-tracking branch 'origin' into add-server-auth-interceptor
danielnugraha Feb 14, 2024
986961e
Move state to superlink
danielnugraha Feb 14, 2024
dddbbc9
Move state to superlink
danielnugraha Feb 14, 2024
fbbcb2a
Fix server interceptor
danielnugraha Feb 15, 2024
77e5c3c
Fix authentication state
danielnugraha Feb 15, 2024
21e590a
Add symmetric encryption test
danielnugraha Feb 15, 2024
6823c83
Add symmetric encryption test
danielnugraha Feb 15, 2024
9df829c
Format code
danielnugraha Feb 15, 2024
d250945
Make tests pass
danielnugraha Feb 15, 2024
de0f041
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Feb 15, 2024
2f61623
Revert commit to only include auth state
danielnugraha Feb 15, 2024
2f7aa48
Remove logging messages
danielnugraha Feb 16, 2024
d96879c
Merge remote-tracking branch 'refs/remotes/origin/add-server-auth-int…
danielnugraha Feb 16, 2024
f9a21b4
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Feb 21, 2024
6ee8a61
Update server interceptor
danielnugraha Feb 22, 2024
041482e
Merge from add-auth-cli
danielnugraha Feb 22, 2024
781796e
Docstring changes
danielnugraha Feb 28, 2024
651f665
Merge branch 'main' into add-auth-state
danielnugraha Feb 28, 2024
013582f
Merge auth state
danielnugraha Feb 28, 2024
700f6dd
Merge main
danielnugraha Feb 28, 2024
ab5317f
Fix merge conflict interceptors gone
danielnugraha Feb 28, 2024
5075ab7
Fix too many instances
danielnugraha Feb 28, 2024
7f49d81
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Feb 28, 2024
d6238ee
Fix imports merge conflict
danielnugraha Feb 29, 2024
578fd96
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Feb 29, 2024
2de5fd5
Add docstring to interceptor
danielnugraha Feb 29, 2024
9b62c4f
Format
danielnugraha Mar 1, 2024
e2ad1ef
Implement feedback
danielnugraha Mar 1, 2024
811d8e8
Merge remote-tracking branch 'origin' into add-server-auth-interceptor
danielnugraha Mar 1, 2024
7a5b6f0
Fix merge conflicts
danielnugraha Apr 3, 2024
28afce5
Format
danielnugraha Apr 4, 2024
78a6697
Fix merge conflicts
danielnugraha Apr 4, 2024
7c098d2
Fix error
danielnugraha Apr 4, 2024
100eadb
Fix error
danielnugraha Apr 4, 2024
5c9f6c3
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Apr 4, 2024
974d2b2
Merge branch 'main' into add-auth-state
danielnugraha Apr 4, 2024
f10333d
Merge branch 'main' into add-auth-state
danielnugraha Apr 15, 2024
74f7036
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Apr 15, 2024
d2b60f6
Merge remote-tracking branch 'origin' into add-auth-state
danielnugraha Apr 18, 2024
94ee2fe
Add lock to write operations
danielnugraha Apr 18, 2024
08c07b5
Merge branch 'add-auth-state' into add-server-auth-interceptor
danielnugraha Apr 18, 2024
0fe2744
Fix docstring
danielnugraha Apr 21, 2024
cc682f3
Fix merge conflict
danielnugraha Apr 24, 2024
54b7afa
Format
danielnugraha Apr 24, 2024
4ab971b
Format
danielnugraha Apr 24, 2024
054bd04
Add more tests
danielnugraha Apr 24, 2024
aac42ab
Add failure tests
danielnugraha Apr 24, 2024
c5b3e46
Add failure tests
danielnugraha Apr 24, 2024
e76b937
Format
danielnugraha Apr 24, 2024
846373f
Fix docstring
danielnugraha Apr 24, 2024
a61892e
Format prepare, execute & assert
danielnugraha Apr 24, 2024
5212653
Merge from main
danielnugraha Apr 24, 2024
8b58c5e
Add get run
danielnugraha Apr 24, 2024
bd6163b
Dynamically generate ssh key
danielnugraha Apr 24, 2024
1dc292a
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Apr 24, 2024
6a02ba9
Encode only once
danielnugraha Apr 24, 2024
e709e99
Merge remote-tracking branch 'refs/remotes/origin/add-server-auth-int…
danielnugraha Apr 24, 2024
d702924
Format
danielnugraha Apr 24, 2024
ac54694
Add get_run
danielnugraha Apr 24, 2024
5735549
Unindent function
danielnugraha Apr 24, 2024
1ca16ba
Format
danielnugraha Apr 24, 2024
5c2f805
Merge branch 'main' into add-server-auth-interceptor
danieljanes Apr 24, 2024
64ec5e8
Update src/py/flwr/server/app.py
danielnugraha Apr 24, 2024
9c03cd6
Update src/py/flwr/server/app.py
danielnugraha Apr 24, 2024
4779dc0
Implement review feedback
danielnugraha Apr 24, 2024
af03db1
Adapt error string
danielnugraha Apr 25, 2024
af1773b
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Apr 26, 2024
24b7182
Update src/py/flwr/server/app.py
danielnugraha Apr 26, 2024
16e20d4
Change data to maybe_keys
danielnugraha Apr 26, 2024
79728df
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Apr 26, 2024
a742d40
Implement feedback
danielnugraha Apr 26, 2024
461802a
Private key first then public
danielnugraha Apr 26, 2024
cc61c1c
Add context to message handler as comment
danielnugraha Apr 29, 2024
948d8a7
Remove state usage
danielnugraha Apr 29, 2024
9cde4a8
Fix server_test
danielnugraha Apr 29, 2024
bbab863
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Apr 29, 2024
4d2f17c
Apply suggestions from code review
danieljanes Apr 29, 2024
ff5f3b8
Merge branch 'main' into add-server-auth-interceptor
danielnugraha Apr 29, 2024
b7357bc
Update src/py/flwr/server/superlink/fleet/grpc_rere/server_intercepto…
danieljanes Apr 29, 2024
499dbdc
Revert delete state
danielnugraha Apr 29, 2024
7e06ca1
Merge remote-tracking branch 'refs/remotes/origin/add-server-auth-int…
danielnugraha Apr 29, 2024
cd4451c
Rename message handler to method handler
danielnugraha Apr 29, 2024
1329328
Change Any to Response type
danielnugraha Apr 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,18 @@ def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool:
return True
except InvalidSignature:
return False


def ssh_types_to_elliptic_curve(
private_key: serialization.SSHPrivateKeyTypes,
public_key: serialization.SSHPublicKeyTypes,
) -> Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]:
"""Cast SSH key types to elliptic curve."""
if isinstance(private_key, ec.EllipticCurvePrivateKey) and isinstance(
public_key, ec.EllipticCurvePublicKey
):
return (private_key, public_key)

raise TypeError(
"The provided key is not an EllipticCurvePrivateKey or EllipticCurvePublicKey"
)
106 changes: 105 additions & 1 deletion src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@

import argparse
import asyncio
import csv
import importlib.util
import sys
import threading
from logging import ERROR, INFO, WARN
from os.path import isfile
from pathlib import Path
from typing import List, Optional, Tuple
from typing import List, Optional, Sequence, Set, Tuple

import grpc
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.serialization import (
load_ssh_private_key,
load_ssh_public_key,
)

from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
from flwr.common.address import parse_address
Expand All @@ -36,6 +42,10 @@
)
from flwr.common.exit_handlers import register_exit_handlers
from flwr.common.logger import log
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
public_key_to_bytes,
ssh_types_to_elliptic_curve,
)
from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
add_FleetServicer_to_server,
)
Expand All @@ -51,6 +61,7 @@
start_grpc_server,
)
from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
from .superlink.fleet.vce import start_vce
from .superlink.state import StateFactory

Expand Down Expand Up @@ -354,10 +365,28 @@ def run_superlink() -> None:
sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
host, port, is_v6 = parsed_address
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"

maybe_keys = _try_setup_client_authentication(args, certificates)
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
if maybe_keys is not None:
(
client_public_keys,
server_private_key,
server_public_key,
) = maybe_keys
interceptors = [
AuthenticateServerInterceptor(
client_public_keys,
server_private_key,
server_public_key,
)
]

fleet_server = _run_fleet_api_grpc_rere(
address=address,
state_factory=state_factory,
certificates=certificates,
interceptors=interceptors,
)
grpc_servers.append(fleet_server)
elif args.fleet_api_type == TRANSPORT_TYPE_VCE:
Expand Down Expand Up @@ -390,6 +419,70 @@ def run_superlink() -> None:
driver_server.wait_for_termination(timeout=1)


def _try_setup_client_authentication(
args: argparse.Namespace,
certificates: Optional[Tuple[bytes, bytes, bytes]],
) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
if not args.require_client_authentication:
return None

if certificates is None:
sys.exit(
"Client authentication only works over secure connections. "
"Please provide certificate paths using '--certificates' when "
"enabling '--require-client-authentication'."
)

client_keys_file_path = Path(args.require_client_authentication[0])
if not client_keys_file_path.exists():
sys.exit(
"The provided path to the client public keys CSV file does not exist: "
f"{client_keys_file_path}. "
"Please provide the CSV file path containing known client public keys "
"to '--require-client-authentication'."
)

client_public_keys: Set[bytes] = set()
ssh_private_key = load_ssh_private_key(
Path(args.require_client_authentication[1]).read_bytes(),
None,
)
ssh_public_key = load_ssh_public_key(
Path(args.require_client_authentication[2]).read_bytes()
)

try:
server_private_key, server_public_key = ssh_types_to_elliptic_curve(
ssh_private_key, ssh_public_key
)
except TypeError:
sys.exit(
"The file paths provided could not be read as a private and public "
"key pair. Client authentication requires an elliptic curve public and "
"private key pair. Please provide the file paths containing elliptic "
"curve private and public keys to '--require-client-authentication'."
)

with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile:
reader = csv.reader(csvfile)
for row in reader:
for element in row:
public_key = load_ssh_public_key(element.encode())
if isinstance(public_key, ec.EllipticCurvePublicKey):
client_public_keys.add(public_key_to_bytes(public_key))
danielnugraha marked this conversation as resolved.
Show resolved Hide resolved
else:
sys.exit(
"Error: Unable to parse the public keys in the .csv "
"file. Please ensure that the .csv file contains valid "
"SSH public keys and try again."
)
return (
client_public_keys,
server_private_key,
server_public_key,
)


def _try_obtain_certificates(
args: argparse.Namespace,
) -> Optional[Tuple[bytes, bytes, bytes]]:
Expand Down Expand Up @@ -417,6 +510,7 @@ def _run_fleet_api_grpc_rere(
address: str,
state_factory: StateFactory,
certificates: Optional[Tuple[bytes, bytes, bytes]],
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
) -> grpc.Server:
"""Run Fleet API (gRPC, request-response)."""
# Create Fleet API gRPC server
Expand All @@ -429,6 +523,7 @@ def _run_fleet_api_grpc_rere(
server_address=address,
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
certificates=certificates,
interceptors=interceptors,
)

log(INFO, "Flower ECE: Starting Fleet API (gRPC-rere) on %s", address)
Expand Down Expand Up @@ -606,6 +701,15 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
"Flower will just create a state in memory.",
default=DATABASE,
)
parser.add_argument(
"--require-client-authentication",
nargs=3,
metavar=("CLIENT_KEYS", "SERVER_PRIVATE_KEY", "SERVER_PUBLIC_KEY"),
type=str,
help="Provide three file paths: (1) a .csv file containing a list of "
"known client public keys for authentication, (2) the server's private "
"key file, and (3) the server's public key file.",
)


def _add_args_driver_api(parser: argparse.ArgumentParser) -> None:
Expand Down
87 changes: 87 additions & 0 deletions src/py/flwr/server/server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,22 @@
"""Flower server tests."""


import argparse
import csv
import tempfile
from pathlib import Path
from typing import List, Optional

import numpy as np
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.serialization import (
Encoding,
NoEncryption,
PrivateFormat,
PublicFormat,
load_ssh_private_key,
load_ssh_public_key,
)

from flwr.common import (
Code,
Expand All @@ -35,8 +48,14 @@
Status,
ndarray_to_bytes,
)
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
generate_key_pairs,
private_key_to_bytes,
public_key_to_bytes,
)
from flwr.server.client_manager import SimpleClientManager

from .app import _try_setup_client_authentication
from .client_proxy import ClientProxy
from .server import Server, evaluate_clients, fit_clients

Expand Down Expand Up @@ -182,3 +201,71 @@ def test_set_max_workers() -> None:

# Assert
assert server.max_workers == 42


def test_setup_client_auth() -> None: # pylint: disable=R0914
"""Test setup client authentication."""
# Prepare
_, first_public_key = generate_key_pairs()
private_key, public_key = generate_key_pairs()

server_public_key = public_key.public_bytes(
encoding=Encoding.OpenSSH, format=PublicFormat.OpenSSH
)
server_private_key = private_key.private_bytes(
Encoding.PEM, PrivateFormat.OpenSSH, NoEncryption()
)
_, second_public_key = generate_key_pairs()

# Execute
with tempfile.TemporaryDirectory() as temp_dir:
# Initialize temporary files
client_keys_file_path = Path(temp_dir) / "client_keys.csv"
server_private_key_path = Path(temp_dir) / "server_private_key"
server_public_key_path = Path(temp_dir) / "server_public_key"

# Fill the files with relevant keys
with open(client_keys_file_path, "w", newline="", encoding="utf-8") as csvfile:
writer = csv.writer(csvfile)
writer.writerow(
[
first_public_key.public_bytes(
encoding=Encoding.OpenSSH, format=PublicFormat.OpenSSH
).decode(),
second_public_key.public_bytes(
encoding=Encoding.OpenSSH, format=PublicFormat.OpenSSH
).decode(),
]
)
server_public_key_path.write_bytes(server_public_key)
server_private_key_path.write_bytes(server_private_key)

# Mock argparse with `require-client-authentication`` flag
mock_args = argparse.Namespace(
require_client_authentication=[
str(client_keys_file_path),
str(server_private_key_path),
str(server_public_key_path),
]
)

# Run _try_setup_client_authentication
result = _try_setup_client_authentication(mock_args, (b"", b"", b""))

expected_private_key = load_ssh_private_key(server_private_key, None)
expected_public_key = load_ssh_public_key(server_public_key)

# Assert
assert isinstance(expected_private_key, ec.EllipticCurvePrivateKey)
assert isinstance(expected_public_key, ec.EllipticCurvePublicKey)
assert result is not None
assert result[0] == {
public_key_to_bytes(first_public_key),
public_key_to_bytes(second_public_key),
}
assert private_key_to_bytes(result[1]) == private_key_to_bytes(
expected_private_key
)
assert public_key_to_bytes(result[2]) == public_key_to_bytes(
expected_public_key
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import concurrent.futures
import sys
from logging import ERROR
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import grpc

Expand Down Expand Up @@ -162,6 +162,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
keepalive_time_ms: int = 210000,
certificates: Optional[Tuple[bytes, bytes, bytes]] = None,
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
) -> grpc.Server:
"""Create a gRPC server with a single servicer.

Expand Down Expand Up @@ -249,6 +250,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
# returning RESOURCE_EXHAUSTED status, or None to indicate no limit.
maximum_concurrent_rpcs=max_concurrent_workers,
options=options,
interceptors=interceptors,
)
add_servicer_to_server_fn(servicer, server)

Expand Down
Loading