Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Support non-str config value types #3746

Merged
merged 98 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from 94 commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
d0cc112
feat(framework) Add proto files for override_config with specific type
charlesbvll Jul 8, 2024
72eb30a
Compile proto files
charlesbvll Jul 8, 2024
bd74c73
feat(framework) Add overrides to SuperExec with specific type
charlesbvll Jul 8, 2024
be141c3
Merge branch 'add-overrides-proto-configrecordvalue' into add-overrid…
charlesbvll Jul 8, 2024
8c0d1b9
Update serde
charlesbvll Jul 8, 2024
519f9f9
Merge branch 'add-overrides-proto-configrecordvalue' into add-overrid…
charlesbvll Jul 8, 2024
0e24f5e
Add necessary import
charlesbvll Jul 8, 2024
3fafa71
feat(framework) Extend ConfigRecordValue serde
charlesbvll Jul 8, 2024
a938ba5
Fix test
charlesbvll Jul 8, 2024
cc01992
Merge branch 'update-serde-to-support-configrecordvalue' into add-ove…
charlesbvll Jul 8, 2024
9ad4fad
Merge branch 'main' into add-overrides-proto-configrecordvalue
charlesbvll Jul 8, 2024
8320c3c
feat(framework) Add override_config to Run with specific type
charlesbvll Jul 8, 2024
78fac27
Fix file
charlesbvll Jul 8, 2024
98f2ede
Add correct pylint disable
charlesbvll Jul 8, 2024
5148329
Merge branch 'main' into add-overrides-to-run-configrecordvalue
charlesbvll Jul 8, 2024
56faf84
Add override_config
charlesbvll Jul 8, 2024
d933ba3
Merge branch 'update-serde-to-support-configrecordvalue' into add-ove…
charlesbvll Jul 8, 2024
62b8e15
Merge branch 'main' into add-override-config-superexec-configrecordvalue
charlesbvll Jul 8, 2024
c526d0b
Merge branch 'main' into add-override-config-superexec-configrecordvalue
charlesbvll Jul 8, 2024
5a15a17
Merge branch 'main' into add-override-config-superexec-configrecordvalue
charlesbvll Jul 8, 2024
624f3a3
Fix imports
charlesbvll Jul 8, 2024
28b787e
Fix proto test
charlesbvll Jul 8, 2024
0274360
Fix test
charlesbvll Jul 8, 2024
3d7b61f
Merge branch 'main' into add-override-config-superexec-configrecordvalue
charlesbvll Jul 8, 2024
e5bed16
Merge branch 'main' into add-override-config-superexec-configrecordvalue
charlesbvll Jul 8, 2024
b244389
Use serde func
charlesbvll Jul 8, 2024
4dbac7d
Merge branch 'main' into add-override-config-superexec-configrecordvalue
charlesbvll Jul 8, 2024
ec1a62a
Merge branch 'add-overrides-proto-configrecordvalue' into change-over…
charlesbvll Jul 8, 2024
119e957
Merge branch 'update-serde-to-support-configrecordvalue' into change-…
charlesbvll Jul 8, 2024
4a81569
Merge branch 'add-overrides-to-run-configrecordvalue' into change-ove…
charlesbvll Jul 8, 2024
7712971
Merge branch 'add-override-config-superexec-configrecordvalue' into c…
charlesbvll Jul 8, 2024
7329b63
Remove unused import
charlesbvll Jul 8, 2024
ba8ee59
Merge branch 'main' into change-override-config-type
charlesbvll Jul 10, 2024
066986d
Add changes
charlesbvll Jul 10, 2024
3b34df6
Merge branch 'main' into change-override-config-type
charlesbvll Jul 11, 2024
631df02
Use correct type
charlesbvll Jul 11, 2024
7d5883b
Merge branch 'main' into change-override-config-type
charlesbvll Jul 11, 2024
4d70e29
Fix tests
charlesbvll Jul 11, 2024
a680613
Fix types
charlesbvll Jul 11, 2024
3b65202
Fix imports
charlesbvll Jul 11, 2024
3e5f549
Fix imports
charlesbvll Jul 11, 2024
c9719ae
Fix utils
charlesbvll Jul 11, 2024
9bb359e
Use Value instead of ConfigsRecordValues
charlesbvll Jul 11, 2024
b10ca8d
Fix imports
charlesbvll Jul 11, 2024
e5a8adb
Merge branch 'main' into change-override-config-type
danieljanes Jul 13, 2024
8c8c3b6
Merge branch 'main' into change-override-config-type
charlesbvll Jul 14, 2024
d397ae0
Use correct type for node_config
charlesbvll Jul 14, 2024
56574da
Use custom type
charlesbvll Jul 14, 2024
11a2e97
fix(framework:skip) Use correct arguments
charlesbvll Jul 14, 2024
5780dd6
Merge branch 'fix-deployment-engine' into change-override-config-type
charlesbvll Jul 14, 2024
5650959
Fix imports
charlesbvll Jul 14, 2024
63ed4b7
Remove unused import
charlesbvll Jul 14, 2024
490b16c
Merge branch 'main' into change-override-config-type
charlesbvll Jul 14, 2024
aab7c0c
Merge branch 'main' into change-override-config-type
charlesbvll Jul 14, 2024
8a3c97f
Fix types for simulation
charlesbvll Jul 14, 2024
d96c3e3
Merge branch 'change-override-config-type' of https://github.com/adap…
charlesbvll Jul 14, 2024
7f5fc56
Fix imports
charlesbvll Jul 14, 2024
162e8dc
Fix error
charlesbvll Jul 14, 2024
f26b645
Fix messagehandler
charlesbvll Jul 14, 2024
5a504a7
Fix imports
charlesbvll Jul 14, 2024
61a4c86
Remove unused list
charlesbvll Jul 14, 2024
16f7c54
Revert some proto changes
charlesbvll Jul 14, 2024
f0bf536
Fix imports
charlesbvll Jul 14, 2024
da09e2f
Fix imports
charlesbvll Jul 14, 2024
34a73c9
Fix imports again
charlesbvll Jul 14, 2024
31c0915
Fix imports
charlesbvll Jul 14, 2024
3171cbb
Fix serde
charlesbvll Jul 14, 2024
056927b
Fix imports
charlesbvll Jul 14, 2024
03d1e4f
Fix proto test
charlesbvll Jul 14, 2024
2b0e4b2
Revert some changes
charlesbvll Jul 14, 2024
a8c5c55
Fix serde
charlesbvll Jul 14, 2024
4712530
Merge branch 'main' into change-override-config-type
charlesbvll Jul 15, 2024
8a74a1d
Fix formatting
charlesbvll Jul 15, 2024
2441381
Fix types
charlesbvll Jul 15, 2024
1ddd379
Fix type mismatch
charlesbvll Jul 15, 2024
72b4f13
Make num-supernodes an int
charlesbvll Jul 15, 2024
6814d79
Merge branch 'main' into change-override-config-type
charlesbvll Jul 15, 2024
b8b9937
Merge branch 'main' into change-override-config-type
charlesbvll Jul 16, 2024
36a46ba
Merge branch 'main' into change-override-config-type
charlesbvll Jul 16, 2024
de789b0
Merge branch 'main' into change-override-config-type
charlesbvll Jul 16, 2024
2948d04
Merge branch 'main' into change-override-config-type
charlesbvll Jul 16, 2024
50c921e
Merge branch 'main' into change-override-config-type
charlesbvll Jul 16, 2024
355c223
Remove unused import
charlesbvll Jul 16, 2024
590354a
Merge branch 'main' into change-override-config-type
charlesbvll Jul 17, 2024
d38a4ac
Fix
charlesbvll Jul 17, 2024
b040fbe
Merge branch 'main' into change-override-config-type
charlesbvll Jul 17, 2024
0ba8cfe
Merge branch 'main' into change-override-config-type
charlesbvll Jul 17, 2024
712e6a2
Merge branch 'main' into change-override-config-type
danieljanes Jul 17, 2024
9f6c8d4
Merge branch 'main' into change-override-config-type
danieljanes Jul 22, 2024
0986e4b
Apply suggestions
charlesbvll Jul 22, 2024
e06bd67
Merge branch 'main' into change-override-config-type
charlesbvll Jul 22, 2024
5ba487a
Merge branch 'main' into change-override-config-type
charlesbvll Jul 22, 2024
f58b1d2
Update src/py/flwr/superexec/deployment.py
danieljanes Jul 22, 2024
6000b54
Update src/py/flwr/superexec/deployment.py
danieljanes Jul 22, 2024
5f29147
Update src/py/flwr/superexec/deployment.py
danieljanes Jul 22, 2024
21b4a50
Update src/py/flwr/superexec/simulation.py
danieljanes Jul 22, 2024
5b789c6
Merge branch 'main' into change-override-config-type
danieljanes Jul 22, 2024
8f45924
Format
danieljanes Jul 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/proto/flwr/proto/driver.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package flwr.proto;
import "flwr/proto/node.proto";
import "flwr/proto/task.proto";
import "flwr/proto/run.proto";
import "flwr/proto/transport.proto";

