From d24ebe5c55c97e5c71e2d5fc1e4d8ed5d917b029 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 22 Jul 2024 20:32:33 +0200 Subject: [PATCH] feat(framework) Support non-`str` config value types (#3746) --- src/proto/flwr/proto/driver.proto | 3 +- src/proto/flwr/proto/exec.proto | 4 +- src/proto/flwr/proto/run.proto | 4 +- src/proto/flwr/proto/task.proto | 1 - src/py/flwr/cli/config_utils.py | 10 +- src/py/flwr/cli/run/run.py | 5 +- src/py/flwr/client/app.py | 6 +- .../client/grpc_rere_client/connection.py | 8 +- src/py/flwr/client/node_state.py | 6 +- src/py/flwr/client/rest_client/connection.py | 8 +- src/py/flwr/common/config.py | 35 +++--- src/py/flwr/common/config_test.py | 41 ++++--- src/py/flwr/common/context.py | 14 +-- src/py/flwr/common/serde.py | 45 +++++++ src/py/flwr/common/typing.py | 6 +- src/py/flwr/proto/common_pb2.py | 14 ++- src/py/flwr/proto/common_pb2.pyi | 114 ++++++++++++++++++ src/py/flwr/proto/driver_pb2.py | 43 +++---- src/py/flwr/proto/driver_pb2.pyi | 11 +- src/py/flwr/proto/exec_pb2.py | 27 +++-- src/py/flwr/proto/exec_pb2.pyi | 11 +- src/py/flwr/proto/run_pb2.py | 19 +-- src/py/flwr/proto/run_pb2.pyi | 11 +- src/py/flwr/proto/task_pb2.py | 15 ++- src/py/flwr/server/driver/grpc_driver.py | 8 +- src/py/flwr/server/run_serverapp.py | 5 +- .../superlink/driver/driver_servicer.py | 17 ++- .../fleet/message_handler/message_handler.py | 15 ++- .../server/superlink/state/in_memory_state.py | 4 +- .../server/superlink/state/sqlite_state.py | 4 +- src/py/flwr/server/superlink/state/state.py | 6 +- .../ray_transport/ray_client_proxy.py | 2 +- .../ray_transport/ray_client_proxy_test.py | 2 +- src/py/flwr/simulation/run_simulation.py | 14 +-- src/py/flwr/superexec/deployment.py | 26 ++-- src/py/flwr/superexec/exec_grpc.py | 5 +- src/py/flwr/superexec/exec_servicer.py | 4 +- src/py/flwr/superexec/executor.py | 12 +- src/py/flwr/superexec/simulation.py | 24 ++-- 39 files changed, 433 insertions(+), 176 deletions(-) diff --git a/src/proto/flwr/proto/driver.proto b/src/proto/flwr/proto/driver.proto index 77dc52b3258b..531d18b4f3ad 100644 --- a/src/proto/flwr/proto/driver.proto +++ b/src/proto/flwr/proto/driver.proto @@ -20,6 +20,7 @@ package flwr.proto; import "flwr/proto/node.proto"; import "flwr/proto/task.proto"; import "flwr/proto/run.proto"; +import "flwr/proto/transport.proto"; service Driver { // Request run_id @@ -42,7 +43,7 @@ service Driver { message CreateRunRequest { string fab_id = 1; string fab_version = 2; - map override_config = 3; + map override_config = 3; } message CreateRunResponse { sint64 run_id = 1; } diff --git a/src/proto/flwr/proto/exec.proto b/src/proto/flwr/proto/exec.proto index d0d8dfcbb273..0968857bdd71 100644 --- a/src/proto/flwr/proto/exec.proto +++ b/src/proto/flwr/proto/exec.proto @@ -17,6 +17,8 @@ syntax = "proto3"; package flwr.proto; +import "flwr/proto/transport.proto"; + service Exec { // Start run upon request rpc StartRun(StartRunRequest) returns (StartRunResponse) {} @@ -27,7 +29,7 @@ service Exec { message StartRunRequest { bytes fab_file = 1; - map override_config = 2; + map override_config = 2; } message StartRunResponse { sint64 run_id = 1; } message StreamLogsRequest { sint64 run_id = 1; } diff --git a/src/proto/flwr/proto/run.proto b/src/proto/flwr/proto/run.proto index e41748381cab..f46f7146c846 100644 --- a/src/proto/flwr/proto/run.proto +++ b/src/proto/flwr/proto/run.proto @@ -17,11 +17,13 @@ syntax = "proto3"; package flwr.proto; +import "flwr/proto/transport.proto"; + message Run { sint64 run_id = 1; string fab_id = 2; string fab_version = 3; - map override_config = 4; + map override_config = 4; } message GetRunRequest { sint64 run_id = 1; } message GetRunResponse { Run run = 1; } diff --git a/src/proto/flwr/proto/task.proto b/src/proto/flwr/proto/task.proto index cf77d110acab..936b8120e495 100644 --- a/src/proto/flwr/proto/task.proto +++ b/src/proto/flwr/proto/task.proto @@ -19,7 +19,6 @@ package flwr.proto; import "flwr/proto/node.proto"; import "flwr/proto/recordset.proto"; -import "flwr/proto/transport.proto"; import "flwr/proto/error.proto"; message Task { diff --git a/src/py/flwr/cli/config_utils.py b/src/py/flwr/cli/config_utils.py index f46a53857dfc..d150e1b5f53d 100644 --- a/src/py/flwr/cli/config_utils.py +++ b/src/py/flwr/cli/config_utils.py @@ -17,11 +17,12 @@ import zipfile from io import BytesIO from pathlib import Path -from typing import IO, Any, Dict, List, Optional, Tuple, Union +from typing import IO, Any, Dict, List, Optional, Tuple, Union, get_args import tomli from flwr.common import object_ref +from flwr.common.typing import UserConfigValue def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]: @@ -112,8 +113,11 @@ def _validate_run_config(config_dict: Dict[str, Any], errors: List[str]) -> None for key, value in config_dict.items(): if isinstance(value, dict): _validate_run_config(config_dict[key], errors) - elif not isinstance(value, str): - errors.append(f"Config value of key {key} is not of type `str`.") + elif not isinstance(value, get_args(UserConfigValue)): + raise ValueError( + f"The value for key {key} needs to be of type `int`, `float`, " + "`bool, `str`, or a `dict` of those.", + ) # pylint: disable=too-many-branches diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py index a8dd5a59a627..00588fec4224 100644 --- a/src/py/flwr/cli/run/run.py +++ b/src/py/flwr/cli/run/run.py @@ -28,6 +28,7 @@ from flwr.common.config import parse_config_args from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel from flwr.common.logger import log +from flwr.common.serde import user_config_to_proto from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611 from flwr.proto.exec_pb2_grpc import ExecStub @@ -164,7 +165,9 @@ def on_channel_state_change(channel_connectivity: str) -> None: req = StartRunRequest( fab_file=Path(fab_path).read_bytes(), - override_config=parse_config_args(config_overrides, separator=","), + override_config=user_config_to_proto( + parse_config_args(config_overrides, separator=",") + ), ) res = stub.StartRun(req) typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 127bb423851f..526f26cb8cc3 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -42,7 +42,7 @@ from flwr.common.logger import log, warn_deprecated_feature from flwr.common.message import Error from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential -from flwr.common.typing import Run +from flwr.common.typing import Run, UserConfig from .grpc_adapter_client.connection import grpc_adapter from .grpc_client.connection import grpc_connection @@ -182,7 +182,7 @@ class `flwr.client.Client` (default: None) def _start_client_internal( *, server_address: str, - node_config: Dict[str, str], + node_config: UserConfig, load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None, client_fn: Optional[ClientFnExt] = None, client: Optional[Client] = None, @@ -205,7 +205,7 @@ def _start_client_internal( The IPv4 or IPv6 address of the server. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"[::]:8080"`. - node_config: Dict[str, str] + node_config: UserConfig The configuration of the node. load_client_app_fn : Optional[Callable[[], ClientApp]] (default: None) A function that can be used to load a `ClientApp` instance. diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index e573df6854bc..64543626e695 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -40,7 +40,11 @@ from flwr.common.logger import log from flwr.common.message import Message, Metadata from flwr.common.retry_invoker import RetryInvoker -from flwr.common.serde import message_from_taskins, message_to_taskres +from flwr.common.serde import ( + message_from_taskins, + message_to_taskres, + user_config_from_proto, +) from flwr.common.typing import Run from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, @@ -281,7 +285,7 @@ def get_run(run_id: int) -> Run: run_id, get_run_response.run.fab_id, get_run_response.run.fab_version, - dict(get_run_response.run.override_config.items()), + user_config_from_proto(get_run_response.run.override_config), ) try: diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py index 08c19967ea3d..3320c90cb8cc 100644 --- a/src/py/flwr/client/node_state.py +++ b/src/py/flwr/client/node_state.py @@ -21,7 +21,7 @@ from flwr.common import Context, RecordSet from flwr.common.config import get_fused_config, get_fused_config_from_dir -from flwr.common.typing import Run +from flwr.common.typing import Run, UserConfig @dataclass() @@ -29,7 +29,7 @@ class RunInfo: """Contains the Context and initial run_config of a Run.""" context: Context - initial_run_config: Dict[str, str] + initial_run_config: UserConfig class NodeState: @@ -38,7 +38,7 @@ class NodeState: def __init__( self, node_id: int, - node_config: Dict[str, str], + node_config: UserConfig, ) -> None: self.node_id = node_id self.node_config = node_config diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index 3e81969d898c..e2bb1f62bc43 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -40,7 +40,11 @@ from flwr.common.logger import log from flwr.common.message import Message, Metadata from flwr.common.retry_invoker import RetryInvoker -from flwr.common.serde import message_from_taskins, message_to_taskres +from flwr.common.serde import ( + message_from_taskins, + message_to_taskres, + user_config_from_proto, +) from flwr.common.typing import Run from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, @@ -359,7 +363,7 @@ def get_run(run_id: int) -> Run: run_id, res.run.fab_id, res.run.fab_version, - dict(res.run.override_config.items()), + user_config_from_proto(res.run.override_config), ) try: diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index 789433a287e7..c915a3ef1621 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -16,13 +16,13 @@ import os from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, cast, get_args import tomli from flwr.cli.config_utils import validate_fields from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME -from flwr.common.typing import Run +from flwr.common.typing import Run, UserConfig, UserConfigValue def get_flwr_dir(provided_path: Optional[str] = None) -> Path: @@ -75,8 +75,9 @@ def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]: def _fuse_dicts( - main_dict: Dict[str, str], override_dict: Dict[str, str] -) -> Dict[str, str]: + main_dict: UserConfig, + override_dict: UserConfig, +) -> UserConfig: fused_dict = main_dict.copy() for key, value in override_dict.items(): @@ -87,8 +88,8 @@ def _fuse_dicts( def get_fused_config_from_dir( - project_dir: Path, override_config: Dict[str, str] -) -> Dict[str, str]: + project_dir: Path, override_config: UserConfig +) -> UserConfig: """Merge the overrides from a given dict with the config from a Flower App.""" default_config = get_project_config(project_dir)["tool"]["flwr"]["app"].get( "config", {} @@ -98,7 +99,7 @@ def get_fused_config_from_dir( return _fuse_dicts(flat_default_config, override_config) -def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> Dict[str, str]: +def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig: """Merge the overrides from a `Run` with the config from a FAB. Get the config using the fab_id and the fab_version, remove the nesting by adding @@ -112,19 +113,20 @@ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> Dict[str, str]: return get_fused_config_from_dir(project_dir, run.override_config) -def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, str]: +def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> UserConfig: """Flatten dict by joining nested keys with a given separator.""" - items: List[Tuple[str, str]] = [] + items: List[Tuple[str, UserConfigValue]] = [] separator: str = "." for k, v in raw_dict.items(): new_key = f"{parent_key}{separator}{k}" if parent_key else k if isinstance(v, dict): items.extend(flatten_dict(v, parent_key=new_key).items()) - elif isinstance(v, str): - items.append((new_key, v)) + elif isinstance(v, get_args(UserConfigValue)): + items.append((new_key, cast(UserConfigValue, v))) else: raise ValueError( - f"The value for key {k} needs to be a `str` or a `dict`.", + f"The value for key {k} needs to be of type `int`, `float`, " + "`bool, `str`, or a `dict` of those.", ) return dict(items) @@ -132,9 +134,9 @@ def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, st def parse_config_args( config: Optional[List[str]], separator: str = ",", -) -> Dict[str, str]: +) -> UserConfig: """Parse separator separated list of key-value pairs separated by '='.""" - overrides: Dict[str, str] = {} + overrides: UserConfig = {} if config is None: return overrides @@ -150,8 +152,7 @@ def parse_config_args( with Path(overrides_list[0]).open("rb") as config_file: overrides = flatten_dict(tomli.load(config_file)) else: - for kv_pair in overrides_list: - key, value = kv_pair.split("=") - overrides[key] = value + toml_str = "\n".join(overrides_list) + overrides.update(tomli.loads(toml_str)) return overrides diff --git a/src/py/flwr/common/config_test.py b/src/py/flwr/common/config_test.py index e1597aa5a2ec..52dcc0f9121e 100644 --- a/src/py/flwr/common/config_test.py +++ b/src/py/flwr/common/config_test.py @@ -21,6 +21,8 @@ import pytest +from flwr.common.typing import UserConfig + from .config import ( _fuse_dicts, flatten_dict, @@ -101,23 +103,25 @@ def test_get_fused_config_valid(tmp_path: Path) -> None: clientapp = "fedgpt.client:app" [tool.flwr.app.config] - num_server_rounds = "10" - momentum = "0.1" - lr = "0.01" + num_server_rounds = 10 + momentum = 0.1 + lr = 0.01 + progress_bar = true serverapp.test = "key" [tool.flwr.app.config.clientapp] test = "key" """ - overrides = { - "num_server_rounds": "5", - "lr": "0.2", + overrides: UserConfig = { + "num_server_rounds": 5, + "lr": 0.2, "serverapp.test": "overriden", } expected_config = { - "num_server_rounds": "5", - "momentum": "0.1", - "lr": "0.2", + "num_server_rounds": 5, + "momentum": 0.1, + "lr": 0.2, + "progress_bar": True, "serverapp.test": "overriden", "clientapp.test": "key", } @@ -168,8 +172,9 @@ def test_get_project_config_file_valid(tmp_path: Path) -> None: clientapp = "fedgpt.client:app" [tool.flwr.app.config] - num_server_rounds = "10" - momentum = "0.1" + num_server_rounds = 10 + momentum = 0.1 + progress_bar = true lr = "0.01" """ expected_config = { @@ -190,8 +195,9 @@ def test_get_project_config_file_valid(tmp_path: Path) -> None: "clientapp": "fedgpt.client:app", }, "config": { - "num_server_rounds": "10", - "momentum": "0.1", + "num_server_rounds": 10, + "momentum": 0.1, + "progress_bar": True, "lr": "0.01", }, }, @@ -231,11 +237,12 @@ def test_parse_config_args_none() -> None: def test_parse_config_args_overrides() -> None: """Test parse_config_args with key-value pairs.""" assert parse_config_args( - ["key1=value1,key2=value2", "key3=value3", "key4=value4,key5=value5"] + ["key1='value1',key2='value2'", "key3=1", "key4=2.0,key5=true,key6='value6'"] ) == { "key1": "value1", "key2": "value2", - "key3": "value3", - "key4": "value4", - "key5": "value5", + "key3": 1, + "key4": 2.0, + "key5": True, + "key6": "value6", } diff --git a/src/py/flwr/common/context.py b/src/py/flwr/common/context.py index 4da52ba44481..1544b96d3fa3 100644 --- a/src/py/flwr/common/context.py +++ b/src/py/flwr/common/context.py @@ -16,9 +16,9 @@ from dataclasses import dataclass -from typing import Dict from .record import RecordSet +from .typing import UserConfig @dataclass @@ -29,7 +29,7 @@ class Context: ---------- node_id : int The ID that identifies the node. - node_config : Dict[str, str] + node_config : UserConfig A config (key/value mapping) unique to the node and independent of the `run_config`. This config persists across all runs this node participates in. state : RecordSet @@ -39,23 +39,23 @@ class Context: executing mods. It can also be used as a memory to access at different points during the lifecycle of this entity (e.g. across multiple rounds) - run_config : Dict[str, str] + run_config : UserConfig A config (key/value mapping) held by the entity in a given run and that will stay local. It can be used at any point during the lifecycle of this entity (e.g. across multiple rounds) """ node_id: int - node_config: Dict[str, str] + node_config: UserConfig state: RecordSet - run_config: Dict[str, str] + run_config: UserConfig def __init__( # pylint: disable=too-many-arguments self, node_id: int, - node_config: Dict[str, str], + node_config: UserConfig, state: RecordSet, - run_config: Dict[str, str], + run_config: UserConfig, ) -> None: self.node_id = node_id self.node_config = node_config diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index 84932b806aff..5e34b2b4b5f8 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -671,3 +671,48 @@ def message_from_taskres(taskres: TaskRes) -> Message: ) message.metadata.created_at = taskres.task.created_at return message + + +# === User configs === + + +def user_config_to_proto(user_config: typing.UserConfig) -> Any: + """Serialize `UserConfig` to ProtoBuf.""" + proto = {} + for key, value in user_config.items(): + proto[key] = user_config_value_to_proto(value) + return proto + + +def user_config_from_proto(proto: Any) -> typing.UserConfig: + """Deserialize `UserConfig` from ProtoBuf.""" + metrics = {} + for key, value in proto.items(): + metrics[key] = user_config_value_from_proto(value) + return metrics + + +def user_config_value_to_proto(user_config_value: typing.UserConfigValue) -> Scalar: + """Serialize `UserConfigValue` to ProtoBuf.""" + if isinstance(user_config_value, bool): + return Scalar(bool=user_config_value) + + if isinstance(user_config_value, float): + return Scalar(double=user_config_value) + + if isinstance(user_config_value, int): + return Scalar(sint64=user_config_value) + + if isinstance(user_config_value, str): + return Scalar(string=user_config_value) + + raise ValueError( + f"Accepted types: {bool, float, int, str} (but not {type(user_config_value)})" + ) + + +def user_config_value_from_proto(scalar_msg: Scalar) -> typing.UserConfigValue: + """Deserialize `UserConfigValue` from ProtoBuf.""" + scalar_field = scalar_msg.WhichOneof("scalar") + scalar = getattr(scalar_msg, cast(str, scalar_field)) + return cast(typing.UserConfigValue, scalar) diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index 04d2cf5bbf7f..c050fe6d4a13 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -60,6 +60,10 @@ Config = Dict[str, Scalar] Properties = Dict[str, Scalar] +# Value type for user configs +UserConfigValue = Union[bool, float, int, str] +UserConfig = Dict[str, UserConfigValue] + class Code(Enum): """Client status codes.""" @@ -194,4 +198,4 @@ class Run: run_id: int fab_id: str fab_version: str - override_config: Dict[str, str] + override_config: UserConfig diff --git a/src/py/flwr/proto/common_pb2.py b/src/py/flwr/proto/common_pb2.py index 8a6430137f05..1025aa862933 100644 --- a/src/py/flwr/proto/common_pb2.py +++ b/src/py/flwr/proto/common_pb2.py @@ -14,11 +14,23 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/common.proto\x12\nflwr.protob\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/common.proto\x12\nflwr.proto\"\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\"\x1a\n\nSint64List\x12\x0c\n\x04vals\x18\x01 \x03(\x12\"\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\"\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\"\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\"\xd9\x02\n\x12\x43onfigsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x12)\n\tbool_list\x18\x17 \x01(\x0b\x32\x14.flwr.proto.BoolListH\x00\x12-\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x16.flwr.proto.StringListH\x00\x12+\n\nbytes_list\x18\x19 \x01(\x0b\x32\x15.flwr.proto.BytesListH\x00\x42\x07\n\x05valueb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.common_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None + _globals['_DOUBLELIST']._serialized_start=39 + _globals['_DOUBLELIST']._serialized_end=65 + _globals['_SINT64LIST']._serialized_start=67 + _globals['_SINT64LIST']._serialized_end=93 + _globals['_BOOLLIST']._serialized_start=95 + _globals['_BOOLLIST']._serialized_end=119 + _globals['_STRINGLIST']._serialized_start=121 + _globals['_STRINGLIST']._serialized_end=147 + _globals['_BYTESLIST']._serialized_start=149 + _globals['_BYTESLIST']._serialized_end=174 + _globals['_CONFIGSRECORDVALUE']._serialized_start=177 + _globals['_CONFIGSRECORDVALUE']._serialized_end=522 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/common_pb2.pyi b/src/py/flwr/proto/common_pb2.pyi index e08fa11c2caa..e2539a7300a9 100644 --- a/src/py/flwr/proto/common_pb2.pyi +++ b/src/py/flwr/proto/common_pb2.pyi @@ -2,6 +2,120 @@ @generated by mypy-protobuf. Do not edit manually! isort:skip_file """ +import builtins import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.message +import typing +import typing_extensions DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class DoubleList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[builtins.float]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___DoubleList = DoubleList + +class Sint64List(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[builtins.int]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___Sint64List = Sint64List + +class BoolList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bool]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[builtins.bool]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___BoolList = BoolList + +class StringList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[typing.Text]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___StringList = StringList + +class BytesList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[builtins.bytes]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___BytesList = BytesList + +class ConfigsRecordValue(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + DOUBLE_FIELD_NUMBER: builtins.int + SINT64_FIELD_NUMBER: builtins.int + BOOL_FIELD_NUMBER: builtins.int + STRING_FIELD_NUMBER: builtins.int + BYTES_FIELD_NUMBER: builtins.int + DOUBLE_LIST_FIELD_NUMBER: builtins.int + SINT64_LIST_FIELD_NUMBER: builtins.int + BOOL_LIST_FIELD_NUMBER: builtins.int + STRING_LIST_FIELD_NUMBER: builtins.int + BYTES_LIST_FIELD_NUMBER: builtins.int + double: builtins.float + """Single element""" + + sint64: builtins.int + bool: builtins.bool + string: typing.Text + bytes: builtins.bytes + @property + def double_list(self) -> global___DoubleList: + """List types""" + pass + @property + def sint64_list(self) -> global___Sint64List: ... + @property + def bool_list(self) -> global___BoolList: ... + @property + def string_list(self) -> global___StringList: ... + @property + def bytes_list(self) -> global___BytesList: ... + def __init__(self, + *, + double: builtins.float = ..., + sint64: builtins.int = ..., + bool: builtins.bool = ..., + string: typing.Text = ..., + bytes: builtins.bytes = ..., + double_list: typing.Optional[global___DoubleList] = ..., + sint64_list: typing.Optional[global___Sint64List] = ..., + bool_list: typing.Optional[global___BoolList] = ..., + string_list: typing.Optional[global___StringList] = ..., + bytes_list: typing.Optional[global___BytesList] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","string",b"string","string_list",b"string_list","value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","string",b"string","string_list",b"string_list","value",b"value"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["value",b"value"]) -> typing.Optional[typing_extensions.Literal["double","sint64","bool","string","bytes","double_list","sint64_list","bool_list","string_list","bytes_list"]]: ... +global___ConfigsRecordValue = ConfigsRecordValue diff --git a/src/py/flwr/proto/driver_pb2.py b/src/py/flwr/proto/driver_pb2.py index 07975937328d..6359e2f7d5fa 100644 --- a/src/py/flwr/proto/driver_pb2.py +++ b/src/py/flwr/proto/driver_pb2.py @@ -15,9 +15,10 @@ from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2 from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2 from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 +from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\"\xb9\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x1a\x35\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\x84\x03\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xcd\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\x84\x03\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -26,24 +27,24 @@ DESCRIPTOR._options = None _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._options = None _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' - _globals['_CREATERUNREQUEST']._serialized_start=108 - _globals['_CREATERUNREQUEST']._serialized_end=293 - _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=240 - _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=293 - _globals['_CREATERUNRESPONSE']._serialized_start=295 - _globals['_CREATERUNRESPONSE']._serialized_end=330 - _globals['_GETNODESREQUEST']._serialized_start=332 - _globals['_GETNODESREQUEST']._serialized_end=365 - _globals['_GETNODESRESPONSE']._serialized_start=367 - _globals['_GETNODESRESPONSE']._serialized_end=418 - _globals['_PUSHTASKINSREQUEST']._serialized_start=420 - _globals['_PUSHTASKINSREQUEST']._serialized_end=484 - _globals['_PUSHTASKINSRESPONSE']._serialized_start=486 - _globals['_PUSHTASKINSRESPONSE']._serialized_end=525 - _globals['_PULLTASKRESREQUEST']._serialized_start=527 - _globals['_PULLTASKRESREQUEST']._serialized_end=597 - _globals['_PULLTASKRESRESPONSE']._serialized_start=599 - _globals['_PULLTASKRESRESPONSE']._serialized_end=664 - _globals['_DRIVER']._serialized_start=667 - _globals['_DRIVER']._serialized_end=1055 + _globals['_CREATERUNREQUEST']._serialized_start=136 + _globals['_CREATERUNREQUEST']._serialized_end=341 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=268 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=341 + _globals['_CREATERUNRESPONSE']._serialized_start=343 + _globals['_CREATERUNRESPONSE']._serialized_end=378 + _globals['_GETNODESREQUEST']._serialized_start=380 + _globals['_GETNODESREQUEST']._serialized_end=413 + _globals['_GETNODESRESPONSE']._serialized_start=415 + _globals['_GETNODESRESPONSE']._serialized_end=466 + _globals['_PUSHTASKINSREQUEST']._serialized_start=468 + _globals['_PUSHTASKINSREQUEST']._serialized_end=532 + _globals['_PUSHTASKINSRESPONSE']._serialized_start=534 + _globals['_PUSHTASKINSRESPONSE']._serialized_end=573 + _globals['_PULLTASKRESREQUEST']._serialized_start=575 + _globals['_PULLTASKRESREQUEST']._serialized_end=645 + _globals['_PULLTASKRESRESPONSE']._serialized_start=647 + _globals['_PULLTASKRESRESPONSE']._serialized_end=712 + _globals['_DRIVER']._serialized_start=715 + _globals['_DRIVER']._serialized_end=1103 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/driver_pb2.pyi b/src/py/flwr/proto/driver_pb2.pyi index 95d4c9785ff1..748399be4e6b 100644 --- a/src/py/flwr/proto/driver_pb2.pyi +++ b/src/py/flwr/proto/driver_pb2.pyi @@ -5,6 +5,7 @@ isort:skip_file import builtins import flwr.proto.node_pb2 import flwr.proto.task_pb2 +import flwr.proto.transport_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers import google.protobuf.message @@ -21,12 +22,14 @@ class CreateRunRequest(google.protobuf.message.Message): KEY_FIELD_NUMBER: builtins.int VALUE_FIELD_NUMBER: builtins.int key: typing.Text - value: typing.Text + @property + def value(self) -> flwr.proto.transport_pb2.Scalar: ... def __init__(self, *, key: typing.Text = ..., - value: typing.Text = ..., + value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ..., ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... FAB_ID_FIELD_NUMBER: builtins.int @@ -35,12 +38,12 @@ class CreateRunRequest(google.protobuf.message.Message): fab_id: typing.Text fab_version: typing.Text @property - def override_config(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, typing.Text]: ... + def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ... def __init__(self, *, fab_id: typing.Text = ..., fab_version: typing.Text = ..., - override_config: typing.Optional[typing.Mapping[typing.Text, typing.Text]] = ..., + override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ..., ) -> None: ... def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config"]) -> None: ... global___CreateRunRequest = CreateRunRequest diff --git a/src/py/flwr/proto/exec_pb2.py b/src/py/flwr/proto/exec_pb2.py index 4aee0f4a882f..5f3a9f1e9f7d 100644 --- a/src/py/flwr/proto/exec_pb2.py +++ b/src/py/flwr/proto/exec_pb2.py @@ -12,9 +12,10 @@ _sym_db = _symbol_database.Default() +from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\"\xa4\x01\n\x0fStartRunRequest\x12\x10\n\x08\x66\x61\x62_file\x18\x01 \x01(\x0c\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x1a\x35\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"#\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"(\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t2\xa0\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/exec.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xb8\x01\n\x0fStartRunRequest\x12\x10\n\x08\x66\x61\x62_file\x18\x01 \x01(\x0c\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\"\n\x10StartRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"#\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"(\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t2\xa0\x01\n\x04\x45xec\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -23,16 +24,16 @@ DESCRIPTOR._options = None _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._options = None _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' - _globals['_STARTRUNREQUEST']._serialized_start=38 - _globals['_STARTRUNREQUEST']._serialized_end=202 - _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=149 - _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=202 - _globals['_STARTRUNRESPONSE']._serialized_start=204 - _globals['_STARTRUNRESPONSE']._serialized_end=238 - _globals['_STREAMLOGSREQUEST']._serialized_start=240 - _globals['_STREAMLOGSREQUEST']._serialized_end=275 - _globals['_STREAMLOGSRESPONSE']._serialized_start=277 - _globals['_STREAMLOGSRESPONSE']._serialized_end=317 - _globals['_EXEC']._serialized_start=320 - _globals['_EXEC']._serialized_end=480 + _globals['_STARTRUNREQUEST']._serialized_start=66 + _globals['_STARTRUNREQUEST']._serialized_end=250 + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=177 + _globals['_STARTRUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=250 + _globals['_STARTRUNRESPONSE']._serialized_start=252 + _globals['_STARTRUNRESPONSE']._serialized_end=286 + _globals['_STREAMLOGSREQUEST']._serialized_start=288 + _globals['_STREAMLOGSREQUEST']._serialized_end=323 + _globals['_STREAMLOGSRESPONSE']._serialized_start=325 + _globals['_STREAMLOGSRESPONSE']._serialized_end=365 + _globals['_EXEC']._serialized_start=368 + _globals['_EXEC']._serialized_end=528 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/exec_pb2.pyi b/src/py/flwr/proto/exec_pb2.pyi index 8065fc1de1b4..fc8a615a6b65 100644 --- a/src/py/flwr/proto/exec_pb2.pyi +++ b/src/py/flwr/proto/exec_pb2.pyi @@ -3,6 +3,7 @@ isort:skip_file """ import builtins +import flwr.proto.transport_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers import google.protobuf.message @@ -18,23 +19,25 @@ class StartRunRequest(google.protobuf.message.Message): KEY_FIELD_NUMBER: builtins.int VALUE_FIELD_NUMBER: builtins.int key: typing.Text - value: typing.Text + @property + def value(self) -> flwr.proto.transport_pb2.Scalar: ... def __init__(self, *, key: typing.Text = ..., - value: typing.Text = ..., + value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ..., ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... FAB_FILE_FIELD_NUMBER: builtins.int OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int fab_file: builtins.bytes @property - def override_config(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, typing.Text]: ... + def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ... def __init__(self, *, fab_file: builtins.bytes = ..., - override_config: typing.Optional[typing.Mapping[typing.Text, typing.Text]] = ..., + override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ..., ) -> None: ... def ClearField(self, field_name: typing_extensions.Literal["fab_file",b"fab_file","override_config",b"override_config"]) -> None: ... global___StartRunRequest = StartRunRequest diff --git a/src/py/flwr/proto/run_pb2.py b/src/py/flwr/proto/run_pb2.py index d6531201f647..c4bf382f1cf9 100644 --- a/src/py/flwr/proto/run_pb2.py +++ b/src/py/flwr/proto/run_pb2.py @@ -12,9 +12,10 @@ _sym_db = _symbol_database.Default() +from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\"\xaf\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x1a\x35\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xc3\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -23,12 +24,12 @@ DESCRIPTOR._options = None _globals['_RUN_OVERRIDECONFIGENTRY']._options = None _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' - _globals['_RUN']._serialized_start=37 - _globals['_RUN']._serialized_end=212 - _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=159 - _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=212 - _globals['_GETRUNREQUEST']._serialized_start=214 - _globals['_GETRUNREQUEST']._serialized_end=245 - _globals['_GETRUNRESPONSE']._serialized_start=247 - _globals['_GETRUNRESPONSE']._serialized_end=293 + _globals['_RUN']._serialized_start=65 + _globals['_RUN']._serialized_end=260 + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=187 + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=260 + _globals['_GETRUNREQUEST']._serialized_start=262 + _globals['_GETRUNREQUEST']._serialized_end=293 + _globals['_GETRUNRESPONSE']._serialized_start=295 + _globals['_GETRUNRESPONSE']._serialized_end=341 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/run_pb2.pyi b/src/py/flwr/proto/run_pb2.pyi index 3c58c04c1734..4db1645da5e2 100644 --- a/src/py/flwr/proto/run_pb2.pyi +++ b/src/py/flwr/proto/run_pb2.pyi @@ -3,6 +3,7 @@ isort:skip_file """ import builtins +import flwr.proto.transport_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers import google.protobuf.message @@ -18,12 +19,14 @@ class Run(google.protobuf.message.Message): KEY_FIELD_NUMBER: builtins.int VALUE_FIELD_NUMBER: builtins.int key: typing.Text - value: typing.Text + @property + def value(self) -> flwr.proto.transport_pb2.Scalar: ... def __init__(self, *, key: typing.Text = ..., - value: typing.Text = ..., + value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ..., ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... RUN_ID_FIELD_NUMBER: builtins.int @@ -34,13 +37,13 @@ class Run(google.protobuf.message.Message): fab_id: typing.Text fab_version: typing.Text @property - def override_config(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, typing.Text]: ... + def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ... def __init__(self, *, run_id: builtins.int = ..., fab_id: typing.Text = ..., fab_version: typing.Text = ..., - override_config: typing.Optional[typing.Mapping[typing.Text, typing.Text]] = ..., + override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ..., ) -> None: ... def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config","run_id",b"run_id"]) -> None: ... global___Run = Run diff --git a/src/py/flwr/proto/task_pb2.py b/src/py/flwr/proto/task_pb2.py index 5f6e9e7be583..3e044f9ec846 100644 --- a/src/py/flwr/proto/task_pb2.py +++ b/src/py/flwr/proto/task_pb2.py @@ -14,21 +14,20 @@ from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2 from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2 -from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\x89\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\x01\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x11\n\tpushed_at\x18\x05 \x01(\x01\x12\x0b\n\x03ttl\x18\x06 \x01(\x01\x12\x10\n\x08\x61ncestry\x18\x07 \x03(\t\x12\x11\n\ttask_type\x18\x08 \x01(\t\x12(\n\trecordset\x18\t \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\n \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x16\x66lwr/proto/error.proto\"\x89\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\x01\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x11\n\tpushed_at\x18\x05 \x01(\x01\x12\x0b\n\x03ttl\x18\x06 \x01(\x01\x12\x10\n\x08\x61ncestry\x18\x07 \x03(\t\x12\x11\n\ttask_type\x18\x08 \x01(\t\x12(\n\trecordset\x18\t \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\n \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.task_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_TASK']._serialized_start=141 - _globals['_TASK']._serialized_end=406 - _globals['_TASKINS']._serialized_start=408 - _globals['_TASKINS']._serialized_end=500 - _globals['_TASKRES']._serialized_start=502 - _globals['_TASKRES']._serialized_end=594 + _globals['_TASK']._serialized_start=113 + _globals['_TASK']._serialized_end=378 + _globals['_TASKINS']._serialized_start=380 + _globals['_TASKINS']._serialized_end=472 + _globals['_TASKRES']._serialized_start=474 + _globals['_TASKRES']._serialized_end=566 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 84da5882eb73..60439892d946 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -24,7 +24,11 @@ from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event from flwr.common.grpc import create_channel from flwr.common.logger import log -from flwr.common.serde import message_from_taskres, message_to_taskins +from flwr.common.serde import ( + message_from_taskres, + message_to_taskins, + user_config_from_proto, +) from flwr.common.typing import Run from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 GetNodesRequest, @@ -127,7 +131,7 @@ def _init_run(self) -> None: run_id=res.run.run_id, fab_id=res.run.fab_id, fab_version=res.run.fab_version, - override_config=dict(res.run.override_config.items()), + override_config=user_config_from_proto(res.run.override_config), ) @property diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index 0169946e237d..b6baca0dff54 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -19,7 +19,7 @@ import sys from logging import DEBUG, INFO, WARN from pathlib import Path -from typing import Dict, Optional +from typing import Optional from flwr.common import Context, EventType, RecordSet, event from flwr.common.config import ( @@ -30,6 +30,7 @@ ) from flwr.common.logger import log, update_console_handler, warn_deprecated_feature from flwr.common.object_ref import load_app +from flwr.common.typing import UserConfig from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 CreateRunRequest, CreateRunResponse, @@ -45,7 +46,7 @@ def run( driver: Driver, server_app_dir: str, - server_app_run_config: Dict[str, str], + server_app_run_config: UserConfig, server_app_attr: Optional[str] = None, loaded_server_app: Optional[ServerApp] = None, ) -> None: diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index 7f8ded3bdb85..0741138d2dd1 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -23,6 +23,7 @@ import grpc from flwr.common.logger import log +from flwr.common.serde import user_config_from_proto, user_config_to_proto from flwr.proto import driver_pb2_grpc # pylint: disable=E0611 from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 CreateRunRequest, @@ -72,7 +73,7 @@ def CreateRun( run_id = state.create_run( request.fab_id, request.fab_version, - dict(request.override_config.items()), + user_config_from_proto(request.override_config), ) return CreateRunResponse(run_id=run_id) @@ -149,8 +150,18 @@ def GetRun( # Retrieve run information run = state.get_run(request.run_id) - run_proto = None if run is None else Run(**vars(run)) - return GetRunResponse(run=run_proto) + + if run is None: + return GetRunResponse() + + return GetRunResponse( + run=Run( + run_id=run.run_id, + fab_id=run.fab_id, + fab_version=run.fab_version, + override_config=user_config_to_proto(run.override_config), + ) + ) def _raise_if(validation_error: bool, detail: str) -> None: diff --git a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py index b70cd54035fe..30865f04d373 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py @@ -19,6 +19,7 @@ from typing import List, Optional from uuid import UUID +from flwr.common.serde import user_config_to_proto from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, @@ -113,5 +114,15 @@ def get_run( ) -> GetRunResponse: """Get run information.""" run = state.get_run(request.run_id) - run_proto = None if run is None else Run(**vars(run)) - return GetRunResponse(run=run_proto) + + if run is None: + return GetRunResponse() + + return GetRunResponse( + run=Run( + run_id=run.run_id, + fab_id=run.fab_id, + fab_version=run.fab_version, + override_config=user_config_to_proto(run.override_config), + ) + ) diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index bc4bd4478a23..beb25ba4e84f 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -23,7 +23,7 @@ from flwr.common import log, now from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES -from flwr.common.typing import Run +from flwr.common.typing import Run, UserConfig from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from flwr.server.superlink.state.state import State from flwr.server.utils import validate_task_ins_or_res @@ -279,7 +279,7 @@ def create_run( self, fab_id: str, fab_version: str, - override_config: Dict[str, str], + override_config: UserConfig, ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" # Sample a random int64 as run_id diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index ea6f349b9f9a..bd3b6ebabd83 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -25,7 +25,7 @@ from flwr.common import log, now from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES -from flwr.common.typing import Run +from flwr.common.typing import Run, UserConfig from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 @@ -619,7 +619,7 @@ def create_run( self, fab_id: str, fab_version: str, - override_config: Dict[str, str], + override_config: UserConfig, ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" # Sample a random int64 as run_id diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index c93f6ba756b8..23c95805948e 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -16,10 +16,10 @@ import abc -from typing import Dict, List, Optional, Set +from typing import List, Optional, Set from uuid import UUID -from flwr.common.typing import Run +from flwr.common.typing import Run, UserConfig from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 @@ -161,7 +161,7 @@ def create_run( self, fab_id: str, fab_version: str, - override_config: Dict[str, str], + override_config: UserConfig, ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 895272c2fd79..90e932aa8015 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -82,7 +82,7 @@ def _submit_job(self, message: Message, timeout: Optional[float]) -> Message: # Retrieve context context = self.proxy_state.retrieve_context(run_id=run_id) - partition_id_str = context.node_config[PARTITION_ID_KEY] + partition_id_str = str(context.node_config[PARTITION_ID_KEY]) try: self.actor_pool.submit_client_job( diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 62e0cfd61c99..a0df3fc1eb8e 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -231,7 +231,7 @@ def _load_app() -> ClientApp: # register and retrieve context node_states[node_id].register_context(run_id=run_id) context = node_states[node_id].retrieve_context(run_id=run_id) - partition_id_str = context.node_config[PARTITION_ID_KEY] + partition_id_str = str(context.node_config[PARTITION_ID_KEY]) pool.submit_client_job( lambda a, c_fn, j_fn, nid_, state: a.run.remote(c_fn, j_fn, nid_, state), (_load_app, message, partition_id_str, context), diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 1b7c6f87591e..7cebb90451d6 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -25,7 +25,7 @@ from logging import DEBUG, ERROR, INFO, WARNING from pathlib import Path from time import sleep -from typing import Dict, List, Optional +from typing import List, Optional from flwr.cli.config_utils import load_and_validate from flwr.client import ClientApp @@ -33,7 +33,7 @@ from flwr.common.config import get_fused_config_from_dir, parse_config_args from flwr.common.constant import RUN_ID_NUM_BYTES from flwr.common.logger import set_logger_propagation, update_console_handler -from flwr.common.typing import Run +from flwr.common.typing import Run, UserConfig from flwr.server.driver import Driver, InMemoryDriver from flwr.server.run_serverapp import run as run_server_app from flwr.server.server_app import ServerApp @@ -238,7 +238,7 @@ def run_simulation( def run_serverapp_th( server_app_attr: Optional[str], server_app: Optional[ServerApp], - server_app_run_config: Dict[str, str], + server_app_run_config: UserConfig, driver: Driver, app_dir: str, f_stop: threading.Event, @@ -254,7 +254,7 @@ def server_th_with_start_checks( exception_event: threading.Event, _driver: Driver, _server_app_dir: str, - _server_app_run_config: Dict[str, str], + _server_app_run_config: UserConfig, _server_app_attr: Optional[str], _server_app: Optional[ServerApp], ) -> None: @@ -319,7 +319,7 @@ def _main_loop( client_app_attr: Optional[str] = None, server_app: Optional[ServerApp] = None, server_app_attr: Optional[str] = None, - server_app_run_config: Optional[Dict[str, str]] = None, + server_app_run_config: Optional[UserConfig] = None, ) -> None: """Launch SuperLink with Simulation Engine, then ServerApp on a separate thread.""" # Initialize StateFactory @@ -395,7 +395,7 @@ def _run_simulation( backend_config: Optional[BackendConfig] = None, client_app_attr: Optional[str] = None, server_app_attr: Optional[str] = None, - server_app_run_config: Optional[Dict[str, str]] = None, + server_app_run_config: Optional[UserConfig] = None, app_dir: str = "", flwr_dir: Optional[str] = None, run: Optional[Run] = None, @@ -438,7 +438,7 @@ def _run_simulation( A path to a `ServerApp` module to be loaded: For example: `server:app` or `project.package.module:wrapper.app`." - server_app_run_config : Optional[Dict[str, str]] + server_app_run_config : Optional[UserConfig] Config dictionary that parameterizes the run config. It will be made accesible to the ServerApp. diff --git a/src/py/flwr/superexec/deployment.py b/src/py/flwr/superexec/deployment.py index d012d408a9ff..2eb40a7464c9 100644 --- a/src/py/flwr/superexec/deployment.py +++ b/src/py/flwr/superexec/deployment.py @@ -17,7 +17,7 @@ import subprocess from logging import ERROR, INFO from pathlib import Path -from typing import Dict, Optional +from typing import Optional from typing_extensions import override @@ -25,6 +25,8 @@ from flwr.cli.install import install_from_fab from flwr.common.grpc import create_channel from flwr.common.logger import log +from flwr.common.serde import user_config_to_proto +from flwr.common.typing import UserConfig from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611 from flwr.proto.driver_pb2_grpc import DriverStub from flwr.server.driver.grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER @@ -65,13 +67,13 @@ def __init__( @override def set_config( self, - config: Dict[str, str], + config: UserConfig, ) -> None: """Set executor config arguments. Parameters ---------- - config : Dict[str, str] + config : UserConfig A dictionary for configuration values. Supported configuration key/value pairs: - "superlink": str @@ -84,12 +86,20 @@ def set_config( if not config: return if superlink_address := config.get("superlink"): + if not isinstance(superlink_address, str): + raise ValueError("The `superlink` value should be of type `str`.") self.superlink = superlink_address if root_certificates := config.get("root-certificates"): + if not isinstance(root_certificates, str): + raise ValueError( + "The `root-certificates` value should be of type `str`." + ) self.root_certificates = root_certificates - self.root_certificates_bytes = Path(root_certificates).read_bytes() + self.root_certificates_bytes = Path(str(root_certificates)).read_bytes() if flwr_dir := config.get("flwr-dir"): - self.flwr_dir = flwr_dir + if not isinstance(flwr_dir, str): + raise ValueError("The `flwr-dir` value should be of type `str`.") + self.flwr_dir = str(flwr_dir) def _connect(self) -> None: if self.stub is not None: @@ -105,7 +115,7 @@ def _create_run( self, fab_id: str, fab_version: str, - override_config: Dict[str, str], + override_config: UserConfig, ) -> int: if self.stub is None: self._connect() @@ -115,7 +125,7 @@ def _create_run( req = CreateRunRequest( fab_id=fab_id, fab_version=fab_version, - override_config=override_config, + override_config=user_config_to_proto(override_config), ) res = self.stub.CreateRun(request=req) return int(res.run_id) @@ -124,7 +134,7 @@ def _create_run( def start_run( self, fab_file: bytes, - override_config: Dict[str, str], + override_config: UserConfig, ) -> Optional[RunTracker]: """Start run using the Flower Deployment Engine.""" try: diff --git a/src/py/flwr/superexec/exec_grpc.py b/src/py/flwr/superexec/exec_grpc.py index d90cec3e47cd..a32ebc1b3e35 100644 --- a/src/py/flwr/superexec/exec_grpc.py +++ b/src/py/flwr/superexec/exec_grpc.py @@ -15,12 +15,13 @@ """SuperExec gRPC API.""" from logging import INFO -from typing import Dict, Optional, Tuple +from typing import Optional, Tuple import grpc from flwr.common import GRPC_MAX_MESSAGE_LENGTH from flwr.common.logger import log +from flwr.common.typing import UserConfig from flwr.proto.exec_pb2_grpc import add_ExecServicer_to_server from flwr.server.superlink.fleet.grpc_bidi.grpc_server import generic_create_grpc_server @@ -32,7 +33,7 @@ def run_superexec_api_grpc( address: str, executor: Executor, certificates: Optional[Tuple[bytes, bytes, bytes]], - config: Dict[str, str], + config: UserConfig, ) -> grpc.Server: """Run SuperExec API (gRPC, request-response).""" executor.set_config(config) diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index 61a7bc289af3..fa54590d3b7b 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -21,6 +21,7 @@ import grpc from flwr.common.logger import log +from flwr.common.serde import user_config_from_proto from flwr.proto import exec_pb2_grpc # pylint: disable=E0611 from flwr.proto.exec_pb2 import ( # pylint: disable=E0611 StartRunRequest, @@ -46,8 +47,7 @@ def StartRun( log(INFO, "ExecServicer.StartRun") run = self.executor.start_run( - request.fab_file, - dict(request.override_config.items()), + request.fab_file, user_config_from_proto(request.override_config) ) if run is None: diff --git a/src/py/flwr/superexec/executor.py b/src/py/flwr/superexec/executor.py index 62d64f366cec..ed941d47e764 100644 --- a/src/py/flwr/superexec/executor.py +++ b/src/py/flwr/superexec/executor.py @@ -17,7 +17,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from subprocess import Popen -from typing import Dict, Optional +from typing import Optional + +from flwr.common.typing import UserConfig @dataclass @@ -34,13 +36,13 @@ class Executor(ABC): @abstractmethod def set_config( self, - config: Dict[str, str], + config: UserConfig, ) -> None: """Register provided config as class attributes. Parameters ---------- - config : Optional[Dict[str, str]] + config : UserConfig A dictionary for configuration values. """ @@ -48,7 +50,7 @@ def set_config( def start_run( self, fab_file: bytes, - override_config: Dict[str, str], + override_config: UserConfig, ) -> Optional[RunTracker]: """Start a run using the given Flower FAB ID and version. @@ -59,7 +61,7 @@ def start_run( ---------- fab_file : bytes The Flower App Bundle file bytes. - override_config: Dict[str, str] + override_config: UserConfig The config overrides dict sent by the user (using `flwr run`). Returns diff --git a/src/py/flwr/superexec/simulation.py b/src/py/flwr/superexec/simulation.py index 58cc194a16d4..737c037375d7 100644 --- a/src/py/flwr/superexec/simulation.py +++ b/src/py/flwr/superexec/simulation.py @@ -18,7 +18,7 @@ import subprocess import sys from logging import ERROR, INFO, WARN -from typing import Dict, Optional +from typing import Optional from typing_extensions import override @@ -26,6 +26,7 @@ from flwr.cli.install import install_from_fab from flwr.common.constant import RUN_ID_NUM_BYTES from flwr.common.logger import log +from flwr.common.typing import UserConfig from flwr.server.superlink.state.utils import generate_rand_int_from_bytes from .executor import Executor, RunTracker @@ -42,43 +43,46 @@ class SimulationEngine(Executor): def __init__( self, - num_supernodes: Optional[str] = None, + num_supernodes: Optional[int] = None, ) -> None: self.num_supernodes = num_supernodes @override def set_config( self, - config: Dict[str, str], + config: UserConfig, ) -> None: """Set executor config arguments. Parameters ---------- - config : Dict[str, str] + config : UserConfig A dictionary for configuration values. Supported configuration key/value pairs: - - "num-supernodes": str + - "num-supernodes": int Number of nodes to register for the simulation. """ if not config: return if num_supernodes := config.get("num-supernodes"): + if not isinstance(num_supernodes, int): + raise ValueError("The `num-supernodes` value should be of type `int`.") self.num_supernodes = num_supernodes - - # Validate config - if self.num_supernodes is None: + else: log( ERROR, "To start a run with the simulation plugin, please specify " "the number of SuperNodes. This can be done by using the " "`--executor-config` argument when launching the SuperExec.", ) - raise ValueError("`num-supernodes` must not be `None`") + raise ValueError( + "`num-supernodes` must not be `None`, it must be a valid " + "positive integer." + ) @override def start_run( - self, fab_file: bytes, override_config: Dict[str, str] + self, fab_file: bytes, override_config: UserConfig ) -> Optional[RunTracker]: """Start run using the Flower Simulation Engine.""" try: