diff --git a/tests/utils/test_sample_gradient.py b/tests/utils/test_sample_gradient.py index 8f49235e72..8f8279b678 100644 --- a/tests/utils/test_sample_gradient.py +++ b/tests/utils/test_sample_gradient.py @@ -5,6 +5,7 @@ import torch from captum._utils.sample_gradient import SampleGradientWrapper, SUPPORTED_MODULES +from packaging import version from tests.helpers.basic import assertTensorAlmostEqual, BaseTest from tests.helpers.basic_models import ( BasicModel_ConvNet_One_Conv, @@ -37,7 +38,7 @@ def test_sample_grads_conv_mean_multi_inp(self) -> None: self._compare_sample_grads_per_sample(model, inp, lambda x: torch.mean(x)) def test_sample_grads_modified_conv_mean(self) -> None: - if torch.__version__ < "1.8": + if version.parse(torch.__version__) < version.parse("1.8.0"): raise unittest.SkipTest( "Skipping sample gradient test with 3D linear module" "since torch version < 1.8" @@ -50,7 +51,7 @@ def test_sample_grads_modified_conv_mean(self) -> None: ) def test_sample_grads_modified_conv_sum(self) -> None: - if torch.__version__ < "1.8": + if version.parse(torch.__version__) < version.parse("1.8.0"): raise unittest.SkipTest( "Skipping sample gradient test with 3D linear module" "since torch version < 1.8"