Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workflow communication APIs and Simplified ML Algorithms #2250

Closed
wants to merge 42 commits into from
Closed
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
0029993
controller workflow APIs and simplified FedAvg and Fed Kaplan-Meier e…
chesterxgchen Dec 29, 2023
fce9c61
update
chesterxgchen Dec 29, 2023
e47ff4a
update
chesterxgchen Dec 29, 2023
660bdcf
update
chesterxgchen Dec 29, 2023
8733f60
Add Fed Cyclic example
chesterxgchen Dec 29, 2023
f1ad53d
Add Fed Cyclic example
chesterxgchen Dec 29, 2023
b1fa51b
addres PR comments
chesterxgchen Dec 29, 2023
d9ad8b2
1. Remove Base Class ErrorHandleController, instead move the function…
chesterxgchen Dec 31, 2023
4322602
add header
chesterxgchen Dec 31, 2023
3cab712
add header
chesterxgchen Dec 31, 2023
6221dff
code style format
chesterxgchen Dec 31, 2023
f535c62
make better user experience
chesterxgchen Dec 31, 2023
c6e19c7
code format and import
chesterxgchen Dec 31, 2023
a3fb099
remove comment
chesterxgchen Dec 31, 2023
6b57dc9
remove used method
chesterxgchen Dec 31, 2023
eb2272f
1. add intime aggregate version of fedavg
chesterxgchen Jan 2, 2024
373ce19
update README.md
chesterxgchen Jan 2, 2024
6ac0888
add ask all clients to end run when server in exception
chesterxgchen Jan 3, 2024
c5994c3
rebase and remove extra command
chesterxgchen Jan 5, 2024
4b84310
wip
chesterxgchen Jan 12, 2024
4f3abe5
remove WF dependency
chesterxgchen Jan 13, 2024
e635ea4
1. remove ctrl_msg_Queue, use controller directly.
chesterxgchen Jan 13, 2024
c9ee619
update README.md and cleanup
chesterxgchen Jan 13, 2024
76c3c43
change comm_msg_pull_interval to result_pull_interval
chesterxgchen Jan 13, 2024
a52d27a
1. fix message_bus
chesterxgchen Jan 13, 2024
8343392
design change, broken commit
chesterxgchen Jan 18, 2024
d914592
everything works now
chesterxgchen Jan 20, 2024
30552e4
everything works now
chesterxgchen Jan 20, 2024
c297073
merge with new data bus changes. The code is broken now.
chesterxgchen Jan 28, 2024
b4da21f
fix the lock issue.
chesterxgchen Jan 28, 2024
ae657f6
define strategy.py in case it is needed.
chesterxgchen Jan 28, 2024
f0ae6e3
define strategy.py in case it is needed.
chesterxgchen Jan 28, 2024
207c13a
make sure the publish in parallel instead of sequential
chesterxgchen Jan 29, 2024
a274a21
ADD CODE TO ADDRESS THE NEW DESIGN CHANGES.
chesterxgchen Jan 30, 2024
7c4f62e
format code
chesterxgchen Jan 30, 2024
5fdd1c0
fix the issue with return result
chesterxgchen Jan 30, 2024
5d09732
cleanup, fix original controller parsing
SYangster Feb 6, 2024
8f35c4a
Merge branch 'main' into wl_controller
SYangster Feb 13, 2024
5b8420c
fix format
SYangster Feb 14, 2024
a5f63ad
databus updates
SYangster Feb 14, 2024
da30df2
add docstrings, address comments
SYangster Feb 22, 2024
d37cd23
fix communicator pairing, remove temp example
SYangster Feb 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions examples/hello-world/hello-fedavg/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# FedAvg: simplified

This example illustrates How to use the new Workflow Communication API to contract a workflow: no need to write a controller.
SYangster marked this conversation as resolved.
Show resolved Hide resolved

## FLARE Workflow Communicator API

The Flare workflow Communicator API only has small set methods

```

class WFCommAPISpec(ABC):
SYangster marked this conversation as resolved.
Show resolved Hide resolved
@abstractmethod
def broadcast_and_wait(self, msg_payload: Dict):
pass

@abstractmethod
def send_and_wait(self, msg_payload: Dict):
pass

@abstractmethod
def relay_and_wait(self, msg_payload: Dict):
pass

@abstractmethod
def broadcast(self, msg_payload: Dict):
pass

@abstractmethod
def send(self, msg_payload: Dict):
pass

@abstractmethod
def relay(self, msg_payload: Dict):
pass

@abstractmethod
def get_site_names(self) -> List[str]:
pass

@abstractmethod
def wait_all(self, min_responses: int, resp_max_wait_time: Optional[float]) -> Dict[str, Dict[str, FLModel]]:
pass

@abstractmethod
def wait_one(self, resp_max_wait_time: Optional[float] = None) -> Tuple[str, str, FLModel]:
pass

```


