2222from monai .data .meta_obj import get_track_meta
2323from monai .data .meta_tensor import MetaTensor
2424from monai .transforms .inverse import TraceableTransform
25- from monai .transforms .utils import convert_pad_mode , create_translate
25+ from monai .transforms .utils import create_translate
2626from monai .utils import TraceKeys , convert_to_dst_type , ensure_tuple
2727
2828__all__ = ["pad_func" , "crop_func" ]
@@ -37,7 +37,9 @@ def pad_func(img_t, to_pad_, mode, kwargs, transform_info):
3737 else torch .eye (4 , device = torch .device ("cpu" ), dtype = torch .float64 )
3838 )
3939 spatial_rank = max (len (_affine ) - 1 , 1 )
40- if np .asarray (to_pad_ ).any ():
40+ if not np .asarray (to_pad_ ).any ():
41+ out = img_t
42+ else :
4143 to_pad_ = list (to_pad_ )
4244 if len (to_pad_ ) < len (img_t .shape ):
4345 to_pad_ = list (to_pad_ ) + [(0 , 0 )] * (len (img_t .shape ) - len (to_pad_ ))
@@ -55,27 +57,7 @@ def pad_func(img_t, to_pad_, mode, kwargs, transform_info):
5557 extra_info = extra_info ,
5658 transform_info = transform_info ,
5759 )
58- if mode in {"linear_ramp" , "maximum" , "mean" , "median" , "minimum" , "symmetric" , "empty" }:
59- out = monai .transforms .Pad ._np_pad (img_t , pad_width = to_pad_ , mode = mode , ** kwargs )
60- else :
61- mode = convert_pad_mode (dst = img_t , mode = mode ).value
62- try :
63- _pad = (
64- monai .transforms .Pad ._pt_pad
65- if mode in {"reflect" , "replicate" }
66- and img_t .dtype not in {torch .int16 , torch .int64 , torch .bool , torch .uint8 }
67- else monai .transforms .Pad ._np_pad
68- )
69- out = _pad (img_t , pad_width = to_pad_ , mode = mode , ** kwargs )
70- except (ValueError , TypeError , RuntimeError ) as err :
71- if isinstance (err , NotImplementedError ) or any (
72- k in str (err ) for k in ("supported" , "unexpected keyword" , "implemented" )
73- ):
74- out = monai .transforms .Pad ._np_pad (img_t , pad_width = to_pad_ , mode = mode , ** kwargs )
75- else :
76- raise ValueError (f"{ img_t .shape } { to_pad_ } { mode } { kwargs } { img_t .dtype } { img_t .device } " ) from err
77- else :
78- out = img_t
60+ out = monai .transforms .Pad .pad_nd (img_t , to_pad_ , mode , ** kwargs )
7961 if get_track_meta ():
8062 to_shift = [- s [0 ] for s in to_pad_ [1 :]] # skipping the channel pad
8163 out .affine @= convert_to_dst_type (create_translate (spatial_rank , to_shift ), _affine )[0 ] # type: ignore
0 commit comments