Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3284 torch version check #3285

Merged
merged 17 commits into from
Nov 12, 2021
Merged
8 changes: 4 additions & 4 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from monai.engines.workflow import Workflow
from monai.inferers import Inferer, SimpleInferer
from monai.transforms import Transform
from monai.utils import PT_BEFORE_1_7, min_version, optional_import
from monai.utils import min_version, optional_import, pytorch_after
from monai.utils.enums import CommonKeys as Keys

if TYPE_CHECKING:
Expand Down Expand Up @@ -190,7 +190,7 @@ def _compute_pred_loss():

self.network.train()
# `set_to_none` only work from PyTorch 1.7.0
if PT_BEFORE_1_7:
if not pytorch_after(1, 7):
self.optimizer.zero_grad()
else:
self.optimizer.zero_grad(set_to_none=self.optim_set_to_none)
Expand Down Expand Up @@ -359,7 +359,7 @@ def _iteration(
d_total_loss = torch.zeros(1)
for _ in range(self.d_train_steps):
# `set_to_none` only work from PyTorch 1.7.0
if PT_BEFORE_1_7:
if not pytorch_after(1, 7):
self.d_optimizer.zero_grad()
else:
self.d_optimizer.zero_grad(set_to_none=self.optim_set_to_none)
Expand All @@ -377,7 +377,7 @@ def _iteration(
non_blocking=engine.non_blocking, # type: ignore
)
g_output = self.g_inferer(g_input, self.g_network)
if PT_BEFORE_1_7:
if not pytorch_after(1, 7):
self.g_optimizer.zero_grad()
else:
self.g_optimizer.zero_grad(set_to_none=self.optim_set_to_none)
Expand Down
14 changes: 7 additions & 7 deletions monai/networks/layers/simplelayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,17 @@
from monai.networks.layers.convutils import gaussian_1d
from monai.networks.layers.factories import Conv
from monai.utils import (
PT_BEFORE_1_7,
ChannelMatching,
InvalidPyTorchVersionError,
SkipMode,
look_up_option,
optional_import,
version_leq,
pytorch_after,
)
from monai.utils.misc import issequenceiterable

_C, _ = optional_import("monai._C")
if not PT_BEFORE_1_7:
if pytorch_after(1, 7):
fft, _ = optional_import("torch.fft")

__all__ = [
Expand Down Expand Up @@ -295,11 +294,12 @@ def apply_filter(x: torch.Tensor, kernel: torch.Tensor, **kwargs) -> torch.Tenso
x = x.view(1, kernel.shape[0], *spatials)
conv = [F.conv1d, F.conv2d, F.conv3d][n_spatial - 1]
if "padding" not in kwargs:
if version_leq(torch.__version__, "1.10.0b"):
if pytorch_after(1, 10):
kwargs["padding"] = "same"
else:
# even-sized kernels are not supported
kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]]
else:
kwargs["padding"] = "same"

if "stride" not in kwargs:
kwargs["stride"] = 1
output = conv(x, kernel, groups=kernel.shape[0], bias=None, **kwargs)
Expand Down Expand Up @@ -387,7 +387,7 @@ class HilbertTransform(nn.Module):

def __init__(self, axis: int = 2, n: Union[int, None] = None) -> None:

if PT_BEFORE_1_7:
if not pytorch_after(1, 7):
raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__)

super().__init__()
Expand Down
4 changes: 2 additions & 2 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.misc import ensure_tuple, set_determinism
from monai.utils.module import PT_BEFORE_1_7
from monai.utils.module import pytorch_after

__all__ = [
"one_hot",
Expand Down Expand Up @@ -464,7 +464,7 @@ def convert_to_torchscript(
with torch.no_grad():
script_module = torch.jit.script(model)
if filename_or_obj is not None:
if PT_BEFORE_1_7:
if not pytorch_after(1, 7):
torch.jit.save(m=script_module, f=filename_or_obj)
else:
torch.jit.save(m=script_module, f=filename_or_obj, _extra_files=extra_files)
Expand Down
4 changes: 2 additions & 2 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@
from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array
from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where
from monai.utils import (
PT_BEFORE_1_7,
InvalidPyTorchVersionError,
convert_data_type,
convert_to_dst_type,
ensure_tuple,
ensure_tuple_rep,
ensure_tuple_size,
fall_back_tuple,
pytorch_after,
)
from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.enums import TransformBackends
Expand Down Expand Up @@ -1072,7 +1072,7 @@ class DetectEnvelope(Transform):

def __init__(self, axis: int = 1, n: Union[int, None] = None) -> None:

if PT_BEFORE_1_7:
if not pytorch_after(1, 7):
raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__)

if axis < 0:
Expand Down
2 changes: 1 addition & 1 deletion monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
zip_with,
)
from .module import (
PT_BEFORE_1_7,
InvalidPyTorchVersionError,
OptionalImportError,
damerau_levenshtein_distance,
Expand All @@ -71,6 +70,7 @@
look_up_option,
min_version,
optional_import,
pytorch_after,
require_pkg,
version_leq,
)
Expand Down
57 changes: 52 additions & 5 deletions monai/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import os
import re
import sys
import warnings
from functools import wraps
Expand Down Expand Up @@ -36,8 +39,8 @@
"get_full_type_name",
"get_package_version",
"get_torch_version_tuple",
"PT_BEFORE_1_7",
"version_leq",
"pytorch_after",
]


