Skip to content

Commit

Permalink
Merge branch 'main' into HE_KM
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 authored Jan 8, 2024
2 parents 484cf3d + fdcf578 commit c927579
Show file tree
Hide file tree
Showing 13 changed files with 40 additions and 26 deletions.
4 changes: 2 additions & 2 deletions examples/hello-world/step-by-step/cifar10/code/fl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def evaluate(input_weights):

# (4) receive FLModel from NVFlare
input_model = flare.receive()
client_id = flare.system_info().get("site_name", None)
client_id = flare.get_site_name()

# Based on different "task" we will do different things
# for "train" task (flare.is_train()) we use the received model to do training and/or evaluation
Expand All @@ -104,7 +104,7 @@ def evaluate(input_weights):
# for "submit_model" task (flare.is_submit_model()) we just need to send back the local model
# (5) performing train task on received model
if flare.is_train():
print(f"({client_id}) round={input_model.current_round}/{input_model.total_rounds-1}")
print(f"({client_id}) current_round={input_model.current_round}, total_rounds={input_model.total_rounds}")

# (5.1) loads model from NVFlare
net.load_state_dict(input_model.params)
Expand Down
5 changes: 0 additions & 5 deletions job_templates/sklearn_linear/config_fed_client.conf
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@
# the custom code need to send back both the trained parameters and the evaluation metric
# otherwise only trained parameters are expected
train_with_evaluation = true

# if launch_once is true, the executor will only call launcher.launch_task once
# for the whole job, if launch_once is false, the executor will call launcher.launch_task
# everytime it receives a task from server
launch_once = true
}
}
}
Expand Down
5 changes: 0 additions & 5 deletions job_templates/sklearn_svm/config_fed_client.conf
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@
# the custom code need to send back both the trained parameters and the evaluation metric
# otherwise only trained parameters are expected
train_with_evaluation = true

# if launch_once is true, the executor will only call launcher.launch_task once
# for the whole job, if launch_once is false, the executor will call launcher.launch_task
# everytime it receives a task from server
launch_once = true
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions job_templates/swarm_cse_pt/config_fed_client.conf
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
format_version = 2
# This is the application script which will be invoked. Client can replace this script with user's own training script.
app_script = "train.py"
# Additional arguments needed by the training code.
app_config = ""
# Client Computing Executors.
executors = [
{
Expand Down
14 changes: 9 additions & 5 deletions nvflare/app_common/abstract/params_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any
from typing import Any, List

from nvflare.apis.dxo import from_shareable
from nvflare.apis.filter import Filter
Expand All @@ -22,10 +22,14 @@


class ParamsConverter(Filter, ABC):
def process(self, shareable: Shareable, fl_ctx: FLContext) -> Shareable:
dxo = from_shareable(shareable)
dxo.data = self.convert(dxo.data, fl_ctx)
dxo.update_shareable(shareable)
def __init__(self, supported_tasks: List[str] = None):
self.supported_tasks = supported_tasks

def process(self, task_name: str, shareable: Shareable, fl_ctx: FLContext) -> Shareable:
if not self.supported_tasks or task_name in self.supported_tasks:
dxo = from_shareable(shareable)
dxo.data = self.convert(dxo.data, fl_ctx)
dxo.update_shareable(shareable)
return shareable

@abstractmethod
Expand Down
2 changes: 2 additions & 0 deletions nvflare/app_common/ccwf/cse_client_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def do_eval(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal)

model_to_validate = reply
model_to_validate.set_header(AppConstants.VALIDATE_TYPE, ValidateType.MODEL_VALIDATE)
model_to_validate.set_header(FLContextKey.TASK_NAME, self.validation_task_name)
if model_type == ModelType.LOCAL:
model_to_validate.set_header(AppConstants.MODEL_OWNER, model_owner)

Expand Down Expand Up @@ -218,6 +219,7 @@ def _do_process_get_model_request(self, request: Shareable, fl_ctx: FLContext) -
if not self.local_model:
task_data = Shareable()
task_data.set_header(AppConstants.SUBMIT_MODEL_NAME, model_name)
task_data.set_header(FLContextKey.TASK_NAME, self.submit_model_task_name)

abort_signal = Signal()
try:
Expand Down
4 changes: 3 additions & 1 deletion nvflare/app_common/ccwf/cyclic_client_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import random

from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_constant import FLContextKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
Expand Down Expand Up @@ -88,6 +88,8 @@ def do_learn_task(self, name: str, data: Shareable, fl_ctx: FLContext, abort_sig
global_weights = self.shareable_generator.shareable_to_learnable(data, fl_ctx)
fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, global_weights, private=True, sticky=True)

data.set_header(FLContextKey.TASK_NAME, name)

# execute the task
result = self.execute_learn_task(data, fl_ctx, abort_signal)

Expand Down
2 changes: 2 additions & 0 deletions nvflare/app_common/ccwf/swarm_client_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,8 @@ def do_learn_task(self, name: str, task_data: Shareable, fl_ctx: FLContext, abor

self.log_info(fl_ctx, f"Round {current_round} started.")

task_data.set_header(FLContextKey.TASK_NAME, name)

# Some shareable generators assume the base model (GLOBAL_MODEL) is always available, which is true for
# server-controlled fed-avg. But this is not true for swarm learning.
# To make these generators happy, we create an empty global model here if not present.
Expand Down
4 changes: 2 additions & 2 deletions nvflare/app_common/executors/launcher_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort
return make_reply(ReturnCode.EXECUTION_EXCEPTION)

if self._from_nvflare_converter is not None:
shareable = self._from_nvflare_converter.process(shareable, fl_ctx)
shareable = self._from_nvflare_converter.process(task_name, shareable, fl_ctx)

result = super().execute(task_name, shareable, fl_ctx, abort_signal)

Expand All @@ -161,7 +161,7 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort
return make_reply(ReturnCode.EXECUTION_EXCEPTION)

if self._to_nvflare_converter is not None:
result = self._to_nvflare_converter.process(result, fl_ctx)
result = self._to_nvflare_converter.process(task_name, result, fl_ctx)

self._finalize_external_execution(task_name, shareable, fl_ctx, abort_signal)

Expand Down
9 changes: 7 additions & 2 deletions nvflare/app_opt/pt/client_api_launcher_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from nvflare.apis.fl_context import FLContext
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.executors.client_api_launcher_executor import ClientAPILauncherExecutor
from nvflare.app_opt.pt.decomposers import TensorDecomposer
from nvflare.app_opt.pt.params_converter import NumpyToPTParamsConverter, PTToNumpyParamsConverter
Expand All @@ -26,6 +27,10 @@ def initialize(self, fl_ctx: FLContext) -> None:
self._params_exchange_format = ExchangeFormat.PYTORCH
super().initialize(fl_ctx)
if self._from_nvflare_converter is None:
self._from_nvflare_converter = NumpyToPTParamsConverter()
self._from_nvflare_converter = NumpyToPTParamsConverter(
[AppConstants.TASK_TRAIN, AppConstants.TASK_VALIDATION]
)
if self._to_nvflare_converter is None:
self._to_nvflare_converter = PTToNumpyParamsConverter()
self._to_nvflare_converter = PTToNumpyParamsConverter(
[AppConstants.TASK_TRAIN, AppConstants.TASK_SUBMIT_MODEL]
)
9 changes: 7 additions & 2 deletions nvflare/fuel/hci/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,12 @@ def execute(self, **kwargs):


class _Operate(State):
def __init__(self, api, sess_check_interval):
def __init__(self, api, sess_check_interval, auto_login_delay):
State.__init__(self, _STATE_NAME_OPERATE)
self.api = api
self.last_sess_check_time = None
self.sess_check_interval = sess_check_interval
self.auto_login_delay = auto_login_delay

def enter(self):
self.api.server_sess_active = True
Expand All @@ -273,6 +274,7 @@ def execute(self, **kwargs):

if new_host != cur_host or new_port != cur_port or cur_ssid != new_ssid:
# need to re-login
time.sleep(self.auto_login_delay)
api.fire_session_event(EventType.SP_ADDR_CHANGED, f"Server address changed to {new_host}:{new_port}")
return _STATE_NAME_LOGIN

Expand Down Expand Up @@ -310,6 +312,7 @@ def __init__(
debug: bool = False,
session_timeout_interval=None,
session_status_check_interval=None,
auto_login_delay: int = 5,
auto_login_max_tries: int = 5,
event_handlers=None,
):
Expand Down Expand Up @@ -418,11 +421,13 @@ def __init__(
# create the FSM for session monitoring
if auto_login_max_tries < 0 or auto_login_max_tries > MAX_AUTO_LOGIN_TRIES:
raise ValueError(f"auto_login_max_tries is out of range: [0, {MAX_AUTO_LOGIN_TRIES}]")
if auto_login_delay < 5.0:
raise ValueError(f"auto_login_delay must be more than 5.0. Got value: {auto_login_delay}]")
self.auto_login_max_tries = auto_login_max_tries
fsm = FSM("session monitor")
fsm.add_state(_WaitForServerAddress(self))
fsm.add_state(_TryLogin(self))
fsm.add_state(_Operate(self, session_status_check_interval))
fsm.add_state(_Operate(self, session_status_check_interval, auto_login_delay))
self.fsm = fsm

self.session_timeout_interval = session_timeout_interval
Expand Down
2 changes: 1 addition & 1 deletion nvflare/private/fed/client/client_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def _check_task_once(self, task_id: str, fl_ctx: FLContext) -> int:
self.log_error(fl_ctx, f"server rejected task_check: {rc}")
return _TASK_CHECK_RESULT_TRY_AGAIN
elif rc == ReturnCode.TASK_UNKNOWN:
self.log_error(fl_ctx, f"task no longer exists on server: {rc}")
self.log_debug(fl_ctx, f"task no longer exists on server: {rc}")
return _TASK_CHECK_RESULT_TASK_GONE
else:
# this should never happen
Expand Down
2 changes: 1 addition & 1 deletion nvflare/private/fed/server/training_cmds.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def restart(self, conn: Connection, args: List[str]):

# ask the admin client to shut down since its current session will become invalid after
# the server is restarted.
conn.append_shutdown("Goodbye!")
# conn.append_shutdown("Goodbye!")
elif target_type == self.TARGET_TYPE_CLIENT:
clients = conn.get_prop(self.TARGET_CLIENT_TOKENS)
if not clients:
Expand Down

0 comments on commit c927579

Please sign in to comment.