From 4476fa88ea66c0f242fb8ff22c0d88013f320b14 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Wed, 11 May 2022 18:28:14 -0600 Subject: [PATCH 1/2] Fix version check bug --- tests/utils/test_sample_gradient.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/utils/test_sample_gradient.py b/tests/utils/test_sample_gradient.py index 8f49235e72..75aea9e11e 100644 --- a/tests/utils/test_sample_gradient.py +++ b/tests/utils/test_sample_gradient.py @@ -5,6 +5,7 @@ import torch from captum._utils.sample_gradient import SampleGradientWrapper, SUPPORTED_MODULES +from packaging import version from tests.helpers.basic import assertTensorAlmostEqual, BaseTest from tests.helpers.basic_models import ( BasicModel_ConvNet_One_Conv, @@ -37,7 +38,7 @@ def test_sample_grads_conv_mean_multi_inp(self) -> None: self._compare_sample_grads_per_sample(model, inp, lambda x: torch.mean(x)) def test_sample_grads_modified_conv_mean(self) -> None: - if torch.__version__ < "1.8": + if version.parse(torch.__version__) <= version.parse("1.8.0"): raise unittest.SkipTest( "Skipping sample gradient test with 3D linear module" "since torch version < 1.8" @@ -50,7 +51,7 @@ def test_sample_grads_modified_conv_mean(self) -> None: ) def test_sample_grads_modified_conv_sum(self) -> None: - if torch.__version__ < "1.8": + if version.parse(torch.__version__) <= version.parse("1.8.0"): raise unittest.SkipTest( "Skipping sample gradient test with 3D linear module" "since torch version < 1.8" From 6e4a7073ac113e64b8bc223446820544eed1bf35 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Wed, 11 May 2022 18:30:26 -0600 Subject: [PATCH 2/2] Fix version check --- tests/utils/test_sample_gradient.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/utils/test_sample_gradient.py b/tests/utils/test_sample_gradient.py index 75aea9e11e..8f8279b678 100644 --- a/tests/utils/test_sample_gradient.py +++ b/tests/utils/test_sample_gradient.py @@ -38,7 +38,7 @@ def test_sample_grads_conv_mean_multi_inp(self) -> None: self._compare_sample_grads_per_sample(model, inp, lambda x: torch.mean(x)) def test_sample_grads_modified_conv_mean(self) -> None: - if version.parse(torch.__version__) <= version.parse("1.8.0"): + if version.parse(torch.__version__) < version.parse("1.8.0"): raise unittest.SkipTest( "Skipping sample gradient test with 3D linear module" "since torch version < 1.8" @@ -51,7 +51,7 @@ def test_sample_grads_modified_conv_mean(self) -> None: ) def test_sample_grads_modified_conv_sum(self) -> None: - if version.parse(torch.__version__) <= version.parse("1.8.0"): + if version.parse(torch.__version__) < version.parse("1.8.0"): raise unittest.SkipTest( "Skipping sample gradient test with 3D linear module" "since torch version < 1.8"