Skip to content

Add Iteration base class #3472

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 13, 2021
1 change: 1 addition & 0 deletions monai/apps/deepgrow/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
class Interaction:
"""
Ignite process_function used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation.
For more details please refer to: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
This implementation is based on:

Sakinis et al., Interactive segmentation of medical images through
Expand Down
32 changes: 22 additions & 10 deletions monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -45,9 +45,13 @@ class Evaluator(Workflow):
epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function to parse image and label for current iteration.
prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
from `engine.state.batch` for every iteration, for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
iteration_update: the callable function for every iteration, expect to accept `engine`
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
if not provided, use `self._iteration()` instead. for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
postprocessing: execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
key_val_metric: compute metric when every iteration completed, and save average value to
Expand Down Expand Up @@ -80,7 +84,7 @@ def __init__(
epoch_length: Optional[int] = None,
non_blocking: bool = False,
prepare_batch: Callable = default_prepare_batch,
iteration_update: Optional[Callable] = None,
iteration_update: Optional[Callable[[Engine, Any], Any]] = None,
postprocessing: Optional[Transform] = None,
key_val_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
Expand Down Expand Up @@ -147,9 +151,13 @@ class SupervisedEvaluator(Evaluator):
epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function to parse image and label for current iteration.
prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
from `engine.state.batch` for every iteration, for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
iteration_update: the callable function for every iteration, expect to accept `engine`
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
if not provided, use `self._iteration()` instead. for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
postprocessing: execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
Expand Down Expand Up @@ -184,7 +192,7 @@ def __init__(
epoch_length: Optional[int] = None,
non_blocking: bool = False,
prepare_batch: Callable = default_prepare_batch,
iteration_update: Optional[Callable] = None,
iteration_update: Optional[Callable[[Engine, Any], Any]] = None,
inferer: Optional[Inferer] = None,
postprocessing: Optional[Transform] = None,
key_val_metric: Optional[Dict[str, Metric]] = None,
Expand Down Expand Up @@ -275,9 +283,13 @@ class EnsembleEvaluator(Evaluator):
the length must exactly match the number of networks.
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function to parse image and label for current iteration.
prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
from `engine.state.batch` for every iteration, for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
iteration_update: the callable function for every iteration, expect to accept `engine`
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
if not provided, use `self._iteration()` instead. for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
postprocessing: execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
Expand Down Expand Up @@ -313,7 +325,7 @@ def __init__(
epoch_length: Optional[int] = None,
non_blocking: bool = False,
prepare_batch: Callable = default_prepare_batch,
iteration_update: Optional[Callable] = None,
iteration_update: Optional[Callable[[Engine, Any], Any]] = None,
inferer: Optional[Inferer] = None,
postprocessing: Optional[Transform] = None,
key_val_metric: Optional[Dict[str, Metric]] = None,
Expand Down
24 changes: 16 additions & 8 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import torch
from torch.optim.optimizer import Optimizer
Expand Down Expand Up @@ -75,9 +75,13 @@ class SupervisedTrainer(Trainer):
epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function to parse image and label for current iteration.
prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
from `engine.state.batch` for every iteration, for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
iteration_update: the callable function for every iteration, expect to accept `engine`
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
if not provided, use `self._iteration()` instead. for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
postprocessing: execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
Expand Down Expand Up @@ -115,7 +119,7 @@ def __init__(
epoch_length: Optional[int] = None,
non_blocking: bool = False,
prepare_batch: Callable = default_prepare_batch,
iteration_update: Optional[Callable] = None,
iteration_update: Optional[Callable[[Engine, Any], Any]] = None,
inferer: Optional[Inferer] = None,
postprocessing: Optional[Transform] = None,
key_train_metric: Optional[Dict[str, Metric]] = None,
Expand Down Expand Up @@ -241,12 +245,16 @@ class GanTrainer(Trainer):
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
d_prepare_batch: callback function to prepare batchdata for D inferer.
Defaults to return ``GanKeys.REALS`` in batchdata dict.
Defaults to return ``GanKeys.REALS`` in batchdata dict. for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
g_prepare_batch: callback function to create batch of latent input for G inferer.
Defaults to return random latents.
Defaults to return random latents. for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
g_update_latents: Calculate G loss with new latent codes. Defaults to ``True``.
iteration_update: the callable function for every iteration, expect to accept `engine`
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
if not provided, use `self._iteration()` instead. for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
postprocessing: execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
key_train_metric: compute metric when every iteration completed, and save average value to
Expand Down Expand Up @@ -286,7 +294,7 @@ def __init__(
d_prepare_batch: Callable = default_prepare_batch,
g_prepare_batch: Callable = default_make_latent,
g_update_latents: bool = True,
iteration_update: Optional[Callable] = None,
iteration_update: Optional[Callable[[Engine, Any], Any]] = None,
postprocessing: Optional[Transform] = None,
key_train_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
Expand Down
12 changes: 8 additions & 4 deletions monai/engines/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import warnings
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -67,9 +67,13 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona
epoch_length: number of iterations for one epoch, default to `len(data_loader)`.
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function to parse image and label for every iteration.
prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
from `engine.state.batch` for every iteration, for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
iteration_update: the callable function for every iteration, expect to accept `engine`
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
if not provided, use `self._iteration()` instead. for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
postprocessing: execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
key_metric: compute metric when every iteration completed, and save average value to
Expand Down Expand Up @@ -107,7 +111,7 @@ def __init__(
epoch_length: Optional[int] = None,
non_blocking: bool = False,
prepare_batch: Callable = default_prepare_batch,
iteration_update: Optional[Callable] = None,
iteration_update: Optional[Callable[[Engine, Any], Any]] = None,
postprocessing: Optional[Callable] = None,
key_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
Expand Down