Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 31 additions & 20 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
look_up_option,
)
from monai.utils.enums import TransformBackends
from monai.utils.type_conversion import convert_data_type
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type

__all__ = [
"SpatialPad",
Expand Down Expand Up @@ -406,19 +406,30 @@ def __init__(
self.slices = list(roi_slices)
else:
if roi_center is not None and roi_size is not None:
roi_center = torch.as_tensor(roi_center, dtype=torch.int16)
roi_size = torch.as_tensor(roi_size, dtype=torch.int16, device=roi_center.device)
roi_start_torch = maximum( # type: ignore
roi_center, *_ = convert_data_type(
data=roi_center,
output_type=torch.Tensor,
dtype=torch.int16,
wrap_sequence=True,
)
roi_size, *_ = convert_to_dst_type(src=roi_size, dst=roi_center, wrap_sequence=True)
roi_start_torch = maximum(
roi_center - floor_divide(roi_size, 2),
torch.zeros_like(roi_center),
torch.zeros_like(roi_center), # type: ignore
)
roi_end_torch = maximum(roi_start_torch + roi_size, roi_start_torch)
else:
if roi_start is None or roi_end is None:
raise ValueError("Please specify either roi_center, roi_size or roi_start, roi_end.")
roi_start_torch = torch.as_tensor(roi_start, dtype=torch.int16)
roi_start_torch, *_ = convert_data_type( # type: ignore
data=roi_start,
output_type=torch.Tensor,
dtype=torch.int16,
wrap_sequence=True,
)
roi_start_torch = maximum(roi_start_torch, torch.zeros_like(roi_start_torch)) # type: ignore
roi_end_torch = maximum(torch.as_tensor(roi_end, dtype=torch.int16), roi_start_torch)
roi_end_torch, *_ = convert_to_dst_type(src=roi_end, dst=roi_start_torch, wrap_sequence=True)
roi_end_torch = maximum(roi_end_torch, roi_start_torch)
# convert to slices (accounting for 1d)
if roi_start_torch.numel() == 1:
self.slices = [slice(int(roi_start_torch.item()), int(roi_end_torch.item()))]
Expand Down Expand Up @@ -632,7 +643,7 @@ class RandSpatialCropSamples(Randomizable, Transform):

"""

backend = RandScaleCrop.backend
backend = RandSpatialCrop.backend

def __init__(
self,
Expand Down Expand Up @@ -706,7 +717,7 @@ def __init__(
margin: Union[Sequence[int], int] = 0,
return_coords: bool = False,
k_divisible: Union[Sequence[int], int] = 1,
mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT,
mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = NumpyPadMode.CONSTANT,
**np_kwargs,
) -> None:
"""
Expand All @@ -718,10 +729,12 @@ def __init__(
return_coords: whether return the coordinates of spatial bounding box for foreground.
k_divisible: make each spatial dimension to be divisible by k, default to 1.
if `k_divisible` is an int, the same `k` be applied to all the input spatial dimensions.
mode: padding mode {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
one of the listed string values or a user supplied function. Defaults to ``"constant"``.
see also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension.
more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html

Expand All @@ -734,20 +747,18 @@ def __init__(
self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode)
self.np_kwargs = np_kwargs

def compute_bounding_box(self, img: NdarrayOrTensor) -> Tuple[np.ndarray, np.ndarray]:
def compute_bounding_box(self, img: NdarrayOrTensor):
"""
Compute the start points and end points of bounding box to crop.
And adjust bounding box coords to be divisible by `k`.

"""
box_start, box_end = generate_spatial_bounding_box(img, self.select_fn, self.channel_indices, self.margin)
box_start = [i.cpu() if isinstance(i, torch.Tensor) else i for i in box_start] # type: ignore
box_end = [i.cpu() if isinstance(i, torch.Tensor) else i for i in box_end] # type: ignore
box_start_ = np.asarray(box_start, dtype=np.int16)
box_end_ = np.asarray(box_end, dtype=np.int16)
box_start_, *_ = convert_data_type(box_start, output_type=np.ndarray, dtype=np.int16, wrap_sequence=True)
box_end_, *_ = convert_data_type(box_end, output_type=np.ndarray, dtype=np.int16, wrap_sequence=True)
orig_spatial_size = box_end_ - box_start_
# make the spatial size divisible by `k`
spatial_size = np.asarray(compute_divisible_spatial_size(spatial_shape=orig_spatial_size, k=self.k_divisible))
spatial_size = np.asarray(compute_divisible_spatial_size(orig_spatial_size.tolist(), k=self.k_divisible))
# update box_start and box_end
box_start_ = box_start_ - np.floor_divide(np.asarray(spatial_size) - orig_spatial_size, 2)
box_end_ = box_start_ + spatial_size
Expand All @@ -758,7 +769,7 @@ def crop_pad(
img: NdarrayOrTensor,
box_start: np.ndarray,
box_end: np.ndarray,
mode: Optional[Union[NumpyPadMode, str]] = None,
mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None,
):
"""
Crop and pad based on the bounding box.
Expand Down
21 changes: 11 additions & 10 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
from monai.utils import ImageMetaKey as Key
from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple
from monai.utils.enums import InverseKeys
from monai.utils.type_conversion import convert_data_type

__all__ = [
"PadModeSequence",
Expand Down Expand Up @@ -810,6 +809,8 @@ class CropForegroundd(MapTransform, InvertibleTransform):
channels. And it can also add margin to every dim of the bounding box of foreground object.
"""

backend = CropForeground.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -818,7 +819,7 @@ def __init__(
channel_indices: Optional[IndexSelection] = None,
margin: Union[Sequence[int], int] = 0,
k_divisible: Union[Sequence[int], int] = 1,
mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT,
mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = NumpyPadMode.CONSTANT,
start_coord_key: str = "foreground_start_coord",
end_coord_key: str = "foreground_end_coord",
allow_missing_keys: bool = False,
Expand All @@ -835,10 +836,12 @@ def __init__(
margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims.
k_divisible: make each spatial dimension to be divisible by k, default to 1.
if `k_divisible` is an int, the same `k` be applied to all the input spatial dimensions.
mode: padding mode {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
one of the listed string values or a user supplied function. Defaults to ``"constant"``.
see also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
it also can be a sequence of string, each element corresponds to a key in ``keys``.
start_coord_key: key to record the start coordinate of spatial bounding box for foreground.
end_coord_key: key to record the end coordinate of spatial bounding box for foreground.
Expand All @@ -860,11 +863,9 @@ def __init__(
)
self.mode = ensure_tuple_rep(mode, len(self.keys))

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
img: np.ndarray
img, *_ = convert_data_type(d[self.source_key], np.ndarray) # type: ignore
box_start, box_end = self.cropper.compute_bounding_box(img=img)
box_start, box_end = self.cropper.compute_bounding_box(img=d[self.source_key])
d[self.start_coord_key] = box_start
d[self.end_coord_key] = box_end
for key, m in self.key_iterator(d, self.mode):
Expand Down
59 changes: 37 additions & 22 deletions monai/utils/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,14 @@ def convert_to_tensor(
if data.ndim > 0:
data = np.ascontiguousarray(data)
return torch.as_tensor(data, dtype=dtype, device=device) # type: ignore
elif (
has_cp
and isinstance(data, cp_ndarray)
or isinstance(data, (float, int, bool))
or (isinstance(data, Sequence) and wrap_sequence)
):
elif (has_cp and isinstance(data, cp_ndarray)) or isinstance(data, (float, int, bool)):
return torch.as_tensor(data, dtype=dtype, device=device) # type: ignore
elif isinstance(data, list):
return [convert_to_tensor(i, dtype=dtype, device=device) for i in data]
list_ret = [convert_to_tensor(i, dtype=dtype, device=device) for i in data]
return torch.as_tensor(list_ret, dtype=dtype, device=device) if wrap_sequence else list_ret # type: ignore
elif isinstance(data, tuple):
return tuple(convert_to_tensor(i, dtype=dtype, device=device) for i in data)
tuple_ret = tuple(convert_to_tensor(i, dtype=dtype, device=device) for i in data)
return torch.as_tensor(tuple_ret, dtype=dtype, device=device) if wrap_sequence else tuple_ret # type: ignore
elif isinstance(data, dict):
return {k: convert_to_tensor(v, dtype=dtype, device=device) for k, v in data.items()}

Expand All @@ -153,12 +150,14 @@ def convert_to_numpy(data, dtype: Optional[DtypeLike] = None, wrap_sequence: boo
data = data.detach().to(dtype=get_equivalent_dtype(dtype, torch.Tensor), device="cpu").numpy()
elif has_cp and isinstance(data, cp_ndarray):
data = cp.asnumpy(data).astype(dtype)
elif isinstance(data, (np.ndarray, float, int, bool)) or (isinstance(data, Sequence) and wrap_sequence):
elif isinstance(data, (np.ndarray, float, int, bool)):
data = np.asarray(data, dtype=dtype)
elif isinstance(data, list):
return [convert_to_numpy(i, dtype=dtype) for i in data]
list_ret = [convert_to_numpy(i, dtype=dtype) for i in data]
return np.asarray(list_ret) if wrap_sequence else list_ret
elif isinstance(data, tuple):
return tuple(convert_to_numpy(i, dtype=dtype) for i in data)
tuple_ret = tuple(convert_to_numpy(i, dtype=dtype) for i in data)
return np.asarray(tuple_ret) if wrap_sequence else tuple_ret
elif isinstance(data, dict):
return {k: convert_to_numpy(v, dtype=dtype) for k, v in data.items()}

Expand All @@ -184,14 +183,14 @@ def convert_to_cupy(data, dtype, wrap_sequence: bool = True):
"""

# direct calls
if isinstance(data, (cp_ndarray, np.ndarray, torch.Tensor, float, int, bool)) or (
isinstance(data, Sequence) and wrap_sequence
):
if isinstance(data, (cp_ndarray, np.ndarray, torch.Tensor, float, int, bool)):
data = cp.asarray(data, dtype)
elif isinstance(data, list):
return [convert_to_cupy(i, dtype) for i in data]
list_ret = [convert_to_cupy(i, dtype) for i in data]
return cp.asarray(list_ret) if wrap_sequence else list_ret
elif isinstance(data, tuple):
return tuple(convert_to_cupy(i, dtype) for i in data)
tuple_ret = tuple(convert_to_cupy(i, dtype) for i in data)
return cp.asarray(tuple_ret) if wrap_sequence else tuple_ret
elif isinstance(data, dict):
return {k: convert_to_cupy(v, dtype) for k, v in data.items()}
# make it contiguous
Expand All @@ -208,6 +207,7 @@ def convert_data_type(
output_type: Optional[type] = None,
device: Optional[torch.device] = None,
dtype: Optional[Union[DtypeLike, torch.dtype]] = None,
wrap_sequence: bool = False,
) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]:
"""
Convert to `torch.Tensor`/`np.ndarray` from `torch.Tensor`/`np.ndarray`/`float`/`int` etc.
Expand All @@ -219,6 +219,8 @@ def convert_data_type(
dtype: dtype of output data. Converted to correct library type (e.g.,
`np.float32` is converted to `torch.float32` if output type is `torch.Tensor`).
If left blank, it remains unchanged.
wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`.
If `True`, then `[1, 2]` -> `array([1, 2])`.
Returns:
modified data, orig_type, orig_device

Expand Down Expand Up @@ -248,24 +250,31 @@ def convert_data_type(
dtype_ = get_equivalent_dtype(dtype or get_dtype(data), output_type)

if output_type is torch.Tensor:
data = convert_to_tensor(data, dtype=dtype_, device=device)
data = convert_to_tensor(data, dtype=dtype_, device=device, wrap_sequence=wrap_sequence)
elif output_type is np.ndarray:
data = convert_to_numpy(data, dtype=dtype_)
data = convert_to_numpy(data, dtype=dtype_, wrap_sequence=wrap_sequence)
elif has_cp and output_type is cp.ndarray:
data = convert_to_cupy(data, dtype=dtype_)
data = convert_to_cupy(data, dtype=dtype_, wrap_sequence=wrap_sequence)
else:
raise ValueError(f"Unsupported output type: {output_type}")
return data, orig_type, orig_device


def convert_to_dst_type(
src: Any, dst: NdarrayOrTensor, dtype: Optional[Union[DtypeLike, torch.dtype]] = None
src: Any, dst: NdarrayOrTensor, dtype: Optional[Union[DtypeLike, torch.dtype]] = None, wrap_sequence: bool = False
) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]:
"""
Convert source data to the same data type and device as the destination data.
If `dst` is an instance of `torch.Tensor` or its subclass, convert `src` to `torch.Tensor` with the same data type as `dst`,
if `dst` is an instance of `numpy.ndarray` or its subclass, convert to `numpy.ndarray` with the same data type as `dst`,
otherwise, convert to the type of `dst` directly.
`dtype` is an optional argument if the target `dtype` is different from the original `dst`'s data type.

Args:
src: sourse data to convert type.
dst: destination data that convert to the same data type as it.
dtype: an optional argument if the target `dtype` is different from the original `dst`'s data type.
wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`.
If `True`, then `[1, 2]` -> `array([1, 2])`.

See Also:
:func:`convert_data_type`
Expand All @@ -281,4 +290,10 @@ def convert_to_dst_type(
output_type = np.ndarray
else:
output_type = type(dst)
return convert_data_type(data=src, output_type=output_type, device=device, dtype=dtype)
return convert_data_type(
data=src,
output_type=output_type,
device=device,
dtype=dtype,
wrap_sequence=wrap_sequence,
)
27 changes: 27 additions & 0 deletions tests/test_convert_data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@
for out_type in TEST_NDARRAYS:
TESTS.append((in_type(np.array(1.0)), out_type(np.array(1.0)))) # type: ignore

TESTS_LIST: List[Tuple] = []
for in_type in TEST_NDARRAYS + (int, float):
for out_type in TEST_NDARRAYS:
TESTS_LIST.append(
([in_type(np.array(1.0)), in_type(np.array(1.0))], out_type(np.array([1.0, 1.0])), True), # type: ignore
)
TESTS_LIST.append(
(
[in_type(np.array(1.0)), in_type(np.array(1.0))], # type: ignore
[out_type(np.array(1.0)), out_type(np.array(1.0))],
False,
)
)


class TestTensor(torch.Tensor):
pass
Expand Down Expand Up @@ -51,6 +65,19 @@ def test_ill_arg(self):
convert_data_type(None, torch.Tensor)
convert_data_type(None, np.ndarray)

@parameterized.expand(TESTS_LIST)
def test_convert_list(self, in_image, im_out, wrap):
output_type = type(im_out) if wrap else type(im_out[0])
converted_im, *_ = convert_data_type(in_image, output_type, wrap_sequence=wrap)
# check output is desired type
if not wrap:
converted_im = converted_im[0]
im_out = im_out[0]
self.assertEqual(type(converted_im), type(im_out))
# check dtype is unchanged
if isinstance(in_type, (np.ndarray, torch.Tensor)):
self.assertEqual(converted_im.dtype, im_out.dtype)


class TestConvertDataSame(unittest.TestCase):
# add test for subclass of Tensor
Expand Down