From 2eaa799c4847780c3a8d8aabf60f02b3250accd4 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Wed, 27 Jul 2022 11:00:15 -0600 Subject: [PATCH 1/3] Improve version checking --- captum/_utils/common.py | 3 ++- captum/influence/_core/similarity_influence.py | 3 ++- captum/influence/_utils/common.py | 4 +++- setup.py | 2 +- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/captum/_utils/common.py b/captum/_utils/common.py index 6db0727024..3b39925bf9 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -14,6 +14,7 @@ TensorOrTupleOfTensorsGeneric, TupleOrTensorOrBoolGeneric, ) +from packaging import version from torch import device, Tensor from torch.nn import Module @@ -671,7 +672,7 @@ def _register_backward_hook( ): return module.register_backward_hook(hook) - if torch.__version__ >= "1.9": + if version.parse(torch.__version__) >= version.parse("1.9.0"): # Only supported for torch >= 1.9 return module.register_full_backward_hook(hook) else: diff --git a/captum/influence/_core/similarity_influence.py b/captum/influence/_core/similarity_influence.py index 83cb2966fa..d9bfdbd0b6 100644 --- a/captum/influence/_core/similarity_influence.py +++ b/captum/influence/_core/similarity_influence.py @@ -9,6 +9,7 @@ from captum._utils.av import AV from captum.attr import LayerActivation from captum.influence._core.influence import DataInfluence +from packaging import version from torch import Tensor from torch.nn import Module from torch.utils.data import DataLoader, Dataset @@ -40,7 +41,7 @@ def cosine_similarity(test, train, replace_nan=0) -> Tensor: test = test.view(test.shape[0], -1) train = train.view(train.shape[0], -1) - if torch.__version__ <= "1.6.0": + if version.parse(torch.__version__) <= version.parse("1.6.0"): test_norm = torch.norm(test, p=None, dim=1, keepdim=True) train_norm = torch.norm(train, p=None, dim=1, keepdim=True) else: diff --git a/captum/influence/_utils/common.py b/captum/influence/_utils/common.py index b86ddf9f93..29c1ebaa76 100644 --- a/captum/influence/_utils/common.py +++ b/captum/influence/_utils/common.py @@ -5,6 +5,8 @@ import torch import torch.nn as nn from captum._utils.progress import progress + +from packaging import version from torch import Tensor from torch.nn import Module from torch.utils.data import DataLoader, Dataset @@ -125,7 +127,7 @@ def _jacobian_loss_wrt_inputs( "Must be either 'sum' or 'mean'." ) - if torch.__version__ >= "1.8": + if version.parse(torch.__version__) >= version.parse("1.8.0"): input_jacobians = torch.autograd.functional.jacobian( lambda out: loss_fn(out, targets), out, vectorize=vectorize ) diff --git a/setup.py b/setup.py index e03ab29915..1136decac9 100755 --- a/setup.py +++ b/setup.py @@ -147,7 +147,7 @@ def get_package_files(root, subdirs): long_description=long_description, long_description_content_type="text/markdown", python_requires=">=3.6", - install_requires=["matplotlib", "numpy", "torch>=1.6"], + install_requires=["matplotlib", "numpy", "packaging", "torch>=1.6"], packages=find_packages(exclude=("tests", "tests.*")), extras_require={ "dev": DEV_REQUIRES, From 876cdcd8525d17f970b8034e0358a187e6a3c918 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Mon, 8 Aug 2022 14:55:47 -0600 Subject: [PATCH 2/3] Use `_parse_version` instead of `version.parse` --- captum/_utils/common.py | 33 ++++++++++++-- .../influence/_core/similarity_influence.py | 3 +- captum/influence/_utils/common.py | 4 +- setup.py | 2 +- tests/utils/test_common.py | 45 ++++++++++++++++++- 5 files changed, 78 insertions(+), 9 deletions(-) diff --git a/captum/_utils/common.py b/captum/_utils/common.py index 3b39925bf9..c9a125568d 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -3,7 +3,7 @@ from enum import Enum from functools import reduce from inspect import signature -from typing import Any, Callable, cast, Dict, List, overload, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Optional, overload, Tuple, Union import numpy as np import torch @@ -14,11 +14,38 @@ TensorOrTupleOfTensorsGeneric, TupleOrTensorOrBoolGeneric, ) -from packaging import version from torch import device, Tensor from torch.nn import Module +def _parse_version(v: str, length: Optional[int] = 3) -> Tuple[int, ...]: + """ + Parse version strings into tuples for comparison. + + Versions should be in the form of "..", ".", + or "". The "dev", "post" and other letter portions of the given version will + be ignored. + + Args: + + v (str): A version string. + length (int, optional): The expected length of the output tuple. If the output + is less than the expected length, then it will be padded with 0 values. Set + to None for no padding or length checks. + Default: ``3`` + + Returns: + version_tuple (tuple of int): A tuple of integer values to use for version + comparison. + """ + v = [n for n in v.split(".") if n.isdigit()] + assert v != [] + if length is not None: + v += ["0"] * (length - len(v)) + assert len(v) == length + return tuple(map(int, v)) + + class ExpansionTypes(Enum): repeat = 1 repeat_interleave = 2 @@ -672,7 +699,7 @@ def _register_backward_hook( ): return module.register_backward_hook(hook) - if version.parse(torch.__version__) >= version.parse("1.9.0"): + if _parse_version(torch.__version__) >= (1, 9, 0): # Only supported for torch >= 1.9 return module.register_full_backward_hook(hook) else: diff --git a/captum/influence/_core/similarity_influence.py b/captum/influence/_core/similarity_influence.py index d9bfdbd0b6..0fd21eedb7 100644 --- a/captum/influence/_core/similarity_influence.py +++ b/captum/influence/_core/similarity_influence.py @@ -9,7 +9,6 @@ from captum._utils.av import AV from captum.attr import LayerActivation from captum.influence._core.influence import DataInfluence -from packaging import version from torch import Tensor from torch.nn import Module from torch.utils.data import DataLoader, Dataset @@ -41,7 +40,7 @@ def cosine_similarity(test, train, replace_nan=0) -> Tensor: test = test.view(test.shape[0], -1) train = train.view(train.shape[0], -1) - if version.parse(torch.__version__) <= version.parse("1.6.0"): + if common._parse_version(torch.__version__) <= (1, 6, 0): test_norm = torch.norm(test, p=None, dim=1, keepdim=True) train_norm = torch.norm(train, p=None, dim=1, keepdim=True) else: diff --git a/captum/influence/_utils/common.py b/captum/influence/_utils/common.py index edf78fa292..c7b60529c7 100644 --- a/captum/influence/_utils/common.py +++ b/captum/influence/_utils/common.py @@ -5,9 +5,9 @@ import torch import torch.nn as nn +from captum._utils.common import _parse_version from captum._utils.progress import progress -from packaging import version from torch import Tensor from torch.nn import Module from torch.utils.data import DataLoader, Dataset @@ -128,7 +128,7 @@ def _jacobian_loss_wrt_inputs( "Must be either 'sum' or 'mean'." ) - if version.parse(torch.__version__) >= version.parse("1.8.0"): + if _parse_version(torch.__version__) >= (1, 8, 0): input_jacobians = torch.autograd.functional.jacobian( lambda out: loss_fn(out, targets), out, vectorize=vectorize ) diff --git a/setup.py b/setup.py index 1136decac9..e03ab29915 100755 --- a/setup.py +++ b/setup.py @@ -147,7 +147,7 @@ def get_package_files(root, subdirs): long_description=long_description, long_description_content_type="text/markdown", python_requires=">=3.6", - install_requires=["matplotlib", "numpy", "packaging", "torch>=1.6"], + install_requires=["matplotlib", "numpy", "torch>=1.6"], packages=find_packages(exclude=("tests", "tests.*")), extras_require={ "dev": DEV_REQUIRES, diff --git a/tests/utils/test_common.py b/tests/utils/test_common.py index 5bea797e97..9509980997 100644 --- a/tests/utils/test_common.py +++ b/tests/utils/test_common.py @@ -3,7 +3,13 @@ from typing import cast, List, Tuple import torch -from captum._utils.common import _reduce_list, _select_targets, _sort_key_list, safe_div +from captum._utils.common import ( + _parse_version, + _reduce_list, + _select_targets, + _sort_key_list, + safe_div, +) from tests.helpers.basic import assertTensorAlmostEqual, BaseTest @@ -109,3 +115,40 @@ def test_select_target_3d(self) -> None: # Verify error is raised if too many dimensions are provided. with self.assertRaises(AssertionError): _select_targets(output_tensor, (1, 2, 3)) + + +class TestParseVersion(BaseTest): + def test_parse_version_dev(self) -> None: + version_str = "1.12.0.dev20201109" + output = _parse_version(version_str) + self.assertEqual(output, (1, 12, 0)) + + def test_parse_version_post(self) -> None: + version_str = "1.3.0.post2" + output = _parse_version(version_str) + self.assertEqual(output, (1, 3, 0)) + + def test_parse_version_1_12_0(self) -> None: + version_str = "1.12.0" + output = _parse_version(version_str) + self.assertEqual(output, (1, 12, 0)) + + def test_parse_version_1_12_2(self) -> None: + version_str = "1.12.2" + output = _parse_version(version_str) + self.assertEqual(output, (1, 12, 2)) + + def test_parse_version_1_6_0(self) -> None: + version_str = "1.6.0" + output = _parse_version(version_str) + self.assertEqual(output, (1, 6, 0)) + + def test_parse_version_1_12(self) -> None: + version_str = "1.12" + output = _parse_version(version_str) + self.assertEqual(output, (1, 12, 0)) + + def test_parse_version_length(self) -> None: + version_str = "1.12.0.1" + output = _parse_version(version_str, 4) + self.assertEqual(output, (1, 12, 0, 1)) From 442352bae66c366ca624d9300a7215f051b10991 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sun, 14 Aug 2022 19:15:01 -0600 Subject: [PATCH 3/3] Remove length option from `_parse_version` --- captum/_utils/common.py | 11 ++--------- tests/utils/test_common.py | 7 +------ 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/captum/_utils/common.py b/captum/_utils/common.py index c9a125568d..1bad602896 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -3,7 +3,7 @@ from enum import Enum from functools import reduce from inspect import signature -from typing import Any, Callable, cast, Dict, List, Optional, overload, Tuple, Union +from typing import Any, Callable, cast, Dict, List, overload, Tuple, Union import numpy as np import torch @@ -18,7 +18,7 @@ from torch.nn import Module -def _parse_version(v: str, length: Optional[int] = 3) -> Tuple[int, ...]: +def _parse_version(v: str) -> Tuple[int, ...]: """ Parse version strings into tuples for comparison. @@ -29,10 +29,6 @@ def _parse_version(v: str, length: Optional[int] = 3) -> Tuple[int, ...]: Args: v (str): A version string. - length (int, optional): The expected length of the output tuple. If the output - is less than the expected length, then it will be padded with 0 values. Set - to None for no padding or length checks. - Default: ``3`` Returns: version_tuple (tuple of int): A tuple of integer values to use for version @@ -40,9 +36,6 @@ def _parse_version(v: str, length: Optional[int] = 3) -> Tuple[int, ...]: """ v = [n for n in v.split(".") if n.isdigit()] assert v != [] - if length is not None: - v += ["0"] * (length - len(v)) - assert len(v) == length return tuple(map(int, v)) diff --git a/tests/utils/test_common.py b/tests/utils/test_common.py index 9509980997..e19c3c26b9 100644 --- a/tests/utils/test_common.py +++ b/tests/utils/test_common.py @@ -146,9 +146,4 @@ def test_parse_version_1_6_0(self) -> None: def test_parse_version_1_12(self) -> None: version_str = "1.12" output = _parse_version(version_str) - self.assertEqual(output, (1, 12, 0)) - - def test_parse_version_length(self) -> None: - version_str = "1.12.0.1" - output = _parse_version(version_str, 4) - self.assertEqual(output, (1, 12, 0, 1)) + self.assertEqual(output, (1, 12))