From 3be55b1795cafd0e11ce7904855c5bb2926297ec Mon Sep 17 00:00:00 2001 From: Yan Cheng <58191769+yanchengnv@users.noreply.github.com> Date: Tue, 1 Oct 2024 20:06:15 -0400 Subject: [PATCH 1/2] [Main] Support object reuse (#2975) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * support object reuse * fix formatting --------- Co-authored-by: Sean Yang Co-authored-by: Yuan-Ting Hsieh (謝沅廷) --- nvflare/app_common/ccwf/ccwf_job.py | 21 +-------- nvflare/job_config/api.py | 72 ++++++++++++++++++++--------- 2 files changed, 52 insertions(+), 41 deletions(-) diff --git a/nvflare/app_common/ccwf/ccwf_job.py b/nvflare/app_common/ccwf/ccwf_job.py index 13631ab1d3..0f66af4567 100644 --- a/nvflare/app_common/ccwf/ccwf_job.py +++ b/nvflare/app_common/ccwf/ccwf_job.py @@ -20,8 +20,7 @@ from nvflare.app_common.abstract.shareable_generator import ShareableGenerator from nvflare.app_common.app_constant import AppConstants from nvflare.app_common.ccwf.common import Constant, CyclicOrder -from nvflare.fuel.utils.validation_utils import check_object_type -from nvflare.job_config.api import FedJob, has_add_to_job_method +from nvflare.job_config.api import FedJob, validate_object_for_job from nvflare.widgets.widget import Widget from .cse_client_ctl import CrossSiteEvalClientController @@ -318,21 +317,3 @@ def add_cross_site_eval( get_model_timeout=cse_config.get_model_timeout, ) self.to_clients(client_controller, tasks=["cse_*"]) - - -def validate_object_for_job(name, obj, obj_type): - """Check whether the specified object is valid for job. - The object must either have the add_to_fed_job method or is valid object type. - - Args: - name: name of the object - obj: the object to be checked - obj_type: the object type that the object should be, if it doesn't have the add_to_fed_job method. - - Returns: None - - """ - if has_add_to_job_method(obj): - return - - check_object_type(name, obj, obj_type) diff --git a/nvflare/job_config/api.py b/nvflare/job_config/api.py index 3fd0c888bd..495b808f94 100644 --- a/nvflare/job_config/api.py +++ b/nvflare/job_config/api.py @@ -14,14 +14,14 @@ import os.path import re import uuid -from typing import Any, List, Optional +from typing import Any, List, Optional, Union from nvflare.apis.executor import Executor from nvflare.apis.filter import Filter from nvflare.apis.impl.controller import Controller from nvflare.apis.job_def import ALL_SITES, SERVER_SITE_NAME from nvflare.fuel.utils.class_utils import get_component_init_parameters -from nvflare.fuel.utils.validation_utils import check_positive_int +from nvflare.fuel.utils.validation_utils import check_object_type, check_positive_int from nvflare.job_config.fed_app_config import ClientAppConfig, FedAppConfig, ServerAppConfig from nvflare.job_config.fed_job_config import FedJobConfig @@ -33,25 +33,39 @@ class FedApp: - def __init__(self): + def __init__(self, app_config: Union[ClientAppConfig, ServerAppConfig]): """FedApp handles `ClientAppConfig` and `ServerAppConfig` and allows setting task result or task data filters.""" - self.app = None # Union[ClientAppConfig, ServerAppConfig] + self.app_config = app_config self._used_ids = [] + # obj_id => comp_id + # obj_id is the Python's object ID; comp_id is the component ID for job config + # _oid_to_cid keeps the mapping between obj_id and comp_id. + # this is to make sure that when the same object is used, it is configured only once in the job. + self._oid_to_cid = {} + def get_app_config(self): - return self.app + return self.app_config def add_task_result_filter(self, tasks: List[str], task_filter: Filter): - self.app.add_task_result_filter(tasks, task_filter) + self.app_config.add_task_result_filter(tasks, task_filter) def add_task_data_filter(self, tasks: List[str], task_filter: Filter): - self.app.add_task_data_filter(tasks, task_filter) - - def add_component(self, component, id=None): - if id is None: - id = "component" - final_id = self.generate_tracked_id(id) - self.app.add_component(final_id, component) + self.app_config.add_task_data_filter(tasks, task_filter) + + def add_component(self, component, comp_id=None): + # is the component already configured? + oid = id(component) + cid = self._oid_to_cid.get(oid) + if cid: + # the component is already configured + return cid + + if comp_id is None: + comp_id = "component" + final_id = self.generate_tracked_id(comp_id) + self.app_config.add_component(final_id, component) + self._oid_to_cid[oid] = final_id return final_id def _generate_id(self, id: str = "") -> str: @@ -79,7 +93,7 @@ def add_external_script(self, ext_script: str): Args: ext_script: List of external scripts that need to be deployed to the client/server. """ - self.app.add_ext_script(ext_script) + self.app_config.add_ext_script(ext_script) def add_external_dir(self, ext_dir: str): """Register external folder to include them in custom directory. @@ -87,7 +101,7 @@ def add_external_dir(self, ext_dir: str): Args: ext_dir: external folder that need to be deployed to the client/server. """ - self.app.add_ext_dir(ext_dir) + self.app_config.add_ext_dir(ext_dir) def _add_resource(self, resource: str): if not isinstance(resource, str): @@ -122,26 +136,24 @@ def __init__(self, obj: Any, target: str, comp_id: str): class ClientApp(FedApp): def __init__(self): """Wrapper around `ClientAppConfig`.""" - super().__init__() - self.app = ClientAppConfig() + super().__init__(ClientAppConfig()) def add_executor(self, executor: Executor, tasks=None): if not tasks: tasks = ["*"] # Add executor for any task by default - self.app.add_executor(tasks, executor) + self.app_config.add_executor(tasks, executor) class ServerApp(FedApp): """Wrapper around `ServerAppConfig`.""" def __init__(self): - super().__init__() - self.app: ServerAppConfig = ServerAppConfig() + super().__init__(ServerAppConfig()) def add_controller(self, controller: Controller, id=None): if not id: id = "controller" - self.app.add_workflow(self.generate_tracked_id(id), controller) + self.app_config.add_workflow(self.generate_tracked_id(id), controller) class FedJob: @@ -571,3 +583,21 @@ def check_kwargs(args_to_check: dict, args_expected: dict): def has_add_to_job_method(obj: Any) -> bool: add_to_job_method = getattr(obj, _ADD_TO_JOB_METHOD_NAME, None) return add_to_job_method is not None and callable(add_to_job_method) + + +def validate_object_for_job(name, obj, obj_type): + """Check whether the specified object is valid for job. + The object must either have the add_to_fed_job method or is valid object type. + + Args: + name: name of the object + obj: the object to be checked + obj_type: the object type that the object should be, if it doesn't have the add_to_fed_job method. + + Returns: None + + """ + if has_add_to_job_method(obj): + return + + check_object_type(name, obj, obj_type) From 9bfd99a20c0431a3c4e23c05ca813c9fe50edcef Mon Sep 17 00:00:00 2001 From: Yan Cheng <58191769+yanchengnv@users.noreply.github.com> Date: Wed, 2 Oct 2024 16:17:43 -0400 Subject: [PATCH 2/2] Allow multiple workflows in CCWF (#2980) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * support object reuse * fix formatting * allow multiple workflows in ccwf * allow multiple workflows in ccwf --------- Co-authored-by: Sean Yang Co-authored-by: Yuan-Ting Hsieh (謝沅廷) --- nvflare/app_common/ccwf/ccwf_job.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/nvflare/app_common/ccwf/ccwf_job.py b/nvflare/app_common/ccwf/ccwf_job.py index 0f66af4567..36ef6a1300 100644 --- a/nvflare/app_common/ccwf/ccwf_job.py +++ b/nvflare/app_common/ccwf/ccwf_job.py @@ -30,6 +30,8 @@ from .swarm_client_ctl import SwarmClientController from .swarm_server_ctl import SwarmServerController +_EXECUTOR_TASKS = ["train", "validate", "submit_model"] + class SwarmServerConfig: def __init__( @@ -189,6 +191,7 @@ def __init__( name: str = "fed_job", min_clients: int = 1, mandatory_clients: Optional[List[str]] = None, + executor_tasks: Optional[List[str]] = None, external_resources: Optional[str] = None, ): """Client-Controlled Workflow Job. @@ -199,9 +202,19 @@ def __init__( name (name, optional): name of the job. Defaults to "fed_job" min_clients (int, optional): the minimum number of clients for the job. Defaults to 1. mandatory_clients (List[str], optional): mandatory clients to run the job. Default None. + executor_tasks (List[str], optional): tasks for the executor external_resources (str, optional): External resources directory or filename. Defaults to None. """ super().__init__(name, min_clients, mandatory_clients) + + # A CCWF job can have multiple workflows (swarm, cyclic, etc.), but can only have one executor for training! + # This executor can be added by any workflow. + self.executor = None + + self.executor_tasks = executor_tasks + if not executor_tasks: + self.executor_tasks = _EXECUTOR_TASKS + if external_resources: self.to_server(external_resources) self.to_clients(external_resources) @@ -249,7 +262,10 @@ def add_swarm( wait_time_after_min_resps_received=client_config.wait_time_after_min_resps_received, ) self.to_clients(client_controller, tasks=["swarm_*"]) - self.to_clients(client_config.executor, tasks=["train", "validate", "submit_model"]) + if not self.executor: + # We add the executor only if it's not added yet. + self.to_clients(client_config.executor, tasks=self.executor_tasks) + self.executor = client_config.executor if client_config.model_selector: self.to_clients(client_config.model_selector, id="model_selector") @@ -287,7 +303,11 @@ def add_cyclic( final_result_ack_timeout=client_config.final_result_ack_timeout, ) self.to_clients(client_controller, tasks=["cyclic_*"]) - self.to_clients(client_config.executor, tasks=["train", "validate", "submit_model"]) + + if not self.executor: + # We add the executor only if it's not added yet. + self.to_clients(client_config.executor, tasks=self.executor_tasks) + self.executor = client_config.executor if cse_config: self.add_cross_site_eval(cse_config, persistor_id)