Skip to content

Commit

Permalink
Add SubprocessLauncher + FilePipe model exchange example
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh committed Aug 16, 2023
1 parent 191ac1e commit 50e0797
Show file tree
Hide file tree
Showing 83 changed files with 4,257 additions and 42 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"format_version": 2,

"executors": [
{
"tasks": ["train"],
"executor": {
"path": "custom_executor.CustomExecutor",
"args": {
"data_exchanger_id": "data_exchanger",
"data_exchange_path": "/tmp/nvflare/av_cn"
}
}
}
],
"task_result_filters": [
],
"task_data_filters": [
],
"components": [
{
"id": "data_exchanger",
"path": "nvflare.app_opt.h5.data_exchanger.H5DataExchanger",
"args": {
"pipe_role": "x",
"heartbeat_timeout": 0
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.executors.launcher_executor import (
LauncherExecutor,
model_data_to_shareable,
shareable_to_model_data,
)


class CustomExecutor(LauncherExecutor):
def __init__(
self,
data_exchanger_id: str,
data_exchange_path: str,
file_accessor_id: Optional[str] = None,
get_timeout: Optional[float] = 10000,
):
"""CustomExecutor for autonomous vehicle training."""
super().__init__(
data_exchanger_id=data_exchanger_id,
file_accessor_id=file_accessor_id,
data_exchange_path=data_exchange_path,
)

self._timeout = get_timeout
self._rounds = 0

def _prepare_for_launch(self, shareable: Shareable):
# dump weights for outer script to read
model_data = shareable_to_model_data(shareable=shareable)
self.data_exchanger.put(self._from_nvflare, data=model_data)
self.data_exchanger.put("round_starts", {"round": self._rounds})

def _get_result(self, task_name: str, fl_ctx: FLContext) -> Shareable:
try:
output_model_data = self.data_exchanger.get(self._to_nvflare, self._timeout)
round_ends = self.data_exchanger.get("round_ends", self._timeout)
external_round = round_ends["round"]
if external_round != self._rounds:
raise RuntimeError("rounds mismatch.")
self._rounds += 1
result = model_data_to_shareable(output_model_data)
return result
except Exception as e:
err_msg = f"External training is not finished within timeout ({self._timeout}) seconds: {e}."
self.log_exception(fl_ctx, err_msg)
self.system_panic(err_msg, fl_ctx)
return None
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
{
"format_version": 2,

"server": {
"heart_beat_timeout": 600
},
"task_data_filters": [],
"task_result_filters": [],
"components": [
{
"id": "persistor",
"name": "PTFileModelPersistor",
"args": {
"model": {
"path": "simple_network.SimpleNetwork"
}
}
},
{
"id": "shareable_generator",
"path": "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator",
"args": {}
},
{
"id": "aggregator",
"path": "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator",
"args": {
"expected_data_kind": "WEIGHTS"
}
},
{
"id": "tb_analytics_receiver",
"name": "TBAnalyticsReceiver",
"args": {"events": ["fed.analytix_log_stats"]}
}
],
"workflows": [
{
"id": "scatter_and_gather",
"name": "ScatterAndGather",
"args": {
"min_clients" : 2,
"num_rounds" : 10,
"start_round": 0,
"wait_time_after_min_received": 0,
"aggregator_id": "aggregator",
"persistor_id": "persistor",
"shareable_generator_id": "shareable_generator",
"train_task_name": "train",
"train_timeout": 0
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F


class SimpleNetwork(nn.Module):
def __init__(self):
super(SimpleNetwork, self).__init__()

self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"format_version": 2,

"executors": [
{
"tasks": ["train"],
"executor": {
"path": "custom_executor.CustomExecutor",
"args": {
"data_exchanger_id": "data_exchanger",
"data_exchange_path": "/tmp/nvflare/av_us"
}
}
}
],
"task_result_filters": [
],
"task_data_filters": [
],
"components": [
{
"id": "data_exchanger",
"path": "nvflare.app_opt.h5.data_exchanger.H5DataExchanger",
"args": {
"pipe_role": "x",
"heartbeat_timeout": 0
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.executors.launcher_executor import (
LauncherExecutor,
model_data_to_shareable,
shareable_to_model_data,
)


class CustomExecutor(LauncherExecutor):
def __init__(
self,
data_exchanger_id: str,
data_exchange_path: str,
file_accessor_id: Optional[str] = None,
get_timeout: Optional[float] = 10000,
):
"""CustomExecutor for autonomous vehicle training."""
super().__init__(
data_exchanger_id=data_exchanger_id,
file_accessor_id=file_accessor_id,
data_exchange_path=data_exchange_path,
)

self._timeout = get_timeout
self._rounds = 0

def _prepare_for_launch(self, shareable: Shareable):
# dump weights for outer script to read
model_data = shareable_to_model_data(shareable=shareable)
self.data_exchanger.put(self._from_nvflare, data=model_data)
self.data_exchanger.put("round_starts", {"round": self._rounds})

def _get_result(self, task_name: str, fl_ctx: FLContext) -> Shareable:
try:
output_model_data = self.data_exchanger.get(self._to_nvflare, self._timeout)
round_ends = self.data_exchanger.get("round_ends", self._timeout)
external_round = round_ends["round"]
if external_round != self._rounds:
raise RuntimeError("rounds mismatch.")
self._rounds += 1
result = model_data_to_shareable(output_model_data)
return result
except Exception as e:
err_msg = f"External training is not finished within timeout ({self._timeout}) seconds: {e}."
self.log_exception(fl_ctx, err_msg)
self.system_panic(err_msg, fl_ctx)
return None
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"name": "av",
"resource_spec": {},
"min_clients" : 2,
"deploy_map": {
"app_server": [
"server"
],
"app_us": [
"site-1"
],
"app_cn": [
"site-2"
]
}
}
Loading

0 comments on commit 50e0797

Please sign in to comment.