Skip to content

Commit

Permalink
fixes tests
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli committed Nov 7, 2021
1 parent 001bade commit 67842d8
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 12 deletions.
5 changes: 3 additions & 2 deletions monai/networks/layers/simplelayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,11 @@ def apply_filter(x: torch.Tensor, kernel: torch.Tensor, **kwargs) -> torch.Tenso
conv = [F.conv1d, F.conv2d, F.conv3d][n_spatial - 1]
if "padding" not in kwargs:
if pytorch_after(1, 10):
kwargs["padding"] = "same"
else:
# even-sized kernels are not supported
kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]]
else:
kwargs["padding"] = "same"

if "stride" not in kwargs:
kwargs["stride"] = 1
output = conv(x, kernel, groups=kernel.shape[0], bias=None, **kwargs)
Expand Down
28 changes: 20 additions & 8 deletions monai/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import os
import re
import sys
import warnings
from functools import wraps
Expand Down Expand Up @@ -453,6 +456,8 @@ def _try_cast(val: str):
def pytorch_after(major, minor, patch=0, current_ver_string=None) -> bool:
"""
Compute whether the current pytorch version is after or equal to the specified version.
The current system pytorch version is determined by `torch.__version__` or
via system environment variable `PYTORCH_VER`.
Args:
major: major version number to be compared with
Expand All @@ -465,23 +470,30 @@ def pytorch_after(major, minor, patch=0, current_ver_string=None) -> bool:
"""
try:
if current_ver_string is None:
current_ver_string = torch.__version__
c_major, c_minor, c_patch = current_ver_string.split("+", 1)[0].split(".", 3)
_env_var = os.environ.get("PYTORCH_VER", "")
current_ver_string = _env_var if _env_var else torch.__version__
parts = f"{current_ver_string}".split("+", 1)[0].split(".", 3)
while len(parts) < 3:
parts += ["0"]
c_major, c_minor, c_patch = parts[:3]
except (AttributeError, ValueError, TypeError):
c_major, c_minor = get_torch_version_tuple()
c_patch = 0
c_patch = "0"
c_mn = int(c_major), int(c_minor)
mn = int(major), int(minor)
if c_mn != mn:
return c_mn > mn
is_prerelease = "a" in c_patch
is_prerelease = ("a" in f"{c_patch}".lower()) or ("rc" in f"{c_patch}".lower())
c_p = 0
try:
c_patch = int(c_patch) if not is_prerelease else int(c_patch.split("a", 1)[0])
p_reg = re.search(r"\d+", f"{c_patch}")
if p_reg:
c_p = int(p_reg.group())
except (AttributeError, TypeError, ValueError):
c_patch = 0
is_prerelease = True
if c_patch != patch:
return c_patch > patch
patch = int(patch)
if c_p != patch:
return c_p > patch # type: ignore
if is_prerelease:
return False
return True
2 changes: 1 addition & 1 deletion tests/test_pytorch_version_after.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
(1, 6, 0, "1.6.0-rc0+3fd9dcf", False), # defaults to prerelease
(1, 6, 0, "1.6.0rc0", False),
(1, 6, 0, "1.6", True),
(1, 6, 0, "1", True),
(1, 6, 0, "1", False),
(1, 6, 0, "1.6.0+cpu", True),
(1, 6, 1, "1.6.0+cpu", False),
)
Expand Down
1 change: 0 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from monai.data import create_test_image_2d, create_test_image_3d
from monai.networks import convert_to_torchscript
from monai.utils import optional_import
from monai.utils.misc import is_module_ver_at_least
from monai.utils.module import pytorch_after, version_leq
from monai.utils.type_conversion import convert_data_type

Expand Down

0 comments on commit 67842d8

Please sign in to comment.