Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(framework) Make federation_options in Exec API use ConfigsRecords #4453

Merged
merged 8 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/proto/flwr/proto/exec.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package flwr.proto;

import "flwr/proto/fab.proto";
import "flwr/proto/transport.proto";
import "flwr/proto/recordset.proto";

service Exec {
// Start run upon request
Expand All @@ -31,7 +32,7 @@ service Exec {
message StartRunRequest {
Fab fab = 1;
map<string, Scalar> override_config = 2;
map<string, Scalar> federation_config = 3;
ConfigsRecord federation_options = 3;
}
message StartRunResponse { uint64 run_id = 1; }
message StreamLogsRequest {
Expand Down
21 changes: 16 additions & 5 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,18 @@
validate_federation_in_project_config,
validate_project_config,
)
from flwr.common.config import flatten_dict, parse_config_args
from flwr.common.config import (
flatten_dict,
parse_config_args,
user_config_to_configsrecord,
)
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
from flwr.common.logger import log
from flwr.common.serde import fab_to_proto, user_config_to_proto
from flwr.common.serde import (
configs_record_to_proto,
fab_to_proto,
user_config_to_proto,
)
from flwr.common.typing import Fab
from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
from flwr.proto.exec_pb2_grpc import ExecStub
Expand Down Expand Up @@ -94,6 +102,7 @@ def run(
_run_without_exec_api(app, federation_config, config_overrides, federation)


# pylint: disable-next=too-many-locals
def _run_with_exec_api(
app: Path,
federation_config: dict[str, Any],
Expand All @@ -118,12 +127,14 @@ def _run_with_exec_api(
content = Path(fab_path).read_bytes()
fab = Fab(fab_hash, content)

# Construct a `ConfigsRecord` out of a flattened `UserConfig`
fed_conf = flatten_dict(federation_config.get("options", {}))
c_record = user_config_to_configsrecord(fed_conf)

req = StartRunRequest(
fab=fab_to_proto(fab),
override_config=user_config_to_proto(parse_config_args(config_overrides)),
federation_config=user_config_to_proto(
flatten_dict(federation_config.get("options"))
),
federation_options=configs_record_to_proto(c_record),
)
res = stub.StartRun(req)

Expand Down
10 changes: 10 additions & 0 deletions src/py/flwr/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tomli

from flwr.cli.config_utils import get_fab_config, validate_fields
from flwr.common import ConfigsRecord
from flwr.common.constant import (
APP_DIR,
FAB_CONFIG_FILE,
Expand Down Expand Up @@ -229,3 +230,12 @@ def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]:
config["project"]["version"],
f"{config['tool']['flwr']['app']['publisher']}/{config['project']['name']}",
)


def user_config_to_configsrecord(config: UserConfig) -> ConfigsRecord:
"""Construct a `ConfigsRecord` out of a `UserConfig`."""
c_record = ConfigsRecord()
for k, v in config.items():
c_record[k] = v

return c_record
31 changes: 14 additions & 17 deletions src/py/flwr/proto/exec_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 6 additions & 20 deletions src/py/flwr/proto/exec_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ isort:skip_file
"""
import builtins
import flwr.proto.fab_pb2
import flwr.proto.recordset_pb2
import flwr.proto.transport_pb2
import google.protobuf.descriptor
import google.protobuf.internal.containers
Expand All @@ -30,38 +31,23 @@ class StartRunRequest(google.protobuf.message.Message):
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...

class FederationConfigEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: typing.Text
@property
def value(self) -> flwr.proto.transport_pb2.Scalar: ...
def __init__(self,
*,
key: typing.Text = ...,
value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...

FAB_FIELD_NUMBER: builtins.int
OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int
FEDERATION_CONFIG_FIELD_NUMBER: builtins.int
FEDERATION_OPTIONS_FIELD_NUMBER: builtins.int
@property
def fab(self) -> flwr.proto.fab_pb2.Fab: ...
@property
def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
@property
def federation_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
def federation_options(self) -> flwr.proto.recordset_pb2.ConfigsRecord: ...
def __init__(self,
*,
fab: typing.Optional[flwr.proto.fab_pb2.Fab] = ...,
override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
federation_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
federation_options: typing.Optional[flwr.proto.recordset_pb2.ConfigsRecord] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["fab",b"fab"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["fab",b"fab","federation_config",b"federation_config","override_config",b"override_config"]) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["fab",b"fab","federation_options",b"federation_options"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["fab",b"fab","federation_options",b"federation_options","override_config",b"override_config"]) -> None: ...
global___StartRunRequest = StartRunRequest

class StartRunResponse(google.protobuf.message.Message):
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/superexec/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def start_run(
self,
fab_file: bytes,
override_config: UserConfig,
federation_config: UserConfig,
federation_options: ConfigsRecord,
) -> Optional[int]:
"""Start run using the Flower Deployment Engine."""
run_id = None
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/superexec/exec_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from flwr.common.constant import LOG_STREAM_INTERVAL, Status
from flwr.common.logger import log
from flwr.common.serde import user_config_from_proto
from flwr.common.serde import configs_record_from_proto, user_config_from_proto
from flwr.proto import exec_pb2_grpc # pylint: disable=E0611
from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
StartRunRequest,
Expand Down Expand Up @@ -61,7 +61,7 @@ def StartRun(
run_id = self.executor.start_run(
request.fab.content,
user_config_from_proto(request.override_config),
user_config_from_proto(request.federation_config),
configs_record_from_proto(request.federation_options),
)

if run_id is None:
Expand Down
7 changes: 4 additions & 3 deletions src/py/flwr/superexec/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from subprocess import Popen
from typing import Optional

from flwr.common import ConfigsRecord
from flwr.common.typing import UserConfig
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.linkstate import LinkStateFactory
Expand Down Expand Up @@ -71,7 +72,7 @@ def start_run(
self,
fab_file: bytes,
override_config: UserConfig,
federation_config: UserConfig,
federation_options: ConfigsRecord,
) -> Optional[int]:
"""Start a run using the given Flower FAB ID and version.

Expand All @@ -84,8 +85,8 @@ def start_run(
The Flower App Bundle file bytes.
override_config: UserConfig
The config overrides dict sent by the user (using `flwr run`).
federation_config: UserConfig
The federation options dict sent by the user (using `flwr run`).
federation_options: ConfigsRecord
The federation options sent by the user (using `flwr run`).

Returns
-------
Expand Down
16 changes: 8 additions & 8 deletions src/py/flwr/superexec/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from flwr.cli.config_utils import load_and_validate
from flwr.cli.install import install_from_fab
from flwr.common import ConfigsRecord
from flwr.common.config import unflatten_dict
from flwr.common.constant import RUN_ID_NUM_BYTES
from flwr.common.logger import log
Expand Down Expand Up @@ -124,7 +125,7 @@ def start_run(
self,
fab_file: bytes,
override_config: UserConfig,
federation_config: UserConfig,
federation_options: ConfigsRecord,
) -> Optional[int]:
"""Start run using the Flower Simulation Engine."""
if self.num_supernodes is None:
Expand Down Expand Up @@ -163,14 +164,13 @@ def start_run(
"Config extracted from FAB's pyproject.toml is not valid"
)

# Flatten federated config
federation_config_flat = unflatten_dict(federation_config)
# Unflatten underlaying dict
fed_opt = unflatten_dict({**federation_options})

num_supernodes = federation_config_flat.get(
"num-supernodes", self.num_supernodes
)
backend_cfg = federation_config_flat.get("backend", {})
verbose: Optional[bool] = federation_config_flat.get("verbose")
# Read data
num_supernodes = fed_opt.get("num-supernodes", self.num_supernodes)
backend_cfg = fed_opt.get("backend", {})
verbose: Optional[bool] = fed_opt.get("verbose")

# In Simulation there is no SuperLink, still we create a run_id
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
Expand Down
Loading