|
28 | 28 | from monai.bundle.scripts import _examine_monai_version, _list_latest_versions, download
|
29 | 29 | from monai.utils import optional_import
|
30 | 30 | from tests.test_utils import (
|
31 |
| - SkipIfBeforePyTorchVersion, |
32 | 31 | assert_allclose,
|
33 | 32 | command_line_tests,
|
34 | 33 | skip_if_downloading_fails,
|
35 | 34 | skip_if_no_cuda,
|
36 | 35 | skip_if_quick,
|
| 36 | + skip_if_windows, |
37 | 37 | )
|
38 | 38 |
|
39 | 39 | _, has_huggingface_hub = optional_import("huggingface_hub")
|
|
95 | 95 | {"model.pt": "27952767e2e154e3b0ee65defc5aed38", "model.ts": "97746870fe591f69ac09827175b00675"},
|
96 | 96 | ]
|
97 | 97 |
|
| 98 | +TEST_CASE_NGC_1 = [ |
| 99 | + "spleen_ct_segmentation", |
| 100 | + "0.3.7", |
| 101 | + None, |
| 102 | + "monai_spleen_ct_segmentation", |
| 103 | + "models/model.pt", |
| 104 | + "b418a2dc8672ce2fd98dc255036e7a3d", |
| 105 | +] |
| 106 | +TEST_CASE_NGC_2 = [ |
| 107 | + "monai_spleen_ct_segmentation", |
| 108 | + "0.3.7", |
| 109 | + "monai_", |
| 110 | + "spleen_ct_segmentation", |
| 111 | + "models/model.pt", |
| 112 | + "b418a2dc8672ce2fd98dc255036e7a3d", |
| 113 | +] |
| 114 | + |
| 115 | +TESTCASE_NGC_WEIGHTS = { |
| 116 | + "key": "model.0.conv.unit0.adn.N.bias", |
| 117 | + "value": torch.tensor( |
| 118 | + [ |
| 119 | + -0.0705, |
| 120 | + -0.0937, |
| 121 | + -0.0422, |
| 122 | + -0.2068, |
| 123 | + 0.1023, |
| 124 | + -0.2007, |
| 125 | + -0.0883, |
| 126 | + 0.0018, |
| 127 | + -0.1719, |
| 128 | + 0.0116, |
| 129 | + 0.0285, |
| 130 | + -0.0044, |
| 131 | + 0.1223, |
| 132 | + -0.1287, |
| 133 | + -0.1858, |
| 134 | + 0.0460, |
| 135 | + ] |
| 136 | + ), |
| 137 | +} |
| 138 | + |
98 | 139 |
|
99 | 140 | class TestDownload(unittest.TestCase):
|
100 | 141 | @parameterized.expand([TEST_CASE_1, TEST_CASE_2])
|
@@ -356,7 +397,6 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override)
|
356 | 397 |
|
357 | 398 | @parameterized.expand([TEST_CASE_9])
|
358 | 399 | @skip_if_quick
|
359 |
| - @SkipIfBeforePyTorchVersion((1, 7, 1)) |
360 | 400 | def test_load_ts_module(self, bundle_files, bundle_name, version, repo, device, model_file):
|
361 | 401 | with skip_if_downloading_fails():
|
362 | 402 | # load ts module
|
@@ -419,5 +459,31 @@ def test_url_download_large_files(self, bundle_files, bundle_name, url, hash_val
|
419 | 459 | self.assertTrue(check_hash(filepath=file_path, val=hash_val[file]))
|
420 | 460 |
|
421 | 461 |
|
| 462 | +@skip_if_windows |
| 463 | +class TestNgcBundleDownload(unittest.TestCase): |
| 464 | + @parameterized.expand([TEST_CASE_NGC_1, TEST_CASE_NGC_2]) |
| 465 | + @skip_if_quick |
| 466 | + def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download_name, file_path, hash_val): |
| 467 | + with skip_if_downloading_fails(): |
| 468 | + with tempfile.TemporaryDirectory() as tempdir: |
| 469 | + download( |
| 470 | + name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix |
| 471 | + ) |
| 472 | + full_file_path = os.path.join(tempdir, download_name, file_path) |
| 473 | + self.assertTrue(os.path.exists(full_file_path)) |
| 474 | + self.assertTrue(check_hash(filepath=full_file_path, val=hash_val)) |
| 475 | + |
| 476 | + model = load( |
| 477 | + name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix |
| 478 | + ) |
| 479 | + assert_allclose( |
| 480 | + model.state_dict()[TESTCASE_NGC_WEIGHTS["key"]], |
| 481 | + TESTCASE_NGC_WEIGHTS["value"], |
| 482 | + atol=1e-4, |
| 483 | + rtol=1e-4, |
| 484 | + type_test=False, |
| 485 | + ) |
| 486 | + |
| 487 | + |
422 | 488 | if __name__ == "__main__":
|
423 | 489 | unittest.main()
|
0 commit comments