Skip to content

Commit

Permalink
feat(framework) Support non-str config value types (#3746)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll authored Jul 22, 2024
1 parent be83ae8 commit d24ebe5
Show file tree
Hide file tree
Showing 39 changed files with 433 additions and 176 deletions.
3 changes: 2 additions & 1 deletion src/proto/flwr/proto/driver.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package flwr.proto;
import "flwr/proto/node.proto";
import "flwr/proto/task.proto";
import "flwr/proto/run.proto";
import "flwr/proto/transport.proto";

service Driver {
// Request run_id
Expand All @@ -42,7 +43,7 @@ service Driver {
message CreateRunRequest {
string fab_id = 1;
string fab_version = 2;
map<string, string> override_config = 3;
map<string, Scalar> override_config = 3;
}
message CreateRunResponse { sint64 run_id = 1; }

Expand Down
4 changes: 3 additions & 1 deletion src/proto/flwr/proto/exec.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ syntax = "proto3";

package flwr.proto;

import "flwr/proto/transport.proto";

service Exec {
// Start run upon request
rpc StartRun(StartRunRequest) returns (StartRunResponse) {}
Expand All @@ -27,7 +29,7 @@ service Exec {

message StartRunRequest {
bytes fab_file = 1;
map<string, string> override_config = 2;
map<string, Scalar> override_config = 2;
}
message StartRunResponse { sint64 run_id = 1; }
message StreamLogsRequest { sint64 run_id = 1; }
Expand Down
4 changes: 3 additions & 1 deletion src/proto/flwr/proto/run.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ syntax = "proto3";

package flwr.proto;

import "flwr/proto/transport.proto";

message Run {
sint64 run_id = 1;
string fab_id = 2;
string fab_version = 3;
map<string, string> override_config = 4;
map<string, Scalar> override_config = 4;
}
message GetRunRequest { sint64 run_id = 1; }
message GetRunResponse { Run run = 1; }
1 change: 0 additions & 1 deletion src/proto/flwr/proto/task.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package flwr.proto;

import "flwr/proto/node.proto";
import "flwr/proto/recordset.proto";
import "flwr/proto/transport.proto";
import "flwr/proto/error.proto";

message Task {
Expand Down
10 changes: 7 additions & 3 deletions src/py/flwr/cli/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import zipfile
from io import BytesIO
from pathlib import Path
from typing import IO, Any, Dict, List, Optional, Tuple, Union
from typing import IO, Any, Dict, List, Optional, Tuple, Union, get_args

import tomli

from flwr.common import object_ref
from flwr.common.typing import UserConfigValue


def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
Expand Down Expand Up @@ -112,8 +113,11 @@ def _validate_run_config(config_dict: Dict[str, Any], errors: List[str]) -> None
for key, value in config_dict.items():
if isinstance(value, dict):
_validate_run_config(config_dict[key], errors)
elif not isinstance(value, str):
errors.append(f"Config value of key {key} is not of type `str`.")
elif not isinstance(value, get_args(UserConfigValue)):
raise ValueError(
f"The value for key {key} needs to be of type `int`, `float`, "
"`bool, `str`, or a `dict` of those.",
)


# pylint: disable=too-many-branches
Expand Down
5 changes: 4 additions & 1 deletion src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from flwr.common.config import parse_config_args
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
from flwr.common.logger import log
from flwr.common.serde import user_config_to_proto
from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
from flwr.proto.exec_pb2_grpc import ExecStub

Expand Down Expand Up @@ -164,7 +165,9 @@ def on_channel_state_change(channel_connectivity: str) -> None:

req = StartRunRequest(
fab_file=Path(fab_path).read_bytes(),
override_config=parse_config_args(config_overrides, separator=","),
override_config=user_config_to_proto(
parse_config_args(config_overrides, separator=",")
),
)
res = stub.StartRun(req)
typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from flwr.common.logger import log, warn_deprecated_feature
from flwr.common.message import Error
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
from flwr.common.typing import Run
from flwr.common.typing import Run, UserConfig

from .grpc_adapter_client.connection import grpc_adapter
from .grpc_client.connection import grpc_connection
Expand Down Expand Up @@ -182,7 +182,7 @@ class `flwr.client.Client` (default: None)
def _start_client_internal(
*,
server_address: str,
node_config: Dict[str, str],
node_config: UserConfig,
load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None,
client_fn: Optional[ClientFnExt] = None,
client: Optional[Client] = None,
Expand All @@ -205,7 +205,7 @@ def _start_client_internal(
The IPv4 or IPv6 address of the server. If the Flower
server runs on the same machine on port 8080, then `server_address`
would be `"[::]:8080"`.
node_config: Dict[str, str]
node_config: UserConfig
The configuration of the node.
load_client_app_fn : Optional[Callable[[], ClientApp]] (default: None)
A function that can be used to load a `ClientApp` instance.
Expand Down
8 changes: 6 additions & 2 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@
from flwr.common.logger import log
from flwr.common.message import Message, Metadata
from flwr.common.retry_invoker import RetryInvoker
from flwr.common.serde import message_from_taskins, message_to_taskres
from flwr.common.serde import (
message_from_taskins,
message_to_taskres,
user_config_from_proto,
)
from flwr.common.typing import Run
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
Expand Down Expand Up @@ -281,7 +285,7 @@ def get_run(run_id: int) -> Run:
run_id,
get_run_response.run.fab_id,
get_run_response.run.fab_version,
dict(get_run_response.run.override_config.items()),
user_config_from_proto(get_run_response.run.override_config),
)

try:
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/client/node_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@

from flwr.common import Context, RecordSet
from flwr.common.config import get_fused_config, get_fused_config_from_dir
from flwr.common.typing import Run
from flwr.common.typing import Run, UserConfig


@dataclass()
class RunInfo:
"""Contains the Context and initial run_config of a Run."""

context: Context
initial_run_config: Dict[str, str]
initial_run_config: UserConfig


class NodeState:
Expand All @@ -38,7 +38,7 @@ class NodeState:
def __init__(
self,
node_id: int,
node_config: Dict[str, str],
node_config: UserConfig,
) -> None:
self.node_id = node_id
self.node_config = node_config
Expand Down
8 changes: 6 additions & 2 deletions src/py/flwr/client/rest_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@
from flwr.common.logger import log
from flwr.common.message import Message, Metadata
from flwr.common.retry_invoker import RetryInvoker
from flwr.common.serde import message_from_taskins, message_to_taskres
from flwr.common.serde import (
message_from_taskins,
message_to_taskres,
user_config_from_proto,
)
from flwr.common.typing import Run
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
Expand Down Expand Up @@ -359,7 +363,7 @@ def get_run(run_id: int) -> Run:
run_id,
res.run.fab_id,
res.run.fab_version,
dict(res.run.override_config.items()),
user_config_from_proto(res.run.override_config),
)

try:
Expand Down
35 changes: 18 additions & 17 deletions src/py/flwr/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union, cast, get_args

import tomli

from flwr.cli.config_utils import validate_fields
from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME
from flwr.common.typing import Run
from flwr.common.typing import Run, UserConfig, UserConfigValue


def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
Expand Down Expand Up @@ -75,8 +75,9 @@ def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]:


def _fuse_dicts(
main_dict: Dict[str, str], override_dict: Dict[str, str]
) -> Dict[str, str]:
main_dict: UserConfig,
override_dict: UserConfig,
) -> UserConfig:
fused_dict = main_dict.copy()

for key, value in override_dict.items():
Expand All @@ -87,8 +88,8 @@ def _fuse_dicts(


def get_fused_config_from_dir(
project_dir: Path, override_config: Dict[str, str]
) -> Dict[str, str]:
project_dir: Path, override_config: UserConfig
) -> UserConfig:
"""Merge the overrides from a given dict with the config from a Flower App."""
default_config = get_project_config(project_dir)["tool"]["flwr"]["app"].get(
"config", {}
Expand All @@ -98,7 +99,7 @@ def get_fused_config_from_dir(
return _fuse_dicts(flat_default_config, override_config)


def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> Dict[str, str]:
def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig:
"""Merge the overrides from a `Run` with the config from a FAB.
Get the config using the fab_id and the fab_version, remove the nesting by adding
Expand All @@ -112,29 +113,30 @@ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> Dict[str, str]:
return get_fused_config_from_dir(project_dir, run.override_config)


def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, str]:
def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> UserConfig:
"""Flatten dict by joining nested keys with a given separator."""
items: List[Tuple[str, str]] = []
items: List[Tuple[str, UserConfigValue]] = []
separator: str = "."
for k, v in raw_dict.items():
new_key = f"{parent_key}{separator}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, parent_key=new_key).items())
elif isinstance(v, str):
items.append((new_key, v))
elif isinstance(v, get_args(UserConfigValue)):
items.append((new_key, cast(UserConfigValue, v)))
else:
raise ValueError(
f"The value for key {k} needs to be a `str` or a `dict`.",
f"The value for key {k} needs to be of type `int`, `float`, "
"`bool, `str`, or a `dict` of those.",
)
return dict(items)


def parse_config_args(
config: Optional[List[str]],
separator: str = ",",
) -> Dict[str, str]:
) -> UserConfig:
"""Parse separator separated list of key-value pairs separated by '='."""
overrides: Dict[str, str] = {}
overrides: UserConfig = {}

if config is None:
return overrides
Expand All @@ -150,8 +152,7 @@ def parse_config_args(
with Path(overrides_list[0]).open("rb") as config_file:
overrides = flatten_dict(tomli.load(config_file))
else:
for kv_pair in overrides_list:
key, value = kv_pair.split("=")
overrides[key] = value
toml_str = "\n".join(overrides_list)
overrides.update(tomli.loads(toml_str))

return overrides
41 changes: 24 additions & 17 deletions src/py/flwr/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import pytest

from flwr.common.typing import UserConfig

from .config import (
_fuse_dicts,
flatten_dict,
Expand Down Expand Up @@ -101,23 +103,25 @@ def test_get_fused_config_valid(tmp_path: Path) -> None:
clientapp = "fedgpt.client:app"
[tool.flwr.app.config]
num_server_rounds = "10"
momentum = "0.1"
lr = "0.01"
num_server_rounds = 10
momentum = 0.1
lr = 0.01
progress_bar = true
serverapp.test = "key"
[tool.flwr.app.config.clientapp]
test = "key"
"""
overrides = {
"num_server_rounds": "5",
"lr": "0.2",
overrides: UserConfig = {
"num_server_rounds": 5,
"lr": 0.2,
"serverapp.test": "overriden",
}
expected_config = {
"num_server_rounds": "5",
"momentum": "0.1",
"lr": "0.2",
"num_server_rounds": 5,
"momentum": 0.1,
"lr": 0.2,
"progress_bar": True,
"serverapp.test": "overriden",
"clientapp.test": "key",
}
Expand Down Expand Up @@ -168,8 +172,9 @@ def test_get_project_config_file_valid(tmp_path: Path) -> None:
clientapp = "fedgpt.client:app"
[tool.flwr.app.config]
num_server_rounds = "10"
momentum = "0.1"
num_server_rounds = 10
momentum = 0.1
progress_bar = true
lr = "0.01"
"""
expected_config = {
Expand All @@ -190,8 +195,9 @@ def test_get_project_config_file_valid(tmp_path: Path) -> None:
"clientapp": "fedgpt.client:app",
},
"config": {
"num_server_rounds": "10",
"momentum": "0.1",
"num_server_rounds": 10,
"momentum": 0.1,
"progress_bar": True,
"lr": "0.01",
},
},
Expand Down Expand Up @@ -231,11 +237,12 @@ def test_parse_config_args_none() -> None:
def test_parse_config_args_overrides() -> None:
"""Test parse_config_args with key-value pairs."""
assert parse_config_args(
["key1=value1,key2=value2", "key3=value3", "key4=value4,key5=value5"]
["key1='value1',key2='value2'", "key3=1", "key4=2.0,key5=true,key6='value6'"]
) == {
"key1": "value1",
"key2": "value2",
"key3": "value3",
"key4": "value4",
"key5": "value5",
"key3": 1,
"key4": 2.0,
"key5": True,
"key6": "value6",
}
Loading

0 comments on commit d24ebe5

Please sign in to comment.