Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(framework) Enforce strong typing for user auth code for ExecServicer #4702

Draft
wants to merge 3 commits into
base: strong-typing-user-auth-proto
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/py/flwr/common/auth_plugin/auth_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,27 @@

from flwr.proto.exec_pb2_grpc import ExecStub

from ..typing import UserAuthCredentials, UserAuthLoginDetails


class ExecAuthPlugin(ABC):
"""Abstract Flower Auth Plugin class for ExecServicer.

Parameters
----------
config : dict[str, Any]
The authentication configuration loaded from a YAML file.
user_auth_config_path : Path
Path to the YAML file containing the authentication configuration.
"""

@abstractmethod
def __init__(self, config: dict[str, Any]):
def __init__(
self,
user_auth_config_path: Path,
):
"""Abstract constructor."""

@abstractmethod
def get_login_details(self) -> dict[str, str]:
def get_login_details(self) -> Optional[UserAuthLoginDetails]:
"""Get the login details."""

@abstractmethod
Expand All @@ -47,7 +52,7 @@ def validate_tokens_in_metadata(
"""Validate authentication tokens in the provided metadata."""

@abstractmethod
def get_auth_tokens(self, auth_details: dict[str, str]) -> dict[str, str]:
def get_auth_tokens(self, device_code: str) -> Optional[UserAuthCredentials]:
"""Get authentication tokens."""

@abstractmethod
Expand Down
20 changes: 20 additions & 0 deletions src/py/flwr/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,23 @@ class InvalidRunStatusException(BaseException):
def __init__(self, message: str) -> None:
super().__init__(message)
self.message = message


# OIDC user authentication types
@dataclass
class UserAuthLoginDetails:
"""User authentication login details."""

auth_type: str
device_code: str
verification_uri_complete: str
expires_in: int
interval: int


@dataclass
class UserAuthCredentials:
"""User authentication tokens."""

access_token: str
refresh_token: str
24 changes: 11 additions & 13 deletions src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,10 @@ def run_superlink() -> None:
# Obtain certificates
certificates = try_obtain_server_certificates(args, args.fleet_api_type)

user_auth_config = _try_obtain_user_auth_config(args)
auth_plugin: Optional[ExecAuthPlugin] = None
# user_auth_config is None only if the args.user_auth_config is not provided
if user_auth_config is not None:
auth_plugin = _try_obtain_exec_auth_plugin(user_auth_config)
# Load the auth plugin if the args.user_auth_config is provided
if cfg_path := getattr(args, "user_auth_config", None):
auth_plugin = _try_obtain_exec_auth_plugin(Path(cfg_path))

# Initialize StateFactory
state_factory = LinkStateFactory(args.database)
Expand Down Expand Up @@ -584,21 +583,20 @@ def _try_setup_node_authentication(
)


def _try_obtain_user_auth_config(args: argparse.Namespace) -> Optional[dict[str, Any]]:
if getattr(args, "user_auth_config", None) is not None:
with open(args.user_auth_config, encoding="utf-8") as file:
config: dict[str, Any] = yaml.safe_load(file)
return config
return None
def _try_obtain_exec_auth_plugin(config_path: Path) -> Optional[ExecAuthPlugin]:
# Load YAML file
with config_path.open("r", encoding="utf-8") as file:
config: dict[str, Any] = yaml.safe_load(file)


def _try_obtain_exec_auth_plugin(config: dict[str, Any]) -> Optional[ExecAuthPlugin]:
# Load authentication configuration
auth_config: dict[str, Any] = config.get("authentication", {})
auth_type: str = auth_config.get(AUTH_TYPE, "")

# Load authentication plugin
try:
all_plugins: dict[str, type[ExecAuthPlugin]] = get_exec_auth_plugins()
auth_plugin_class = all_plugins[auth_type]
return auth_plugin_class(config=auth_config)
return auth_plugin_class(user_auth_config_path=config_path)
except KeyError:
if auth_type != "":
sys.exit(
Expand Down
25 changes: 23 additions & 2 deletions src/py/flwr/superexec/exec_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,20 @@ def GetLoginDetails(
"ExecServicer initialized without user authentication",
)
raise grpc.RpcError() # This line is unreachable

# Get login details
details = self.auth_plugin.get_login_details()

# Return empty response if details is None
if details is None:
return GetLoginDetailsResponse()

return GetLoginDetailsResponse(
login_details=self.auth_plugin.get_login_details()
auth_type=details.auth_type,
device_code=details.device_code,
verification_uri_complete=details.verification_uri_complete,
expires_in=details.expires_in,
interval=details.interval,
)

def GetAuthTokens(
Expand All @@ -196,8 +208,17 @@ def GetAuthTokens(
"ExecServicer initialized without user authentication",
)
raise grpc.RpcError() # This line is unreachable

# Get auth tokens
credentials = self.auth_plugin.get_auth_tokens(request.device_code)

# Return empty response if credentials is None
if credentials is None:
return GetAuthTokensResponse()

return GetAuthTokensResponse(
auth_tokens=self.auth_plugin.get_auth_tokens(dict(request.auth_details))
access_token=credentials.access_token,
refresh_token=credentials.refresh_token,
)


Expand Down