diff --git a/tests/test_keep_largest_connected_component.py b/tests/test_keep_largest_connected_component.py index 87a5a81b75..5c96b62131 100644 --- a/tests/test_keep_largest_connected_component.py +++ b/tests/test_keep_largest_connected_component.py @@ -10,6 +10,7 @@ # limitations under the License. import unittest +from copy import deepcopy import torch import torch.nn.functional as F @@ -18,7 +19,7 @@ from monai.transforms import KeepLargestConnectedComponent from monai.transforms.utils_pytorch_numpy_unification import moveaxis from monai.utils.type_conversion import convert_to_dst_type -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose def to_onehot(x): @@ -350,6 +351,21 @@ def test_correct_results(self, _, args, input_image, expected): converter = KeepLargestConnectedComponent(**args) result = converter(input_image) assert_allclose(result, expected, type_test=False) + + @parameterized.expand(TESTS) + @SkipIfBeforePyTorchVersion((1, 7)) + def test_correct_results_before_after_onehot(self, _, args, input_image, expected): + """ + From torch==1.7, torch.argmax changes its mechanism that if there are multiple maximal values then the + indices of the first maximal value are returned (before this version, the indices of the last maximal value + are returned). + Therefore, we can may use of this changes to convert the onehotted labels into un-onehot format directly + and then check if the result stays the same. + + """ + converter = KeepLargestConnectedComponent(**args) + result = converter(deepcopy(input_image)) + if "is_onehot" in args: args["is_onehot"] = not args["is_onehot"] # if not onehotted, onehot it and make sure result stays the same