Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto3DSeg continue training (skip trained algos) #6310

Merged
merged 11 commits into from
Apr 6, 2023
112 changes: 65 additions & 47 deletions monai/apps/auto3dseg/auto_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def __init__(
# determine if we need to analyze, algo_gen or train from cache, unless manually provided
self.analyze = not self.cache["analyze"] if analyze is None else analyze
self.algo_gen = not self.cache["algo_gen"] if algo_gen is None else algo_gen
self.train = not self.cache["train"] if train is None else train
self.train = train
self.ensemble = ensemble # last step, no need to check

self.set_training_params()
Expand Down Expand Up @@ -635,13 +635,15 @@ def _train_algo_in_sequence(self, history: list[dict[str, Any]]) -> None:
folders under the working directory. The results include the model checkpoints, a
progress.yaml, accuracies in CSV and a pickle file of the Algo object.
"""
for task in history:
for _, algo in task.items():
algo.train(self.train_params)
acc = algo.get_score()
algo_to_pickle(algo, template_path=algo.template_path, best_metrics=acc)
for algo_dict in history:
algo = algo_dict[AlgoEnsembleKeys.ALGO]
algo.train(self.train_params)
acc = algo.get_score()

def _train_algo_in_nni(self, history):
algo_meta_data = {str(AlgoEnsembleKeys.SCORE): acc}
algo_to_pickle(algo, template_path=algo.template_path, **algo_meta_data)

def _train_algo_in_nni(self, history: list[dict[str, Any]]) -> None:
"""
Train the Algos using HPO.

Expand Down Expand Up @@ -672,40 +674,41 @@ def _train_algo_in_nni(self, history):

last_total_tasks = len(import_bundle_algo_history(self.work_dir, only_trained=True))
mode_dry_run = self.hpo_params.pop("nni_dry_run", False)
for task in history:
for name, algo in task.items():
nni_gen = NNIGen(algo=algo, params=self.hpo_params)
obj_filename = nni_gen.get_obj_filename()
nni_config = deepcopy(default_nni_config)
# override the default nni config with the same key in hpo_params
for key in self.hpo_params:
if key in nni_config:
nni_config[key] = self.hpo_params[key]
nni_config.update({"experimentName": name})
nni_config.update({"search_space": self.search_space})
trial_cmd = "python -m monai.apps.auto3dseg NNIGen run_algo " + obj_filename + " " + self.work_dir
nni_config.update({"trialCommand": trial_cmd})
nni_config_filename = os.path.abspath(os.path.join(self.work_dir, f"{name}_nni_config.yaml"))
ConfigParser.export_config_file(nni_config, nni_config_filename, fmt="yaml", default_flow_style=None)

max_trial = min(self.hpo_tasks, cast(int, default_nni_config["maxTrialNumber"]))
cmd = "nnictl create --config " + nni_config_filename + " --port 8088"

if mode_dry_run:
logger.info(f"AutoRunner HPO is in dry-run mode. Please manually launch: {cmd}")
continue

subprocess.run(cmd.split(), check=True)

for algo_dict in history:
name = algo_dict[AlgoEnsembleKeys.ID]
algo = algo_dict[AlgoEnsembleKeys.ALGO]
nni_gen = NNIGen(algo=algo, params=self.hpo_params)
obj_filename = nni_gen.get_obj_filename()
nni_config = deepcopy(default_nni_config)
# override the default nni config with the same key in hpo_params
for key in self.hpo_params:
if key in nni_config:
nni_config[key] = self.hpo_params[key]
nni_config.update({"experimentName": name})
nni_config.update({"search_space": self.search_space})
trial_cmd = "python -m monai.apps.auto3dseg NNIGen run_algo " + obj_filename + " " + self.work_dir
nni_config.update({"trialCommand": trial_cmd})
nni_config_filename = os.path.abspath(os.path.join(self.work_dir, f"{name}_nni_config.yaml"))
ConfigParser.export_config_file(nni_config, nni_config_filename, fmt="yaml", default_flow_style=None)

max_trial = min(self.hpo_tasks, cast(int, default_nni_config["maxTrialNumber"]))
cmd = "nnictl create --config " + nni_config_filename + " --port 8088"

if mode_dry_run:
logger.info(f"AutoRunner HPO is in dry-run mode. Please manually launch: {cmd}")
continue

subprocess.run(cmd.split(), check=True)

n_trainings = len(import_bundle_algo_history(self.work_dir, only_trained=True))
while n_trainings - last_total_tasks < max_trial:
sleep(1)
n_trainings = len(import_bundle_algo_history(self.work_dir, only_trained=True))
while n_trainings - last_total_tasks < max_trial:
sleep(1)
n_trainings = len(import_bundle_algo_history(self.work_dir, only_trained=True))

cmd = "nnictl stop --all"
subprocess.run(cmd.split(), check=True)
logger.info(f"NNI completes HPO on {name}")
last_total_tasks = n_trainings
cmd = "nnictl stop --all"
subprocess.run(cmd.split(), check=True)
logger.info(f"NNI completes HPO on {name}")
last_total_tasks = n_trainings

def run(self):
"""
Expand Down Expand Up @@ -758,7 +761,8 @@ def run(self):
logger.info("Skipping algorithm generation...")

# step 3: algo training
if self.train:
auto_train_choice = self.train is None
if self.train or (auto_train_choice and not self.cache["train"]):
history = import_bundle_algo_history(self.work_dir, only_trained=False)

if len(history) == 0:
Expand All @@ -767,20 +771,34 @@ def run(self):
"Possibly the required algorithms generation step was not completed."
)

if not self.hpo:
self._train_algo_in_sequence(history)
else:
self._train_algo_in_nni(history)
if auto_train_choice:
history = [h for h in history if not h["is_trained"]] # skip trained
myron marked this conversation as resolved.
Show resolved Hide resolved

if len(history) > 0:
if not self.hpo:
self._train_algo_in_sequence(history)
else:
self._train_algo_in_nni(history)

self.export_cache(train=True)
else:
logger.info("Skipping algorithm training...")

# step 4: model ensemble and write the prediction to disks.
if self.ensemble:
history = import_bundle_algo_history(self.work_dir, only_trained=True)
history = import_bundle_algo_history(self.work_dir, only_trained=False)

history_untrained = [h for h in history if not h["is_trained"]]
if len(history_untrained) > 0:
warnings.warn(
f"Ensembling step will skip {[h['name'] for h in history_untrained]} untrained algos"
"Generally it means these algos did not complete training"
myron marked this conversation as resolved.
Show resolved Hide resolved
)
history = [h for h in history if h["is_trained"]]

if len(history) == 0:
raise ValueError(
f"Could not find the trained results in {self.work_dir}. "
f"Could not find any trained algos in {self.work_dir}. "
"Possibly the required training step was not completed."
)

Expand All @@ -798,4 +816,4 @@ def run(self):
self.save_image(pred)
logger.info(f"Auto3Dseg ensemble prediction outputs are saved in {self.output_dir}.")

logger.info("Auto3Dseg pipeline is complete successfully.")
logger.info("Auto3Dseg pipeline is completed successfully.")
5 changes: 4 additions & 1 deletion monai/apps/auto3dseg/bundle_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from monai.auto3dseg.utils import algo_to_pickle
from monai.bundle.config_parser import ConfigParser
from monai.utils import ensure_tuple
from monai.utils.enums import AlgoEnsembleKeys

logger = get_logger(module_name=__name__)
ALGO_HASH = os.environ.get("MONAI_ALGO_HASH", "7758ad1")
Expand Down Expand Up @@ -537,4 +538,6 @@ def generate(
gen_algo.export_to_disk(output_folder, name, fold=f_id)

algo_to_pickle(gen_algo, template_path=algo.template_path)
self.history.append({name: gen_algo}) # track the previous, may create a persistent history
self.history.append(
{AlgoEnsembleKeys.ID: name, AlgoEnsembleKeys.ALGO: gen_algo}
) # track the previous, may create a persistent history
12 changes: 5 additions & 7 deletions monai/apps/auto3dseg/ensemble_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,22 +266,20 @@ class AlgoEnsembleBuilder:

"""

def __init__(self, history: Sequence[dict], data_src_cfg_filename: str | None = None):
def __init__(self, history: Sequence[dict[str, Any]], data_src_cfg_filename: str | None = None):
self.infer_algos: list[dict[AlgoEnsembleKeys, Any]] = []
self.ensemble: AlgoEnsemble
self.data_src_cfg = ConfigParser(globals=False)

if data_src_cfg_filename is not None and os.path.exists(str(data_src_cfg_filename)):
self.data_src_cfg.read_config(data_src_cfg_filename)

for h in history:
for algo_dict in history:
# load inference_config_paths
# raise warning/error if not found
if len(h) > 1:
raise ValueError(f"{h} should only contain one set of genAlgo key-value")

name = list(h.keys())[0]
gen_algo = h[name]
name = algo_dict[AlgoEnsembleKeys.ID]
gen_algo = algo_dict[AlgoEnsembleKeys.ALGO]

best_metric = gen_algo.get_score()
algo_path = gen_algo.output_path
infer_path = os.path.join(algo_path, "scripts", "infer.py")
Expand Down
16 changes: 10 additions & 6 deletions monai/apps/auto3dseg/hpo_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from monai.bundle.config_parser import ConfigParser
from monai.config import PathLike
from monai.utils import optional_import
from monai.utils.enums import AlgoEnsembleKeys

nni, has_nni = optional_import("nni")
optuna, has_optuna = optional_import("optuna")
Expand Down Expand Up @@ -98,8 +99,8 @@ class NNIGen(HPOGen):
# Bundle Algorithms are already generated by BundleGen in work_dir
import_bundle_algo_history(work_dir, only_trained=False)
algo_dict = self.history[0] # pick the first algorithm
algo_name = list(algo_dict.keys())[0]
onealgo = algo_dict[algo_name]
algo_name = algo_dict[AlgoEnsembleKeys.ID]
onealgo = algo_dict[AlgoEnsembleKeys.ALGO]
nni_gen = NNIGen(algo=onealgo)
nni_gen.print_bundle_algo_instruction()

Expand Down Expand Up @@ -237,10 +238,12 @@ def run_algo(self, obj_filename: str, output_folder: str = ".", template_path: P
self.algo.train(self.params)
# step 4 report validation acc to controller
acc = self.algo.get_score()
algo_meta_data = {str(AlgoEnsembleKeys.SCORE): acc}

if isinstance(self.algo, BundleAlgo):
algo_to_pickle(self.algo, template_path=self.algo.template_path, best_metrics=acc)
algo_to_pickle(self.algo, template_path=self.algo.template_path, **algo_meta_data)
else:
algo_to_pickle(self.algo, best_metrics=acc)
algo_to_pickle(self.algo, **algo_meta_data)
self.set_score(acc)


Expand Down Expand Up @@ -408,8 +411,9 @@ def run_algo(self, obj_filename: str, output_folder: str = ".", template_path: P
self.algo.train(self.params)
# step 4 report validation acc to controller
acc = self.algo.get_score()
algo_meta_data = {str(AlgoEnsembleKeys.SCORE): acc}
if isinstance(self.algo, BundleAlgo):
algo_to_pickle(self.algo, template_path=self.algo.template_path, best_metrics=acc)
algo_to_pickle(self.algo, template_path=self.algo.template_path, **algo_meta_data)
else:
algo_to_pickle(self.algo, best_metrics=acc)
algo_to_pickle(self.algo, **algo_meta_data)
self.set_score(acc)
27 changes: 18 additions & 9 deletions monai/apps/auto3dseg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@

from monai.apps.auto3dseg.bundle_gen import BundleAlgo
from monai.auto3dseg import algo_from_pickle, algo_to_pickle
from monai.utils.enums import AlgoEnsembleKeys


def import_bundle_algo_history(
output_folder: str = ".", template_path: str | None = None, only_trained: bool = True
) -> list:
"""
import the history of the bundleAlgo object with their names/identifiers
import the history of the bundleAlgo objects as a list of algo dicts.
each algo_dict has keys name (folder name), algo (bundleAlgo), is_trained (bool),

Args:
output_folder: the root path of the algorithms templates.
Expand All @@ -47,11 +49,18 @@ def import_bundle_algo_history(
if isinstance(algo, BundleAlgo): # algo's template path needs override
algo.template_path = algo_meta_data["template_path"]

if only_trained:
if "best_metrics" in algo_meta_data:
history.append({name: algo})
else:
history.append({name: algo})
best_metric = algo_meta_data.get(AlgoEnsembleKeys.SCORE, None)
is_trained = best_metric is not None

if (only_trained and is_trained) or not only_trained:
history.append(
{
AlgoEnsembleKeys.ID: name,
AlgoEnsembleKeys.ALGO: algo,
AlgoEnsembleKeys.SCORE: best_metric,
"is_trained": is_trained,
}
)

return history

Expand All @@ -63,6 +72,6 @@ def export_bundle_algo_history(history: list[dict[str, BundleAlgo]]) -> None:
Args:
history: a List of Bundle. Typically, the history can be obtained from BundleGen get_history method
"""
for task in history:
for _, algo in task.items():
algo_to_pickle(algo, template_path=algo.template_path)
for algo_dict in history:
algo = algo_dict[AlgoEnsembleKeys.ALGO]
myron marked this conversation as resolved.
Show resolved Hide resolved
algo_to_pickle(algo, template_path=algo.template_path)
20 changes: 10 additions & 10 deletions tests/test_auto3dseg_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,16 @@ def test_ensemble(self) -> None:
bundle_generator.generate(work_dir, num_fold=1)
history = bundle_generator.get_history()

for h in history:
self.assertEqual(len(h.keys()), 1, "each record should have one model")
for name, algo in h.items():
_train_param = train_param.copy()
if name.startswith("segresnet"):
_train_param["network#init_filters"] = 8
_train_param["pretrained_ckpt_name"] = ""
elif name.startswith("swinunetr"):
_train_param["network#feature_size"] = 12
algo.train(_train_param)
for algo_dict in history:
name = algo_dict[AlgoEnsembleKeys.ID]
algo = algo_dict[AlgoEnsembleKeys.ALGO]
_train_param = train_param.copy()
if name.startswith("segresnet"):
_train_param["network#init_filters"] = 8
_train_param["pretrained_ckpt_name"] = ""
elif name.startswith("swinunetr"):
_train_param["network#feature_size"] = 12
algo.train(_train_param)

builder = AlgoEnsembleBuilder(history, data_src_cfg)
builder.set_ensemble_method(AlgoEnsembleBestN(n_best=1))
Expand Down
10 changes: 4 additions & 6 deletions tests/test_auto3dseg_hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from monai.bundle.config_parser import ConfigParser
from monai.data import create_test_image_3d
from monai.utils import optional_import
from monai.utils.enums import AlgoEnsembleKeys
from tests.utils import (
SkipIfBeforePyTorchVersion,
get_testing_algo_template_path,
Expand Down Expand Up @@ -139,8 +140,7 @@ def setUp(self) -> None:
@skip_if_no_cuda
def test_run_algo(self) -> None:
algo_dict = self.history[0]
algo_name = list(algo_dict.keys())[0]
algo = algo_dict[algo_name]
algo = algo_dict[AlgoEnsembleKeys.ALGO]
nni_gen = NNIGen(algo=algo, params=override_param)
obj_filename = nni_gen.get_obj_filename()
# this function will be used in HPO via Python Fire
Expand All @@ -150,8 +150,7 @@ def test_run_algo(self) -> None:
@skip_if_no_optuna
def test_run_optuna(self) -> None:
algo_dict = self.history[0]
algo_name = list(algo_dict.keys())[0]
algo = algo_dict[algo_name]
algo = algo_dict[AlgoEnsembleKeys.ALGO]

class OptunaGenLearningRate(OptunaGen):
def get_hyperparameters(self):
Expand All @@ -173,8 +172,7 @@ def get_hyperparameters(self):
@skip_if_no_cuda
def test_get_history(self) -> None:
algo_dict = self.history[0]
algo_name = list(algo_dict.keys())[0]
algo = algo_dict[algo_name]
algo = algo_dict[AlgoEnsembleKeys.ALGO]
nni_gen = NNIGen(algo=algo, params=override_param)
obj_filename = nni_gen.get_obj_filename()

Expand Down
7 changes: 3 additions & 4 deletions tests/test_integration_gpu_customization.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,9 @@ def test_ensemble_gpu_customization(self) -> None:
)
history = bundle_generator.get_history()

for h in history:
self.assertEqual(len(h.keys()), 1, "each record should have one model")
for _, algo in h.items():
algo.train(train_param)
for algo_dict in history:
algo = algo_dict[AlgoEnsembleKeys.ALGO]
algo.train(train_param)

builder = AlgoEnsembleBuilder(history, data_src_cfg)
builder.set_ensemble_method(AlgoEnsembleBestN(n_best=2))
Expand Down