Skip to content

Commit cb200af

Browse files
ericspodpre-commit-ci[bot]KumoLiu
authored
Tests Cleanup (#8535)
Fixes #8534. ### Description This cleans up the tests further as discussed in the issue. This also additionally updates the zarr tests to save the .zarr directories to a temp directory rather than the current one, this is just cleaner when running local tests. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent 1e6c661 commit cb200af

File tree

71 files changed

+244
-584
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+244
-584
lines changed

tests/__init__.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,3 @@
88
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
11-
12-
from __future__ import annotations
13-
14-
import sys
15-
import unittest
16-
import warnings
17-
18-
19-
def _enter_pr_4800(self):
20-
"""
21-
code from https://github.com/python/cpython/pull/4800
22-
"""
23-
# The __warningregistry__'s need to be in a pristine state for tests
24-
# to work properly.
25-
for v in list(sys.modules.values()):
26-
if getattr(v, "__warningregistry__", None):
27-
v.__warningregistry__ = {}
28-
self.warnings_manager = warnings.catch_warnings(record=True)
29-
self.warnings = self.warnings_manager.__enter__()
30-
warnings.simplefilter("always", self.expected)
31-
return self
32-
33-
34-
# FIXME: workaround for https://bugs.python.org/issue29620
35-
try:
36-
# Suppression for issue #494: tests/__init__.py:34: error: Cannot assign to a method
37-
unittest.case._AssertWarnsContext.__enter__ = _enter_pr_4800 # type: ignore
38-
except AttributeError:
39-
pass

tests/apps/detection/networks/test_retinanet.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from monai.networks import eval_mode
2121
from monai.networks.nets import resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200
2222
from monai.utils import ensure_tuple, optional_import
23-
from tests.test_utils import SkipIfBeforePyTorchVersion, dict_product, skip_if_quick, test_onnx_save, test_script_save
23+
from tests.test_utils import dict_product, skip_if_quick, test_onnx_save, test_script_save
2424

2525
_, has_torchvision = optional_import("torchvision")
2626

@@ -94,7 +94,6 @@
9494
TEST_CASES_TS = [[params["model"], *params["case"]] for params in dict_product(model=MODEL_LIST, case=[TEST_CASE_1])]
9595

9696

97-
@SkipIfBeforePyTorchVersion((1, 12))
9897
@unittest.skipUnless(has_torchvision, "Requires torchvision")
9998
@skip_if_quick
10099
class TestRetinaNet(unittest.TestCase):

tests/apps/detection/networks/test_retinanet_detector.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from monai.apps.detection.utils.anchor_utils import AnchorGeneratorWithAnchorShape
2222
from monai.networks import eval_mode, train_mode
2323
from monai.utils import optional_import
24-
from tests.test_utils import SkipIfBeforePyTorchVersion, skip_if_quick, test_script_save
24+
from tests.test_utils import skip_if_quick, test_script_save
2525

2626
_, has_torchvision = optional_import("torchvision")
2727

@@ -110,7 +110,6 @@ def forward(self, images):
110110
return {self.cls_key: [torch.randn(out_cls_shape)], self.box_reg_key: [torch.randn(out_box_reg_shape)]}
111111

112112

113-
@SkipIfBeforePyTorchVersion((1, 11))
114113
@unittest.skipUnless(has_torchvision, "Requires torchvision")
115114
@skip_if_quick
116115
class TestRetinaNetDetector(unittest.TestCase):

tests/apps/detection/utils/test_anchor_box.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from monai.apps.detection.utils.anchor_utils import AnchorGenerator, AnchorGeneratorWithAnchorShape
2020
from monai.utils import optional_import
21-
from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose, test_script_save
21+
from tests.test_utils import assert_allclose, test_script_save
2222

2323
_, has_torchvision = optional_import("torchvision")
2424

@@ -39,7 +39,6 @@
3939
]
4040

4141

42-
@SkipIfBeforePyTorchVersion((1, 11))
4342
@unittest.skipUnless(has_torchvision, "Requires torchvision")
4443
class TestAnchorGenerator(unittest.TestCase):
4544
@parameterized.expand(TEST_CASES_2D)

tests/apps/maisi/networks/test_autoencoderkl_maisi.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi
2020
from monai.networks import eval_mode
2121
from monai.utils import optional_import
22-
from tests.test_utils import SkipIfBeforePyTorchVersion
2322

