diff --git a/monai/bundle/properties.py b/monai/bundle/properties.py index 456d84a3b3..db79101290 100644 --- a/monai/bundle/properties.py +++ b/monai/bundle/properties.py @@ -159,6 +159,11 @@ BundleProperty.REQUIRED: True, BundlePropertyConfig.ID: "device", }, + "evaluator": { + BundleProperty.DESC: "inference / evaluation workflow engine.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: "evaluator", + }, "network_def": { BundleProperty.DESC: "network module for the inference.", BundleProperty.REQUIRED: True, diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 82bab73fe2..5a5380057e 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -44,15 +44,18 @@ class BundleWorkflow(ABC): """ + supported_train_type: tuple = ("train", "training") + supported_infer_type: tuple = ("infer", "inference", "eval", "evaluation") + def __init__(self, workflow: str | None = None): if workflow is None: self.properties = None self.workflow = None return - if workflow.lower() in ("train", "training"): + if workflow.lower() in self.supported_train_type: self.properties = TrainProperties self.workflow = "train" - elif workflow.lower() in ("infer", "inference", "eval", "evaluation"): + elif workflow.lower() in self.supported_infer_type: self.properties = InferProperties self.workflow = "infer" else: @@ -215,6 +218,7 @@ def __init__( else: settings_ = ConfigParser.load_config_files(tracking) self.patch_bundle_tracking(parser=self.parser, settings=settings_) + self._is_initialized: bool = False def initialize(self) -> Any: """ @@ -223,6 +227,7 @@ def initialize(self) -> Any: """ # reset the "reference_resolver" buffer at initialization stage self.parser.parse(reset=True) + self._is_initialized = True return self._run_expr(id=self.init_id) def run(self) -> Any: @@ -284,7 +289,7 @@ def _get_property(self, name: str, property: dict) -> Any: property: other information for the target property, defined in `TrainProperties` or `InferProperties`. """ - if not self.parser.ref_resolver.is_resolved(): + if not self._is_initialized: raise RuntimeError("Please execute 'initialize' before getting any parsed content.") prop_id = self._get_prop_id(name, property) return self.parser.get_parsed_content(id=prop_id) if prop_id is not None else None @@ -303,6 +308,7 @@ def _set_property(self, name: str, property: dict, value: Any) -> None: if prop_id is not None: self.parser[prop_id] = value # must parse the config again after changing the content + self._is_initialized = False self.parser.ref_resolver.reset() def _check_optional_id(self, name: str, property: dict) -> bool: diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py index 031143c69b..a846c99d01 100644 --- a/monai/fl/client/monai_algo.py +++ b/monai/fl/client/monai_algo.py @@ -11,7 +11,6 @@ from __future__ import annotations -import logging import os from collections.abc import Mapping, MutableMapping from typing import Any, cast @@ -19,23 +18,20 @@ import torch import torch.distributed as dist -import monai from monai.apps.auto3dseg.data_analyzer import DataAnalyzer from monai.apps.utils import get_logger from monai.auto3dseg import SegSummarizer -from monai.bundle import DEFAULT_EXP_MGMT_SETTINGS, ConfigComponent, ConfigItem, ConfigParser, ConfigWorkflow -from monai.engines import SupervisedTrainer, Trainer -from monai.fl.client import ClientAlgo, ClientAlgoStats -from monai.fl.utils.constants import ( - BundleKeys, - ExtraItems, - FiltersType, - FlPhase, - FlStatistics, - ModelType, - RequiredBundleKeys, - WeightType, +from monai.bundle import ( + DEFAULT_EXP_MGMT_SETTINGS, + BundleWorkflow, + ConfigComponent, + ConfigItem, + ConfigParser, + ConfigWorkflow, ) +from monai.engines import SupervisedEvaluator, SupervisedTrainer, Trainer +from monai.fl.client import ClientAlgo, ClientAlgoStats +from monai.fl.utils.constants import ExtraItems, FiltersType, FlPhase, FlStatistics, ModelType, WeightType from monai.fl.utils.exchange_object import ExchangeObject from monai.networks.utils import copy_model_state, get_state_dict from monai.utils import min_version, require_pkg @@ -70,25 +66,23 @@ def compute_weight_diff(global_weights, local_var_dict): raise ValueError("Cannot compute weight differences if `local_var_dict` is None!") # compute delta model, global model has the primary key set weight_diff = {} + n_diff = 0 for name in global_weights: if name not in local_var_dict: continue # returned weight diff will be on the cpu weight_diff[name] = local_var_dict[name].cpu() - global_weights[name].cpu() + n_diff += 1 if torch.any(torch.isnan(weight_diff[name])): raise ValueError(f"Weights for {name} became NaN...") + if n_diff == 0: + raise RuntimeError("No weight differences computed!") return weight_diff -def check_bundle_config(parser): - for k in RequiredBundleKeys: - if parser.get(k, None) is None: - raise KeyError(f"Bundle config misses required key `{k}`") - - -def disable_ckpt_loaders(parser): - if BundleKeys.VALIDATE_HANDLERS in parser: - for h in parser[BundleKeys.VALIDATE_HANDLERS]: +def disable_ckpt_loaders(parser: ConfigParser) -> None: + if "validate#handlers" in parser: + for h in parser["validate#handlers"]: if ConfigComponent.is_instantiable(h): if "CheckpointLoader" in h["_target_"]: h["_disabled_"] = True @@ -99,36 +93,43 @@ class MonaiAlgoStats(ClientAlgoStats): Implementation of ``ClientAlgoStats`` to allow federated learning with MONAI bundle configurations. Args: - bundle_root: path of bundle. + bundle_root: directory path of the bundle. + workflow: the bundle workflow to execute, usually it's training, evaluation or inference. + if None, will create an `ConfigWorkflow` based on `config_train_filename`. config_train_filename: bundle training config path relative to bundle_root. Can be a list of files; - defaults to "configs/train.json". + defaults to "configs/train.json". only necessary when `workflow` is None. config_filters_filename: filter configuration file. Can be a list of files; defaults to `None`. + data_stats_transform_list: transforms to apply for the data stats result. histogram_only: whether to only compute histograms. Defaults to False. """ def __init__( self, bundle_root: str, + workflow: BundleWorkflow | None = None, config_train_filename: str | list | None = "configs/train.json", config_filters_filename: str | list | None = None, - train_data_key: str | None = BundleKeys.TRAIN_DATA, - eval_data_key: str | None = BundleKeys.VALID_DATA, data_stats_transform_list: list | None = None, histogram_only: bool = False, ): self.logger = logger self.bundle_root = bundle_root + self.workflow = None + if workflow is not None: + if not isinstance(workflow, BundleWorkflow): + raise ValueError("workflow must be a subclass of BundleWorkflow.") + if workflow.get_workflow_type() is None: + raise ValueError("workflow doesn't specify the type.") + self.workflow = workflow self.config_train_filename = config_train_filename self.config_filters_filename = config_filters_filename - self.train_data_key = train_data_key - self.eval_data_key = eval_data_key + self.train_data_key = "train" + self.eval_data_key = "eval" self.data_stats_transform_list = data_stats_transform_list self.histogram_only = histogram_only self.client_name: str | None = None self.app_root: str = "" - self.train_parser: ConfigParser | None = None - self.filter_parser: ConfigParser | None = None self.post_statistics_filters: Any = None self.phase = FlPhase.IDLE self.dataset_root: Any = None @@ -149,35 +150,26 @@ def initialize(self, extra=None): # FL platform needs to provide filepath to configuration files self.app_root = extra.get(ExtraItems.APP_ROOT, "") - - # Read bundle config files self.bundle_root = os.path.join(self.app_root, self.bundle_root) - config_train_files = self._add_config_files(self.config_train_filename) - config_filter_files = self._add_config_files(self.config_filters_filename) + if self.workflow is None: + config_train_files = self._add_config_files(self.config_train_filename) + self.workflow = ConfigWorkflow( + config_file=config_train_files, meta_file=None, logging_file=None, workflow="train" + ) + self.workflow.initialize() + self.workflow.bundle_root = self.bundle_root + # initialize the workflow as the content changed + self.workflow.initialize() - # Parse - self.train_parser = ConfigParser() - self.filter_parser = ConfigParser() - if len(config_train_files) > 0: - self.train_parser.read_config(config_train_files) - check_bundle_config(self.train_parser) + config_filter_files = self._add_config_files(self.config_filters_filename) + filter_parser = ConfigParser() if len(config_filter_files) > 0: - self.filter_parser.read_config(config_filter_files) - - # override some config items - self.train_parser[RequiredBundleKeys.BUNDLE_ROOT] = self.bundle_root - - # Get data location - self.dataset_root = self.train_parser.get_parsed_content( - BundleKeys.DATASET_DIR, default=ConfigItem(None, BundleKeys.DATASET_DIR) - ) - - # Get filters - self.post_statistics_filters = self.filter_parser.get_parsed_content( - FiltersType.POST_STATISTICS_FILTERS, default=ConfigItem(None, FiltersType.POST_STATISTICS_FILTERS) - ) - + filter_parser.read_config(config_filter_files) + # Get filters + self.post_statistics_filters = filter_parser.get_parsed_content( + FiltersType.POST_STATISTICS_FILTERS, default=ConfigItem(None, FiltersType.POST_STATISTICS_FILTERS) + ) self.logger.info(f"Initialized {self.client_name}.") def get_data_stats(self, extra: dict | None = None) -> ExchangeObject: @@ -195,9 +187,9 @@ def get_data_stats(self, extra: dict | None = None) -> ExchangeObject: if extra is None: raise ValueError("`extra` has to be set") - if self.dataset_root: + if self.workflow.dataset_dir: # type: ignore self.phase = FlPhase.GET_DATA_STATS - self.logger.info(f"Computing statistics on {self.dataset_root}") + self.logger.info(f"Computing statistics on {self.workflow.dataset_dir}") # type: ignore if FlStatistics.HIST_BINS not in extra: raise ValueError("FlStatistics.NUM_OF_BINS not specified in `extra`") @@ -212,7 +204,7 @@ def get_data_stats(self, extra: dict | None = None) -> ExchangeObject: # train data stats train_summary_stats, train_case_stats = self._get_data_key_stats( - parser=self.train_parser, + data=self.workflow.train_dataset_data, # type: ignore data_key=self.train_data_key, hist_bins=hist_bins, hist_range=hist_range, @@ -223,13 +215,18 @@ def get_data_stats(self, extra: dict | None = None) -> ExchangeObject: stats_dict.update({self.train_data_key: train_summary_stats}) # eval data stats - eval_summary_stats, eval_case_stats = self._get_data_key_stats( - parser=self.train_parser, - data_key=self.eval_data_key, - hist_bins=hist_bins, - hist_range=hist_range, - output_path=os.path.join(self.app_root, "eval_data_stats.yaml"), - ) + eval_summary_stats = None + eval_case_stats = None + if self.workflow.val_dataset_data is not None: # type: ignore + eval_summary_stats, eval_case_stats = self._get_data_key_stats( + data=self.workflow.val_dataset_data, # type: ignore + data_key=self.eval_data_key, + hist_bins=hist_bins, + hist_range=hist_range, + output_path=os.path.join(self.app_root, "eval_data_stats.yaml"), + ) + else: + self.logger.warning("the datalist doesn't contain validation section.") if eval_summary_stats: # Only return summary statistics to FL server stats_dict.update({self.eval_data_key: eval_summary_stats}) @@ -252,17 +249,10 @@ def get_data_stats(self, extra: dict | None = None) -> ExchangeObject: else: raise ValueError("data_root not set!") - def _get_data_key_stats(self, parser, data_key, hist_bins, hist_range, output_path=None): - if data_key not in parser: - self.logger.warning(f"Data key {data_key} not available in bundle configs.") - return None, None - data = parser.get_parsed_content(data_key) - - datalist = {data_key: data} - + def _get_data_key_stats(self, data, data_key, hist_bins, hist_range, output_path=None): analyzer = DataAnalyzer( - datalist=datalist, - dataroot=self.dataset_root, + datalist={data_key: data}, + dataroot=self.workflow.dataset_dir, # type: ignore hist_bins=hist_bins, hist_range=hist_range, output_path=output_path, @@ -325,16 +315,21 @@ def _add_config_files(self, config_files): class MonaiAlgo(ClientAlgo, MonaiAlgoStats): """ Implementation of ``ClientAlgo`` to allow federated learning with MONAI bundle configurations. - FIXME: reimplement this class based on the bundle "ConfigWorkflow". Args: - bundle_root: path of bundle. + bundle_root: directory path of the bundle. + train_workflow: the bundle workflow to execute training. + eval_workflow: the bundle workflow to execute evaluation. local_epochs: number of local epochs to execute during each round of local training; defaults to 1. send_weight_diff: whether to send weight differences rather than full weights; defaults to `True`. - config_train_filename: bundle training config path relative to bundle_root. Can be a list of files; - defaults to "configs/train.json". - config_evaluate_filename: bundle evaluation config path relative to bundle_root. Can be a list of files. - If "default", config_evaluate_filename = ["configs/train.json", "configs/evaluate.json"] will be used; + config_train_filename: bundle training config path relative to bundle_root. can be a list of files. + defaults to "configs/train.json". only useful when `train_workflow` is None. + config_evaluate_filename: bundle evaluation config path relative to bundle_root. can be a list of files. + if "default", ["configs/train.json", "configs/evaluate.json"] will be used. + this arg is only useful when `eval_workflow` is None. + eval_workflow_name: the workflow name corresponding to the "config_evaluate_filename", default to "train" + as the default "config_evaluate_filename" overrides the train workflow config. + this arg is only useful when `eval_workflow` is None. config_filters_filename: filter configuration file. Can be a list of files; defaults to `None`. disable_ckpt_loading: do not use any CheckpointLoader if defined in train/evaluate configs; defaults to `True`. best_model_filepath: location of best model checkpoint; defaults "models/model.pt" relative to `bundle_root`. @@ -342,15 +337,9 @@ class MonaiAlgo(ClientAlgo, MonaiAlgoStats): save_dict_key: If a model checkpoint contains several state dicts, the one defined by `save_dict_key` will be returned by `get_weights`; defaults to "model". If all state dicts should be returned, set `save_dict_key` to None. - seed: set random seed for modules to enable or disable deterministic training; defaults to `None`, - i.e., non-deterministic training. - benchmark: set benchmark to `False` for full deterministic behavior in cuDNN components. - Note, full determinism in federated learning depends also on deterministic behavior of other FL components, - e.g., the aggregator, which is not controlled by this class. - multi_gpu: whether to run MonaiAlgo in a multi-GPU setting; defaults to `False`. - backend: backend to use for torch.distributed; defaults to "nccl". - init_method: init_method for torch.distributed; defaults to "env://". + data_stats_transform_list: transforms to apply for the data stats result. tracking: if not None, enable the experiment tracking at runtime with optionally configurable and extensible. + it expects the `train_workflow` or `eval_workflow` to be `ConfigWorkflow`, not customized `BundleWorkflow`. if "mlflow", will add `MLFlowHandler` to the parsed bundle with default tracking settings, if other string, treat it as file path to load the tracking settings. if `dict`, treat it as tracking settings. @@ -363,60 +352,60 @@ class MonaiAlgo(ClientAlgo, MonaiAlgoStats): def __init__( self, bundle_root: str, + train_workflow: BundleWorkflow | None = None, + eval_workflow: BundleWorkflow | None = None, local_epochs: int = 1, send_weight_diff: bool = True, config_train_filename: str | list | None = "configs/train.json", config_evaluate_filename: str | list | None = "default", + eval_workflow_name: str = "train", config_filters_filename: str | list | None = None, disable_ckpt_loading: bool = True, best_model_filepath: str | None = "models/model.pt", final_model_filepath: str | None = "models/model_final.pt", save_dict_key: str | None = "model", - seed: int | None = None, - benchmark: bool = True, - multi_gpu: bool = False, - backend: str = "nccl", - init_method: str = "env://", - train_data_key: str | None = BundleKeys.TRAIN_DATA, - eval_data_key: str | None = BundleKeys.VALID_DATA, data_stats_transform_list: list | None = None, tracking: str | dict | None = None, ): self.logger = logger - if config_evaluate_filename == "default": - # by default, evaluator needs both training and evaluate to be instantiated. - config_evaluate_filename = ["configs/train.json", "configs/evaluate.json"] self.bundle_root = bundle_root + self.train_workflow = None + self.eval_workflow = None + if train_workflow is not None: + if not isinstance(train_workflow, BundleWorkflow) or train_workflow.get_workflow_type() != "train": + raise ValueError( + f"train workflow must be BundleWorkflow and set type in {BundleWorkflow.supported_train_type}." + ) + self.train_workflow = train_workflow + if eval_workflow is not None: + # evaluation workflow can be "train" type or "infer" type + if not isinstance(eval_workflow, BundleWorkflow) or eval_workflow.get_workflow_type() is None: + raise ValueError("train workflow must be BundleWorkflow and set type.") + self.eval_workflow = eval_workflow self.local_epochs = local_epochs self.send_weight_diff = send_weight_diff self.config_train_filename = config_train_filename + if config_evaluate_filename == "default": + # by default, evaluator needs both training and evaluate to be instantiated + config_evaluate_filename = ["configs/train.json", "configs/evaluate.json"] self.config_evaluate_filename = config_evaluate_filename + self.eval_workflow_name = eval_workflow_name self.config_filters_filename = config_filters_filename self.disable_ckpt_loading = disable_ckpt_loading self.model_filepaths = {ModelType.BEST_MODEL: best_model_filepath, ModelType.FINAL_MODEL: final_model_filepath} self.save_dict_key = save_dict_key - self.seed = seed - self.benchmark = benchmark - self.multi_gpu = multi_gpu - self.backend = backend - self.init_method = init_method - self.train_data_key = train_data_key - self.eval_data_key = eval_data_key self.data_stats_transform_list = data_stats_transform_list self.tracking = tracking self.app_root = "" - self.train_parser: ConfigParser | None = None - self.eval_parser: ConfigParser | None = None self.filter_parser: ConfigParser | None = None self.trainer: SupervisedTrainer | None = None - self.evaluator: Any | None = None + self.evaluator: SupervisedEvaluator | None = None self.pre_filters = None self.post_weight_filters = None self.post_evaluate_filters = None self.iter_of_start_time = 0 self.global_weights: Mapping | None = None - self.rank = 0 self.phase = FlPhase.IDLE self.client_name = None @@ -431,75 +420,63 @@ def initialize(self, extra=None): i.e., `ExtraItems.CLIENT_NAME` and `ExtraItems.APP_ROOT`. """ + self._set_cuda_device() if extra is None: extra = {} self.client_name = extra.get(ExtraItems.CLIENT_NAME, "noname") self.logger.info(f"Initializing {self.client_name} ...") - - if self.multi_gpu: - dist.init_process_group(backend=self.backend, init_method=self.init_method) - self._set_cuda_device() - self.logger.info( - f"Using multi-gpu training on rank {self.rank} (available devices: {torch.cuda.device_count()})" - ) - if self.rank > 0: - self.logger.setLevel(logging.WARNING) - - if self.seed: - monai.utils.set_determinism(seed=self.seed) - torch.backends.cudnn.benchmark = self.benchmark - # FL platform needs to provide filepath to configuration files self.app_root = extra.get(ExtraItems.APP_ROOT, "") - - # Read bundle config files self.bundle_root = os.path.join(self.app_root, self.bundle_root) - config_train_files = self._add_config_files(self.config_train_filename) - config_eval_files = self._add_config_files(self.config_evaluate_filename) - config_filter_files = self._add_config_files(self.config_filters_filename) - - # Parse - self.train_parser = ConfigParser() - self.eval_parser = ConfigParser() - self.filter_parser = ConfigParser() - if len(config_train_files) > 0: - self.train_parser.read_config(config_train_files) - check_bundle_config(self.train_parser) - if len(config_eval_files) > 0: - self.eval_parser.read_config(config_eval_files) - check_bundle_config(self.eval_parser) - if len(config_filter_files) > 0: - self.filter_parser.read_config(config_filter_files) - - # override some config items - self.train_parser[RequiredBundleKeys.BUNDLE_ROOT] = self.bundle_root - self.eval_parser[RequiredBundleKeys.BUNDLE_ROOT] = self.bundle_root - # number of training epochs for each round - if BundleKeys.TRAIN_TRAINER_MAX_EPOCHS in self.train_parser: - self.train_parser[BundleKeys.TRAIN_TRAINER_MAX_EPOCHS] = self.local_epochs - - # remove checkpoint loaders - if self.disable_ckpt_loading: - disable_ckpt_loaders(self.train_parser) - disable_ckpt_loaders(self.eval_parser) - # set tracking configs for experiment management if self.tracking is not None: if isinstance(self.tracking, str) and self.tracking in DEFAULT_EXP_MGMT_SETTINGS: settings_ = DEFAULT_EXP_MGMT_SETTINGS[self.tracking] else: settings_ = ConfigParser.load_config_files(self.tracking) - ConfigWorkflow.patch_bundle_tracking(parser=self.train_parser, settings=settings_) - ConfigWorkflow.patch_bundle_tracking(parser=self.eval_parser, settings=settings_) - # Get trainer, evaluator - self.trainer = self.train_parser.get_parsed_content( - BundleKeys.TRAINER, default=ConfigItem(None, BundleKeys.TRAINER) - ) - self.evaluator = self.eval_parser.get_parsed_content( - BundleKeys.EVALUATOR, default=ConfigItem(None, BundleKeys.EVALUATOR) - ) + if self.train_workflow is None and self.config_train_filename is not None: + config_train_files = self._add_config_files(self.config_train_filename) + self.train_workflow = ConfigWorkflow( + config_file=config_train_files, meta_file=None, logging_file=None, workflow="train" + ) + if self.train_workflow is not None: + self.train_workflow.initialize() + self.train_workflow.bundle_root = self.bundle_root + self.train_workflow.max_epochs = self.local_epochs + if self.tracking is not None and isinstance(self.train_workflow, ConfigWorkflow): + ConfigWorkflow.patch_bundle_tracking(parser=self.train_workflow.parser, settings=settings_) + if self.disable_ckpt_loading and isinstance(self.train_workflow, ConfigWorkflow): + disable_ckpt_loaders(parser=self.train_workflow.parser) + # initialize the workflow as the content changed + self.train_workflow.initialize() + self.trainer = self.train_workflow.trainer + if not isinstance(self.trainer, SupervisedTrainer): + raise ValueError(f"trainer must be SupervisedTrainer, but got: {type(self.trainer)}.") + + if self.eval_workflow is None and self.config_evaluate_filename is not None: + config_eval_files = self._add_config_files(self.config_evaluate_filename) + self.eval_workflow = ConfigWorkflow( + config_file=config_eval_files, meta_file=None, logging_file=None, workflow=self.eval_workflow_name + ) + if self.eval_workflow is not None: + self.eval_workflow.initialize() + self.eval_workflow.bundle_root = self.bundle_root + if self.tracking is not None and isinstance(self.eval_workflow, ConfigWorkflow): + ConfigWorkflow.patch_bundle_tracking(parser=self.eval_workflow.parser, settings=settings_) + if self.disable_ckpt_loading and isinstance(self.eval_workflow, ConfigWorkflow): + disable_ckpt_loaders(parser=self.eval_workflow.parser) + # initialize the workflow as the content changed + self.eval_workflow.initialize() + self.evaluator = self.eval_workflow.evaluator + if not isinstance(self.evaluator, SupervisedEvaluator): + raise ValueError(f"evaluator must be SupervisedEvaluator, but got: {type(self.evaluator)}.") + + config_filter_files = self._add_config_files(self.config_filters_filename) + self.filter_parser = ConfigParser() + if len(config_filter_files) > 0: + self.filter_parser.read_config(config_filter_files) # Get filters self.pre_filters = self.filter_parser.get_parsed_content( @@ -514,17 +491,6 @@ def initialize(self, extra=None): self.post_statistics_filters = self.filter_parser.get_parsed_content( FiltersType.POST_STATISTICS_FILTERS, default=ConfigItem(None, FiltersType.POST_STATISTICS_FILTERS) ) - - # Get data location - self.dataset_root = self.train_parser.get_parsed_content( - BundleKeys.DATASET_DIR, default=ConfigItem(None, BundleKeys.DATASET_DIR) - ) - - if self.multi_gpu: - if self.rank > 0 and self.trainer: - self.trainer.logger.setLevel(logging.WARNING) - if self.rank > 0 and self.evaluator: - self.evaluator.logger.setLevel(logging.WARNING) self.logger.info(f"Initialized {self.client_name}.") def train(self, data: ExchangeObject, extra: dict | None = None) -> None: @@ -536,8 +502,8 @@ def train(self, data: ExchangeObject, extra: dict | None = None) -> None: extra: Dict with additional information that can be provided by the FL system. """ - self._set_cuda_device() + self._set_cuda_device() if extra is None: extra = {} if not isinstance(data, ExchangeObject): @@ -579,8 +545,8 @@ def get_weights(self, extra=None): or load requested model type from disk (`ModelType.BEST_MODEL` or `ModelType.FINAL_MODEL`). """ - self._set_cuda_device() + self._set_cuda_device() if extra is None: extra = {} @@ -658,8 +624,8 @@ def evaluate(self, data: ExchangeObject, extra: dict | None = None) -> ExchangeO return_metrics: `ExchangeObject` containing evaluation metrics. """ - self._set_cuda_device() + self._set_cuda_device() if extra is None: extra = {} if not isinstance(data, ExchangeObject): @@ -721,13 +687,14 @@ def finalize(self, extra: dict | None = None) -> None: if isinstance(self.evaluator, Trainer): self.logger.info(f"Terminating {self.client_name} evaluator...") self.evaluator.terminate() - - if self.multi_gpu: - dist.destroy_process_group() + if self.train_workflow is not None: + self.train_workflow.finalize() + if self.eval_workflow is not None: + self.eval_workflow.finalize() def _check_converted(self, global_weights, local_var_dict, n_converted): if n_converted == 0: - self.logger.warning( + raise RuntimeError( f"No global weights converted! Received weight dict keys are {list(global_weights.keys())}" ) else: @@ -736,6 +703,6 @@ def _check_converted(self, global_weights, local_var_dict, n_converted): ) def _set_cuda_device(self): - if self.multi_gpu: + if dist.is_initialized(): self.rank = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(self.rank) diff --git a/monai/fl/utils/constants.py b/monai/fl/utils/constants.py index fbd18b364c..3f229d6ecc 100644 --- a/monai/fl/utils/constants.py +++ b/monai/fl/utils/constants.py @@ -51,20 +51,6 @@ class FlStatistics(StrEnum): FEATURE_NAMES = "feature_names" -class RequiredBundleKeys(StrEnum): - BUNDLE_ROOT = "bundle_root" - - -class BundleKeys(StrEnum): - TRAINER = "train#trainer" - EVALUATOR = "validate#evaluator" - TRAIN_TRAINER_MAX_EPOCHS = "train#trainer#max_epochs" - VALIDATE_HANDLERS = "validate#handlers" - DATASET_DIR = "dataset_dir" - TRAIN_DATA = "train#dataset#data" - VALID_DATA = "validate#dataset#data" - - class FiltersType(StrEnum): PRE_FILTERS = "pre_filters" POST_WEIGHT_FILTERS = "post_weight_filters" diff --git a/monai/fl/utils/filters.py b/monai/fl/utils/filters.py index 15acabd9a2..56e94246ef 100644 --- a/monai/fl/utils/filters.py +++ b/monai/fl/utils/filters.py @@ -38,7 +38,7 @@ def __call__(self, data: ExchangeObject, extra: dict | None = None) -> ExchangeO class SummaryFilter(Filter): """ - Summary filter to content of ExchangeObject. + Summary filter to show content of ExchangeObject. """ def __call__(self, data: ExchangeObject, extra: dict | None = None) -> ExchangeObject: diff --git a/tests/nonconfig_workflow.py b/tests/nonconfig_workflow.py index 63562af868..150531abc7 100644 --- a/tests/nonconfig_workflow.py +++ b/tests/nonconfig_workflow.py @@ -99,6 +99,8 @@ def _get_property(self, name, property): return self._bundle_root if name == "device": return self._device + if name == "evaluator": + return self._evaluator if name == "network_def": return self._network_def if name == "inferer": @@ -115,6 +117,8 @@ def _set_property(self, name, property, value): self._bundle_root = value elif name == "device": self._device = value + elif name == "evaluator": + self._evaluator = value elif name == "network_def": self._network_def = value elif name == "inferer": diff --git a/tests/test_fl_monai_algo.py b/tests/test_fl_monai_algo.py index c4c5da00bb..cf8d3254ed 100644 --- a/tests/test_fl_monai_algo.py +++ b/tests/test_fl_monai_algo.py @@ -13,12 +13,13 @@ import os import shutil -import tempfile import unittest +from copy import deepcopy +from os.path import join as pathjoin from parameterized import parameterized -from monai.bundle import ConfigParser +from monai.bundle import ConfigParser, ConfigWorkflow from monai.bundle.utils import DEFAULT_HANDLERS_ID from monai.fl.client.monai_algo import MonaiAlgo from monai.fl.utils.constants import ExtraItems @@ -28,11 +29,14 @@ _root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__))) _data_dir = os.path.join(_root_dir, "testing_data") +_logging_file = pathjoin(_data_dir, "logging.conf") TEST_TRAIN_1 = [ { "bundle_root": _data_dir, - "config_train_filename": os.path.join(_data_dir, "config_fl_train.json"), + "train_workflow": ConfigWorkflow( + config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file + ), "config_evaluate_filename": None, "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"), } @@ -48,15 +52,34 @@ TEST_TRAIN_3 = [ { "bundle_root": _data_dir, - "config_train_filename": [ - os.path.join(_data_dir, "config_fl_train.json"), - os.path.join(_data_dir, "config_fl_train.json"), - ], + "train_workflow": ConfigWorkflow( + config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file + ), "config_evaluate_filename": None, - "config_filters_filename": [ - os.path.join(_data_dir, "config_fl_filters.json"), - os.path.join(_data_dir, "config_fl_filters.json"), - ], + "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"), + } +] + +TEST_TRAIN_4 = [ + { + "bundle_root": _data_dir, + "train_workflow": ConfigWorkflow( + config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file + ), + "config_evaluate_filename": None, + "tracking": { + "handlers_id": DEFAULT_HANDLERS_ID, + "configs": { + "execute_config": f"{_data_dir}/config_executed.json", + "trainer": { + "_target_": "MLFlowHandler", + "tracking_uri": path_to_uri(_data_dir) + "/mlflow_override", + "output_transform": "$monai.handlers.from_engine(['loss'], first=True)", + "close_on_complete": True, + }, + }, + }, + "config_filters_filename": None, } ] @@ -64,7 +87,14 @@ { "bundle_root": _data_dir, "config_train_filename": None, - "config_evaluate_filename": os.path.join(_data_dir, "config_fl_evaluate.json"), + "eval_workflow": ConfigWorkflow( + config_file=[ + os.path.join(_data_dir, "config_fl_train.json"), + os.path.join(_data_dir, "config_fl_evaluate.json"), + ], + workflow="train", + logging_file=_logging_file, + ), "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"), } ] @@ -72,7 +102,11 @@ { "bundle_root": _data_dir, "config_train_filename": None, - "config_evaluate_filename": os.path.join(_data_dir, "config_fl_evaluate.json"), + "config_evaluate_filename": [ + os.path.join(_data_dir, "config_fl_train.json"), + os.path.join(_data_dir, "config_fl_evaluate.json"), + ], + "eval_workflow_name": "training", "config_filters_filename": None, } ] @@ -80,36 +114,30 @@ { "bundle_root": _data_dir, "config_train_filename": None, - "config_evaluate_filename": [ - os.path.join(_data_dir, "config_fl_evaluate.json"), - os.path.join(_data_dir, "config_fl_evaluate.json"), - ], - "config_filters_filename": [ - os.path.join(_data_dir, "config_fl_filters.json"), - os.path.join(_data_dir, "config_fl_filters.json"), - ], + "eval_workflow": ConfigWorkflow( + config_file=[ + os.path.join(_data_dir, "config_fl_train.json"), + os.path.join(_data_dir, "config_fl_evaluate.json"), + ], + workflow="train", + logging_file=_logging_file, + ), + "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"), } ] TEST_GET_WEIGHTS_1 = [ { "bundle_root": _data_dir, - "config_train_filename": os.path.join(_data_dir, "config_fl_train.json"), + "train_workflow": ConfigWorkflow( + config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file + ), "config_evaluate_filename": None, "send_weight_diff": False, "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"), } ] TEST_GET_WEIGHTS_2 = [ - { - "bundle_root": _data_dir, - "config_train_filename": None, - "config_evaluate_filename": None, - "send_weight_diff": False, - "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"), - } -] -TEST_GET_WEIGHTS_3 = [ { "bundle_root": _data_dir, "config_train_filename": os.path.join(_data_dir, "config_fl_train.json"), @@ -118,19 +146,15 @@ "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"), } ] -TEST_GET_WEIGHTS_4 = [ +TEST_GET_WEIGHTS_3 = [ { "bundle_root": _data_dir, - "config_train_filename": [ - os.path.join(_data_dir, "config_fl_train.json"), - os.path.join(_data_dir, "config_fl_train.json"), - ], + "train_workflow": ConfigWorkflow( + config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file + ), "config_evaluate_filename": None, "send_weight_diff": True, - "config_filters_filename": [ - os.path.join(_data_dir, "config_fl_filters.json"), - os.path.join(_data_dir, "config_fl_filters.json"), - ], + "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"), } ] @@ -138,39 +162,15 @@ @SkipIfNoModule("ignite") @SkipIfNoModule("mlflow") class TestFLMonaiAlgo(unittest.TestCase): - @parameterized.expand([TEST_TRAIN_1, TEST_TRAIN_2, TEST_TRAIN_3]) + @parameterized.expand([TEST_TRAIN_1, TEST_TRAIN_2, TEST_TRAIN_3, TEST_TRAIN_4]) def test_train(self, input_params): - # get testing data dir and update train config; using the first to define data dir - if isinstance(input_params["config_train_filename"], list): - config_train_filename = [ - os.path.join(input_params["bundle_root"], x) for x in input_params["config_train_filename"] - ] - else: - config_train_filename = os.path.join(input_params["bundle_root"], input_params["config_train_filename"]) - - data_dir = tempfile.mkdtemp() - # test experiment management - input_params["tracking"] = { - "handlers_id": DEFAULT_HANDLERS_ID, - "configs": { - "execute_config": f"{data_dir}/config_executed.json", - "trainer": { - "_target_": "MLFlowHandler", - "tracking_uri": path_to_uri(data_dir) + "/mlflow_override", - "output_transform": "$monai.handlers.from_engine(['loss'], first=True)", - "close_on_complete": True, - }, - }, - } - # initialize algo algo = MonaiAlgo(**input_params) algo.initialize(extra={ExtraItems.CLIENT_NAME: "test_fl"}) algo.abort() # initialize model - parser = ConfigParser() - parser.read_config(config_train_filename) + parser = ConfigParser(config=deepcopy(algo.train_workflow.parser.get())) parser.parse() network = parser.get_parsed_content("network") @@ -179,27 +179,22 @@ def test_train(self, input_params): # test train algo.train(data=data, extra={}) algo.finalize() - self.assertTrue(os.path.exists(f"{data_dir}/mlflow_override")) - self.assertTrue(os.path.exists(f"{data_dir}/config_executed.json")) - shutil.rmtree(data_dir) + + # test experiment management + if "execute_config" in algo.train_workflow.parser: + self.assertTrue(os.path.exists(f"{_data_dir}/mlflow_override")) + shutil.rmtree(f"{_data_dir}/mlflow_override") + self.assertTrue(os.path.exists(f"{_data_dir}/config_executed.json")) + os.remove(f"{_data_dir}/config_executed.json") @parameterized.expand([TEST_EVALUATE_1, TEST_EVALUATE_2, TEST_EVALUATE_3]) def test_evaluate(self, input_params): - # get testing data dir and update train config; using the first to define data dir - if isinstance(input_params["config_evaluate_filename"], list): - config_eval_filename = [ - os.path.join(input_params["bundle_root"], x) for x in input_params["config_evaluate_filename"] - ] - else: - config_eval_filename = os.path.join(input_params["bundle_root"], input_params["config_evaluate_filename"]) - # initialize algo algo = MonaiAlgo(**input_params) algo.initialize(extra={ExtraItems.CLIENT_NAME: "test_fl"}) # initialize model - parser = ConfigParser() - parser.read_config(config_eval_filename) + parser = ConfigParser(config=deepcopy(algo.eval_workflow.parser.get())) parser.parse() network = parser.get_parsed_content("network") @@ -208,7 +203,7 @@ def test_evaluate(self, input_params): # test evaluate algo.evaluate(data=data, extra={}) - @parameterized.expand([TEST_GET_WEIGHTS_1, TEST_GET_WEIGHTS_2, TEST_GET_WEIGHTS_3, TEST_GET_WEIGHTS_4]) + @parameterized.expand([TEST_GET_WEIGHTS_1, TEST_GET_WEIGHTS_2, TEST_GET_WEIGHTS_3]) def test_get_weights(self, input_params): # initialize algo algo = MonaiAlgo(**input_params) diff --git a/tests/test_fl_monai_algo_dist.py b/tests/test_fl_monai_algo_dist.py index 36c2f419b3..1bf599a0fa 100644 --- a/tests/test_fl_monai_algo_dist.py +++ b/tests/test_fl_monai_algo_dist.py @@ -16,83 +16,95 @@ from os.path import join as pathjoin import torch.distributed as dist -from parameterized import parameterized -from monai.bundle import ConfigParser +from monai.bundle import ConfigParser, ConfigWorkflow from monai.fl.client.monai_algo import MonaiAlgo from monai.fl.utils.constants import ExtraItems from monai.fl.utils.exchange_object import ExchangeObject +from monai.networks import get_state_dict from tests.utils import DistCall, DistTestCase, SkipIfBeforePyTorchVersion, SkipIfNoModule, skip_if_no_cuda _root_dir = os.path.abspath(pathjoin(os.path.dirname(__file__))) _data_dir = pathjoin(_root_dir, "testing_data") -TEST_TRAIN_1 = [ - { - "bundle_root": _data_dir, - "config_train_filename": [ - pathjoin(_data_dir, "config_fl_train.json"), - pathjoin(_data_dir, "multi_gpu_train.json"), - ], - "config_evaluate_filename": None, - "config_filters_filename": pathjoin(_root_dir, "testing_data", "config_fl_filters.json"), - "multi_gpu": True, - } -] - -TEST_EVALUATE_1 = [ - { - "bundle_root": _data_dir, - "config_train_filename": None, - "config_evaluate_filename": [ - pathjoin(_data_dir, "config_fl_evaluate.json"), - pathjoin(_data_dir, "multi_gpu_evaluate.json"), - ], - "config_filters_filename": pathjoin(_data_dir, "config_fl_filters.json"), - "multi_gpu": True, - } -] +_logging_file = pathjoin(_data_dir, "logging.conf") @SkipIfNoModule("ignite") @SkipIfBeforePyTorchVersion((1, 11, 1)) class TestFLMonaiAlgo(DistTestCase): - @parameterized.expand([TEST_TRAIN_1]) @DistCall(nnodes=1, nproc_per_node=2, init_method="no_init") @skip_if_no_cuda - def test_train(self, input_params): + def test_train(self): + train_configs = [pathjoin(_data_dir, "config_fl_train.json"), pathjoin(_data_dir, "multi_gpu_train.json")] + eval_configs = [ + pathjoin(_data_dir, "config_fl_train.json"), + pathjoin(_data_dir, "config_fl_evaluate.json"), + pathjoin(_data_dir, "multi_gpu_evaluate.json"), + ] # initialize algo - algo = MonaiAlgo(**input_params) + algo = MonaiAlgo( + bundle_root=_data_dir, + train_workflow=ConfigWorkflow(config_file=train_configs, workflow="train", logging_file=_logging_file), + eval_workflow=ConfigWorkflow(config_file=eval_configs, workflow="train", logging_file=_logging_file), + config_filters_filename=pathjoin(_root_dir, "testing_data", "config_fl_filters.json"), + ) algo.initialize(extra={ExtraItems.CLIENT_NAME: "test_fl"}) self.assertTrue(dist.get_rank() in (0, 1)) # initialize model parser = ConfigParser() - parser.read_config([pathjoin(input_params["bundle_root"], x) for x in input_params["config_train_filename"]]) + parser.read_config(train_configs) parser.parse() network = parser.get_parsed_content("network") - data = ExchangeObject(weights=network.state_dict()) + data = ExchangeObject(weights=get_state_dict(network)) # test train - algo.train(data=data, extra={}) + for i in range(2): + print(f"Testing round {i+1} of {2}...") + # test evaluate + metric_eo = algo.evaluate(data=data, extra={}) + self.assertIsInstance(metric_eo, ExchangeObject) + metric = metric_eo.metrics + self.assertIsInstance(metric["accuracy"], float) + + # test train + algo.train(data=data, extra={}) + weights_eo = algo.get_weights() + self.assertIsInstance(weights_eo, ExchangeObject) + self.assertTrue(weights_eo.is_valid_weights()) + self.assertIsInstance(weights_eo.weights, dict) + self.assertTrue(len(weights_eo.weights) > 0) - @parameterized.expand([TEST_EVALUATE_1]) @DistCall(nnodes=1, nproc_per_node=2, init_method="no_init") @skip_if_no_cuda - def test_evaluate(self, input_params): + def test_evaluate(self): + config_file = [ + pathjoin(_data_dir, "config_fl_train.json"), + pathjoin(_data_dir, "config_fl_evaluate.json"), + pathjoin(_data_dir, "multi_gpu_evaluate.json"), + ] # initialize algo - algo = MonaiAlgo(**input_params) + algo = MonaiAlgo( + bundle_root=_data_dir, + config_train_filename=None, + eval_workflow=ConfigWorkflow(config_file=config_file, workflow="train", logging_file=_logging_file), + config_filters_filename=pathjoin(_data_dir, "config_fl_filters.json"), + ) algo.initialize(extra={ExtraItems.CLIENT_NAME: "test_fl"}) self.assertTrue(dist.get_rank() in (0, 1)) # initialize model parser = ConfigParser() parser.read_config( - [os.path.join(input_params["bundle_root"], x) for x in input_params["config_evaluate_filename"]] + [pathjoin(_data_dir, "config_fl_train.json"), pathjoin(_data_dir, "config_fl_evaluate.json")] ) parser.parse() network = parser.get_parsed_content("network") - data = ExchangeObject(weights=network.state_dict()) + data = ExchangeObject(weights=get_state_dict(network)) # test evaluate - algo.evaluate(data=data, extra={}) + metric_eo = algo.evaluate(data=data, extra={}) + self.assertIsInstance(metric_eo, ExchangeObject) + metric = metric_eo.metrics + self.assertIsInstance(metric["accuracy"], float) if __name__ == "__main__": diff --git a/tests/test_fl_monai_algo_stats.py b/tests/test_fl_monai_algo_stats.py index 1955c35b36..e46b6b899a 100644 --- a/tests/test_fl_monai_algo_stats.py +++ b/tests/test_fl_monai_algo_stats.py @@ -16,6 +16,7 @@ from parameterized import parameterized +from monai.bundle import ConfigWorkflow from monai.fl.client import MonaiAlgoStats from monai.fl.utils.constants import ExtraItems, FlStatistics from monai.fl.utils.exchange_object import ExchangeObject @@ -23,11 +24,17 @@ _root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__))) _data_dir = os.path.join(_root_dir, "testing_data") +_logging_file = os.path.join(_data_dir, "logging.conf") TEST_GET_DATA_STATS_1 = [ { "bundle_root": _data_dir, - "config_train_filename": os.path.join(_data_dir, "config_fl_stats_1.json"), + "workflow": ConfigWorkflow( + workflow="train", + config_file=os.path.join(_data_dir, "config_fl_stats_1.json"), + logging_file=_logging_file, + meta_file=None, + ), "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"), } ] @@ -41,14 +48,16 @@ TEST_GET_DATA_STATS_3 = [ { "bundle_root": _data_dir, - "config_train_filename": [ - os.path.join(_data_dir, "config_fl_stats_1.json"), - os.path.join(_data_dir, "config_fl_stats_2.json"), - ], - "config_filters_filename": [ - os.path.join(_data_dir, "config_fl_filters.json"), - os.path.join(_data_dir, "config_fl_filters.json"), - ], + "workflow": ConfigWorkflow( + workflow="train", + config_file=[ + os.path.join(_data_dir, "config_fl_stats_1.json"), + os.path.join(_data_dir, "config_fl_stats_2.json"), + ], + logging_file=_logging_file, + meta_file=None, + ), + "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"), } ] diff --git a/tests/testing_data/config_fl_evaluate.json b/tests/testing_data/config_fl_evaluate.json index 113596070a..917e762736 100644 --- a/tests/testing_data/config_fl_evaluate.json +++ b/tests/testing_data/config_fl_evaluate.json @@ -1,87 +1,18 @@ { - "bundle_root": "tests/testing_data", - "dataset_dir": "@bundle_root", - "imports": [ - "$import os" - ], - "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", - "network_def": { - "_target_": "DenseNet121", - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 6 - }, - "network": "$@network_def.to(@device)", - "validate": { - "val_transforms": [ - { - "_target_": "LoadImaged", - "keys": [ - "image" - ], - "image_only": true - }, - { - "_target_": "EnsureChannelFirstD", - "keys": [ - "image" - ] - }, - { - "_target_": "ScaleIntensityd", - "keys": [ - "image" - ] - }, - { - "_target_": "ToTensord", - "keys": [ - "image", - "label" - ] + "validate#handlers": [ + { + "_target_": "CheckpointLoader", + "load_path": "$@bundle_root + '/models/model.pt'", + "load_dict": { + "model": "@network" } - ], - "preprocessing": { - "_target_": "Compose", - "transforms": "$@validate#val_transforms" - }, - "dataset": { - "_target_": "Dataset", - "data": [ - { - "image": "$os.path.join(@dataset_dir, 'image0.jpeg')", - "label": 0 - }, - { - "image": "$os.path.join(@dataset_dir, 'image1.jpeg')", - "label": 1 - } - ], - "transform": "@validate#preprocessing" - }, - "dataloader": { - "_target_": "DataLoader", - "dataset": "@validate#dataset", - "batch_size": 3, - "shuffle": false, - "num_workers": 4 - }, - "inferer": { - "_target_": "SimpleInferer" }, - "key_metric": { - "accuracy": { - "_target_": "ignite.metrics.Accuracy", - "output_transform": "$monai.handlers.from_engine(['pred', 'label'])" - } - }, - "evaluator": { - "_target_": "SupervisedEvaluator", - "device": "@device", - "val_data_loader": "@validate#dataloader", - "network": "@network", - "inferer": "@validate#inferer", - "key_val_metric": "@validate#key_metric" + { + "_target_": "StatsHandler", + "iteration_log": false } - } + ], + "run": [ + "$@validate#evaluator.run()" + ] } diff --git a/tests/testing_data/config_fl_stats_1.json b/tests/testing_data/config_fl_stats_1.json index 41b42eb3bb..80773139c2 100644 --- a/tests/testing_data/config_fl_stats_1.json +++ b/tests/testing_data/config_fl_stats_1.json @@ -2,7 +2,7 @@ "imports": [ "$import os" ], - "bundle_root": "tests/testing_data", + "bundle_root": "", "dataset_dir": "@bundle_root", "train": { "dataset": { diff --git a/tests/testing_data/config_fl_stats_2.json b/tests/testing_data/config_fl_stats_2.json index bf55673f67..8d24bc6a8b 100644 --- a/tests/testing_data/config_fl_stats_2.json +++ b/tests/testing_data/config_fl_stats_2.json @@ -2,7 +2,7 @@ "imports": [ "$import os" ], - "bundle_root": "tests/testing_data", + "bundle_root": "", "dataset_dir": "@bundle_root", "train": { "dataset": { diff --git a/tests/testing_data/config_fl_train.json b/tests/testing_data/config_fl_train.json index bdb9792fce..5b7fb6608e 100644 --- a/tests/testing_data/config_fl_train.json +++ b/tests/testing_data/config_fl_train.json @@ -165,7 +165,7 @@ "_target_": "DataLoader", "dataset": "@validate#dataset", "batch_size": 1, - "shuffle": true, + "shuffle": false, "num_workers": 2 }, "inferer": { @@ -181,13 +181,27 @@ } ] }, + "key_metric": { + "accuracy": { + "_target_": "ignite.metrics.Accuracy", + "output_transform": "$monai.handlers.from_engine(['pred', 'label'])" + } + }, + "handlers": [ + { + "_target_": "StatsHandler", + "iteration_log": false + } + ], "evaluator": { "_target_": "SupervisedEvaluator", "device": "@device", "val_data_loader": "@validate#dataloader", "network": "@network", "inferer": "@validate#inferer", - "postprocessing": "@validate#postprocessing" + "val_handlers": "@validate#handlers", + "postprocessing": "@validate#postprocessing", + "key_val_metric": "@validate#key_metric" } }, "initialize": [ diff --git a/tests/testing_data/multi_gpu_evaluate.json b/tests/testing_data/multi_gpu_evaluate.json index 7af24a6b2e..37286cfb7a 100644 --- a/tests/testing_data/multi_gpu_evaluate.json +++ b/tests/testing_data/multi_gpu_evaluate.json @@ -14,14 +14,18 @@ "shuffle": false }, "validate#dataloader#sampler": "@validate#sampler", - "evaluating": [ + "initialize": [ "$import torch.distributed as dist", - "$dist.init_process_group(backend='nccl')", + "$dist.is_initialized() or dist.init_process_group(backend='nccl')", "$torch.cuda.set_device(@device)", "$setattr(torch.backends.cudnn, 'benchmark', True)", "$import logging", - "$@validate#evaluator.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)", - "$@validate#evaluator.run()", + "$@validate#evaluator.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)" + ], + "run": [ + "$@validate#evaluator.run()" + ], + "finalize": [ "$dist.destroy_process_group()" ] } diff --git a/tests/testing_data/multi_gpu_train.json b/tests/testing_data/multi_gpu_train.json index 41fd7698db..a617e53dfd 100644 --- a/tests/testing_data/multi_gpu_train.json +++ b/tests/testing_data/multi_gpu_train.json @@ -15,16 +15,20 @@ }, "train#dataloader#sampler": "@train#sampler", "train#dataloader#shuffle": false, - "train#trainer#train_handlers": "$@train#handlers[: -2 if dist.get_rank() > 0 else None]", - "training": [ + "initialize": [ "$import torch.distributed as dist", - "$dist.init_process_group(backend='nccl')", + "$dist.is_initialized() or dist.init_process_group(backend='nccl')", "$torch.cuda.set_device(@device)", "$monai.utils.set_determinism(seed=123)", "$setattr(torch.backends.cudnn, 'benchmark', True)", "$import logging", "$@train#trainer.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)", - "$@train#trainer.run()", + "$@validate#evaluator.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)" + ], + "run": [ + "$@train#trainer.run()" + ], + "finalize": [ "$dist.destroy_process_group()" ] }