Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2.5] Job api update #2991

Merged
merged 2 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions nvflare/app_common/ccwf/ccwf_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from nvflare.app_common.abstract.shareable_generator import ShareableGenerator
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.ccwf.common import Constant, CyclicOrder
from nvflare.fuel.utils.validation_utils import check_object_type
from nvflare.job_config.api import FedJob, has_add_to_job_method
from nvflare.job_config.api import FedJob, validate_object_for_job
from nvflare.widgets.widget import Widget

from .cse_client_ctl import CrossSiteEvalClientController
Expand All @@ -31,6 +30,8 @@
from .swarm_client_ctl import SwarmClientController
from .swarm_server_ctl import SwarmServerController

_EXECUTOR_TASKS = ["train", "validate", "submit_model"]


class SwarmServerConfig:
def __init__(
Expand Down Expand Up @@ -190,6 +191,7 @@ def __init__(
name: str = "fed_job",
min_clients: int = 1,
mandatory_clients: Optional[List[str]] = None,
executor_tasks: Optional[List[str]] = None,
external_resources: Optional[str] = None,
):
"""Client-Controlled Workflow Job.
Expand All @@ -200,9 +202,19 @@ def __init__(
name (name, optional): name of the job. Defaults to "fed_job"
min_clients (int, optional): the minimum number of clients for the job. Defaults to 1.
mandatory_clients (List[str], optional): mandatory clients to run the job. Default None.
executor_tasks (List[str], optional): tasks for the executor
external_resources (str, optional): External resources directory or filename. Defaults to None.
"""
super().__init__(name, min_clients, mandatory_clients)

# A CCWF job can have multiple workflows (swarm, cyclic, etc.), but can only have one executor for training!
# This executor can be added by any workflow.
self.executor = None

self.executor_tasks = executor_tasks
if not executor_tasks:
self.executor_tasks = _EXECUTOR_TASKS

if external_resources:
self.to_server(external_resources)
self.to_clients(external_resources)
Expand Down Expand Up @@ -250,7 +262,10 @@ def add_swarm(
wait_time_after_min_resps_received=client_config.wait_time_after_min_resps_received,
)
self.to_clients(client_controller, tasks=["swarm_*"])
self.to_clients(client_config.executor, tasks=["train", "validate", "submit_model"])
if not self.executor:
# We add the executor only if it's not added yet.
self.to_clients(client_config.executor, tasks=self.executor_tasks)
self.executor = client_config.executor

if client_config.model_selector:
self.to_clients(client_config.model_selector, id="model_selector")
Expand Down Expand Up @@ -288,7 +303,11 @@ def add_cyclic(
final_result_ack_timeout=client_config.final_result_ack_timeout,
)
self.to_clients(client_controller, tasks=["cyclic_*"])
self.to_clients(client_config.executor, tasks=["train", "validate", "submit_model"])

if not self.executor:
# We add the executor only if it's not added yet.
self.to_clients(client_config.executor, tasks=self.executor_tasks)
self.executor = client_config.executor

if cse_config:
self.add_cross_site_eval(cse_config, persistor_id)
Expand Down Expand Up @@ -318,21 +337,3 @@ def add_cross_site_eval(
get_model_timeout=cse_config.get_model_timeout,
)
self.to_clients(client_controller, tasks=["cse_*"])


def validate_object_for_job(name, obj, obj_type):
"""Check whether the specified object is valid for job.
The object must either have the add_to_fed_job method or is valid object type.
Args:
name: name of the object
obj: the object to be checked
obj_type: the object type that the object should be, if it doesn't have the add_to_fed_job method.
Returns: None
"""
if has_add_to_job_method(obj):
return

check_object_type(name, obj, obj_type)
72 changes: 51 additions & 21 deletions nvflare/job_config/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
import os.path
import re
import uuid
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from nvflare.apis.executor import Executor
from nvflare.apis.filter import Filter
from nvflare.apis.impl.controller import Controller
from nvflare.apis.job_def import ALL_SITES, SERVER_SITE_NAME
from nvflare.fuel.utils.class_utils import get_component_init_parameters
from nvflare.fuel.utils.validation_utils import check_positive_int
from nvflare.fuel.utils.validation_utils import check_object_type, check_positive_int
from nvflare.job_config.fed_app_config import ClientAppConfig, FedAppConfig, ServerAppConfig
from nvflare.job_config.fed_job_config import FedJobConfig

Expand All @@ -33,25 +33,39 @@


class FedApp:
def __init__(self):
def __init__(self, app_config: Union[ClientAppConfig, ServerAppConfig]):
"""FedApp handles `ClientAppConfig` and `ServerAppConfig` and allows setting task result or task data filters."""
self.app = None # Union[ClientAppConfig, ServerAppConfig]
self.app_config = app_config
self._used_ids = []

# obj_id => comp_id
# obj_id is the Python's object ID; comp_id is the component ID for job config
# _oid_to_cid keeps the mapping between obj_id and comp_id.
# this is to make sure that when the same object is used, it is configured only once in the job.
self._oid_to_cid = {}

def get_app_config(self):
return self.app
return self.app_config

def add_task_result_filter(self, tasks: List[str], task_filter: Filter):
self.app.add_task_result_filter(tasks, task_filter)
self.app_config.add_task_result_filter(tasks, task_filter)

def add_task_data_filter(self, tasks: List[str], task_filter: Filter):
self.app.add_task_data_filter(tasks, task_filter)

def add_component(self, component, id=None):
if id is None:
id = "component"
final_id = self.generate_tracked_id(id)
self.app.add_component(final_id, component)
self.app_config.add_task_data_filter(tasks, task_filter)

def add_component(self, component, comp_id=None):
# is the component already configured?
oid = id(component)
cid = self._oid_to_cid.get(oid)
if cid:
# the component is already configured
return cid

if comp_id is None:
comp_id = "component"
final_id = self.generate_tracked_id(comp_id)
self.app_config.add_component(final_id, component)
self._oid_to_cid[oid] = final_id
return final_id

def _generate_id(self, id: str = "") -> str:
Expand Down Expand Up @@ -79,15 +93,15 @@ def add_external_script(self, ext_script: str):
Args:
ext_script: List of external scripts that need to be deployed to the client/server.
"""
self.app.add_ext_script(ext_script)
self.app_config.add_ext_script(ext_script)

def add_external_dir(self, ext_dir: str):
"""Register external folder to include them in custom directory.
Args:
ext_dir: external folder that need to be deployed to the client/server.
"""
self.app.add_ext_dir(ext_dir)
self.app_config.add_ext_dir(ext_dir)

def _add_resource(self, resource: str):
if not isinstance(resource, str):
Expand Down Expand Up @@ -122,26 +136,24 @@ def __init__(self, obj: Any, target: str, comp_id: str):
class ClientApp(FedApp):
def __init__(self):
"""Wrapper around `ClientAppConfig`."""
super().__init__()
self.app = ClientAppConfig()
super().__init__(ClientAppConfig())

def add_executor(self, executor: Executor, tasks=None):
if not tasks:
tasks = ["*"] # Add executor for any task by default
self.app.add_executor(tasks, executor)
self.app_config.add_executor(tasks, executor)


class ServerApp(FedApp):
"""Wrapper around `ServerAppConfig`."""

def __init__(self):
super().__init__()
self.app: ServerAppConfig = ServerAppConfig()
super().__init__(ServerAppConfig())

def add_controller(self, controller: Controller, id=None):
if not id:
id = "controller"
self.app.add_workflow(self.generate_tracked_id(id), controller)
self.app_config.add_workflow(self.generate_tracked_id(id), controller)


class FedJob:
Expand Down Expand Up @@ -571,3 +583,21 @@ def check_kwargs(args_to_check: dict, args_expected: dict):
def has_add_to_job_method(obj: Any) -> bool:
add_to_job_method = getattr(obj, _ADD_TO_JOB_METHOD_NAME, None)
return add_to_job_method is not None and callable(add_to_job_method)


def validate_object_for_job(name, obj, obj_type):
"""Check whether the specified object is valid for job.
The object must either have the add_to_fed_job method or is valid object type.
Args:
name: name of the object
obj: the object to be checked
obj_type: the object type that the object should be, if it doesn't have the add_to_fed_job method.
Returns: None
"""
if has_add_to_job_method(obj):
return

check_object_type(name, obj, obj_type)
Loading