## Writing a new Workflow

With this new API writing the new workflow is really simple:

* Workflow (Server)

```
from nvflare.app_common.workflows import wf_comm as flare

class FedAvg:
SYangster marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
min_clients: int,
num_rounds: int,
output_path: str,
start_round: int = 1,
stop_cond: str = None,
model_selection_rule: str = None,
):
super(FedAvg, self).__init__()

<skip init code>

self.flare_comm = flare.get_wf_comm_api()

def run(self):
self.logger.info("start Fed Avg Workflow\n \n")

start = self.start_round
end = self.start_round + self.num_rounds

model = self.init_model()
for current_round in range(start, end):

self.logger.info(f"Round {current_round}/{self.num_rounds} started. {start=}, {end=}")
self.current_round = current_round

sag_results = self.scatter_and_gather(model, current_round)

aggr_result = self.aggr_fn(sag_results)

self.logger.info(f"aggregate metrics = {aggr_result.metrics}")

model = update_model(model, aggr_result)

self.select_best_model(model)

self.save_model(self.best_model, self.output_path)

self.logger.info("end Fed Avg Workflow\n \n")


```
Scatter and Gather (SAG):

SAG is simply ask WFController to broadcast the model to all clients

```
def scatter_and_gather(self, model: FLModel, current_round):
msg_payload = {"min_responses": self.min_clients,
"current_round": current_round,
"num_round": self.num_rounds,
"start_round": self.start_round,
"data": model}

# (2) broadcast and wait
results = self.flare_comm.broadcast_and_wait(msg_payload)
return results
```

## Configurations

### client-side configuration

This is the same as FLARE Client API configuration

### server-side configuration

Server side controller is really simple, all we need is to use WFController with newly defined workflow class


```
{
# version of the configuration
format_version = 2
task_data_filters =[]
task_result_filters = []

workflows = [
{
id = "fed_avg"
path = "nvflare.app_opt.pt.wf_controller.PTWFController"
args {
comm_msg_pull_interval = 5
task_name = "train"
wf_class_path = "fedavg_pt.PTFedAvg",
wf_args {
min_clients = 2
num_rounds = 10
output_path = "/tmp/nvflare/fedavg/mode.pth"
stop_cond = "accuracy >= 55"
model_selection_rule = "accuracy >="
}
}
}
]

components = []

}

```


## Run the job

assume current working directory is at ```hello-fedavg``` directory

```
nvflare simulator -n 2 -t 2 jobs/fedavg -w /tmp/fedavg

```
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
{
format_version = 2
app_script = "train.py"
app_config = ""
executors = [
{
tasks = [
"train"
]
executor {
path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor"
args {
launcher_id = "launcher"
pipe_id = "pipe"
heartbeat_timeout = 60
params_exchange_format = "pytorch"
params_transfer_type = "DIFF"
train_with_evaluation = true
}
}
}
]
task_data_filters = []
task_result_filters = []
components = [
{
id = "launcher"
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher"
args {
script = "python3 custom/{app_script} {app_config} "
launch_once = true
}
}
{
id = "pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
}
{
id = "metrics_pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
}
{
id = "metric_relay"
path = "nvflare.app_common.widgets.metric_relay.MetricRelay"
args {
pipe_id = "metrics_pipe"
event_type = "fed.analytix_log_stats"
read_interval = 0.1
}
}
{
id = "config_preparer"
path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator"
args {
component_ids = [
"metric_relay"
]
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
# version of the configuration
format_version = 2
task_data_filters =[]
task_result_filters = []

workflows = [
{
id = "fed_avg"
path = "nvflare.app_common.workflows.fed_avg_pt.PTFedAvg"
args {
min_clients = 2
num_rounds = 2
output_path = "/tmp/nvflare/fedavg/mode.pth"
# stop_cond = "accuracy >= 55"
}
}
]

components = [
]
}
37 changes: 37 additions & 0 deletions examples/hello-world/hello-fedavg/jobs/fedavg/app/custom/net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
SYangster marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
Loading
Loading