service Driver {
// Request run_id
Expand All @@ -42,7 +43,7 @@ service Driver {
message CreateRunRequest {
string fab_id = 1;
string fab_version = 2;
map<string, string> override_config = 3;
map<string, Scalar> override_config = 3;
}
message CreateRunResponse { sint64 run_id = 1; }

Expand Down
4 changes: 3 additions & 1 deletion src/proto/flwr/proto/exec.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ syntax = "proto3";

package flwr.proto;

import "flwr/proto/transport.proto";

service Exec {
// Start run upon request
rpc StartRun(StartRunRequest) returns (StartRunResponse) {}
Expand All @@ -27,7 +29,7 @@ service Exec {

message StartRunRequest {
bytes fab_file = 1;
map<string, string> override_config = 2;
map<string, Scalar> override_config = 2;
}
message StartRunResponse { sint64 run_id = 1; }
message StreamLogsRequest { sint64 run_id = 1; }
Expand Down
4 changes: 3 additions & 1 deletion src/proto/flwr/proto/run.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ syntax = "proto3";

package flwr.proto;

import "flwr/proto/transport.proto";

message Run {
sint64 run_id = 1;
string fab_id = 2;
string fab_version = 3;
map<string, string> override_config = 4;
map<string, Scalar> override_config = 4;
}
message GetRunRequest { sint64 run_id = 1; }
message GetRunResponse { Run run = 1; }
1 change: 0 additions & 1 deletion src/proto/flwr/proto/task.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package flwr.proto;

import "flwr/proto/node.proto";
import "flwr/proto/recordset.proto";
import "flwr/proto/transport.proto";
import "flwr/proto/error.proto";

message Task {
Expand Down
10 changes: 7 additions & 3 deletions src/py/flwr/cli/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import zipfile
from io import BytesIO
from pathlib import Path
from typing import IO, Any, Dict, List, Optional, Tuple, Union
from typing import IO, Any, Dict, List, Optional, Tuple, Union, get_args

import tomli

from flwr.common import object_ref
from flwr.common.typing import UserConfigValue


def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
Expand Down Expand Up @@ -112,8 +113,11 @@ def _validate_run_config(config_dict: Dict[str, Any], errors: List[str]) -> None
for key, value in config_dict.items():
if isinstance(value, dict):
_validate_run_config(config_dict[key], errors)
elif not isinstance(value, str):
errors.append(f"Config value of key {key} is not of type `str`.")
elif not isinstance(value, get_args(UserConfigValue)):
raise ValueError(
f"The value for key {key} needs to be of type `int`, `float`, "
"`bool, `str`, or a `dict` of those.",
)


# pylint: disable=too-many-branches
Expand Down
5 changes: 4 additions & 1 deletion src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from flwr.common.config import parse_config_args
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
from flwr.common.logger import log
from flwr.common.serde import user_config_to_proto
from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
from flwr.proto.exec_pb2_grpc import ExecStub

Expand Down Expand Up @@ -164,7 +165,9 @@ def on_channel_state_change(channel_connectivity: str) -> None:

req = StartRunRequest(
fab_file=Path(fab_path).read_bytes(),
override_config=parse_config_args(config_overrides, separator=","),
override_config=user_config_to_proto(
parse_config_args(config_overrides, separator=",")
),
)
res = stub.StartRun(req)
typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from flwr.common.logger import log, warn_deprecated_feature
from flwr.common.message import Error
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
from flwr.common.typing import Run
from flwr.common.typing import Run, UserConfig

from .grpc_adapter_client.connection import grpc_adapter
from .grpc_client.connection import grpc_connection
Expand Down Expand Up @@ -182,7 +182,7 @@ class `flwr.client.Client` (default: None)
def _start_client_internal(
*,
server_address: str,
node_config: Dict[str, str],
node_config: UserConfig,
load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None,
client_fn: Optional[ClientFnExt] = None,
client: Optional[Client] = None,
Expand All @@ -205,7 +205,7 @@ def _start_client_internal(
The IPv4 or IPv6 address of the server. If the Flower
server runs on the same machine on port 8080, then `server_address`
would be `"[::]:8080"`.
node_config: Dict[str, str]
node_config: UserConfig
The configuration of the node.
load_client_app_fn : Optional[Callable[[], ClientApp]] (default: None)
A function that can be used to load a `ClientApp` instance.
Expand Down
8 changes: 6 additions & 2 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@
from flwr.common.logger import log
from flwr.common.message import Message, Metadata
from flwr.common.retry_invoker import RetryInvoker
from flwr.common.serde import message_from_taskins, message_to_taskres
from flwr.common.serde import (
message_from_taskins,
message_to_taskres,
user_config_from_proto,
)
from flwr.common.typing import Run
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
Expand Down Expand Up @@ -281,7 +285,7 @@ def get_run(run_id: int) -> Run:
run_id,
get_run_response.run.fab_id,
get_run_response.run.fab_version,
dict(get_run_response.run.override_config.items()),
user_config_from_proto(get_run_response.run.override_config),
)

try:
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/client/node_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@

from flwr.common import Context, RecordSet
from flwr.common.config import get_fused_config, get_fused_config_from_dir
from flwr.common.typing import Run
from flwr.common.typing import Run, UserConfig


@dataclass()
class RunInfo:
"""Contains the Context and initial run_config of a Run."""

context: Context
initial_run_config: Dict[str, str]
initial_run_config: UserConfig


class NodeState:
Expand All @@ -38,7 +38,7 @@ class NodeState:
def __init__(
self,
node_id: int,
node_config: Dict[str, str],
node_config: UserConfig,
) -> None:
self.node_id = node_id
self.node_config = node_config
Expand Down
8 changes: 6 additions & 2 deletions src/py/flwr/client/rest_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@
from flwr.common.logger import log
from flwr.common.message import Message, Metadata
from flwr.common.retry_invoker import RetryInvoker
from flwr.common.serde import message_from_taskins, message_to_taskres
from flwr.common.serde import (
message_from_taskins,
message_to_taskres,
user_config_from_proto,
)
from flwr.common.typing import Run
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
Expand Down Expand Up @@ -359,7 +363,7 @@ def get_run(run_id: int) -> Run:
run_id,
res.run.fab_id,
res.run.fab_version,
dict(res.run.override_config.items()),
user_config_from_proto(res.run.override_config),
)

try:
Expand Down
35 changes: 18 additions & 17 deletions src/py/flwr/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union, cast, get_args

import tomli

from flwr.cli.config_utils import validate_fields
from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME
from flwr.common.typing import Run
from flwr.common.typing import Run, UserConfig, UserConfigValue


def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
Expand Down Expand Up @@ -75,8 +75,9 @@ def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]:


def _fuse_dicts(
main_dict: Dict[str, str], override_dict: Dict[str, str]
) -> Dict[str, str]:
main_dict: UserConfig,
override_dict: UserConfig,
) -> UserConfig:
fused_dict = main_dict.copy()

for key, value in override_dict.items():
Expand All @@ -87,8 +88,8 @@ def _fuse_dicts(


def get_fused_config_from_dir(
project_dir: Path, override_config: Dict[str, str]
) -> Dict[str, str]:
project_dir: Path, override_config: UserConfig
) -> UserConfig:
"""Merge the overrides from a given dict with the config from a Flower App."""
default_config = get_project_config(project_dir)["tool"]["flwr"]["app"].get(
"config", {}
Expand All @@ -98,7 +99,7 @@ def get_fused_config_from_dir(
return _fuse_dicts(flat_default_config, override_config)


def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> Dict[str, str]:
def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig:
"""Merge the overrides from a `Run` with the config from a FAB.

Get the config using the fab_id and the fab_version, remove the nesting by adding
Expand All @@ -112,29 +113,30 @@ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> Dict[str, str]:
return get_fused_config_from_dir(project_dir, run.override_config)


def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, str]:
def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> UserConfig:
"""Flatten dict by joining nested keys with a given separator."""
items: List[Tuple[str, str]] = []
items: List[Tuple[str, UserConfigValue]] = []
separator: str = "."
for k, v in raw_dict.items():
new_key = f"{parent_key}{separator}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, parent_key=new_key).items())
elif isinstance(v, str):
items.append((new_key, v))
elif isinstance(v, get_args(UserConfigValue)):
items.append((new_key, cast(UserConfigValue, v)))
else:
raise ValueError(
f"The value for key {k} needs to be a `str` or a `dict`.",
f"The value for key {k} needs to be of type `int`, `float`, "
"`bool, `str`, or a `dict` of those.",
)
return dict(items)


def parse_config_args(
config: Optional[List[str]],
separator: str = ",",
) -> Dict[str, str]:
) -> UserConfig:
"""Parse separator separated list of key-value pairs separated by '='."""
overrides: Dict[str, str] = {}
overrides: UserConfig = {}

if config is None:
return overrides
Expand All @@ -150,8 +152,7 @@ def parse_config_args(
with Path(overrides_list[0]).open("rb") as config_file:
overrides = flatten_dict(tomli.load(config_file))
else:
for kv_pair in overrides_list:
key, value = kv_pair.split("=")
overrides[key] = value
toml_str = "\n".join(overrides_list)
overrides.update(tomli.loads(toml_str))

return overrides
41 changes: 24 additions & 17 deletions src/py/flwr/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import pytest

from flwr.common.typing import UserConfig

from .config import (
_fuse_dicts,
flatten_dict,
Expand Down Expand Up @@ -101,23 +103,25 @@ def test_get_fused_config_valid(tmp_path: Path) -> None:
clientapp = "fedgpt.client:app"

[tool.flwr.app.config]
num_server_rounds = "10"
momentum = "0.1"
lr = "0.01"
num_server_rounds = 10
momentum = 0.1
lr = 0.01
charlesbvll marked this conversation as resolved.
Show resolved Hide resolved
progress_bar = true
serverapp.test = "key"

[tool.flwr.app.config.clientapp]
test = "key"
"""
overrides = {
"num_server_rounds": "5",
"lr": "0.2",
overrides: UserConfig = {
"num_server_rounds": 5,
"lr": 0.2,
"serverapp.test": "overriden",
}
expected_config = {
"num_server_rounds": "5",
"momentum": "0.1",
"lr": "0.2",
"num_server_rounds": 5,
"momentum": 0.1,
"lr": 0.2,
"progress_bar": True,
"serverapp.test": "overriden",
"clientapp.test": "key",
}
Expand Down Expand Up @@ -168,8 +172,9 @@ def test_get_project_config_file_valid(tmp_path: Path) -> None:
clientapp = "fedgpt.client:app"

[tool.flwr.app.config]
num_server_rounds = "10"
momentum = "0.1"
num_server_rounds = 10
momentum = 0.1
charlesbvll marked this conversation as resolved.
Show resolved Hide resolved
progress_bar = true
lr = "0.01"
"""
expected_config = {
Expand All @@ -190,8 +195,9 @@ def test_get_project_config_file_valid(tmp_path: Path) -> None:
"clientapp": "fedgpt.client:app",
},
"config": {
"num_server_rounds": "10",
"momentum": "0.1",
"num_server_rounds": 10,
"momentum": 0.1,
"progress_bar": True,
"lr": "0.01",
},
},
Expand Down Expand Up @@ -231,11 +237,12 @@ def test_parse_config_args_none() -> None:
def test_parse_config_args_overrides() -> None:
"""Test parse_config_args with key-value pairs."""
assert parse_config_args(
["key1=value1,key2=value2", "key3=value3", "key4=value4,key5=value5"]
["key1='value1',key2='value2'", "key3=1", "key4=2.0,key5=true,key6='value6'"]
) == {
"key1": "value1",
"key2": "value2",
"key3": "value3",
"key4": "value4",
"key5": "value5",
"key3": 1,
"key4": 2.0,
"key5": True,
"key6": "value6",
}
Loading
Loading