2423
tqdm, has_tqdm = optional_import("tqdm", name="tqdm")
2524
_, has_einops = optional_import("einops")
@@ -87,7 +86,6 @@ def test_shape(self, input_param, input_shape, expected_shape, expected_latent_s
8786
self.assertEqual(result[2].shape, expected_latent_shape)
8887

8988
@parameterized.expand(CASES)
90-
@SkipIfBeforePyTorchVersion((1, 11))
9189
def test_shape_with_convtranspose_and_checkpointing(
9290
self, input_param, input_shape, expected_shape, expected_latent_shape
9391
):
@@ -152,7 +150,6 @@ def test_shape_reconstruction(self):
152150
result = net.reconstruct(torch.randn(input_shape).to(device))
153151
self.assertEqual(result.shape, expected_shape)
154152

155-
@SkipIfBeforePyTorchVersion((1, 11))
156153
def test_shape_reconstruction_with_convtranspose_and_checkpointing(self):
157154
input_param, input_shape, expected_shape, _ = CASES[0]
158155
input_param = input_param.copy()
@@ -170,7 +167,6 @@ def test_shape_encode(self):
170167
self.assertEqual(result[0].shape, expected_latent_shape)
171168
self.assertEqual(result[1].shape, expected_latent_shape)
172169

173-
@SkipIfBeforePyTorchVersion((1, 11))
174170
def test_shape_encode_with_convtranspose_and_checkpointing(self):
175171
input_param, input_shape, _, expected_latent_shape = CASES[0]
176172
input_param = input_param.copy()
@@ -190,7 +186,6 @@ def test_shape_sampling(self):
190186
)
191187
self.assertEqual(result.shape, expected_latent_shape)
192188

193-
@SkipIfBeforePyTorchVersion((1, 11))
194189
def test_shape_sampling_convtranspose_and_checkpointing(self):
195190
input_param, _, _, expected_latent_shape = CASES[0]
196191
input_param = input_param.copy()
@@ -209,7 +204,6 @@ def test_shape_decode(self):
209204
result = net.decode(torch.randn(latent_shape).to(device))
210205
self.assertEqual(result.shape, expected_input_shape)
211206

212-
@SkipIfBeforePyTorchVersion((1, 11))
213207
def test_shape_decode_convtranspose_and_checkpointing(self):
214208
input_param, expected_input_shape, _, latent_shape = CASES[0]
215209
input_param = input_param.copy()

tests/apps/maisi/networks/test_controlnet_maisi.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from monai.apps.generation.maisi.networks.controlnet_maisi import ControlNetMaisi
2121
from monai.networks import eval_mode
2222
from monai.utils import optional_import
23-
from tests.test_utils import SkipIfBeforePyTorchVersion
2423

2524
_, has_einops = optional_import("einops")
2625

@@ -127,7 +126,6 @@
127126
]
128127

129128

130-
@SkipIfBeforePyTorchVersion((2, 0))
131129
class TestControlNet(unittest.TestCase):
132130
@parameterized.expand(TEST_CASES)
133131
@skipUnless(has_einops, "Requires einops")
File renamed without changes.

tests/apps/test_auto3dseg_bundlegen.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,7 @@
2626
from monai.bundle.config_parser import ConfigParser
2727
from monai.data import create_test_image_3d
2828
from monai.utils import set_determinism
29-
from tests.test_utils import (
30-
SkipIfBeforePyTorchVersion,
31-
get_testing_algo_template_path,
32-
skip_if_downloading_fails,
33-
skip_if_no_cuda,
34-
skip_if_quick,
35-
)
29+
from tests.test_utils import get_testing_algo_template_path, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick
3630

3731
num_images_perfold = max(torch.cuda.device_count(), 4)
3832
num_images_per_batch = 2
@@ -104,7 +98,6 @@ def run_auto3dseg_before_bundlegen(test_path, work_dir):
10498

10599

106100
@skip_if_no_cuda
107-
@SkipIfBeforePyTorchVersion((1, 11, 1))
108101
@skip_if_quick
109102
class TestBundleGen(unittest.TestCase):
110103
def setUp(self) -> None:

