diff --git a/ignite/contrib/engines/common.py b/ignite/contrib/engines/common.py index 76c25e2b2a48..a4833e502b36 100644 --- a/ignite/contrib/engines/common.py +++ b/ignite/contrib/engines/common.py @@ -4,6 +4,7 @@ from functools import partial import torch +import torch.nn as nn from torch.utils.data.distributed import DistributedSampler import ignite.distributed as idist @@ -40,6 +41,8 @@ def setup_common_training_handlers( device=None, stop_on_nan=True, clear_cuda_cache=True, + save_handler=None, + **kwargs ): """Helper method to setup trainer with common handlers (it also supports distributed configuration): - :class:`~ignite.handlers.TerminateOnNan` @@ -57,7 +60,8 @@ def setup_common_training_handlers( :class:`~ignite.handlers.Checkpoint` instance. save_every_iters (int, optional): saving interval. By default, `to_save` objects are stored each 1000 iterations. - output_path (str, optional): output path to indicate where `to_save` objects are stored. + output_path (str, optional): output path to indicate where `to_save` objects are stored. Argument is mutually + exclusive with ``save_handler``. lr_scheduler (ParamScheduler or subclass of `torch.optim.lr_scheduler._LRScheduler`): learning rate scheduler as native torch LRScheduler or ignite's parameter scheduler. with_gpu_stats (bool, optional): if True, :class:`~ignite.contrib.metrics.handlers.GpuInfo` is attached to the @@ -73,12 +77,16 @@ def setup_common_training_handlers( Default, True. clear_cuda_cache (bool, optional): if True, `torch.cuda.empty_cache()` is called every end of epoch. Default, True. + save_handler (callable or :class:`~ignite.handlers.checkpoint.BaseSaveHandler`, optional): Method or callable + class to use to store ``to_save``. See :class:`~ignite.handlers.checkpoint.Checkpoint` for more details. + Argument is mutually exclusive with ``output_path``. + **kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.checkpoint.Checkpoint`. device (str of torch.device, optional): deprecated argument, it will be removed in v0.5.0. """ if device is not None: warnings.warn("Argument device is unused and deprecated. It will be removed in v0.5.0") - kwargs = dict( + _kwargs = dict( to_save=to_save, save_every_iters=save_every_iters, output_path=output_path, @@ -90,10 +98,12 @@ def setup_common_training_handlers( log_every_iters=log_every_iters, stop_on_nan=stop_on_nan, clear_cuda_cache=clear_cuda_cache, + save_handler=save_handler, ) + _kwargs.update(kwargs) if idist.get_world_size() > 1: - _setup_common_distrib_training_handlers(trainer, train_sampler=train_sampler, **kwargs) + _setup_common_distrib_training_handlers(trainer, train_sampler=train_sampler, **_kwargs) else: if train_sampler is not None and isinstance(train_sampler, DistributedSampler): warnings.warn( @@ -102,7 +112,7 @@ def setup_common_training_handlers( "Train sampler argument will be ignored", UserWarning, ) - _setup_common_training_handlers(trainer, **kwargs) + _setup_common_training_handlers(trainer, **_kwargs) setup_common_distrib_training_handlers = setup_common_training_handlers @@ -121,7 +131,14 @@ def _setup_common_training_handlers( log_every_iters=100, stop_on_nan=True, clear_cuda_cache=True, + save_handler=None, + **kwargs ): + if output_path is not None and save_handler is not None: + raise ValueError( + "Arguments output_path and save_handler are mutually exclusive. Please, define only one of them" + ) + if stop_on_nan: trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) @@ -137,11 +154,15 @@ def _setup_common_training_handlers( trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache) if to_save is not None: - if output_path is None: - raise ValueError("If to_save argument is provided then output_path argument should be also defined") - checkpoint_handler = Checkpoint( - to_save, DiskSaver(dirname=output_path, require_empty=False), filename_prefix="training", - ) + + if output_path is None and save_handler is None: + raise ValueError( + "If to_save argument is provided then output_path or save_handler arguments should be also defined" + ) + if output_path is not None: + save_handler = DiskSaver(dirname=output_path, require_empty=False) + + checkpoint_handler = Checkpoint(to_save, save_handler, filename_prefix="training", **kwargs) trainer.add_event_handler(Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler) if with_gpu_stats: @@ -192,6 +213,8 @@ def _setup_common_distrib_training_handlers( log_every_iters=100, stop_on_nan=True, clear_cuda_cache=True, + save_handler=None, + **kwargs ): _setup_common_training_handlers( @@ -207,6 +230,8 @@ def _setup_common_distrib_training_handlers( log_every_iters=log_every_iters, stop_on_nan=stop_on_nan, clear_cuda_cache=clear_cuda_cache, + save_handler=save_handler, + **kwargs, ) if train_sampler is not None: @@ -450,19 +475,29 @@ def wrapper(engine): return wrapper -def save_best_model_by_val_score(output_path, evaluator, model, metric_name, n_saved=3, trainer=None, tag="val"): - """Method adds a handler to `evaluator` to save best models based on the score (named by `metric_name`) - provided by `evaluator`. +def gen_save_best_models_by_val_score( + save_handler, evaluator, models, metric_name, n_saved=3, trainer=None, tag="val", **kwargs +): + """Method adds a handler to ``evaluator`` to save ``n_saved`` of best models based on the metric + (named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``). + The logic of how to store objects is delegated to ``save_handler``. Args: - output_path (str): output path to indicate where to save best models + save_handler (callable or :class:`~ignite.handlers.checkpoint.BaseSaveHandler`): Method or callable class to + use to save engine and other provided objects. Function receives two objects: checkpoint as a dictionary + and filename. If ``save_handler`` is callable class, it can + inherit of :class:`~ignite.handlers.checkpoint.BaseSaveHandler` and optionally implement ``remove`` method + to keep a fixed number of saved checkpoints. In case if user needs to save engine's checkpoint on a disk, + ``save_handler`` can be defined with :class:`~ignite.handlers.DiskSaver`. evaluator (Engine): evaluation engine used to provide the score - model (nn.Module): model to store + models (nn.Module or Mapping): model or dictionary with the object to save. Objects should have + implemented ``state_dict`` and ``load_state_dict`` methods. metric_name (str): metric name to use for score evaluation. This metric should be present in `evaluator.state.metrics`. n_saved (int, optional): number of best models to store trainer (Engine, optional): trainer engine to fetch the epoch when saving the best model. tag (str, optional): score name prefix: `{tag}_{metric_name}`. By default, tag is "val". + **kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.checkpoint.Checkpoint`. Returns: A :class:`~ignite.handlers.checkpoint.Checkpoint` handler. @@ -471,14 +506,19 @@ def save_best_model_by_val_score(output_path, evaluator, model, metric_name, n_s if trainer is not None: global_step_transform = global_step_from_engine(trainer) + to_save = models + if isinstance(models, nn.Module): + to_save = {"model": models} + best_model_handler = Checkpoint( - {"model": model,}, - DiskSaver(dirname=output_path, require_empty=False), + to_save, + save_handler, filename_prefix="best", n_saved=n_saved, global_step_transform=global_step_transform, score_name="{}_{}".format(tag, metric_name.lower()), score_function=get_default_score_fn(metric_name), + **kwargs, ) evaluator.add_event_handler( Events.COMPLETED, best_model_handler, @@ -487,6 +527,38 @@ def save_best_model_by_val_score(output_path, evaluator, model, metric_name, n_s return best_model_handler +def save_best_model_by_val_score( + output_path, evaluator, model, metric_name, n_saved=3, trainer=None, tag="val", **kwargs +): + """Method adds a handler to ``evaluator`` to save on a disk ``n_saved`` of best models based on the metric + (named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``). + + Args: + output_path (str): output path to indicate where to save best models + evaluator (Engine): evaluation engine used to provide the score + model (nn.Module): model to store + metric_name (str): metric name to use for score evaluation. This metric should be present in + `evaluator.state.metrics`. + n_saved (int, optional): number of best models to store + trainer (Engine, optional): trainer engine to fetch the epoch when saving the best model. + tag (str, optional): score name prefix: `{tag}_{metric_name}`. By default, tag is "val". + **kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.checkpoint.Checkpoint`. + + Returns: + A :class:`~ignite.handlers.checkpoint.Checkpoint` handler. + """ + return gen_save_best_models_by_val_score( + save_handler=DiskSaver(dirname=output_path, require_empty=False), + evaluator=evaluator, + models=model, + metric_name=metric_name, + n_saved=n_saved, + trainer=trainer, + tag=tag, + **kwargs, + ) + + def add_early_stopping_by_val_score(patience, evaluator, trainer, metric_name): """Method setups early stopping handler based on the score (named by `metric_name`) provided by `evaluator`. diff --git a/tests/ignite/contrib/engines/test_common.py b/tests/ignite/contrib/engines/test_common.py index 2b17af874294..8d9a8ee73154 100644 --- a/tests/ignite/contrib/engines/test_common.py +++ b/tests/ignite/contrib/engines/test_common.py @@ -1,6 +1,6 @@ import os import sys -from unittest.mock import MagicMock +from unittest.mock import MagicMock, call import pytest import torch @@ -12,6 +12,7 @@ from ignite.contrib.engines.common import ( _setup_logging, add_early_stopping_by_val_score, + gen_save_best_models_by_val_score, save_best_model_by_val_score, setup_any_logging, setup_common_training_handlers, @@ -24,7 +25,7 @@ setup_wandb_logging, ) from ignite.engine import Engine, Events -from ignite.handlers import TerminateOnNan +from ignite.handlers import DiskSaver, TerminateOnNan class DummyModel(nn.Module): @@ -36,7 +37,9 @@ def forward(self, x): return self.net(x) -def _test_setup_common_training_handlers(dirname, device, rank=0, local_rank=0, distributed=False, lr_scheduler=None): +def _test_setup_common_training_handlers( + dirname, device, rank=0, local_rank=0, distributed=False, lr_scheduler=None, save_handler=None +): lr = 0.01 step_size = 100 @@ -86,6 +89,7 @@ def update_fn(engine, batch): to_save={"model": model, "optimizer": optimizer}, save_every_iters=75, output_path=dirname, + save_handler=save_handler, lr_scheduler=lr_scheduler, with_gpu_stats=False, output_names=["batch_loss",], @@ -107,6 +111,8 @@ def update_fn(engine, batch): # Check saved checkpoint if rank == 0: + if save_handler is not None: + dirname = save_handler.dirname checkpoints = list(os.listdir(dirname)) assert len(checkpoints) == 1 for v in [ @@ -124,10 +130,14 @@ def test_asserts_setup_common_training_handlers(): trainer = Engine(lambda e, b: None) with pytest.raises( - ValueError, match=r"If to_save argument is provided then output_path argument should be also defined" + ValueError, + match=r"If to_save argument is provided then output_path or save_handler arguments should be also defined", ): setup_common_training_handlers(trainer, to_save={}) + with pytest.raises(ValueError, match=r"Arguments output_path and save_handler are mutually exclusive"): + setup_common_training_handlers(trainer, to_save={}, output_path="abc", save_handler=lambda c, f, m: None) + with pytest.warns(UserWarning, match=r"Argument train_sampler is a distributed sampler"): train_sampler = MagicMock(spec=DistributedSampler) setup_common_training_handlers(trainer, train_sampler=train_sampler) @@ -167,10 +177,23 @@ def test_setup_common_training_handlers(dirname, capsys): out = captured.err.split("\r") out = list(map(lambda x: x.strip(), out)) out = list(filter(None, out)) - assert "Epoch:" in out[-1], "{}".format(out[-1]) + assert "Epoch" in out[-1] or "Epoch" in out[-2], "{}, {}".format(out[-2], out[-1]) + + +def test_setup_common_training_handlers_using_save_handler(dirname, capsys): + + save_handler = DiskSaver(dirname=dirname, require_empty=False) + _test_setup_common_training_handlers(dirname=None, device="cpu", save_handler=save_handler) + + # Check epoch-wise pbar + captured = capsys.readouterr() + out = captured.err.split("\r") + out = list(map(lambda x: x.strip(), out)) + out = list(filter(None, out)) + assert "Epoch" in out[-1] or "Epoch" in out[-2], "{}, {}".format(out[-2], out[-1]) -def test_save_best_model_by_val_score(dirname, capsys): +def test_save_best_model_by_val_score(dirname): trainer = Engine(lambda e, b: None) evaluator = Engine(lambda e, b: None) @@ -180,9 +203,7 @@ def test_save_best_model_by_val_score(dirname, capsys): @trainer.on(Events.EPOCH_COMPLETED) def validate(engine): - evaluator.run( - [0,] - ) + evaluator.run([0, 1]) @evaluator.on(Events.EPOCH_COMPLETED) def set_eval_metric(engine): @@ -190,12 +211,49 @@ def set_eval_metric(engine): save_best_model_by_val_score(dirname, evaluator, model, metric_name="acc", n_saved=2, trainer=trainer) - data = [ - 0, - ] - trainer.run(data, max_epochs=len(acc_scores)) + trainer.run([0, 1], max_epochs=len(acc_scores)) - assert set(os.listdir(dirname)) == set(["best_model_8_val_acc=0.6100.pt", "best_model_9_val_acc=0.7000.pt"]) + assert set(os.listdir(dirname)) == {"best_model_8_val_acc=0.6100.pt", "best_model_9_val_acc=0.7000.pt"} + + +def test_gen_save_best_models_by_val_score(): + + trainer = Engine(lambda e, b: None) + evaluator = Engine(lambda e, b: None) + model = DummyModel() + + acc_scores = [0.1, 0.2, 0.3, 0.4, 0.3, 0.5, 0.6, 0.61, 0.7, 0.5] + + @trainer.on(Events.EPOCH_COMPLETED) + def validate(engine): + evaluator.run([0, 1]) + + @evaluator.on(Events.EPOCH_COMPLETED) + def set_eval_metric(engine): + engine.state.metrics = {"acc": acc_scores[trainer.state.epoch - 1]} + + save_handler = MagicMock() + + gen_save_best_models_by_val_score( + save_handler, evaluator, {"a": model, "b": model}, metric_name="acc", n_saved=2, trainer=trainer + ) + + trainer.run([0, 1], max_epochs=len(acc_scores)) + + assert save_handler.call_count == len(acc_scores) - 2 # 2 score values (0.3 and 0.5) are not the best + print(save_handler.mock_calls) + obj_to_save = {"a": model.state_dict(), "b": model.state_dict()} + save_handler.assert_has_calls( + [ + call( + obj_to_save, + "best_checkpoint_{}_val_acc={:.4f}.pt".format(e, p), + dict([("basename", "best_checkpoint"), ("score_name", "val_acc"), ("priority", p)]), + ) + for e, p in zip([1, 2, 3, 4, 6, 7, 8, 9], [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.61, 0.7]) + ], + any_order=True, + ) def test_add_early_stopping_by_val_score(): @@ -206,9 +264,7 @@ def test_add_early_stopping_by_val_score(): @trainer.on(Events.EPOCH_COMPLETED) def validate(engine): - evaluator.run( - [0,] - ) + evaluator.run([0, 1]) @evaluator.on(Events.EPOCH_COMPLETED) def set_eval_metric(engine): @@ -216,10 +272,7 @@ def set_eval_metric(engine): add_early_stopping_by_val_score(patience=3, evaluator=evaluator, trainer=trainer, metric_name="acc") - data = [ - 0, - ] - state = trainer.run(data, max_epochs=len(acc_scores)) + state = trainer.run([0, 1], max_epochs=len(acc_scores)) assert state.epoch == 7 @@ -259,9 +312,7 @@ def _test_setup_logging( @trainer.on(Events.EPOCH_COMPLETED) def validate(engine): - evaluator.run( - [0,] - ) + evaluator.run([0, 1]) @evaluator.on(Events.EPOCH_COMPLETED) def set_eval_metric(engine):