Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/dev' into sliding-windows
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli committed Apr 2, 2023
2 parents e85efe9 + 05533ab commit 3c05d17
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 19 deletions.
4 changes: 4 additions & 0 deletions monai/apps/auto3dseg/auto_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,10 @@ def run(self):
self.datalist_filename, self.dataroot, output_path=self.datastats_filename, **self.analyze_params
)
da.get_all_case_stats()

da = None # type: ignore
torch.cuda.empty_cache()

self.export_cache(analyze=True, datastats=self.datastats_filename)
else:
logger.info("Skipping data analysis...")
Expand Down
2 changes: 1 addition & 1 deletion monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def sliding_window_inference(
[slice(idx // num_win, idx // num_win + 1), slice(None)] + list(slices[idx % num_win])
for idx in slice_range
]
if len(unravel_slice) > 1:
if sw_batch_size > 1:
win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
else:
win_data = inputs[unravel_slice[0]].to(sw_device)
Expand Down
18 changes: 4 additions & 14 deletions monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def use_factory(fact_args):
from collections.abc import Callable
from typing import Any

import torch
import torch.nn as nn

from monai.utils import look_up_option, optional_import
Expand Down Expand Up @@ -262,22 +261,13 @@ def instance_nvfuser_factory(dim):
https://github.com/NVIDIA/apex#installation
"""
types = (nn.InstanceNorm1d, nn.InstanceNorm2d)

if dim != 3:
types = (nn.InstanceNorm1d, nn.InstanceNorm2d)
warnings.warn(f"`InstanceNorm3dNVFuser` only supports 3d cases, use {types[dim - 1]} instead.")
return types[dim - 1]
# test InstanceNorm3dNVFuser installation with a basic example
has_nvfuser_flag = has_nvfuser
if not torch.cuda.is_available():
return nn.InstanceNorm3d
try:
layer = InstanceNorm3dNVFuser(num_features=1, affine=True).to("cuda:0")
inp = torch.randn([1, 1, 1, 1, 1]).to("cuda:0")
out = layer(inp)
del inp, out, layer
except Exception:
has_nvfuser_flag = False
if not has_nvfuser_flag:

if not has_nvfuser:
warnings.warn(
"`apex.normalization.InstanceNorm3dNVFuser` is not installed properly, use nn.InstanceNorm3d instead."
)
Expand Down
19 changes: 15 additions & 4 deletions tests/test_dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

from monai.networks import eval_mode
from monai.networks.nets import DynUNet
from tests.utils import assert_allclose, test_script_save
from monai.utils import optional_import
from tests.utils import assert_allclose, skip_if_no_cuda, skip_if_windows, test_script_save

InstanceNorm3dNVFuser, _ = optional_import("apex.normalization", name="InstanceNorm3dNVFuser")

device = "cuda" if torch.cuda.is_available() else "cpu"

Expand Down Expand Up @@ -122,10 +125,18 @@ def test_script(self):
test_script_save(net, test_data)


# @skip_if_no_cuda
# @skip_if_windows
@unittest.skip("temporary skip for 22.12/23.02")
@skip_if_no_cuda
@skip_if_windows
class TestDynUNetWithInstanceNorm3dNVFuser(unittest.TestCase):
def setUp(self):
try:
layer = InstanceNorm3dNVFuser(num_features=1, affine=False).to("cuda:0")
inp = torch.randn([1, 1, 1, 1, 1]).to("cuda:0")
out = layer(inp)
del inp, out, layer
except Exception:
self.skipTest("NVFuser not available")

@parameterized.expand([TEST_CASE_DYNUNET_3D[0]])
def test_consistency(self, input_param, input_shape, _):
for eps in [1e-4, 1e-5]:
Expand Down

0 comments on commit 3c05d17

Please sign in to comment.