Skip to content

Commit 42844e8

Browse files
committed
refactoring
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent 5d7085d commit 42844e8

File tree

3 files changed

+30
-25
lines changed

3 files changed

+30
-25
lines changed

monai/transforms/croppad/array.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,26 @@ def _pt_pad(img: torch.Tensor, pad_width, mode, **kwargs) -> torch.Tensor:
141141
# torch.pad expects `[B, C, H, W, [D]]` shape
142142
return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **kwargs).squeeze(0)
143143

144+
@staticmethod
145+
def pad_nd(img_t, to_pad_, mode, **kwargs):
146+
if mode in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}:
147+
return Pad._np_pad(img_t, pad_width=to_pad_, mode=mode, **kwargs)
148+
mode = convert_pad_mode(dst=img_t, mode=mode).value
149+
try:
150+
_pad = (
151+
Pad._pt_pad
152+
if mode in {"reflect", "replicate"}
153+
and img_t.dtype not in {torch.int16, torch.int64, torch.bool, torch.uint8}
154+
else Pad._np_pad
155+
)
156+
return _pad(img_t, pad_width=to_pad_, mode=mode, **kwargs)
157+
except (ValueError, TypeError, RuntimeError) as err:
158+
if isinstance(err, NotImplementedError) or any(
159+
k in str(err) for k in ("supported", "unexpected keyword", "implemented")
160+
):
161+
return Pad._np_pad(img_t, pad_width=to_pad_, mode=mode, **kwargs)
162+
raise ValueError(f"{img_t.shape} {to_pad_} {mode} {kwargs} {img_t.dtype} {img_t.device}") from err
163+
144164
def __call__( # type: ignore
145165
self, img: torch.Tensor, to_pad: list[tuple[int, int]] | None = None, mode: str | None = None, **kwargs
146166
) -> torch.Tensor:

monai/transforms/croppad/functional.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from monai.data.meta_obj import get_track_meta
2323
from monai.data.meta_tensor import MetaTensor
2424
from monai.transforms.inverse import TraceableTransform
25-
from monai.transforms.utils import convert_pad_mode, create_translate
25+
from monai.transforms.utils import create_translate
2626
from monai.utils import TraceKeys, convert_to_dst_type, ensure_tuple
2727

2828
__all__ = ["pad_func", "crop_func"]
@@ -37,7 +37,9 @@ def pad_func(img_t, to_pad_, mode, kwargs, transform_info):
3737
else torch.eye(4, device=torch.device("cpu"), dtype=torch.float64)
3838
)
3939
spatial_rank = max(len(_affine) - 1, 1)
40-
if np.asarray(to_pad_).any():
40+
if not np.asarray(to_pad_).any():
41+
out = img_t
42+
else:
4143
to_pad_ = list(to_pad_)
4244
if len(to_pad_) < len(img_t.shape):
4345
to_pad_ = list(to_pad_) + [(0, 0)] * (len(img_t.shape) - len(to_pad_))
@@ -55,27 +57,7 @@ def pad_func(img_t, to_pad_, mode, kwargs, transform_info):
5557
extra_info=extra_info,
5658
transform_info=transform_info,
5759
)
58-
if mode in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}:
59-
out = monai.transforms.Pad._np_pad(img_t, pad_width=to_pad_, mode=mode, **kwargs)
60-
else:
61-
mode = convert_pad_mode(dst=img_t, mode=mode).value
62-
try:
63-
_pad = (
64-
monai.transforms.Pad._pt_pad
65-
if mode in {"reflect", "replicate"}
66-
and img_t.dtype not in {torch.int16, torch.int64, torch.bool, torch.uint8}
67-
else monai.transforms.Pad._np_pad
68-
)
69-
out = _pad(img_t, pad_width=to_pad_, mode=mode, **kwargs)
70-
except (ValueError, TypeError, RuntimeError) as err:
71-
if isinstance(err, NotImplementedError) or any(
72-
k in str(err) for k in ("supported", "unexpected keyword", "implemented")
73-
):
74-
out = monai.transforms.Pad._np_pad(img_t, pad_width=to_pad_, mode=mode, **kwargs)
75-
else:
76-
raise ValueError(f"{img_t.shape} {to_pad_} {mode} {kwargs} {img_t.dtype} {img_t.device}") from err
77-
else:
78-
out = img_t
60+
out = monai.transforms.Pad.pad_nd(img_t, to_pad_, mode, **kwargs)
7961
if get_track_meta():
8062
to_shift = [-s[0] for s in to_pad_[1:]] # skipping the channel pad
8163
out.affine @= convert_to_dst_type(create_translate(spatial_rank, to_shift), _affine)[0] # type: ignore

monai/transforms/spatial/array.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def __call__(
470470
affine_: np.ndarray
471471
if affine is not None:
472472
warnings.warn("arg `affine` is deprecated, the affine of MetaTensor in data_array has higher priority.")
473-
input_affine = data_array.affine if isinstance(data_array, MetaTensor) else affine
473+
input_affine = data_array.peek_pending_affine() if isinstance(data_array, MetaTensor) else affine
474474
if input_affine is None:
475475
warnings.warn("`data_array` is not of type MetaTensor, assuming affine to be identity.")
476476
# default to identity
@@ -517,7 +517,10 @@ def __call__(
517517
dtype=dtype,
518518
)
519519
if self.recompute_affine and isinstance(data_array, MetaTensor):
520-
data_array.affine = scale_affine(affine_, original_spatial_shape, actual_shape)
520+
if not self.lazy_evaluation:
521+
data_array.affine = scale_affine(affine_, original_spatial_shape, actual_shape)
522+
else:
523+
raise NotImplementedError("recompute_affine is not supported with lazy evaluation.")
521524
return data_array
522525

523526
def inverse(self, data: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)