Expand Down Expand Up @@ -450,7 +453,51 @@ def _try_cast(val: str):
return True


try:
PT_BEFORE_1_7 = torch.__version__ != "1.7.0" and version_leq(torch.__version__, "1.7.0")
except (AttributeError, TypeError):
PT_BEFORE_1_7 = True
def pytorch_after(major, minor, patch=0, current_ver_string=None) -> bool:
"""
Compute whether the current pytorch version is after or equal to the specified version.
The current system pytorch version is determined by `torch.__version__` or
via system environment variable `PYTORCH_VER`.

Args:
major: major version number to be compared with
minor: minor version number to be compared with
patch: patch version number to be compared with
current_ver_string: if None, `torch.__version__` will be used.

Returns:
True if the current pytorch version is greater than or equal to the specified version.
"""

try:
if current_ver_string is None:
_env_var = os.environ.get("PYTORCH_VER", "")
current_ver_string = _env_var if _env_var else torch.__version__
ver, has_ver = optional_import("pkg_resources", name="parse_version")
if has_ver:
return ver(".".join((f"{major}", f"{minor}", f"{patch}"))) <= ver(f"{current_ver_string}") # type: ignore
parts = f"{current_ver_string}".split("+", 1)[0].split(".", 3)
while len(parts) < 3:
parts += ["0"]
c_major, c_minor, c_patch = parts[:3]
except (AttributeError, ValueError, TypeError):
c_major, c_minor = get_torch_version_tuple()
c_patch = "0"
c_mn = int(c_major), int(c_minor)
mn = int(major), int(minor)
if c_mn != mn:
return c_mn > mn
wyli marked this conversation as resolved.
Show resolved Hide resolved
is_prerelease = ("a" in f"{c_patch}".lower()) or ("rc" in f"{c_patch}".lower())
c_p = 0
try:
p_reg = re.search(r"\d+", f"{c_patch}")
if p_reg:
c_p = int(p_reg.group())
except (AttributeError, TypeError, ValueError):
is_prerelease = True
patch = int(patch)
if c_p != patch:
return c_p > patch # type: ignore
if is_prerelease:
return False
return True
4 changes: 2 additions & 2 deletions tests/test_map_label_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from parameterized import parameterized

from monai.transforms import MapLabelValue
from monai.utils import PT_BEFORE_1_7
from monai.utils import pytorch_after
from tests.utils import TEST_NDARRAYS

TESTS = []
Expand All @@ -34,7 +34,7 @@
]
)
# PyTorch 1.5.1 doesn't support rich dtypes
if not PT_BEFORE_1_7:
if pytorch_after(1, 7):
TESTS.append(
[
{"orig_labels": [1.5, 2.5, 3.5], "target_labels": [0, 1, 2], "dtype": np.int8},
Expand Down
47 changes: 47 additions & 0 deletions tests/test_pytorch_version_after.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest

from parameterized import parameterized

from monai.utils import pytorch_after

TEST_CASES = (
(1, 5, 9, "1.6.0"),
(1, 6, 0, "1.6.0"),
(1, 6, 1, "1.6.0", False),
(1, 7, 0, "1.6.0", False),
(2, 6, 0, "1.6.0", False),
(0, 6, 0, "1.6.0a0+3fd9dcf"),
(1, 5, 9, "1.6.0a0+3fd9dcf"),
(1, 6, 0, "1.6.0a0+3fd9dcf", False),
(1, 6, 1, "1.6.0a0+3fd9dcf", False),
(2, 6, 0, "1.6.0a0+3fd9dcf", False),
(1, 6, 0, "1.6.0-rc0+3fd9dcf", False), # defaults to prerelease
(1, 6, 0, "1.6.0rc0", False),
(1, 6, 0, "1.6", True),
(1, 6, 0, "1", False),
(1, 6, 0, "1.6.0+cpu", True),
(1, 6, 1, "1.6.0+cpu", False),
)


class TestPytorchVersionCompare(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_compare(self, a, b, p, current, expected=True):
"""Test pytorch_after with a and b"""
self.assertEqual(pytorch_after(a, b, p, current), expected)


if __name__ == "__main__":
unittest.main()
8 changes: 3 additions & 5 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@
from monai.data import create_test_image_2d, create_test_image_3d
from monai.networks import convert_to_torchscript
from monai.utils import optional_import
from monai.utils.misc import is_module_ver_at_least
from monai.utils.module import version_leq
from monai.utils.module import pytorch_after, version_leq
from monai.utils.type_conversion import convert_data_type

nib, _ = optional_import("nibabel")
Expand Down Expand Up @@ -193,7 +192,7 @@ class SkipIfBeforePyTorchVersion:

def __init__(self, pytorch_version_tuple):
self.min_version = pytorch_version_tuple
self.version_too_old = not is_module_ver_at_least(torch, pytorch_version_tuple)
self.version_too_old = not pytorch_after(*pytorch_version_tuple)

def __call__(self, obj):
return unittest.skipIf(
Expand All @@ -207,8 +206,7 @@ class SkipIfAtLeastPyTorchVersion:

def __init__(self, pytorch_version_tuple):
self.max_version = pytorch_version_tuple
test_ver = ".".join(map(str, self.max_version))
self.version_too_new = version_leq(test_ver, torch.__version__)
self.version_too_new = pytorch_after(*pytorch_version_tuple)

def __call__(self, obj):
return unittest.skipIf(
Expand Down