From 6b42f6d9b9c9a32aab71799f41ebfc360d43c62f Mon Sep 17 00:00:00 2001 From: Felix Schnabel Date: Sun, 8 Jan 2023 16:46:32 +0100 Subject: [PATCH 1/9] Type deprecated decorator Signed-off-by: Felix Schnabel --- monai/utils/deprecate_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/utils/deprecate_utils.py b/monai/utils/deprecate_utils.py index eb182aae47..9a6b162b81 100644 --- a/monai/utils/deprecate_utils.py +++ b/monai/utils/deprecate_utils.py @@ -14,14 +14,14 @@ import warnings from functools import wraps from types import FunctionType -from typing import Any, Optional +from typing import Any, Optional, Callable, TypeVar from monai.utils.module import version_leq from .. import __version__ __all__ = ["deprecated", "deprecated_arg", "DeprecatedError", "deprecated_arg_default"] - +T = TypeVar('T', type, Callable) class DeprecatedError(Exception): pass @@ -40,7 +40,7 @@ def deprecated( msg_suffix: str = "", version_val: str = __version__, warning_category=FutureWarning, -): +) -> Callable[[T], T]: """ Marks a function or class as deprecated. If `since` is given this should be a version at or earlier than the current version and states at what version of the definition was marked as deprecated. If `removed` is given From cfb7ec3d8c44877ef4e223a4d43247d4ab73e072 Mon Sep 17 00:00:00 2001 From: Felix Schnabel Date: Sun, 8 Jan 2023 16:47:28 +0100 Subject: [PATCH 2/9] Type deprecated_arg decorator and fix typing issues Signed-off-by: Felix Schnabel --- monai/apps/reconstruction/networks/nets/complex_unet.py | 1 + monai/transforms/spatial/array.py | 4 ++-- monai/transforms/spatial/dictionary.py | 6 +++--- monai/utils/deprecate_utils.py | 2 +- tests/test_orientation.py | 4 +++- tests/test_orientationd.py | 5 +++-- tests/test_spacingd.py | 6 +++--- 7 files changed, 16 insertions(+), 12 deletions(-) diff --git a/monai/apps/reconstruction/networks/nets/complex_unet.py b/monai/apps/reconstruction/networks/nets/complex_unet.py index ccbb5731a1..3747669174 100644 --- a/monai/apps/reconstruction/networks/nets/complex_unet.py +++ b/monai/apps/reconstruction/networks/nets/complex_unet.py @@ -65,6 +65,7 @@ def __init__( conv_net: Optional[nn.Module] = None, ): super().__init__() + self.unet: nn.Module if conv_net is None: self.unet = BasicUNet( spatial_dims=spatial_dims, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 455b1c62ae..d63dfd5309 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -16,7 +16,7 @@ from copy import deepcopy from enum import Enum from itertools import zip_longest -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast import numpy as np import torch @@ -2730,7 +2730,7 @@ def __call__( grid = CenterSpatialCrop(roi_size=sp_size)(grid[0]) else: _device = img.device if isinstance(img, torch.Tensor) else self.device - grid = create_grid(spatial_size=sp_size, device=_device, backend="torch") + grid = cast(torch.Tensor, create_grid(spatial_size=sp_size, device=_device, backend="torch")) out: torch.Tensor = self.resampler( img, grid, diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 706e8d7f8b..f30c97e452 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -15,7 +15,7 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union, cast import numpy as np import torch @@ -426,7 +426,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): - d[key] = self.spacing_transform.inverse(d[key]) + d[key] = self.spacing_transform.inverse(cast(torch.Tensor, d[key])) return d @@ -1045,7 +1045,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N ) grid = CenterSpatialCrop(roi_size=sp_size)(grid[0]) else: - grid = create_grid(spatial_size=sp_size, device=device, backend="torch") + grid = cast(torch.Tensor, create_grid(spatial_size=sp_size, device=device, backend="torch")) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): d[key] = self.rand_2d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) # type: ignore diff --git a/monai/utils/deprecate_utils.py b/monai/utils/deprecate_utils.py index 9a6b162b81..6204cc4bec 100644 --- a/monai/utils/deprecate_utils.py +++ b/monai/utils/deprecate_utils.py @@ -124,7 +124,7 @@ def deprecated_arg( version_val: str = __version__, new_name: Optional[str] = None, warning_category=FutureWarning, -): +) -> Callable[[T], T]: """ Marks a particular named argument of a callable as deprecated. The same conditions for `since` and `removed` as described in the `deprecated` decorator. diff --git a/tests/test_orientation.py b/tests/test_orientation.py index 979f6ae485..48e85fd212 100644 --- a/tests/test_orientation.py +++ b/tests/test_orientation.py @@ -10,6 +10,7 @@ # limitations under the License. import unittest +from typing import cast import nibabel as nib import numpy as np @@ -186,7 +187,7 @@ def test_ornt_meta( ): img = MetaTensor(img, affine=affine).to(device) ornt = Orientation(**init_param) - res: MetaTensor = ornt(img) + res = cast(MetaTensor, ornt(img)) assert_allclose(res, expected_data.to(device)) new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=ornt.labels) self.assertEqual("".join(new_code), expected_code) @@ -204,6 +205,7 @@ def test_ornt_torch(self, init_param, img: torch.Tensor, track_meta: bool, devic assert_allclose(res, expected_data) if track_meta: self.assertIsInstance(res, MetaTensor) + assert isinstance(res, MetaTensor) # for mypy type narrowing new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=ornt.labels) self.assertEqual("".join(new_code), expected_code) else: diff --git a/tests/test_orientationd.py b/tests/test_orientationd.py index 1b4660a60a..f0a599268c 100644 --- a/tests/test_orientationd.py +++ b/tests/test_orientationd.py @@ -10,7 +10,7 @@ # limitations under the License. import unittest -from typing import Optional +from typing import Optional, cast import nibabel as nib import numpy as np @@ -74,7 +74,7 @@ def test_orntd( data = {k: img.clone() for k in ornt.keys} res = ornt(data) for k in ornt.keys: - _im = res[k] + _im = cast(MetaTensor, res[k]) self.assertIsInstance(_im, MetaTensor) np.testing.assert_allclose(_im.shape, expected_shape) code = nib.aff2axcodes(_im.affine.cpu(), ornt.ornt_transform.labels) @@ -94,6 +94,7 @@ def test_orntd_torch(self, init_param, img: torch.Tensor, track_meta: bool, devi np.testing.assert_allclose(_im.shape, expected_shape) if track_meta: self.assertIsInstance(_im, MetaTensor) + assert isinstance(_im, MetaTensor) # for mypy type narrowing code = nib.aff2axcodes(_im.affine.cpu(), ornt.ornt_transform.labels) self.assertEqual("".join(code), expected_code) else: diff --git a/tests/test_spacingd.py b/tests/test_spacingd.py index 22729fd1b2..f265f95c16 100644 --- a/tests/test_spacingd.py +++ b/tests/test_spacingd.py @@ -10,7 +10,7 @@ # limitations under the License. import unittest -from typing import List, Tuple +from typing import List, Tuple, Mapping import numpy as np import torch @@ -104,11 +104,11 @@ def test_spacingd(self, _, data, kw_args, expected_shape, expected_affine, devic def test_orntd_torch(self, init_param, img: torch.Tensor, track_meta: bool, device): set_track_meta(track_meta) tr = Spacingd(**init_param) - data = {"seg": img.to(device)} - res = tr(data)["seg"] + res = tr({"seg": img.to(device)})["seg"] if track_meta: self.assertIsInstance(res, MetaTensor) + assert isinstance(res, MetaTensor) # for mypy type narrowing new_spacing = affine_to_spacing(res.affine, 3) assert_allclose(new_spacing, init_param["pixdim"], type_test=False) self.assertNotEqual(img.shape, res.shape) From 69269bbe92dd0e6812a14fe7775daf94eb899975 Mon Sep 17 00:00:00 2001 From: Felix Schnabel Date: Sun, 8 Jan 2023 16:48:21 +0100 Subject: [PATCH 3/9] Type deprecated_arg_default decorator Signed-off-by: Felix Schnabel Update docstrings. Signed-off-by: Felix Schnabel --- monai/utils/deprecate_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/monai/utils/deprecate_utils.py b/monai/utils/deprecate_utils.py index 6204cc4bec..01f1a3ff0a 100644 --- a/monai/utils/deprecate_utils.py +++ b/monai/utils/deprecate_utils.py @@ -138,8 +138,6 @@ def deprecated_arg( using the Sphinx directives such as `.. versionchanged:: version` and `.. deprecated:: version`. https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded - In the current implementation type annotations are not preserved. - Args: name: name of position or keyword argument to mark as deprecated. @@ -234,7 +232,7 @@ def deprecated_arg_default( msg_suffix: str = "", version_val: str = __version__, warning_category=FutureWarning, -): +) -> Callable[[T], T]: """ Marks a particular arguments default of a callable as deprecated. It is changed from `old_default` to `new_default` in version `changed`. @@ -247,8 +245,6 @@ def deprecated_arg_default( using the Sphinx directives such as `.. versionchanged:: version` and `.. deprecated:: version`. https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded - In the current implementation type annotations are not preserved. - Args: name: name of position or keyword argument where the default is deprecated/changed. From 34d713f9887a85f140630a75e9261b89f0005c84 Mon Sep 17 00:00:00 2001 From: Felix Schnabel Date: Sun, 8 Jan 2023 18:04:16 +0100 Subject: [PATCH 4/9] Use typed Ignite Engine as superclass of `Workflow`. Signed-off-by: Felix Schnabel Revert Workflow. Signed-off-by: Felix Schnabel --- monai/apps/deepedit/interaction.py | 2 +- monai/apps/deepgrow/interaction.py | 2 +- monai/engines/evaluator.py | 10 +++++----- monai/engines/trainer.py | 6 +++--- monai/engines/workflow.py | 25 ++++++++++++------------- monai/fl/client/monai_algo.py | 4 ++-- 6 files changed, 24 insertions(+), 25 deletions(-) diff --git a/monai/apps/deepedit/interaction.py b/monai/apps/deepedit/interaction.py index dce81f095e..14af8c975b 100644 --- a/monai/apps/deepedit/interaction.py +++ b/monai/apps/deepedit/interaction.py @@ -96,4 +96,4 @@ def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchd # first item in batch only engine.state.batch = batchdata - return engine._iteration(engine, batchdata) + return engine._iteration(engine, batchdata) # type: ignore[arg-type] diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py index 73bc8e7e0b..e8e95f87f5 100644 --- a/monai/apps/deepgrow/interaction.py +++ b/monai/apps/deepgrow/interaction.py @@ -85,4 +85,4 @@ def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchd # collate list into a batch for next round interaction batchdata = list_data_collate(batchdata_list) - return engine._iteration(engine, batchdata) + return engine._iteration(engine, batchdata) # type: ignore[arg-type] diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index cc6e3c4253..a082094918 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence, Type import torch from torch.utils.data import DataLoader @@ -99,7 +99,7 @@ def __init__( val_handlers: Sequence | None = None, amp: bool = False, mode: ForwardMode | str = ForwardMode.EVAL, - event_names: list[str | EventEnum] | None = None, + event_names: list[str | EventEnum | Type[EventEnum]] | None = None, event_to_attr: dict | None = None, decollate: bool = True, to_kwargs: dict | None = None, @@ -133,7 +133,7 @@ def __init__( else: raise ValueError(f"unsupported mode: {mode}, should be 'eval' or 'train'.") - def run(self, global_epoch: int = 1) -> None: + def run(self, global_epoch: int = 1) -> None: # type: ignore[override] """ Execute validation/evaluation based on Ignite Engine. @@ -237,7 +237,7 @@ def __init__( val_handlers: Sequence | None = None, amp: bool = False, mode: ForwardMode | str = ForwardMode.EVAL, - event_names: list[str | EventEnum] | None = None, + event_names: list[str | EventEnum | Type[EventEnum]] | None = None, event_to_attr: dict | None = None, decollate: bool = True, to_kwargs: dict | None = None, @@ -380,7 +380,7 @@ def __init__( val_handlers: Sequence | None = None, amp: bool = False, mode: ForwardMode | str = ForwardMode.EVAL, - event_names: list[str | EventEnum] | None = None, + event_names: list[str | EventEnum | Type[EventEnum]] | None = None, event_to_attr: dict | None = None, decollate: bool = True, to_kwargs: dict | None = None, diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 2394688b9e..b1cd6bce9c 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence, Type import torch from torch.optim.optimizer import Optimizer @@ -43,7 +43,7 @@ class Trainer(Workflow): """ - def run(self) -> None: + def run(self) -> None: # type: ignore[override] """ Execute training based on Ignite Engine. If call this function multiple times, it will continuously run from the previous state. @@ -151,7 +151,7 @@ def __init__( metric_cmp_fn: Callable = default_metric_cmp_fn, train_handlers: Sequence | None = None, amp: bool = False, - event_names: list[str | EventEnum] | None = None, + event_names: list[str | EventEnum | Type[EventEnum]] | None = None, event_to_attr: dict | None = None, decollate: bool = True, optim_set_to_none: bool = False, diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index da6c086be9..64c655e696 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -10,7 +10,7 @@ # limitations under the License. import warnings -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Union, Type import torch import torch.distributed as dist @@ -24,7 +24,6 @@ from .utils import engine_apply_transform -IgniteEngine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine", as_type="") State, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "State") Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") @@ -43,7 +42,7 @@ ) -class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optional_import +class Workflow(Engine): """ Workflow defines the core work process inheriting from Ignite engine. All trainer, validator and evaluator share this same workflow as base class, @@ -114,7 +113,7 @@ def __init__( metric_cmp_fn: Callable = default_metric_cmp_fn, handlers: Optional[Sequence] = None, amp: bool = False, - event_names: Optional[List[Union[str, EventEnum]]] = None, + event_names: Optional[List[Union[str, EventEnum, Type[EventEnum]]]] = None, event_to_attr: Optional[dict] = None, decollate: bool = True, to_kwargs: Optional[Dict] = None, @@ -140,7 +139,7 @@ def set_sampler_epoch(engine: Engine): raise ValueError("If data_loader is not PyTorch DataLoader, must specify the epoch_length.") # set all sharable data for the workflow based on Ignite engine.state - self.state = State( + self.state: Any = State( rank=dist.get_rank() if dist.is_available() and dist.is_initialized() else 0, seed=0, iteration=0, @@ -167,18 +166,18 @@ def set_sampler_epoch(engine: Engine): self.scaler: Optional[torch.cuda.amp.GradScaler] = None if event_names is None: - event_names = [IterationEvents] # type: ignore + event_names = [IterationEvents] else: if not isinstance(event_names, list): - raise ValueError("`event_names` must be a list or string or EventEnum.") - event_names += [IterationEvents] # type: ignore + raise ValueError("`event_names` must be a list of strings or EventEnums.") + event_names += [IterationEvents] for name in event_names: - if isinstance(name, str): - self.register_events(name, event_to_attr=event_to_attr) - elif issubclass(name, EventEnum): # type: ignore + if isinstance(name, (str, EventEnum)): + self.register_events(name, event_to_attr=event_to_attr) # type: ignore[arg-type] + elif issubclass(name, EventEnum): self.register_events(*name, event_to_attr=event_to_attr) else: - raise ValueError("`event_names` must be a list or string or EventEnum.") + raise ValueError("`event_names` must be a list of strings or EventEnums.") if decollate: self._register_decollate() @@ -267,7 +266,7 @@ def _register_handlers(self, handlers: Sequence): for handler in handlers_: handler.attach(self) - def run(self) -> None: + def run(self) -> None: # type: ignore[override] """ Execute training, validation or evaluation based on Ignite Engine. """ diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py index 2cdabdca9a..d7309d147e 100644 --- a/monai/fl/client/monai_algo.py +++ b/monai/fl/client/monai_algo.py @@ -21,7 +21,7 @@ from monai.apps.auto3dseg.data_analyzer import DataAnalyzer from monai.auto3dseg import SegSummarizer from monai.bundle import DEFAULT_EXP_MGMT_SETTINGS, ConfigComponent, ConfigItem, ConfigParser, patch_bundle_tracking -from monai.engines import Trainer +from monai.engines import Trainer, SupervisedTrainer from monai.fl.client import ClientAlgo, ClientAlgoStats from monai.fl.utils.constants import ( BundleKeys, @@ -429,7 +429,7 @@ def __init__( self.train_parser: Optional[ConfigParser] = None self.eval_parser: Optional[ConfigParser] = None self.filter_parser: Optional[ConfigParser] = None - self.trainer: Optional[Trainer] = None + self.trainer: Optional[SupervisedTrainer] = None self.evaluator: Optional[Any] = None self.pre_filters = None self.post_weight_filters = None From f9d472af05ceb56513771ea2faccebb79edc1a7a Mon Sep 17 00:00:00 2001 From: Felix Schnabel Date: Sun, 8 Jan 2023 18:14:04 +0100 Subject: [PATCH 5/9] Use typed Ignite reinit__is_reduced decorator and Metric as superclass of IgniteMetric. Signed-off-by: Felix Schnabel --- monai/handlers/ignite_metric.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py index d6f3f50144..813fb58635 100644 --- a/monai/handlers/ignite_metric.py +++ b/monai/handlers/ignite_metric.py @@ -19,17 +19,21 @@ from monai.utils import min_version, optional_import idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") -Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric", as_type="base") -reinit__is_reduced, _ = optional_import( - "ignite.metrics.metric", IgniteInfo.OPT_IMPORT_VERSION, min_version, "reinit__is_reduced", as_type="decorator" -) + + if TYPE_CHECKING: from ignite.engine import Engine + from ignite.metrics import Metric + from ignite.metrics.metric import reinit__is_reduced else: Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric", as_type="base") + reinit__is_reduced, _ = optional_import( + "ignite.metrics.metric", IgniteInfo.OPT_IMPORT_VERSION, min_version, "reinit__is_reduced", as_type="decorator" + ) -class IgniteMetric(Metric): # type: ignore[valid-type, misc] # due to optional_import +class IgniteMetric(Metric): """ Base Metric class based on ignite event handler mechanism. The input `prediction` or `label` data can be a PyTorch Tensor or numpy array with batch dim and channel dim, @@ -107,7 +111,7 @@ def compute(self) -> Any: result = result.item() return result - def attach(self, engine: Engine, name: str) -> None: + def attach(self, engine: Engine, name: str) -> None: # type: ignore[override] """ Attaches current metric to provided engine. On the end of engine's run, `engine.state.metrics` dictionary will contain computed metric's value under provided name. From a9b0cd5061226f4e68bc9c3763ba339f1523681a Mon Sep 17 00:00:00 2001 From: Felix Schnabel Date: Sun, 8 Jan 2023 18:16:40 +0100 Subject: [PATCH 6/9] Change mypy setting to disallow untyped decorators. Signed-off-by: Felix Schnabel --- setup.cfg | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.cfg b/setup.cfg index 30352d50db..095d77c91f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -208,7 +208,10 @@ ignore_errors = True ignore_errors = True [mypy-monai.*] +# Also check the body of functions with no types in their type signature. check_untyped_defs = True +# Warns about usage of untyped decorators. +disallow_untyped_decorators = True [pytype] # Space-separated list of files or directories to exclude. From 7a0d5eb5fe84a8726b238acb26762494070928ec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 8 Jan 2023 18:52:48 +0000 Subject: [PATCH 7/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/engines/evaluator.py | 8 ++++---- monai/engines/trainer.py | 4 ++-- tests/test_spacingd.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index a082094918..e333acabf6 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence, Type +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence import torch from torch.utils.data import DataLoader @@ -99,7 +99,7 @@ def __init__( val_handlers: Sequence | None = None, amp: bool = False, mode: ForwardMode | str = ForwardMode.EVAL, - event_names: list[str | EventEnum | Type[EventEnum]] | None = None, + event_names: list[str | EventEnum | type[EventEnum]] | None = None, event_to_attr: dict | None = None, decollate: bool = True, to_kwargs: dict | None = None, @@ -237,7 +237,7 @@ def __init__( val_handlers: Sequence | None = None, amp: bool = False, mode: ForwardMode | str = ForwardMode.EVAL, - event_names: list[str | EventEnum | Type[EventEnum]] | None = None, + event_names: list[str | EventEnum | type[EventEnum]] | None = None, event_to_attr: dict | None = None, decollate: bool = True, to_kwargs: dict | None = None, @@ -380,7 +380,7 @@ def __init__( val_handlers: Sequence | None = None, amp: bool = False, mode: ForwardMode | str = ForwardMode.EVAL, - event_names: list[str | EventEnum | Type[EventEnum]] | None = None, + event_names: list[str | EventEnum | type[EventEnum]] | None = None, event_to_attr: dict | None = None, decollate: bool = True, to_kwargs: dict | None = None, diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index b1cd6bce9c..58a730e4c8 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence, Type +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence import torch from torch.optim.optimizer import Optimizer @@ -151,7 +151,7 @@ def __init__( metric_cmp_fn: Callable = default_metric_cmp_fn, train_handlers: Sequence | None = None, amp: bool = False, - event_names: list[str | EventEnum | Type[EventEnum]] | None = None, + event_names: list[str | EventEnum | type[EventEnum]] | None = None, event_to_attr: dict | None = None, decollate: bool = True, optim_set_to_none: bool = False, diff --git a/tests/test_spacingd.py b/tests/test_spacingd.py index f265f95c16..6505d2ceb8 100644 --- a/tests/test_spacingd.py +++ b/tests/test_spacingd.py @@ -10,7 +10,7 @@ # limitations under the License. import unittest -from typing import List, Tuple, Mapping +from typing import List, Tuple import numpy as np import torch From 731865f27a55a0fb9efcbc20d790112619035b5d Mon Sep 17 00:00:00 2001 From: Felix Schnabel Date: Sun, 8 Jan 2023 20:01:28 +0100 Subject: [PATCH 8/9] Formatting. Signed-off-by: Felix Schnabel --- monai/engines/workflow.py | 2 +- monai/fl/client/monai_algo.py | 2 +- monai/utils/deprecate_utils.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 64c655e696..8123d14fac 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -10,7 +10,7 @@ # limitations under the License. import warnings -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Union, Type +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Type, Union import torch import torch.distributed as dist diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py index d7309d147e..a424b3afc3 100644 --- a/monai/fl/client/monai_algo.py +++ b/monai/fl/client/monai_algo.py @@ -21,7 +21,7 @@ from monai.apps.auto3dseg.data_analyzer import DataAnalyzer from monai.auto3dseg import SegSummarizer from monai.bundle import DEFAULT_EXP_MGMT_SETTINGS, ConfigComponent, ConfigItem, ConfigParser, patch_bundle_tracking -from monai.engines import Trainer, SupervisedTrainer +from monai.engines import SupervisedTrainer, Trainer from monai.fl.client import ClientAlgo, ClientAlgoStats from monai.fl.utils.constants import ( BundleKeys, diff --git a/monai/utils/deprecate_utils.py b/monai/utils/deprecate_utils.py index 01f1a3ff0a..793bcc2104 100644 --- a/monai/utils/deprecate_utils.py +++ b/monai/utils/deprecate_utils.py @@ -14,14 +14,15 @@ import warnings from functools import wraps from types import FunctionType -from typing import Any, Optional, Callable, TypeVar +from typing import Any, Callable, Optional, TypeVar from monai.utils.module import version_leq from .. import __version__ __all__ = ["deprecated", "deprecated_arg", "DeprecatedError", "deprecated_arg_default"] -T = TypeVar('T', type, Callable) +T = TypeVar("T", type, Callable) + class DeprecatedError(Exception): pass From 2aebe3512ae9a18912e797b6cdafca89773365c9 Mon Sep 17 00:00:00 2001 From: Felix Schnabel Date: Sun, 8 Jan 2023 20:18:32 +0100 Subject: [PATCH 9/9] Formatting Signed-off-by: Felix Schnabel --- monai/engines/evaluator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index e333acabf6..468ffdb19a 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -237,7 +237,7 @@ def __init__( val_handlers: Sequence | None = None, amp: bool = False, mode: ForwardMode | str = ForwardMode.EVAL, - event_names: list[str | EventEnum | type[EventEnum]] | None = None, + event_names: list[str | EventEnum | type[EventEnum]] | None = None, event_to_attr: dict | None = None, decollate: bool = True, to_kwargs: dict | None = None,