Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 88 additions & 16 deletions ignite/contrib/engines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import partial

import torch
import torch.nn as nn
from torch.utils.data.distributed import DistributedSampler

import ignite.distributed as idist
Expand Down Expand Up @@ -40,6 +41,8 @@ def setup_common_training_handlers(
device=None,
stop_on_nan=True,
clear_cuda_cache=True,
save_handler=None,
**kwargs
):
"""Helper method to setup trainer with common handlers (it also supports distributed configuration):
- :class:`~ignite.handlers.TerminateOnNan`
Expand All @@ -57,7 +60,8 @@ def setup_common_training_handlers(
:class:`~ignite.handlers.Checkpoint` instance.
save_every_iters (int, optional): saving interval. By default, `to_save` objects are stored
each 1000 iterations.
output_path (str, optional): output path to indicate where `to_save` objects are stored.
output_path (str, optional): output path to indicate where `to_save` objects are stored. Argument is mutually
exclusive with ``save_handler``.
lr_scheduler (ParamScheduler or subclass of `torch.optim.lr_scheduler._LRScheduler`): learning rate scheduler
as native torch LRScheduler or ignite's parameter scheduler.
with_gpu_stats (bool, optional): if True, :class:`~ignite.contrib.metrics.handlers.GpuInfo` is attached to the
Expand All @@ -73,12 +77,16 @@ def setup_common_training_handlers(
Default, True.
clear_cuda_cache (bool, optional): if True, `torch.cuda.empty_cache()` is called every end of epoch.
Default, True.
save_handler (callable or :class:`~ignite.handlers.checkpoint.BaseSaveHandler`, optional): Method or callable
class to use to store ``to_save``. See :class:`~ignite.handlers.checkpoint.Checkpoint` for more details.
Argument is mutually exclusive with ``output_path``.
**kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.checkpoint.Checkpoint`.
device (str of torch.device, optional): deprecated argument, it will be removed in v0.5.0.
"""
if device is not None:
warnings.warn("Argument device is unused and deprecated. It will be removed in v0.5.0")

