Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1568 Enhance tests to use self random state #1599

Merged
merged 10 commits into from
Feb 21, 2021
2 changes: 1 addition & 1 deletion tests/test_rand_rotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor
self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False
)
expected = np.stack(expected).astype(np.float32)
np.testing.assert_allclose(expected, rotated[0])
np.testing.assert_allclose(expected, rotated[0], rtol=1e-2, atol=1)


class TestRandRotate3D(NumpyImageTestCase3D):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_rand_rotated.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor
self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False
)
expected = np.stack(expected).astype(np.float32)
self.assertTrue(np.allclose(expected, rotated["img"][0]))
self.assertTrue(np.allclose(expected, rotated["img"][0], rtol=1e-2, atol=1))


class TestRandRotated3D(NumpyImageTestCase3D):
Expand Down
32 changes: 28 additions & 4 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from monai.config.deviceconfig import USE_COMPILED
from monai.data import create_test_image_2d, create_test_image_3d
from monai.transforms import Randomizable
from monai.utils import ensure_tuple, optional_import, set_determinism
from monai.utils.module import get_torch_version_tuple

Expand Down Expand Up @@ -434,14 +435,25 @@ def _call_original_func(name, module, *args, **kwargs):
return f(*args, **kwargs)


class NumpyImageTestCase2D(unittest.TestCase):
class NumpyImageTestCase2D(unittest.TestCase, Randomizable):
im_shape = (128, 64)
input_channels = 1
output_channels = 4
num_classes = 3

def randomize(self, data=None):
return create_test_image_2d(
width=self.im_shape[0],
height=self.im_shape[1],
num_objs=4,
rad_max=20,
noise_max=0,
num_seg_classes=self.num_classes,
random_state=self.R,
)

def setUp(self):
im, msk = create_test_image_2d(self.im_shape[0], self.im_shape[1], 4, 20, 0, self.num_classes)
im, msk = self.randomize()

self.imt = im[None, None]
self.seg1 = (msk[None, None] > 0).astype(np.float32)
Expand All @@ -456,14 +468,26 @@ def setUp(self):
self.segn = torch.tensor(self.segn)


class NumpyImageTestCase3D(unittest.TestCase):
class NumpyImageTestCase3D(unittest.TestCase, Randomizable):
im_shape = (64, 48, 80)
input_channels = 1
output_channels = 4
num_classes = 3

def randomize(self, data=None):
return create_test_image_3d(
height=self.im_shape[0],
width=self.im_shape[1],
depth=self.im_shape[2],
num_objs=4,
rad_max=20,
noise_max=0,
num_seg_classes=self.num_classes,
random_state=self.R,
)

def setUp(self):
im, msk = create_test_image_3d(self.im_shape[0], self.im_shape[1], self.im_shape[2], 4, 20, 0, self.num_classes)
im, msk = self.randomize()

self.imt = im[None, None]
self.seg1 = (msk[None, None] > 0).astype(np.float32)
Expand Down