tests/apps/vista3d/test_point_based_window_inferer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from monai.networks import eval_mode
2121
from monai.networks.nets.vista3d import vista3d132
2222
from monai.utils import optional_import
23-
from tests.test_utils import SkipIfBeforePyTorchVersion, skip_if_quick
23+
from tests.test_utils import skip_if_quick
2424

2525
device = "cuda" if torch.cuda.is_available() else "cpu"
2626

@@ -60,7 +60,6 @@
6060
]
6161

6262

63-
@SkipIfBeforePyTorchVersion((1, 11))
6463
@skip_if_quick
6564
class TestPointBasedWindowInferer(unittest.TestCase):
6665
@parameterized.expand(TEST_CASES)

tests/bundle/test_bundle_download.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@
2828
from monai.bundle.scripts import _examine_monai_version, _list_latest_versions, download
2929
from monai.utils import optional_import
3030
from tests.test_utils import (
31-
SkipIfBeforePyTorchVersion,
3231
assert_allclose,
3332
command_line_tests,
3433
skip_if_downloading_fails,
3534
skip_if_no_cuda,
3635
skip_if_quick,
36+
skip_if_windows,
3737
)
3838

3939
_, has_huggingface_hub = optional_import("huggingface_hub")
@@ -95,6 +95,47 @@
9595
{"model.pt": "27952767e2e154e3b0ee65defc5aed38", "model.ts": "97746870fe591f69ac09827175b00675"},
9696
]
9797

98+
TEST_CASE_NGC_1 = [
99+
"spleen_ct_segmentation",
100+
"0.3.7",
101+
None,
102+
"monai_spleen_ct_segmentation",
103+
"models/model.pt",
104+
"b418a2dc8672ce2fd98dc255036e7a3d",
105+
]
106+
TEST_CASE_NGC_2 = [
107+
"monai_spleen_ct_segmentation",
108+
"0.3.7",
109+
"monai_",
110+
"spleen_ct_segmentation",
111+
"models/model.pt",
112+
"b418a2dc8672ce2fd98dc255036e7a3d",
113+
]
114+
115+
TESTCASE_NGC_WEIGHTS = {
116+
"key": "model.0.conv.unit0.adn.N.bias",
117+
"value": torch.tensor(
118+
[
119+
-0.0705,
120+
-0.0937,
121+
-0.0422,
122+
-0.2068,
123+
0.1023,
124+
-0.2007,
125+
-0.0883,
126+
0.0018,
127+
-0.1719,
128+
0.0116,
129+
0.0285,
130+
-0.0044,
131+
0.1223,
132+
-0.1287,
133+
-0.1858,
134+
0.0460,
135+
]
136+
),
137+
}
138+
98139

99140
class TestDownload(unittest.TestCase):
100141
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
@@ -356,7 +397,6 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override)
356397

357398
@parameterized.expand([TEST_CASE_9])
358399
@skip_if_quick
359-
@SkipIfBeforePyTorchVersion((1, 7, 1))
360400
def test_load_ts_module(self, bundle_files, bundle_name, version, repo, device, model_file):
361401
with skip_if_downloading_fails():
362402
# load ts module
@@ -419,5 +459,31 @@ def test_url_download_large_files(self, bundle_files, bundle_name, url, hash_val
419459
self.assertTrue(check_hash(filepath=file_path, val=hash_val[file]))
420460

421461

462+
@skip_if_windows
463+
class TestNgcBundleDownload(unittest.TestCase):
464+
@parameterized.expand([TEST_CASE_NGC_1, TEST_CASE_NGC_2])
465+
@skip_if_quick
466+
def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download_name, file_path, hash_val):
467+
with skip_if_downloading_fails():
468+
with tempfile.TemporaryDirectory() as tempdir:
469+
download(
470+
name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix
471+
)
472+
full_file_path = os.path.join(tempdir, download_name, file_path)
473+
self.assertTrue(os.path.exists(full_file_path))
474+
self.assertTrue(check_hash(filepath=full_file_path, val=hash_val))
475+
476+
model = load(
477+
name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix
478+
)
479+
assert_allclose(
480+
model.state_dict()[TESTCASE_NGC_WEIGHTS["key"]],
481+
TESTCASE_NGC_WEIGHTS["value"],
482+
atol=1e-4,
483+
rtol=1e-4,
484+
type_test=False,
485+
)
486+
487+
422488
if __name__ == "__main__":
423489
unittest.main()

0 commit comments

Comments
 (0)