kwargs = dict(
_kwargs = dict(
to_save=to_save,
save_every_iters=save_every_iters,
output_path=output_path,
Expand All @@ -90,10 +98,12 @@ def setup_common_training_handlers(
log_every_iters=log_every_iters,
stop_on_nan=stop_on_nan,
clear_cuda_cache=clear_cuda_cache,
save_handler=save_handler,
)
_kwargs.update(kwargs)

if idist.get_world_size() > 1:
_setup_common_distrib_training_handlers(trainer, train_sampler=train_sampler, **kwargs)
_setup_common_distrib_training_handlers(trainer, train_sampler=train_sampler, **_kwargs)
else:
if train_sampler is not None and isinstance(train_sampler, DistributedSampler):
warnings.warn(
Expand All @@ -102,7 +112,7 @@ def setup_common_training_handlers(
"Train sampler argument will be ignored",
UserWarning,
)
_setup_common_training_handlers(trainer, **kwargs)
_setup_common_training_handlers(trainer, **_kwargs)


setup_common_distrib_training_handlers = setup_common_training_handlers
Expand All @@ -121,7 +131,14 @@ def _setup_common_training_handlers(
log_every_iters=100,
stop_on_nan=True,
clear_cuda_cache=True,
save_handler=None,
**kwargs
):
if output_path is not None and save_handler is not None:
raise ValueError(
"Arguments output_path and save_handler are mutually exclusive. Please, define only one of them"
)

if stop_on_nan:
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

Expand All @@ -137,11 +154,15 @@ def _setup_common_training_handlers(
trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)

if to_save is not None:
if output_path is None:
raise ValueError("If to_save argument is provided then output_path argument should be also defined")
checkpoint_handler = Checkpoint(
to_save, DiskSaver(dirname=output_path, require_empty=False), filename_prefix="training",
)

if output_path is None and save_handler is None:
raise ValueError(
"If to_save argument is provided then output_path or save_handler arguments should be also defined"
)
if output_path is not None:
save_handler = DiskSaver(dirname=output_path, require_empty=False)

checkpoint_handler = Checkpoint(to_save, save_handler, filename_prefix="training", **kwargs)
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler)

if with_gpu_stats:
Expand Down Expand Up @@ -192,6 +213,8 @@ def _setup_common_distrib_training_handlers(
log_every_iters=100,
stop_on_nan=True,
clear_cuda_cache=True,
save_handler=None,
**kwargs
):

_setup_common_training_handlers(
Expand All @@ -207,6 +230,8 @@ def _setup_common_distrib_training_handlers(
log_every_iters=log_every_iters,
stop_on_nan=stop_on_nan,
clear_cuda_cache=clear_cuda_cache,
save_handler=save_handler,
**kwargs,
)

if train_sampler is not None:
Expand Down Expand Up @@ -450,19 +475,29 @@ def wrapper(engine):
return wrapper


def save_best_model_by_val_score(output_path, evaluator, model, metric_name, n_saved=3, trainer=None, tag="val"):
"""Method adds a handler to `evaluator` to save best models based on the score (named by `metric_name`)
provided by `evaluator`.
def gen_save_best_models_by_val_score(
save_handler, evaluator, models, metric_name, n_saved=3, trainer=None, tag="val", **kwargs
):
"""Method adds a handler to ``evaluator`` to save ``n_saved`` of best models based on the metric
(named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``).
The logic of how to store objects is delegated to ``save_handler``.

Args:
output_path (str): output path to indicate where to save best models
save_handler (callable or :class:`~ignite.handlers.checkpoint.BaseSaveHandler`): Method or callable class to
use to save engine and other provided objects. Function receives two objects: checkpoint as a dictionary
and filename. If ``save_handler`` is callable class, it can
inherit of :class:`~ignite.handlers.checkpoint.BaseSaveHandler` and optionally implement ``remove`` method
to keep a fixed number of saved checkpoints. In case if user needs to save engine's checkpoint on a disk,
``save_handler`` can be defined with :class:`~ignite.handlers.DiskSaver`.
evaluator (Engine): evaluation engine used to provide the score
model (nn.Module): model to store
models (nn.Module or Mapping): model or dictionary with the object to save. Objects should have
implemented ``state_dict`` and ``load_state_dict`` methods.
metric_name (str): metric name to use for score evaluation. This metric should be present in
`evaluator.state.metrics`.
n_saved (int, optional): number of best models to store
trainer (Engine, optional): trainer engine to fetch the epoch when saving the best model.
tag (str, optional): score name prefix: `{tag}_{metric_name}`. By default, tag is "val".
**kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.checkpoint.Checkpoint`.

Returns:
A :class:`~ignite.handlers.checkpoint.Checkpoint` handler.
Expand All @@ -471,14 +506,19 @@ def save_best_model_by_val_score(output_path, evaluator, model, metric_name, n_s
if trainer is not None:
global_step_transform = global_step_from_engine(trainer)

to_save = models
if isinstance(models, nn.Module):
to_save = {"model": models}

best_model_handler = Checkpoint(
{"model": model,},
DiskSaver(dirname=output_path, require_empty=False),
to_save,
save_handler,
filename_prefix="best",
n_saved=n_saved,
global_step_transform=global_step_transform,
score_name="{}_{}".format(tag, metric_name.lower()),
score_function=get_default_score_fn(metric_name),
**kwargs,
)
evaluator.add_event_handler(
Events.COMPLETED, best_model_handler,
Expand All @@ -487,6 +527,38 @@ def save_best_model_by_val_score(output_path, evaluator, model, metric_name, n_s
return best_model_handler


def save_best_model_by_val_score(
output_path, evaluator, model, metric_name, n_saved=3, trainer=None, tag="val", **kwargs
):
"""Method adds a handler to ``evaluator`` to save on a disk ``n_saved`` of best models based on the metric
(named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``).

Args:
output_path (str): output path to indicate where to save best models
evaluator (Engine): evaluation engine used to provide the score
model (nn.Module): model to store
metric_name (str): metric name to use for score evaluation. This metric should be present in
`evaluator.state.metrics`.
n_saved (int, optional): number of best models to store
trainer (Engine, optional): trainer engine to fetch the epoch when saving the best model.
tag (str, optional): score name prefix: `{tag}_{metric_name}`. By default, tag is "val".
**kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.checkpoint.Checkpoint`.

Returns:
A :class:`~ignite.handlers.checkpoint.Checkpoint` handler.
"""
return gen_save_best_models_by_val_score(
save_handler=DiskSaver(dirname=output_path, require_empty=False),
evaluator=evaluator,
models=model,
metric_name=metric_name,
n_saved=n_saved,
trainer=trainer,
tag=tag,
**kwargs,
)


def add_early_stopping_by_val_score(patience, evaluator, trainer, metric_name):
"""Method setups early stopping handler based on the score (named by `metric_name`) provided by `evaluator`.

Expand Down
Loading