diff --git a/monai/apps/deepgrow/dataset.py b/monai/apps/deepgrow/dataset.py index acaeba0bc3..b0095c96e6 100644 --- a/monai/apps/deepgrow/dataset.py +++ b/monai/apps/deepgrow/dataset.py @@ -17,6 +17,7 @@ from monai.transforms import AsChannelFirstd, Compose, LoadImaged, Orientationd, Spacingd from monai.utils import GridSampleMode +from monai.utils.misc import convert_data_type def create_dataset( @@ -82,7 +83,7 @@ def create_dataset( if not len(datalist): raise ValueError("Input datalist is empty") - transforms = _default_transforms(image_key, label_key, pixdim) if transforms is None else transforms + transforms = transforms or _default_transforms(image_key, label_key, pixdim) new_datalist = [] for idx in range(len(datalist)): if limit and idx >= limit: @@ -133,25 +134,34 @@ def _default_transforms(image_key, label_key, pixdim): def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): + if vol_image is not None: + vol_image_np, *_ = convert_data_type(vol_image, np.ndarray) + else: + vol_image_np = vol_image + if vol_label is not None: + vol_label_np, *_ = convert_data_type(vol_label, np.ndarray) + else: + vol_label_np = vol_label + data_list = [] - if len(vol_image.shape) == 4: + if len(vol_image_np.shape) == 4: logging.info( "4D-Image, pick only first series; Image: {}; Label: {}".format( - vol_image.shape, vol_label.shape if vol_label is not None else None + vol_image_np.shape, vol_label_np.shape if vol_label_np is not None else None ) ) - vol_image = vol_image[0] - vol_image = np.moveaxis(vol_image, -1, 0) + vol_image_np = vol_image_np[0] + vol_image_np = np.moveaxis(vol_image_np, -1, 0) image_count = 0 label_count = 0 unique_labels_count = 0 - for sid in range(vol_image.shape[0]): - image = vol_image[sid, ...] - label = vol_label[sid, ...] if vol_label is not None else None + for sid in range(vol_image_np.shape[0]): + image = vol_image_np[sid, ...] + label = vol_label_np[sid, ...] if vol_label_np is not None else None - if vol_label is not None and np.sum(label) == 0: + if vol_label_np is not None and np.sum(label) == 0: continue image_file_prefix = "vol_idx_{:0>4d}_slice_{:0>3d}".format(vol_idx, sid) @@ -163,7 +173,7 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): image_count += 1 # Test Data - if vol_label is None: + if vol_label_np is None: data_list.append( { "image": image_file.replace(dataset_dir + os.pathsep, "") if relative_path else image_file, @@ -200,9 +210,9 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): logging.info( "{} => Image Shape: {} => {}; Label Shape: {} => {}; Unique Labels: {}".format( vol_idx, - vol_image.shape, + vol_image_np.shape, image_count, - vol_label.shape if vol_label is not None else None, + vol_label_np.shape if vol_label_np is not None else None, label_count, unique_labels_count, ) @@ -211,16 +221,25 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): + if vol_image is not None: + vol_image_np, *_ = convert_data_type(vol_image, np.ndarray) + else: + vol_image_np = vol_image + if vol_label is not None: + vol_label_np, *_ = convert_data_type(vol_label, np.ndarray) + else: + vol_label_np = vol_label + data_list = [] - if len(vol_image.shape) == 4: + if len(vol_image_np.shape) == 4: logging.info( "4D-Image, pick only first series; Image: {}; Label: {}".format( - vol_image.shape, vol_label.shape if vol_label is not None else None + vol_image_np.shape, vol_label_np.shape if vol_label_np is not None else None ) ) - vol_image = vol_image[0] - vol_image = np.moveaxis(vol_image, -1, 0) + vol_image_np = vol_image_np[0] + vol_image_np = np.moveaxis(vol_image_np, -1, 0) image_count = 0 label_count = 0 @@ -231,11 +250,11 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): image_file += ".npy" os.makedirs(os.path.join(dataset_dir, "images"), exist_ok=True) - np.save(image_file, vol_image) + np.save(image_file, vol_image_np) image_count += 1 # Test Data - if vol_label is None: + if vol_label_np is None: data_list.append( { "image": image_file.replace(dataset_dir + os.pathsep, "") if relative_path else image_file, @@ -243,7 +262,7 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): ) else: # For all Labels - unique_labels = np.unique(vol_label.flatten()) + unique_labels = np.unique(vol_label_np.flatten()) unique_labels = unique_labels[unique_labels != 0] unique_labels_count = max(unique_labels_count, len(unique_labels)) @@ -252,7 +271,7 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): label_file = os.path.join(dataset_dir, "labels", label_file_prefix) label_file += ".npy" - curr_label = (vol_label == idx).astype(np.float32) + curr_label = (vol_label_np == idx).astype(np.float32) os.makedirs(os.path.join(dataset_dir, "labels"), exist_ok=True) np.save(label_file, curr_label) @@ -271,9 +290,9 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): logging.info( "{} => Image Shape: {} => {}; Label Shape: {} => {}; Unique Labels: {}".format( vol_idx, - vol_image.shape, + vol_image_np.shape, image_count, - vol_label.shape if vol_label is not None else None, + vol_label_np.shape if vol_label_np is not None else None, label_count, unique_labels_count, ) diff --git a/monai/apps/pathology/utils.py b/monai/apps/pathology/utils.py index 0d1f530bff..9e51bcf6df 100644 --- a/monai/apps/pathology/utils.py +++ b/monai/apps/pathology/utils.py @@ -9,13 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import List import numpy as np -import torch from monai.transforms.post.array import ProbNMS from monai.utils import optional_import +from monai.utils.enums import DataObjects measure, _ = optional_import("skimage.measure") ndimage, _ = optional_import("scipy.ndimage") @@ -67,7 +67,7 @@ class PathologyProbNMS(ProbNMS): def __call__( self, - probs_map: Union[np.ndarray, torch.Tensor], + probs_map: DataObjects.Images, resolution_level: int = 0, ): """ diff --git a/monai/data/csv_saver.py b/monai/data/csv_saver.py index c79cd1016a..c333f69c4c 100644 --- a/monai/data/csv_saver.py +++ b/monai/data/csv_saver.py @@ -12,12 +12,13 @@ import os import warnings from collections import OrderedDict -from typing import Dict, Optional, Union +from typing import Dict, Optional import numpy as np import torch from monai.utils import ImageMetaKey as Key +from monai.utils.enums import DataObjects class CSVSaver: @@ -75,7 +76,7 @@ def finalize(self) -> None: # clear cache content after writing self.reset_cache() - def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: + def save(self, data: DataObjects.Images, meta_data: Optional[Dict] = None) -> None: """Save data into the cache dictionary. The metadata should have the following key: - ``'filename_or_obj'`` -- save the data corresponding to file name or object. If meta_data is None, use the default index from 0 to save data instead. @@ -92,7 +93,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] data = data.detach().cpu().numpy() self._cache_dict[save_key] = np.asarray(data, dtype=float) - def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: + def save_batch(self, batch_data: DataObjects.Images, meta_data: Optional[Dict] = None) -> None: """Save a batch of data into the cache dictionary. Args: diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index 2aa9b44058..45026749f9 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -19,6 +19,7 @@ from monai.data.utils import create_file_basename from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key +from monai.utils.enums import DataObjects class NiftiSaver: @@ -104,7 +105,7 @@ def __init__( self.separate_folder = separate_folder self.print_log = print_log - def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: + def save(self, data: DataObjects.Images, meta_data: Optional[Dict] = None) -> None: """ Save data into a Nifti file. The meta_data could optionally have the following keys: @@ -175,7 +176,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] if self.print_log: print(f"file written: {path}.") - def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: + def save_batch(self, batch_data: DataObjects.Images, meta_data: Optional[Dict] = None) -> None: """ Save a batch of data into Nifti format files. diff --git a/monai/data/nifti_writer.py b/monai/data/nifti_writer.py index c56d4c1e8d..17740a27bc 100644 --- a/monai/data/nifti_writer.py +++ b/monai/data/nifti_writer.py @@ -18,12 +18,14 @@ from monai.data.utils import compute_shape_offset, to_affine_nd from monai.networks.layers import AffineTransform from monai.utils import GridSampleMode, GridSamplePadMode, optional_import +from monai.utils.enums import DataObjects +from monai.utils.misc import convert_data_type nib, _ = optional_import("nibabel") def write_nifti( - data: np.ndarray, + data: DataObjects.Images, file_name: str, affine: Optional[np.ndarray] = None, target_affine: Optional[np.ndarray] = None, @@ -36,7 +38,7 @@ def write_nifti( output_dtype: DtypeLike = np.float32, ) -> None: """ - Write numpy data into NIfTI files to disk. This function converts data + Write numpy or torch data into NIfTI files to disk. This function converts data into the coordinate system defined by `target_affine` when `target_affine` is specified. @@ -96,21 +98,27 @@ def write_nifti( If None, use the data type of input data. output_dtype: data type for saving data. Defaults to ``np.float32``. """ - if not isinstance(data, np.ndarray): - raise AssertionError("input data must be numpy array.") - dtype = dtype or data.dtype - sr = min(data.ndim, 3) + if not isinstance(data, (np.ndarray, torch.Tensor)): + raise AssertionError("input data must be numpy array or torch.Tensor.") + # if torch, convert to numpy + data_np: np.ndarray + data_np, *_ = convert_data_type(data, np.ndarray) # type: ignore + if target_affine is not None: + target_affine, *_ = convert_data_type(target_affine, np.ndarray) # type: ignore + + dtype = dtype or data_np.dtype + sr = min(data_np.ndim, 3) if affine is None: affine = np.eye(4, dtype=np.float64) - affine = to_affine_nd(sr, affine) + affine = to_affine_nd(sr, affine) # type: ignore if target_affine is None: target_affine = affine - target_affine = to_affine_nd(sr, target_affine) + target_affine = to_affine_nd(sr, target_affine) # type: ignore if np.allclose(affine, target_affine, atol=1e-3): # no affine changes, save (data, affine) - results_img = nib.Nifti1Image(data.astype(output_dtype), to_affine_nd(3, target_affine)) + results_img = nib.Nifti1Image(data_np.astype(output_dtype), to_affine_nd(3, target_affine)) nib.save(results_img, file_name) return @@ -118,11 +126,11 @@ def write_nifti( start_ornt = nib.orientations.io_orientation(affine) target_ornt = nib.orientations.io_orientation(target_affine) ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt) - data_shape = data.shape - data = nib.orientations.apply_orientation(data, ornt_transform) + data_shape = data_np.shape + data_np = nib.orientations.apply_orientation(data_np, ornt_transform) _affine = affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape) if np.allclose(_affine, target_affine, atol=1e-3) or not resample: - results_img = nib.Nifti1Image(data.astype(output_dtype), to_affine_nd(3, _affine)) + results_img = nib.Nifti1Image(data_np.astype(output_dtype), to_affine_nd(3, _affine)) nib.save(results_img, file_name) return @@ -132,13 +140,13 @@ def write_nifti( ) transform = np.linalg.inv(_affine) @ target_affine if output_spatial_shape is None: - output_spatial_shape, _ = compute_shape_offset(data.shape, _affine, target_affine) + output_spatial_shape, _ = compute_shape_offset(data_np.shape, _affine, target_affine) output_spatial_shape_ = list(output_spatial_shape) if output_spatial_shape is not None else [] - if data.ndim > 3: # multi channel, resampling each channel + if data_np.ndim > 3: # multi channel, resampling each channel while len(output_spatial_shape_) < 3: output_spatial_shape_ = output_spatial_shape_ + [1] - spatial_shape, channel_shape = data.shape[:3], data.shape[3:] - data_np = data.reshape(list(spatial_shape) + [-1]) + spatial_shape, channel_shape = data_np.shape[:3], data_np.shape[3:] + data_np = data_np.reshape(list(spatial_shape) + [-1]) data_np = np.moveaxis(data_np, -1, 0) # channel first for pytorch data_torch = affine_xform( torch.as_tensor(np.ascontiguousarray(data_np).astype(dtype)).unsqueeze(0), @@ -149,12 +157,12 @@ def write_nifti( data_np = np.moveaxis(data_np, 0, -1) # channel last for nifti data_np = data_np.reshape(list(data_np.shape[:3]) + list(channel_shape)) else: # single channel image, need to expand to have batch and channel - while len(output_spatial_shape_) < len(data.shape): + while len(output_spatial_shape_) < len(data_np.shape): output_spatial_shape_ = output_spatial_shape_ + [1] data_torch = affine_xform( - torch.as_tensor(np.ascontiguousarray(data).astype(dtype)[None, None]), + torch.as_tensor(np.ascontiguousarray(data_np).astype(dtype)[None, None]), torch.as_tensor(np.ascontiguousarray(transform).astype(dtype)), - spatial_size=output_spatial_shape_[: len(data.shape)], + spatial_size=output_spatial_shape_[: len(data_np.shape)], ) data_np = data_torch.squeeze(0).squeeze(0).detach().cpu().numpy() diff --git a/monai/data/png_saver.py b/monai/data/png_saver.py index d0aa787850..1ce787ba4e 100644 --- a/monai/data/png_saver.py +++ b/monai/data/png_saver.py @@ -18,6 +18,7 @@ from monai.data.utils import create_file_basename from monai.utils import ImageMetaKey as Key from monai.utils import InterpolateMode, look_up_option +from monai.utils.enums import DataObjects class PNGSaver: @@ -82,7 +83,7 @@ def __init__( self._data_index = 0 - def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: + def save(self, data: DataObjects.Images, meta_data: Optional[Dict] = None) -> None: """ Save data into a png file. The meta_data could optionally have the following keys: @@ -143,7 +144,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] if self.print_log: print(f"file written: {path}.") - def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: + def save_batch(self, batch_data: DataObjects.Images, meta_data: Optional[Dict] = None) -> None: """Save a batch of data into png format files. Args: diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py index 2baec3b872..4feb288cb2 100644 --- a/monai/data/png_writer.py +++ b/monai/data/png_writer.py @@ -12,15 +12,18 @@ from typing import Optional, Sequence, Union import numpy as np +import torch from monai.transforms.spatial.array import Resize from monai.utils import InterpolateMode, ensure_tuple_rep, look_up_option, optional_import +from monai.utils.enums import DataObjects +from monai.utils.misc import convert_data_type Image, _ = optional_import("PIL", name="Image") def write_png( - data: np.ndarray, + data: DataObjects.Images, file_name: str, output_spatial_shape: Optional[Sequence[int]] = None, mode: Union[InterpolateMode, str] = InterpolateMode.BICUBIC, @@ -47,40 +50,43 @@ def write_png( ValueError: When ``scale`` is not one of [255, 65535]. """ - if not isinstance(data, np.ndarray): - raise AssertionError("input data must be numpy array.") - if len(data.shape) == 3 and data.shape[2] == 1: # PIL Image can't save image with 1 channel - data = data.squeeze(2) + if not isinstance(data, (np.ndarray, torch.Tensor)): + raise AssertionError("input data must be np.ndarray/torch.Tensor.") + data_np: np.ndarray + data_np, *_ = convert_data_type(data, np.ndarray) # type: ignore + if len(data_np.shape) == 3 and data_np.shape[2] == 1: # PIL Image can't save image with 1 channel + data_np = data_np.squeeze(2) if output_spatial_shape is not None: output_spatial_shape_ = ensure_tuple_rep(output_spatial_shape, 2) mode = look_up_option(mode, InterpolateMode) align_corners = None if mode in (InterpolateMode.NEAREST, InterpolateMode.AREA) else False xform = Resize(spatial_size=output_spatial_shape_, mode=mode, align_corners=align_corners) - _min, _max = np.min(data), np.max(data) - if len(data.shape) == 3: - data = np.moveaxis(data, -1, 0) # to channel first - data = xform(data) - data = np.moveaxis(data, 0, -1) + _min, _max = np.min(data_np), np.max(data_np) + if len(data_np.shape) == 3: + data_np = np.moveaxis(data_np, -1, 0) # to channel first + data_np = xform(data_np) # type: ignore + data_np = np.moveaxis(data_np, 0, -1) else: # (H, W) - data = np.expand_dims(data, 0) # make a channel - data = xform(data)[0] # first channel + data_np = np.expand_dims(data_np, 0) # make a channel + # first channel + data_np = xform(data_np)[0] # type: ignore if mode != InterpolateMode.NEAREST: - data = np.clip(data, _min, _max) # type: ignore + data_np = np.clip(data_np, _min, _max) # type: ignore if scale is not None: - data = np.clip(data, 0.0, 1.0) # type: ignore # png writer only can scale data in range [0, 1] + data_np = np.clip(data_np, 0.0, 1.0) # type: ignore # png writer only can scale data in range [0, 1] if scale == np.iinfo(np.uint8).max: - data = (scale * data).astype(np.uint8) + data_np = (scale * data_np).astype(np.uint8) elif scale == np.iinfo(np.uint16).max: - data = (scale * data).astype(np.uint16) + data_np = (scale * data_np).astype(np.uint16) else: raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535]") # PNG data must be int number - if data.dtype not in (np.uint8, np.uint16): # type: ignore - data = data.astype(np.uint8) + if data_np.dtype not in (np.uint8, np.uint16): # type: ignore + data_np = data_np.astype(np.uint8) - data = np.moveaxis(data, 0, 1) - img = Image.fromarray(data) + data_np = np.moveaxis(data_np, 0, 1) + img = Image.fromarray(data_np) img.save(file_name, "PNG") return diff --git a/monai/data/synthetic.py b/monai/data/synthetic.py index 20a7829cab..6eec9fd277 100644 --- a/monai/data/synthetic.py +++ b/monai/data/synthetic.py @@ -76,7 +76,7 @@ def create_test_image_2d( labels = np.ceil(image).astype(np.int32) norm = rs.uniform(0, num_seg_classes * noise_max, size=image.shape) - noisyimage = rescale_array(np.maximum(image, norm)) + noisyimage: np.ndarray = rescale_array(np.maximum(image, norm)) # type: ignore if channel_dim is not None: if not (isinstance(channel_dim, int) and channel_dim in (-1, 0, 2)): @@ -151,7 +151,7 @@ def create_test_image_3d( labels = np.ceil(image).astype(np.int32) norm = rs.uniform(0, num_seg_classes * noise_max, size=image.shape) - noisyimage = rescale_array(np.maximum(image, norm)) + noisyimage: np.ndarray = rescale_array(np.maximum(image, norm)) # type: ignore if channel_dim is not None: if not (isinstance(channel_dim, int) and channel_dim in (-1, 0, 3)): diff --git a/monai/data/utils.py b/monai/data/utils.py index 94c8582e9a..344006ac64 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -40,7 +40,8 @@ look_up_option, optional_import, ) -from monai.utils.enums import Method +from monai.utils.enums import DataObjects, Method +from monai.utils.misc import convert_data_type pd, _ = optional_import("pandas") DataFrame, _ = optional_import("pandas", name="DataFrame") @@ -353,6 +354,10 @@ def decollate_batch(batch, detach: bool = True): if out_list[0].ndim == 0 and detach: return [t.item() for t in out_list] return list(out_list) + if isinstance(batch, np.ndarray): + if batch.ndim == 0: + return batch + return list(batch) if isinstance(batch, Mapping): _dict_list = {key: decollate_batch(batch[key], detach) for key in batch} return [dict(zip(_dict_list, item)) for item in zip(*_dict_list.values())] @@ -543,7 +548,7 @@ def rectify_header_sform_qform(img_nii): return img_nii -def zoom_affine(affine: np.ndarray, scale: Sequence[float], diagonal: bool = True): +def zoom_affine(affine: DataObjects.Images, scale: Sequence[float], diagonal: bool = True): """ To make column norm of `affine` the same as `scale`. If diagonal is False, returns an affine that combines orthogonal rotation and the new scale. @@ -568,8 +573,7 @@ def zoom_affine(affine: np.ndarray, scale: Sequence[float], diagonal: bool = Tru the updated `n x n` affine. """ - - affine = np.array(affine, dtype=float, copy=True) + affine, *_ = convert_data_type(deepcopy(affine), np.ndarray, dtype=float) if len(affine) != len(affine[0]): raise ValueError(f"affine must be n x n, got {len(affine)} x {len(affine[0])}.") scale_np = np.array(scale, dtype=float, copy=True) @@ -584,14 +588,15 @@ def zoom_affine(affine: np.ndarray, scale: Sequence[float], diagonal: bool = Tru scale_np[scale_np == 0] = 1.0 if diagonal: - return np.diag(np.append(scale_np, [1.0])) - rzs = affine[:-1, :-1] # rotation zoom scale - zs = np.linalg.cholesky(rzs.T @ rzs).T - rotation = rzs @ np.linalg.inv(zs) - s = np.sign(np.diag(zs)) * np.abs(scale_np) - # construct new affine with rotation and zoom - new_affine = np.eye(len(affine)) - new_affine[:-1, :-1] = rotation @ np.diag(s) + new_affine = np.diag(np.append(scale_np, [1.0])) + else: + rzs = affine[:-1, :-1] # rotation zoom scale + zs = np.linalg.cholesky(rzs.T @ rzs).T + rotation = rzs @ np.linalg.inv(zs) + s = np.sign(np.diag(zs)) * np.abs(scale_np) + # construct new affine with rotation and zoom + new_affine = np.eye(len(affine)) + new_affine[:-1, :-1] = rotation @ np.diag(s) return new_affine @@ -611,8 +616,8 @@ def compute_shape_offset( """ shape = np.array(spatial_shape, copy=True, dtype=float) sr = len(shape) - in_affine = to_affine_nd(sr, in_affine) - out_affine = to_affine_nd(sr, out_affine) + in_affine = to_affine_nd(sr, in_affine) # type: ignore + out_affine = to_affine_nd(sr, out_affine) # type: ignore in_coords = [(0.0, dim - 1.0) for dim in shape] corners = np.asarray(np.meshgrid(*in_coords, indexing="ij")).reshape((len(shape), -1)) corners = np.concatenate((corners, np.ones_like(corners[:1]))) @@ -631,7 +636,7 @@ def compute_shape_offset( return out_shape.astype(int), offset -def to_affine_nd(r: Union[np.ndarray, int], affine: np.ndarray) -> np.ndarray: +def to_affine_nd(r: Union[DataObjects.Images, int], affine: DataObjects.Images) -> DataObjects.Images: """ Using elements from affine, to create a new affine matrix by assigning the rotation/zoom/scaling matrix and the translation vector. @@ -658,19 +663,20 @@ def to_affine_nd(r: Union[np.ndarray, int], affine: np.ndarray) -> np.ndarray: an (r+1) x (r+1) matrix """ - affine_np = np.array(affine, dtype=np.float64) - if affine_np.ndim != 2: - raise ValueError(f"affine must have 2 dimensions, got {affine_np.ndim}.") - new_affine = np.array(r, dtype=np.float64, copy=True) + if affine.ndim != 2: + raise ValueError(f"affine must have 2 dimensions, got {affine.ndim}.") + device = affine.device if isinstance(affine, torch.Tensor) else None + new_affine, *_ = convert_data_type(deepcopy(r), type(affine), dtype=np.float64, device=device) if new_affine.ndim == 0: - sr: int = int(new_affine.astype(np.uint)) + sr: int = int(new_affine) if not np.isfinite(sr) or sr < 0: raise ValueError(f"r must be positive, got {sr}.") new_affine = np.eye(sr + 1, dtype=np.float64) - d = max(min(len(new_affine) - 1, len(affine_np) - 1), 1) - new_affine[:d, :d] = affine_np[:d, :d] + new_affine, *_ = convert_data_type(new_affine, type(affine), device=device) + d = max(min(len(new_affine) - 1, len(affine) - 1), 1) + new_affine[:d, :d] = affine[:d, :d] if d > 1: - new_affine[:d, -1] = affine_np[:d, -1] + new_affine[:d, -1] = affine[:d, -1] return new_affine diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 793683f2c5..e67c979ea4 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -18,6 +18,7 @@ from monai.config import IgniteInfo, KeysCollection from monai.utils import deprecated, ensure_tuple, get_torch_version_tuple, look_up_option, min_version, optional_import +from monai.utils.enums import DataObjects idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") if TYPE_CHECKING: @@ -126,8 +127,8 @@ def string_list_all_gather(strings: List[str]) -> List[str]: def write_metrics_reports( save_dir: str, images: Optional[Sequence[str]], - metrics: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]], - metric_details: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]], + metrics: Optional[Dict[str, DataObjects.Images]], + metric_details: Optional[Dict[str, DataObjects.Images]], summary_ops: Optional[Union[str, Sequence[str]]], deli: str = "\t", output_type: str = "csv", diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 325c5300ea..cff107acc2 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -20,8 +20,9 @@ from monai.losses.focal_loss import FocalLoss from monai.losses.spatial_mask import MaskedLoss -from monai.networks import one_hot +from monai.networks.utils import one_hot_torch from monai.utils import LossReduction, Weight, look_up_option +from monai.utils.enums import DataObjects class DiceLoss(_Loss): @@ -129,7 +130,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: - target = one_hot(target, num_classes=n_pred_ch) + target = one_hot_torch(target, num_classes=n_pred_ch) if not self.include_background: if n_pred_ch == 1: @@ -305,7 +306,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: - target = one_hot(target, num_classes=n_pred_ch) + target = one_hot_torch(target, num_classes=n_pred_ch) if not self.include_background: if n_pred_ch == 1: @@ -368,7 +369,7 @@ class GeneralizedWassersteinDiceLoss(_Loss): def __init__( self, - dist_matrix: Union[np.ndarray, torch.Tensor], + dist_matrix: DataObjects.Images, weighting_mode: str = "default", reduction: Union[LossReduction, str] = LossReduction.MEAN, smooth_nr: float = 1e-5, diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index b4b3698e5b..7f11ecaca0 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -16,7 +16,7 @@ import torch.nn.functional as F from torch.nn.modules.loss import _Loss -from monai.networks import one_hot +from monai.networks.utils import one_hot_torch from monai.utils import LossReduction @@ -96,7 +96,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: - target = one_hot(target, num_classes=n_pred_ch) + target = one_hot_torch(target, num_classes=n_pred_ch) if not self.include_background: if n_pred_ch == 1: diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index 1d75b9e8cc..355a662901 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -15,7 +15,7 @@ import torch from torch.nn.modules.loss import _Loss -from monai.networks import one_hot +from monai.networks.utils import one_hot_torch from monai.utils import LossReduction @@ -120,7 +120,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: - target = one_hot(target, num_classes=n_pred_ch) + target = one_hot_torch(target, num_classes=n_pred_ch) if not self.include_background: if n_pred_ch == 1: diff --git a/monai/metrics/froc.py b/monai/metrics/froc.py index faebbbf7a6..93ceee0efc 100644 --- a/monai/metrics/froc.py +++ b/monai/metrics/froc.py @@ -10,17 +10,19 @@ # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple import numpy as np import torch +from monai.utils.enums import DataObjects + def compute_fp_tp_probs( - probs: Union[np.ndarray, torch.Tensor], - y_coord: Union[np.ndarray, torch.Tensor], - x_coord: Union[np.ndarray, torch.Tensor], - evaluation_mask: Union[np.ndarray, torch.Tensor], + probs: DataObjects.Images, + y_coord: DataObjects.Images, + x_coord: DataObjects.Images, + evaluation_mask: DataObjects.Images, labels_to_exclude: Optional[List] = None, resolution_level: int = 0, ): @@ -77,8 +79,8 @@ def compute_fp_tp_probs( def compute_froc_curve_data( - fp_probs: Union[np.ndarray, torch.Tensor], - tp_probs: Union[np.ndarray, torch.Tensor], + fp_probs: DataObjects.Images, + tp_probs: DataObjects.Images, num_targets: int, num_images: int, ): diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 12f3b49d32..06118b87b4 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -17,6 +17,7 @@ from monai.metrics.utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background from monai.utils import MetricReduction +from monai.utils.enums import DataObjects from .metric import CumulativeIterationMetric @@ -114,8 +115,8 @@ def aggregate(self): # type: ignore def compute_hausdorff_distance( - y_pred: Union[np.ndarray, torch.Tensor], - y: Union[np.ndarray, torch.Tensor], + y_pred: DataObjects.Images, + y: DataObjects.Images, include_background: bool = False, distance_metric: str = "euclidean", percentile: Optional[float] = None, diff --git a/monai/metrics/rocauc.py b/monai/metrics/rocauc.py index 3bd6c0d69c..50aec9be12 100644 --- a/monai/metrics/rocauc.py +++ b/monai/metrics/rocauc.py @@ -156,6 +156,6 @@ def compute_roc_auc( if average == Average.MACRO: return np.mean(auc_values) if average == Average.WEIGHTED: - weights = [sum(y_) for y_ in y] + weights = [sum(y_.cpu()) for y_ in y] return np.average(auc_values, weights=weights) raise ValueError(f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].') diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 6039f1b55e..eacf63102b 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -17,6 +17,7 @@ from monai.metrics.utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background from monai.utils import MetricReduction +from monai.utils.enums import DataObjects from .metric import CumulativeIterationMetric @@ -106,8 +107,8 @@ def aggregate(self): # type: ignore def compute_average_surface_distance( - y_pred: Union[np.ndarray, torch.Tensor], - y: Union[np.ndarray, torch.Tensor], + y_pred: DataObjects.Images, + y: DataObjects.Images, include_background: bool = False, symmetric: bool = False, distance_metric: str = "euclidean", diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 84de834f74..ffb52ffe49 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -16,7 +16,8 @@ from monai.transforms.croppad.array import SpatialCrop from monai.transforms.utils import generate_spatial_bounding_box -from monai.utils import MetricReduction, look_up_option, optional_import +from monai.utils.enums import DataObjects, MetricReduction +from monai.utils.module import look_up_option, optional_import binary_erosion, _ = optional_import("scipy.ndimage.morphology", name="binary_erosion") distance_transform_edt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_edt") @@ -26,8 +27,8 @@ def ignore_background( - y_pred: Union[np.ndarray, torch.Tensor], - y: Union[np.ndarray, torch.Tensor], + y_pred: DataObjects.Images, + y: DataObjects.Images, ): """ This function is used to remove background (the first channel) for `y_pred` and `y`. @@ -106,8 +107,8 @@ def do_metric_reduction( def get_mask_edges( - seg_pred: Union[np.ndarray, torch.Tensor], - seg_gt: Union[np.ndarray, torch.Tensor], + seg_pred: DataObjects.Images, + seg_gt: DataObjects.Images, label_idx: int = 1, crop: bool = True, ) -> Tuple[np.ndarray, np.ndarray]: diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 3c347dad22..b74805ef88 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -16,6 +16,8 @@ normal_init, normalize_transform, one_hot, + one_hot_np, + one_hot_torch, pixelshuffle, predict_segmentation, slice_channels, diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 9d20d2a83b..cd1485fd1d 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -17,11 +17,18 @@ from contextlib import contextmanager from typing import Any, Callable, Mapping, Optional, Sequence, Union +import numpy as np import torch import torch.nn as nn +from monai.config.type_definitions import DtypeLike +from monai.utils.enums import DataObjects +from monai.utils.misc import dtype_convert + __all__ = [ "one_hot", + "one_hot_np", + "one_hot_torch", "slice_channels", "predict_segmentation", "normalize_transform", @@ -35,7 +42,54 @@ ] -def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor: +def _one_hot_pre_process(labels, dim: int): + # if `dim` is bigger, add singleton dim at the end + if labels.ndim < dim + 1: + shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape)) + labels = labels.reshape(shape) + + sh = list(labels.shape) + + if sh[dim] != 1: + raise AssertionError("labels should have a channel with length equal to one.") + return labels, sh + + +def one_hot_np(labels: np.ndarray, num_classes: int, dtype: DtypeLike = np.float32, dim: int = 1) -> np.ndarray: + """ + Numpy implementation of `one_hot`. + See also: :py:meth:`monai.networks.utils.one_hot`. + """ + labels, _ = _one_hot_pre_process(labels, dim) + + label: np.ndarray + label = np.eye(num_classes)[labels.astype(np.longlong)] # adds one hot to end + label = label.astype(dtype) + label = label.squeeze(dim) # remove singleton + label = np.moveaxis(label, -1, dim) # move one hot dim to desired index + return label + + +def one_hot_torch( + labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1 +) -> torch.Tensor: + """ + Torch implementation of `one_hot`. + See also: :py:meth:`monai.networks.utils.one_hot`. + """ + labels, sh = _one_hot_pre_process(labels, dim) + + sh[dim] = num_classes + + o = torch.zeros(size=sh, dtype=dtype, device=labels.device) + labels = o.scatter_(dim=dim, index=labels.long(), value=1) + + return labels + + +def one_hot( + labels: DataObjects.Images, num_classes: int, dtype: Union[DtypeLike, torch.dtype] = torch.float, dim: int = 1 +) -> DataObjects.Images: """ For every value v in `labels`, the value in the output will be either 1 or 0. Each vector along the `dim`-th dimension has the "one-hot" format, i.e., it has a total length of `num_classes`, @@ -67,25 +121,11 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f a = torch.randint(0, 2, size=(2, 1, 2, 2, 2)) out = one_hot(a, num_classes=2, dim=1) print(out.shape) # torch.Size([2, 2, 2, 2, 2]) - """ - - # if `dim` is bigger, add singleton dim at the end - if labels.ndim < dim + 1: - shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape)) - labels = torch.reshape(labels, shape) - - sh = list(labels.shape) - - if sh[dim] != 1: - raise AssertionError("labels should have a channel with length equal to one.") - - sh[dim] = num_classes - - o = torch.zeros(size=sh, dtype=dtype, device=labels.device) - labels = o.scatter_(dim=dim, index=labels.long(), value=1) - - return labels + dtype = dtype_convert(dtype, type(labels)) + if isinstance(labels, np.ndarray): + return one_hot_np(labels, num_classes, dtype, dim) # type: ignore + return one_hot_torch(labels, num_classes, dtype, dim) # type: ignore def slice_channels(tensor: torch.Tensor, *slicevals: Optional[int]) -> torch.Tensor: diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 487a995e5e..c7e6b98764 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -18,6 +18,7 @@ CenterSpatialCrop, CropForeground, DivisiblePad, + Pad, RandCropByLabelClasses, RandCropByPosNegLabel, RandScaleCrop, @@ -306,7 +307,15 @@ ZoomD, ZoomDict, ) -from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform +from .transform import ( + MapTransform, + NumpyTransform, + Randomizable, + RandomizableTransform, + ThreadUnsafe, + Transform, + apply_transform, +) from .utility.array import ( AddChannel, AddExtremePointsChannel, diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index b380f7d42a..a488f289de 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -12,6 +12,7 @@ A collection of generic interfaces for MONAI transforms. """ +import math import warnings from typing import Any, Callable, Optional, Sequence, Union @@ -22,8 +23,10 @@ # For backwards compatibility (so this still works: from monai.transforms.compose import MapTransform) from monai.transforms.transform import ( # noqa: F401 MapTransform, + NumpyTransform, Randomizable, RandomizableTransform, + TorchTransform, Transform, apply_transform, ) @@ -168,3 +171,23 @@ def inverse(self, data): for t in reversed(invertible_transforms): data = apply_transform(t.inverse, data, self.map_items, self.unpack_items) return data + + def get_number_np_torch_conversions(self): + """Get the number of times that the data need to be converted from numpy to torch or vice versa. + Returns `math.nan` if any of transforms are neither `NumpyTransform` nor `TorchTransform`.""" + num_conversions = 0 + tr = self.flatten().transforms + # if any are unknown, return math.nan + if not all([isinstance(t, (NumpyTransform, TorchTransform)) for t in tr]): + return math.nan + # ignore torch or numpy transforms as they are ambivalent + tr = [t for t in tr if not (isinstance(t, NumpyTransform) and isinstance(t, TorchTransform))] + if len(tr) > 0: + t_prev = tr[0] + for t in tr[1:]: + if isinstance(t, NumpyTransform) and isinstance(t_prev, TorchTransform): + num_conversions += 1 + elif isinstance(t, TorchTransform) and isinstance(t_prev, NumpyTransform): + num_conversions += 1 + t_prev = t + return num_conversions diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index fe482270f0..8828dbd102 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -19,10 +19,11 @@ import numpy as np import torch +from torch.nn.functional import pad as pad_pt from monai.config import IndexSelection from monai.data.utils import get_random_patch, get_valid_patch_size -from monai.transforms.transform import Randomizable, Transform +from monai.transforms.transform import NumpyTransform, Randomizable, TorchTransform from monai.transforms.utils import ( compute_divisible_spatial_size, generate_label_classes_crop_centers, @@ -34,8 +35,11 @@ weighted_patch_samples, ) from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple, look_up_option +from monai.utils.enums import DataObjects +from monai.utils.misc import convert_data_type, is_module_ver_at_least __all__ = [ + "Pad", "SpatialPad", "BorderPad", "DivisiblePad", @@ -54,9 +58,76 @@ ] -class SpatialPad(Transform): +class Pad(TorchTransform, NumpyTransform): + """ + Perform padding for a given an amount of padding in each dimension. + + If input is `torch.Tensor` and mode is `constant`, `torch.nn.functional.pad` will be used. + Otherwise, `np.pad` will be used (input converted to `np.ndarray` if necessary). + Uses np.pad so in practice, a mode needs to be provided. See numpy.lib.arraypad.pad + for additional details. + + Args: + to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...]. + mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, + ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + """ + + def __init__( + self, + to_pad: List[Tuple[int, int]], + mode: Union[NumpyPadMode, str, None] = NumpyPadMode.CONSTANT, + **np_kwargs, + ) -> None: + self.to_pad = to_pad + self.mode = mode or NumpyPadMode.CONSTANT + self.np_kwargs = np_kwargs + + @staticmethod + def _np_pad(img: DataObjects.Images, all_pad_width, mode, **np_kwargs) -> DataObjects.Images: + out, orig_type, orig_device = convert_data_type(img, np.ndarray) + out = np.pad(out, all_pad_width, mode=mode, **np_kwargs) + out, *_ = convert_data_type(out, orig_type, orig_device) + return out + + @staticmethod + def _pt_pad(img: DataObjects.Images, all_pad_width, mode, **np_kwargs) -> DataObjects.Images: + out, orig_type, orig_device = convert_data_type(img, torch.Tensor) + pt_pad_width = [val for sublist in all_pad_width for val in sublist[::-1]][::-1] + out = pad_pt(out, pt_pad_width, mode=mode, **np_kwargs) # type: ignore + out, *_ = convert_data_type(out, orig_type, orig_device) + return out + + def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, str]] = None) -> DataObjects.Images: + """ + Args: + img: data to be transformed, assuming `img` is channel-first and + padding doesn't apply to the channel dim. + mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, + ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + One of the listed string values or a user supplied function. Defaults to ``self.mode``. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + """ + if not np.asarray(self.to_pad).any(): + # all zeros, skip padding + return img + mode = mode or self.mode + mode = mode.value if isinstance(mode, NumpyPadMode) else mode + if isinstance(img, torch.Tensor) and mode == "constant" and not self.np_kwargs: + pad = self._pt_pad + else: + pad = self._np_pad + return pad(img, self.to_pad, mode, **self.np_kwargs) + + +class SpatialPad(TorchTransform, NumpyTransform): """ Performs padding to the data, symmetric for all sides or all on one side for each dimension. + + If input is `torch.Tensor` and mode is `constant`, `torch.nn.functional.pad` will be used. + Otherwise, `np.pad` will be used (input converted to `np.ndarray` if necessary). Uses np.pad so in practice, a mode needs to be provided. See numpy.lib.arraypad.pad for additional details. @@ -99,7 +170,7 @@ def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int return pad_width return [(0, max(sp_i - data_shape[i], 0)) for i, sp_i in enumerate(spatial_size)] - def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None) -> np.ndarray: + def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, str]] = None) -> DataObjects.Images: """ Args: img: data to be transformed, assuming `img` is channel-first and @@ -114,13 +185,12 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N if not np.asarray(all_pad_width).any(): # all zeros, skip padding return img + mode = look_up_option(mode or self.mode, NumpyPadMode) + padder = Pad(all_pad_width, mode, **self.np_kwargs) + return padder(img) - mode = look_up_option(self.mode if mode is None else mode, NumpyPadMode).value - img = np.pad(img, all_pad_width, mode=mode, **self.np_kwargs) - return img - -class BorderPad(Transform): +class BorderPad(TorchTransform, NumpyTransform): """ Pad the input data by adding specified borders to every dimension. @@ -155,7 +225,7 @@ def __init__( self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) self.np_kwargs = np_kwargs - def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None): + def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, str]] = None) -> DataObjects.Images: """ Args: img: data to be transformed, assuming `img` is channel-first and @@ -188,12 +258,13 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N f"Unsupported spatial_border length: {len(spatial_border)}, available options are " f"[1, len(spatial_shape)={len(spatial_shape)}, 2*len(spatial_shape)={2*len(spatial_shape)}]." ) - - mode = look_up_option(self.mode if mode is None else mode, NumpyPadMode).value - return np.pad(img, [(0, 0)] + data_pad_width, mode=mode, **self.np_kwargs) + all_pad_width = [(0, 0)] + data_pad_width + mode = look_up_option(mode or self.mode, NumpyPadMode) + padder = Pad(all_pad_width, mode, **self.np_kwargs) + return padder(img) -class DivisiblePad(Transform): +class DivisiblePad(TorchTransform, NumpyTransform): """ Pad the input data, so that the spatial sizes are divisible by `k`. """ @@ -226,7 +297,7 @@ def __init__( self.method: Method = Method(method) self.np_kwargs = np_kwargs - def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None) -> np.ndarray: + def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, str]] = None) -> DataObjects.Images: """ Args: img: data to be transformed, assuming `img` is channel-first @@ -247,7 +318,7 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N return spatial_pad(img) -class SpatialCrop(Transform): +class SpatialCrop(TorchTransform, NumpyTransform): """ General purpose cropper to produce sub-volume region of interest (ROI). If a dimension of the expected ROI size is bigger than the input image size, will not crop that dimension. @@ -261,12 +332,27 @@ class SpatialCrop(Transform): - the start and end coordinates of the ROI """ + @staticmethod + def _maximum(a, b): + if isinstance(a, np.ndarray): + return np.maximum(a, b) + # is torch and has torch.maximum (pt>1.6) + if hasattr(torch, "maximum"): + return torch.maximum(a, b) + return np.maximum(a.cpu(), b.cpu()) + + @staticmethod + def _floor_div(a, b): + if is_module_ver_at_least(torch, (1, 8, 0)): + return torch.div(a, b, rounding_mode="floor") + return torch.floor_divide(a, b) + def __init__( self, - roi_center: Union[Sequence[int], np.ndarray, None] = None, - roi_size: Union[Sequence[int], np.ndarray, None] = None, - roi_start: Union[Sequence[int], np.ndarray, None] = None, - roi_end: Union[Sequence[int], np.ndarray, None] = None, + roi_center: Optional[Union[Sequence[int], DataObjects.Images]] = None, + roi_size: Optional[Union[Sequence[int], DataObjects.Images]] = None, + roi_start: Optional[Union[Sequence[int], DataObjects.Images]] = None, + roi_end: Optional[Union[Sequence[int], DataObjects.Images]] = None, roi_slices: Optional[Sequence[slice]] = None, ) -> None: """ @@ -285,22 +371,26 @@ def __init__( self.slices = list(roi_slices) else: if roi_center is not None and roi_size is not None: - roi_center = np.asarray(roi_center, dtype=np.int16) - roi_size = np.asarray(roi_size, dtype=np.int16) - roi_start_np = np.maximum(roi_center - np.floor_divide(roi_size, 2), 0) - roi_end_np = np.maximum(roi_start_np + roi_size, roi_start_np) + roi_center = torch.as_tensor(roi_center, dtype=torch.int16) + roi_size = torch.as_tensor(roi_size, dtype=torch.int16, device=roi_center.device) + roi_start_torch = self._maximum( + roi_center - self._floor_div(roi_size, 2), + torch.tensor(0, device=roi_center.device), + ) + roi_end_torch = self._maximum(roi_start_torch + roi_size, roi_start_torch) else: if roi_start is None or roi_end is None: raise ValueError("Please specify either roi_center, roi_size or roi_start, roi_end.") - roi_start_np = np.maximum(np.asarray(roi_start, dtype=np.int16), 0) - roi_end_np = np.maximum(np.asarray(roi_end, dtype=np.int16), roi_start_np) - # Allow for 1D by converting back to np.array (since np.maximum will convert to int) - roi_start_np = roi_start_np if isinstance(roi_start_np, np.ndarray) else np.array([roi_start_np]) - roi_end_np = roi_end_np if isinstance(roi_end_np, np.ndarray) else np.array([roi_end_np]) - # convert to slices - self.slices = [slice(s, e) for s, e in zip(roi_start_np, roi_end_np)] - - def __call__(self, img: Union[np.ndarray, torch.Tensor]): + roi_start_torch = torch.as_tensor(roi_start, dtype=torch.int16) + roi_start_torch = self._maximum(roi_start_torch, torch.tensor(0, device=roi_start_torch.device)) + roi_end_torch = self._maximum(torch.as_tensor(roi_end, dtype=torch.int16), roi_start_torch) + # convert to slices (accounting for 1d) + if roi_start_torch.numel() == 1: + self.slices = [slice(int(roi_start_torch.item()), int(roi_end_torch.item()))] + else: + self.slices = [slice(int(s.item()), int(e.item())) for s, e in zip(roi_start_torch, roi_end_torch)] + + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -310,7 +400,7 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]): return img[tuple(slices)] -class CenterSpatialCrop(Transform): +class CenterSpatialCrop(TorchTransform, NumpyTransform): """ Crop at the center of image with specified ROI size. If a dimension of the expected ROI size is bigger than the input image size, will not crop that dimension. @@ -328,7 +418,7 @@ class CenterSpatialCrop(Transform): def __init__(self, roi_size: Union[Sequence[int], int]) -> None: self.roi_size = roi_size - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -339,7 +429,7 @@ def __call__(self, img: np.ndarray): return cropper(img) -class CenterScaleCrop(Transform): +class CenterScaleCrop(TorchTransform, NumpyTransform): """ Crop at the center of image with specified scale of ROI size. @@ -352,7 +442,7 @@ class CenterScaleCrop(Transform): def __init__(self, roi_scale: Union[Sequence[float], float]): self.roi_scale = roi_scale - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: img_size = img.shape[1:] ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] @@ -360,7 +450,7 @@ def __call__(self, img: np.ndarray): return sp_crop(img=img) -class RandSpatialCrop(Randomizable, Transform): +class RandSpatialCrop(Randomizable, TorchTransform, NumpyTransform): """ Crop image with random size or specific size ROI. It can crop at a random position as center or at the image center. And allows to set the minimum and maximum size to limit the randomly generated ROI. @@ -409,7 +499,7 @@ def randomize(self, img_size: Sequence[int]) -> None: valid_size = get_valid_patch_size(img_size, self._size) self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -455,7 +545,7 @@ def __init__( self.roi_scale = roi_scale self.max_roi_scale = max_roi_scale - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -470,7 +560,7 @@ def __call__(self, img: np.ndarray): return super().__call__(img=img) -class RandSpatialCropSamples(Randomizable, Transform): +class RandSpatialCropSamples(Randomizable, TorchTransform, NumpyTransform): """ Crop image with random size or specific size ROI to generate a list of N samples. It can crop at a random position as center or at the image center. And allows to set @@ -521,10 +611,10 @@ def set_random_state( self.cropper.set_random_state(state=self.R) return self - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, _: Optional[Any] = None) -> None: pass - def __call__(self, img: np.ndarray) -> List[np.ndarray]: + def __call__(self, img: DataObjects.Images) -> List[DataObjects.Images]: """ Apply the transform to `img`, assuming `img` is channel-first and cropping doesn't change the channel dim. @@ -532,7 +622,7 @@ def __call__(self, img: np.ndarray) -> List[np.ndarray]: return [self.cropper(img) for _ in range(self.num_samples)] -class CropForeground(Transform): +class CropForeground(TorchTransform, NumpyTransform): """ Crop an image using a bounding box. The bounding box is generated by selecting foreground using select_fn at channels channel_indices. margin is added in each spatial dimension of the bounding box. @@ -595,13 +685,15 @@ def __init__( self.k_divisible = k_divisible self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) - def compute_bounding_box(self, img: np.ndarray): + def compute_bounding_box(self, img: DataObjects.Images) -> Tuple[np.ndarray, np.ndarray]: """ Compute the start points and end points of bounding box to crop. And adjust bounding box coords to be divisible by `k`. """ box_start, box_end = generate_spatial_bounding_box(img, self.select_fn, self.channel_indices, self.margin) + box_start = [i.cpu() if isinstance(i, torch.Tensor) else i for i in box_start] # type: ignore + box_end = [i.cpu() if isinstance(i, torch.Tensor) else i for i in box_end] # type: ignore box_start_ = np.asarray(box_start, dtype=np.int16) box_end_ = np.asarray(box_end, dtype=np.int16) orig_spatial_size = box_end_ - box_start_ @@ -612,7 +704,7 @@ def compute_bounding_box(self, img: np.ndarray): box_end_ = box_start_ + spatial_size return box_start_, box_end_ - def crop_pad(self, img: np.ndarray, box_start: np.ndarray, box_end: np.ndarray): + def crop_pad(self, img: DataObjects.Images, box_start: np.ndarray, box_end: np.ndarray) -> DataObjects.Images: """ Crop and pad based on the bounding box. @@ -623,7 +715,7 @@ def crop_pad(self, img: np.ndarray, box_start: np.ndarray, box_end: np.ndarray): pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) return BorderPad(spatial_border=pad, mode=self.mode)(cropped) - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images): """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't change the channel dim. @@ -636,7 +728,7 @@ def __call__(self, img: np.ndarray): return cropped -class RandWeightedCrop(Randomizable, Transform): +class RandWeightedCrop(Randomizable, TorchTransform, NumpyTransform): """ Samples a list of `num_samples` image patches according to the provided `weight_map`. @@ -650,19 +742,24 @@ class RandWeightedCrop(Randomizable, Transform): """ def __init__( - self, spatial_size: Union[Sequence[int], int], num_samples: int = 1, weight_map: Optional[np.ndarray] = None + self, + spatial_size: Union[Sequence[int], int], + num_samples: int = 1, + weight_map: Optional[DataObjects.Images] = None, ): self.spatial_size = ensure_tuple(spatial_size) self.num_samples = int(num_samples) self.weight_map = weight_map - self.centers: List[np.ndarray] = [] + self.centers: List[DataObjects.Images] = [] - def randomize(self, weight_map: np.ndarray) -> None: + def randomize(self, weight_map: DataObjects.Images) -> None: self.centers = weighted_patch_samples( spatial_size=self.spatial_size, w=weight_map[0], n_samples=self.num_samples, r_state=self.R ) # using only the first channel as weight map - def __call__(self, img: np.ndarray, weight_map: Optional[np.ndarray] = None) -> List[np.ndarray]: + def __call__( + self, img: DataObjects.Images, weight_map: Optional[DataObjects.Images] = None + ) -> List[DataObjects.Images]: """ Args: img: input image to sample patches from. assuming `img` is a channel-first array. @@ -681,14 +778,14 @@ def __call__(self, img: np.ndarray, weight_map: Optional[np.ndarray] = None) -> raise ValueError(f"image and weight map spatial shape mismatch: {img.shape[1:]} vs {weight_map.shape[1:]}.") self.randomize(weight_map) _spatial_size = fall_back_tuple(self.spatial_size, weight_map.shape[1:]) - results = [] + results: List[DataObjects.Images] = [] for center in self.centers: cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size) results.append(cropper(img)) return results -class RandCropByPosNegLabel(Randomizable, Transform): +class RandCropByPosNegLabel(Randomizable, TorchTransform, NumpyTransform): """ Crop random fixed sized regions with the center being a foreground or background voxel based on the Pos Neg Ratio. @@ -741,14 +838,14 @@ class RandCropByPosNegLabel(Randomizable, Transform): def __init__( self, spatial_size: Union[Sequence[int], int], - label: Optional[np.ndarray] = None, + label: Optional[DataObjects.Images] = None, pos: float = 1.0, neg: float = 1.0, num_samples: int = 1, - image: Optional[np.ndarray] = None, + image: Optional[DataObjects.Images] = None, image_threshold: float = 0.0, - fg_indices: Optional[np.ndarray] = None, - bg_indices: Optional[np.ndarray] = None, + fg_indices: Optional[DataObjects.Images] = None, + bg_indices: Optional[DataObjects.Images] = None, ) -> None: self.spatial_size = ensure_tuple(spatial_size) self.label = label @@ -766,10 +863,10 @@ def __init__( def randomize( self, - label: np.ndarray, - fg_indices: Optional[np.ndarray] = None, - bg_indices: Optional[np.ndarray] = None, - image: Optional[np.ndarray] = None, + label: DataObjects.Images, + fg_indices: Optional[DataObjects.Images] = None, + bg_indices: Optional[DataObjects.Images] = None, + image: Optional[DataObjects.Images] = None, ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) if fg_indices is None or bg_indices is None: @@ -787,12 +884,12 @@ def randomize( def __call__( self, - img: np.ndarray, - label: Optional[np.ndarray] = None, - image: Optional[np.ndarray] = None, - fg_indices: Optional[np.ndarray] = None, - bg_indices: Optional[np.ndarray] = None, - ) -> List[np.ndarray]: + img: DataObjects.Images, + label: Optional[DataObjects.Images] = None, + image: Optional[DataObjects.Images] = None, + fg_indices: Optional[DataObjects.Images] = None, + bg_indices: Optional[DataObjects.Images] = None, + ) -> List[DataObjects.Images]: """ Args: img: input data to crop samples from based on the pos/neg ratio of `label` and `image`. @@ -815,7 +912,7 @@ def __call__( image = self.image self.randomize(label, fg_indices, bg_indices, image) - results: List[np.ndarray] = [] + results: List[DataObjects.Images] = [] if self.centers is not None: for center in self.centers: cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore @@ -824,7 +921,7 @@ def __call__( return results -class RandCropByLabelClasses(Randomizable, Transform): +class RandCropByLabelClasses(Randomizable, TorchTransform, NumpyTransform): """ Crop random fixed sized regions with the center being a class based on the specified ratios of every class. The label data can be One-Hot format array or Argmax data. And will return a list of arrays for all the @@ -887,12 +984,12 @@ def __init__( self, spatial_size: Union[Sequence[int], int], ratios: Optional[List[Union[float, int]]] = None, - label: Optional[np.ndarray] = None, + label: Optional[DataObjects.Images] = None, num_classes: Optional[int] = None, num_samples: int = 1, - image: Optional[np.ndarray] = None, + image: Optional[DataObjects.Images] = None, image_threshold: float = 0.0, - indices: Optional[List[np.ndarray]] = None, + indices: Optional[List[DataObjects.Images]] = None, ) -> None: self.spatial_size = ensure_tuple(spatial_size) self.ratios = ratios @@ -906,30 +1003,29 @@ def __init__( def randomize( self, - label: np.ndarray, - indices: Optional[List[np.ndarray]] = None, - image: Optional[np.ndarray] = None, + label: DataObjects.Images, + indices: Optional[List[DataObjects.Images]] = None, + image: Optional[DataObjects.Images] = None, ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) - indices_: List[np.ndarray] - if indices is None: - if self.indices is not None: - indices_ = self.indices - else: - indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) - else: + indices_: List[DataObjects.Images] + if indices is not None: indices_ = indices + elif self.indices is not None: + indices_ = self.indices + else: + indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) self.centers = generate_label_classes_crop_centers( self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R ) def __call__( self, - img: np.ndarray, - label: Optional[np.ndarray] = None, - image: Optional[np.ndarray] = None, - indices: Optional[List[np.ndarray]] = None, - ) -> List[np.ndarray]: + img: DataObjects.Images, + label: Optional[DataObjects.Images] = None, + image: Optional[DataObjects.Images] = None, + indices: Optional[List[DataObjects.Images]] = None, + ) -> List[DataObjects.Images]: """ Args: img: input data to crop samples from based on the ratios of every class, assumes `img` is a @@ -948,7 +1044,7 @@ def __call__( image = self.image self.randomize(label, indices, image) - results: List[np.ndarray] = [] + results: List[DataObjects.Images] = [] if self.centers is not None: for center in self.centers: cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore @@ -957,7 +1053,7 @@ def __call__( return results -class ResizeWithPadOrCrop(Transform): +class ResizeWithPadOrCrop(TorchTransform, NumpyTransform): """ Resize an image to a target spatial size by either centrally cropping the image or padding it evenly with a user-specified mode. @@ -988,7 +1084,7 @@ def __init__( self.padder = SpatialPad(spatial_size=spatial_size, method=method, mode=mode, **np_kwargs) self.cropper = CenterSpatialCrop(roi_size=spatial_size) - def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None) -> np.ndarray: + def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, str]] = None) -> DataObjects.Images: """ Args: img: data to pad or crop, assuming `img` is channel-first and @@ -1002,7 +1098,7 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N return self.padder(self.cropper(img), mode=mode) -class BoundingRect(Transform): +class BoundingRect(TorchTransform, NumpyTransform): """ Compute coordinates of axis-aligned bounding rectangles from input image `img`. The output format of the coordinates is (shape is [channel, 2 * spatial dims]): @@ -1029,7 +1125,7 @@ class BoundingRect(Transform): def __init__(self, select_fn: Callable = is_positive) -> None: self.select_fn = select_fn - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`. """ @@ -1038,5 +1134,5 @@ def __call__(self, img: np.ndarray) -> np.ndarray: for channel in range(img.shape[0]): start_, end_ = generate_spatial_bounding_box(img, select_fn=self.select_fn, channel_indices=channel) bbox.append([i for k in zip(start_, end_) for i in k]) - + bbox = [[j.cpu() if isinstance(j, torch.Tensor) else j for j in i] for i in bbox] # type: ignore return np.stack(bbox, axis=0) diff --git a/monai/transforms/croppad/batch.py b/monai/transforms/croppad/batch.py index 956dff7881..078084a62d 100644 --- a/monai/transforms/croppad/batch.py +++ b/monai/transforms/croppad/batch.py @@ -14,17 +14,16 @@ """ from copy import deepcopy -from typing import Any, Dict, Hashable, Union +from typing import Any, Sequence, Union import numpy as np import torch from monai.data.utils import list_data_collate -from monai.transforms.compose import Compose from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad from monai.transforms.inverse import InvertibleTransform -from monai.transforms.utility.array import ToTensor -from monai.utils.enums import InverseKeys, Method, NumpyPadMode +from monai.transforms.transform import NumpyTransform, TorchTransform +from monai.utils.enums import DataObjects, InverseKeys, Method, NumpyPadMode __all__ = [ "PadListDataCollate", @@ -39,11 +38,14 @@ def replace_element(to_replace, batch, idx, key_or_idx): batch[idx] = tuple(batch_idx_list) # else, replace else: - batch[idx][key_or_idx] = to_replace + if key_or_idx is not None: + batch[idx][key_or_idx] = to_replace + else: + batch[idx] = to_replace return batch -class PadListDataCollate(InvertibleTransform): +class PadListDataCollate(InvertibleTransform, TorchTransform, NumpyTransform): """ Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of @@ -75,6 +77,35 @@ def __init__( self.mode = mode self.np_kwargs = np_kwargs + def replace_batch_element(self, batch, key_or_idx, is_list_of_dicts): + # calculate max size of each dimension + max_shapes = [] + for elem in batch: + im = elem[key_or_idx] if key_or_idx is not None else elem + if not isinstance(im, (torch.Tensor, np.ndarray)): + return batch + max_shapes.append(im.shape[1:]) + max_shape = np.array(max_shapes).max(axis=0) + # If all same size, skip + if np.all(np.array(max_shapes).min(axis=0) == max_shape): + return batch + + # Use `SpatialPad` to match sizes + # Default params are central padding, padding with 0's + padder = SpatialPad(spatial_size=max_shape, method=self.method, mode=self.mode, **self.np_kwargs) + + for idx, elem in enumerate(batch): + im = elem[key_or_idx] if key_or_idx is not None else elem + orig_size = im.shape[1:] + padded = padder(im) + batch = replace_element(padded, batch, idx, key_or_idx) + + # If we have a dictionary of data, append to list + if is_list_of_dicts: + self.push_transform(batch[idx], key_or_idx, orig_size=orig_size) + + return batch + def __call__(self, batch: Any): """ Args: @@ -82,46 +113,23 @@ def __call__(self, batch: Any): """ # data is either list of dicts or list of lists is_list_of_dicts = isinstance(batch[0], dict) - # loop over items inside of each element in a batch - for key_or_idx in batch[0].keys() if is_list_of_dicts else range(len(batch[0])): - # calculate max size of each dimension - max_shapes = [] - for elem in batch: - if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)): - break - max_shapes.append(elem[key_or_idx].shape[1:]) - # len > 0 if objects were arrays, else skip as no padding to be done - if not max_shapes: - continue - max_shape = np.array(max_shapes).max(axis=0) - # If all same size, skip - if np.all(np.array(max_shapes).min(axis=0) == max_shape): - continue - # Do we need to convert output to Tensor? - output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor) - - # Use `SpatialPadd` or `SpatialPad` to match sizes - # Default params are central padding, padding with 0's - # If input is dictionary, use the dictionary version so that the transformation is recorded - - padder = SpatialPad(spatial_size=max_shape, method=self.method, mode=self.mode, **self.np_kwargs) - transform = padder if not output_to_tensor else Compose([padder, ToTensor()]) - - for idx, batch_i in enumerate(batch): - im = batch_i[key_or_idx] - orig_size = im.shape[1:] - padded = transform(batch_i[key_or_idx]) - batch = replace_element(padded, batch, idx, key_or_idx) - - # If we have a dictionary of data, append to list - if is_list_of_dicts: - self.push_transform(batch[idx], key_or_idx, orig_size=orig_size) + # if data is a list of dictionaries, loop over keys + if is_list_of_dicts: + for key in batch[0].keys(): + batch = self.replace_batch_element(batch, key, is_list_of_dicts) + # elif is a list of lists/tuples + elif isinstance(batch[0], Sequence): + for idx in range(len(batch[0])): + batch = self.replace_batch_element(batch, idx, is_list_of_dicts) + # elif there's only one element per batch, either a torcn.Tensor or np.ndarray + elif isinstance(batch[0], (torch.Tensor, np.ndarray)): + batch = self.replace_batch_element(batch, None, is_list_of_dicts) # After padding, use default list collator return list_data_collate(batch) @staticmethod - def inverse(data: dict) -> Dict[Hashable, np.ndarray]: + def inverse(data: dict) -> DataObjects.Dict: if not isinstance(data, dict): raise RuntimeError("Inverse can only currently be applied on dictionaries.") diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 346071aa3b..56f78d1ac4 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -49,7 +49,8 @@ ) from monai.utils import ImageMetaKey as Key from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple -from monai.utils.enums import InverseKeys +from monai.utils.enums import DataObjects, InverseKeys +from monai.utils.misc import convert_data_type __all__ = [ "NumpyPadModeSequence", @@ -140,14 +141,14 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padder = SpatialPad(spatial_size, method, **np_kwargs) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, m in self.key_iterator(d, self.mode): self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -211,14 +212,14 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padder = BorderPad(spatial_border=spatial_border, **np_kwargs) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, m in self.key_iterator(d, self.mode): self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -283,14 +284,14 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padder = DivisiblePad(k=k, method=method, **np_kwargs) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, m in self.key_iterator(d, self.mode): self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -350,14 +351,14 @@ def __init__( super().__init__(keys, allow_missing_keys) self.cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) d[key] = self.cropper(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -403,7 +404,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.cropper = CenterSpatialCrop(roi_size) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): orig_size = d[key].shape[1:] @@ -411,7 +412,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.push_transform(d, key, orig_size=orig_size) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -453,7 +454,7 @@ def __init__( super().__init__(keys, allow_missing_keys=allow_missing_keys) self.roi_scale = roi_scale - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) # use the spatial size of first image to scale, expect all images have the same spatial size img_size = data[self.keys[0]].shape[1:] @@ -466,7 +467,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, DataObjects.Images]) -> Dict[Hashable, DataObjects.Images]: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -546,7 +547,7 @@ def randomize(self, img_size: Sequence[int]) -> None: valid_size = get_valid_patch_size(img_size, self._size) self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) self.randomize(d[self.keys[0]].shape[1:]) # image shape from the first data key if self._size is None: @@ -561,7 +562,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = cropper(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -640,7 +641,7 @@ def __init__( self.roi_scale = roi_scale self.max_roi_scale = max_roi_scale - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: img_size = data[self.keys[0]].shape[1:] ndim = len(img_size) self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] @@ -736,7 +737,7 @@ def set_random_state( def randomize(self, data: Optional[Any] = None) -> None: pass - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: DataObjects.Mapping) -> List[DataObjects.Dict]: ret = [] for i in range(self.num_samples): d = dict(data) @@ -757,7 +758,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n ret.append(cropped) return ret - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) # We changed the transform name from RandSpatialCropd to RandSpatialCropSamplesd # Need to revert that since we're calling RandSpatialCropd's inverse @@ -826,7 +827,7 @@ def __init__( mode=mode, ) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) box_start, box_end = self.cropper.compute_bounding_box(img=d[self.source_key]) d[self.start_coord_key] = box_start @@ -836,7 +837,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.cropper.crop_pad(d[key], box_start, box_end) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -920,13 +921,14 @@ def randomize(self, weight_map: np.ndarray) -> None: spatial_size=self.spatial_size, w=weight_map[0], n_samples=self.num_samples, r_state=self.R ) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: DataObjects.Mapping) -> List[DataObjects.Dict]: d = dict(data) - self.randomize(d[self.w_key]) + weight_map: np.ndarray = convert_data_type(d[self.w_key], np.ndarray)[0] # type: ignore + self.randomize(weight_map) _spatial_size = fall_back_tuple(self.spatial_size, d[self.w_key].shape[1:]) # initialize returned list with shallow copy to preserve key ordering - results: List[Dict[Hashable, np.ndarray]] = [dict(data) for _ in range(self.num_samples)] + results: List[Dict[Hashable, Any]] = [dict(data) for _ in range(self.num_samples)] # fill in the extra keys with unmodified data for i in range(self.num_samples): for key in set(data.keys()).difference(set(self.keys)): @@ -956,7 +958,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n return results - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -1087,12 +1089,16 @@ def randomize( self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R ) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: DataObjects.Mapping) -> List[DataObjects.Dict]: d = dict(data) - label = d[self.label_key] - image = d[self.image_key] if self.image_key else None - fg_indices = d.get(self.fg_indices_key) if self.fg_indices_key is not None else None - bg_indices = d.get(self.bg_indices_key) if self.bg_indices_key is not None else None + + def get_as_np(x) -> np.ndarray: + return convert_data_type(x, np.ndarray)[0] # type: ignore + + label = get_as_np(d[self.label_key]) + image = get_as_np(d[self.image_key]) if self.image_key else None + fg_indices = get_as_np(d[self.fg_indices_key]) if self.fg_indices_key is not None else None + bg_indices = get_as_np(d[self.bg_indices_key]) if self.bg_indices_key is not None else None self.randomize(label, fg_indices, bg_indices, image) if not isinstance(self.spatial_size, tuple): @@ -1111,7 +1117,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n img = d[key] cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore orig_size = img.shape[1:] - results[i][key] = cropper(img) + results[i][key] = cropper(img) # type: ignore self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) # add `patch_index` to the meta data for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): @@ -1138,7 +1144,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) inverse_transform = BorderPad(pad) # Apply inverse transform - d[key] = inverse_transform(d[key]) + d[key] = inverse_transform(d[key]) # type: ignore # Remove the applied transform self.pop_transform(d, key) @@ -1254,11 +1260,11 @@ def __init__( def randomize( self, label: np.ndarray, - indices: Optional[List[np.ndarray]] = None, + indices: Optional[List[DataObjects.Images]] = None, image: Optional[np.ndarray] = None, ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) - indices_: List[np.ndarray] + indices_: List[DataObjects.Images] if indices is None: indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) else: @@ -1267,7 +1273,7 @@ def randomize( self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R ) - def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, DataObjects.Images]]: d = dict(data) label = d[self.label_key] image = d[self.image_key] if self.image_key else None @@ -1280,7 +1286,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarr raise ValueError("no available ROI centers to crop.") # initialize returned list with shallow copy to preserve key ordering - results: List[Dict[Hashable, np.ndarray]] = [dict(data) for _ in range(self.num_samples)] + results: List[Dict[Hashable, Any]] = [dict(data) for _ in range(self.num_samples)] for i, center in enumerate(self.centers): # fill in the extra keys with unmodified data @@ -1301,7 +1307,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarr return results - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -1359,7 +1365,7 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, method=method, **np_kwargs) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, m in self.key_iterator(d, self.mode): orig_size = d[key].shape[1:] @@ -1374,7 +1380,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -1434,7 +1440,7 @@ def __init__( self.bbox = BoundingRect(select_fn=select_fn) self.bbox_key_postfix = bbox_key_postfix - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: """ See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`. """ diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index dfbac7465c..0d53cbb8c5 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -14,7 +14,9 @@ """ from collections.abc import Iterable -from typing import Any, List, Optional, Sequence, Tuple, Union +from copy import deepcopy +from functools import partial +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union from warnings import warn import numpy as np @@ -23,17 +25,18 @@ from monai.config import DtypeLike from monai.data.utils import get_random_patch, get_valid_patch_size from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter -from monai.transforms.transform import RandomizableTransform, Transform +from monai.transforms.transform import NumpyTransform, RandomizableTransform, TorchTransform from monai.transforms.utils import rescale_array from monai.utils import ( PT_BEFORE_1_7, InvalidPyTorchVersionError, - dtype_torch_to_numpy, ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, ) +from monai.utils.enums import DataObjects +from monai.utils.misc import convert_data_type, dtype_convert __all__ = [ "RandGaussianNoise", @@ -67,7 +70,7 @@ ] -class RandGaussianNoise(RandomizableTransform): +class RandGaussianNoise(TorchTransform, NumpyTransform, RandomizableTransform): """ Add Gaussian noise to image. @@ -87,20 +90,23 @@ def randomize(self, im_shape: Sequence[int]) -> None: super().randomize(None) self._noise = self.R.normal(self.mean, self.R.uniform(0, self.std), size=im_shape) - def __call__(self, img: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ self.randomize(img.shape) + if self._noise is None: raise AssertionError if not self._do_transform: return img - dtype = dtype_torch_to_numpy(img.dtype) if isinstance(img, torch.Tensor) else img.dtype - return img + self._noise.astype(dtype) + noise, *_ = convert_data_type( + self._noise, type(img), dtype=img.dtype, device=img.device if isinstance(img, torch.Tensor) else None + ) + return img + noise -class RandRicianNoise(RandomizableTransform): +class RandRicianNoise(TorchTransform, NumpyTransform, RandomizableTransform): """ Add Rician noise to image. Rician noise in MRI is the result of performing a magnitude operation on complex @@ -139,20 +145,23 @@ def __init__( self.channel_wise = channel_wise self.relative = relative self.sample_std = sample_std - self._noise1: np.ndarray - self._noise2: np.ndarray + self._noise1: DataObjects.Images + self._noise2: DataObjects.Images - def _add_noise(self, img: Union[torch.Tensor, np.ndarray], mean: float, std: float): + def _add_noise(self, img: DataObjects.Images, mean: float, std: float): + dtype_np = dtype_convert(img.dtype, np.ndarray) im_shape = img.shape _std = self.R.uniform(0, std) if self.sample_std else std - self._noise1 = self.R.normal(mean, _std, size=im_shape) - self._noise2 = self.R.normal(mean, _std, size=im_shape) - if self._noise1 is None or self._noise2 is None: - raise AssertionError - dtype = dtype_torch_to_numpy(img.dtype) if isinstance(img, torch.Tensor) else img.dtype - return np.sqrt((img + self._noise1.astype(dtype)) ** 2 + self._noise2.astype(dtype) ** 2) + self._noise1 = self.R.normal(mean, _std, size=im_shape).astype(dtype_np) + self._noise2 = self.R.normal(mean, _std, size=im_shape).astype(dtype_np) + if isinstance(img, torch.Tensor): + n1 = torch.tensor(self._noise1, device=img.device) + n2 = torch.tensor(self._noise2, device=img.device) + return torch.sqrt((img + n1) ** 2 + n2 ** 2) + + return np.sqrt((img + self._noise1) ** 2 + self._noise2 ** 2) # type: ignore - def __call__(self, img: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ @@ -176,7 +185,7 @@ def __call__(self, img: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, return img -class ShiftIntensity(Transform): +class ShiftIntensity(TorchTransform, NumpyTransform): """ Shift intensity uniformly for the entire image with specified `offset`. @@ -187,14 +196,17 @@ class ShiftIntensity(Transform): def __init__(self, offset: float) -> None: self.offset = offset - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ - return np.asarray((img + self.offset), dtype=img.dtype) + out = img + self.offset + if isinstance(out, torch.Tensor): + return out.type(img.dtype) # type: ignore + return out.astype(img.dtype) # type: ignore -class RandShiftIntensity(RandomizableTransform): +class RandShiftIntensity(RandomizableTransform, TorchTransform, NumpyTransform): """ Randomly shift intensity with randomly picked offset. """ @@ -219,7 +231,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) super().randomize(None) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ @@ -230,7 +242,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: return shifter(img) -class StdShiftIntensity(Transform): +class StdShiftIntensity(TorchTransform, NumpyTransform): """ Shift intensity for the image with a factor and the standard deviation of the image by: ``v = v + factor * std(v)``. @@ -253,28 +265,38 @@ def __init__( self.channel_wise = channel_wise self.dtype = dtype - def _stdshift(self, img: np.ndarray) -> np.ndarray: - slices = (img != 0) if self.nonzero else np.ones(img.shape, dtype=bool) - if not np.any(slices): - return img - offset = self.factor * np.std(img[slices]) - img[slices] = img[slices] + offset + def _stdshift(self, img: DataObjects.Images) -> DataObjects.Images: + ones: Callable + std: Callable + if isinstance(img, torch.Tensor): + ones = torch.ones + std = partial(torch.std, unbiased=False) + else: + ones = np.ones + std = np.std + + slices = (img != 0) if self.nonzero else ones(img.shape, dtype=bool) + if slices.any(): + out = deepcopy(img) + offset = self.factor * std(out[slices]) + out[slices] = out[slices] + offset + return out return img - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ - img = img.astype(self.dtype) + img, *_ = convert_data_type(img, dtype=self.dtype) if self.channel_wise: for i, d in enumerate(img): - img[i] = self._stdshift(d) + img[i] = self._stdshift(d) # type: ignore else: img = self._stdshift(img) return img -class RandStdShiftIntensity(RandomizableTransform): +class RandStdShiftIntensity(TorchTransform, NumpyTransform, RandomizableTransform): """ Shift intensity for the image with a factor and the standard deviation of the image by: ``v = v + factor * std(v)`` where the `factor` is randomly picked. @@ -314,7 +336,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) super().randomize(None) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ @@ -327,7 +349,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: return shifter(img) -class ScaleIntensity(Transform): +class ScaleIntensity(TorchTransform, NumpyTransform): """ Scale the intensity of input image to the given value range (minv, maxv). If `minv` and `maxv` not provided, use `factor` to scale image by ``v = v * (1 + factor)``. @@ -347,7 +369,7 @@ def __init__( self.maxv = maxv self.factor = factor - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. @@ -356,13 +378,17 @@ def __call__(self, img: np.ndarray) -> np.ndarray: """ if self.minv is not None and self.maxv is not None: - return np.asarray(rescale_array(img, self.minv, self.maxv, img.dtype)) + return rescale_array(img, self.minv, self.maxv, img.dtype) if self.factor is not None: - return np.asarray(img * (1 + self.factor), dtype=img.dtype) + out = img * (1 + self.factor) + if isinstance(out, torch.Tensor): + return out.to(img.dtype) # type: ignore + else: + return out.astype(img.dtype) # type: ignore raise ValueError("Incompatible values: minv=None or maxv=None and factor=None.") -class RandScaleIntensity(RandomizableTransform): +class RandScaleIntensity(TorchTransform, NumpyTransform, RandomizableTransform): """ Randomly scale the intensity of input image by ``v = v * (1 + factor)`` where the `factor` is randomly picked. @@ -389,7 +415,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) super().randomize(None) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ @@ -400,7 +426,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: return scaler(img) -class RandBiasField(RandomizableTransform): +class RandBiasField(RandomizableTransform, TorchTransform, NumpyTransform): """ Random bias field augmentation for MR images. The bias field is considered as a linear combination of smoothly varying basis (polynomial) @@ -459,12 +485,12 @@ def _generate_random_field(self, spatial_shape: Sequence[int], degree: int, coef return np.polynomial.legendre.leggrid3d(coords[0], coords[1], coords[2], coeff_mat) raise NotImplementedError("only supports 2D or 3D fields") - def randomize(self, data: np.ndarray) -> None: + def randomize(self, data: DataObjects.Images) -> None: super().randomize(None) n_coeff = int(np.prod([(self.degree + k) / k for k in range(1, len(data.shape[1:]) + 1)])) self._coeff = self.R.uniform(*self.coeff_range, n_coeff).tolist() - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ @@ -479,10 +505,18 @@ def __call__(self, img: np.ndarray): ], axis=0, ) - return (img * np.exp(_bias_fields)).astype(self.dtype) + _bias_fields_exp: DataObjects.Images + _bias_fields_exp = np.exp(_bias_fields) + _bias_fields_exp, *_ = convert_data_type( + _bias_fields_exp, type(img), dtype=self.dtype, device=img.device if isinstance(img, torch.Tensor) else None + ) + out = img * _bias_fields_exp + out, *_ = convert_data_type(out, dtype=self.dtype) + return out -class NormalizeIntensity(Transform): + +class NormalizeIntensity(TorchTransform, NumpyTransform): """ Normalize input based on provided args, using calculated mean and std if not provided. This transform can normalize only non-zero values or entire image, and can also calculate @@ -501,8 +535,8 @@ class NormalizeIntensity(Transform): def __init__( self, - subtrahend: Union[Sequence, np.ndarray, None] = None, - divisor: Union[Sequence, np.ndarray, None] = None, + subtrahend: Optional[Union[Sequence, DataObjects.Images]] = None, + divisor: Optional[Union[Sequence, DataObjects.Images]] = None, nonzero: bool = False, channel_wise: bool = False, dtype: DtypeLike = np.float32, @@ -513,26 +547,55 @@ def __init__( self.channel_wise = channel_wise self.dtype = dtype - def _normalize(self, img: np.ndarray, sub=None, div=None) -> np.ndarray: - slices = (img != 0) if self.nonzero else np.ones(img.shape, dtype=bool) - if not np.any(slices): + @staticmethod + def _mean(x): + if isinstance(x, np.ndarray): + return np.mean(x) + x = torch.mean(x.float()) + return x.item() if x.numel() == 1 else x + + @staticmethod + def _std(x): + if isinstance(x, np.ndarray): + return np.std(x) + x = torch.std(x.float(), unbiased=False) + return x.item() if x.numel() == 1 else x + + def _normalize(self, img: DataObjects.Images, sub=None, div=None) -> DataObjects.Images: + img, *_ = convert_data_type(img, dtype=torch.float32) + + if self.nonzero: + slices = img != 0 + else: + if isinstance(img, np.ndarray): + slices = np.ones_like(img, dtype=bool) + else: + slices = torch.ones_like(img, dtype=torch.bool) + if not slices.any(): return img - _sub = sub if sub is not None else np.mean(img[slices]) - if isinstance(_sub, np.ndarray): + _sub = sub if sub is not None else self._mean(img[slices]) + if isinstance(_sub, (np.ndarray, torch.Tensor)): + _sub, *_ = convert_data_type( + _sub, type(img), dtype=img.dtype, device=img.device if isinstance(img, torch.Tensor) else None + ) _sub = _sub[slices] - _div = div if div is not None else np.std(img[slices]) + _div = div if div is not None else self._std(img[slices]) if np.isscalar(_div): if _div == 0.0: _div = 1.0 - elif isinstance(_div, np.ndarray): + elif isinstance(_div, (np.ndarray, torch.Tensor)): + _div, *_ = convert_data_type( + _div, type(img), dtype=img.dtype, device=img.device if isinstance(img, torch.Tensor) else None + ) _div = _div[slices] _div[_div == 0.0] = 1.0 img[slices] = (img[slices] - _sub) / _div + return img - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True, """ @@ -543,7 +606,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: raise ValueError(f"img has {len(img)} channels, but divisor has {len(self.divisor)} components.") for i, d in enumerate(img): - img[i] = self._normalize( + img[i] = self._normalize( # type: ignore d, sub=self.subtrahend[i] if self.subtrahend is not None else None, div=self.divisor[i] if self.divisor is not None else None, @@ -551,10 +614,11 @@ def __call__(self, img: np.ndarray) -> np.ndarray: else: img = self._normalize(img, self.subtrahend, self.divisor) - return img.astype(self.dtype) + out, *_ = convert_data_type(img, dtype=self.dtype) + return out -class ThresholdIntensity(Transform): +class ThresholdIntensity(TorchTransform, NumpyTransform): """ Filter the intensity values of whole image to below threshold or above threshold. And fill the remaining parts of the image to the `cval` value. @@ -572,16 +636,18 @@ def __init__(self, threshold: float, above: bool = True, cval: float = 0.0) -> N self.above = above self.cval = cval - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ - return np.asarray( - np.where(img > self.threshold if self.above else img < self.threshold, img, self.cval), dtype=img.dtype - ) + where = np.where if isinstance(img, np.ndarray) else torch.where + cval = self.cval if isinstance(img, np.ndarray) else torch.tensor(self.cval, dtype=img.dtype, device=img.device) + res = where(img > self.threshold if self.above else img < self.threshold, img, cval) + res, *_ = convert_data_type(res, dtype=img.dtype) + return res # type: ignore -class ScaleIntensityRange(Transform): +class ScaleIntensityRange(TorchTransform, NumpyTransform): """ Apply specific intensity scaling to the whole numpy array. Scaling from [a_min, a_max] to [b_min, b_max] with clip option. @@ -601,7 +667,7 @@ def __init__(self, a_min: float, a_max: float, b_min: float, b_max: float, clip: self.b_max = b_max self.clip = clip - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ @@ -612,11 +678,12 @@ def __call__(self, img: np.ndarray): img = (img - self.a_min) / (self.a_max - self.a_min) img = img * (self.b_max - self.b_min) + self.b_min if self.clip: - img = np.asarray(np.clip(img, self.b_min, self.b_max)) + mod = torch if isinstance(img, torch.Tensor) else np + img = mod.clip(img, self.b_min, self.b_max) # type: ignore return img -class AdjustContrast(Transform): +class AdjustContrast(TorchTransform, NumpyTransform): """ Changes image intensity by gamma. Each pixel/voxel intensity is updated as:: @@ -631,17 +698,17 @@ def __init__(self, gamma: float) -> None: raise AssertionError("gamma must be a float or int number.") self.gamma = gamma - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ epsilon = 1e-7 img_min = img.min() img_range = img.max() - img_min - return np.power(((img - img_min) / float(img_range + epsilon)), self.gamma) * img_range + img_min + return ((img - img_min) / float(img_range + epsilon)) ** self.gamma * img_range + img_min -class RandAdjustContrast(RandomizableTransform): +class RandAdjustContrast(TorchTransform, NumpyTransform, RandomizableTransform): """ Randomly changes image intensity by gamma. Each pixel/voxel intensity is updated as:: @@ -686,7 +753,7 @@ def __call__(self, img: np.ndarray): return adjuster(img) -class ScaleIntensityRangePercentiles(Transform): +class ScaleIntensityRangePercentiles(TorchTransform, NumpyTransform): """ Apply range scaling to a numpy array based on the intensity distribution of the input. @@ -755,7 +822,7 @@ def __init__( self.clip = clip self.relative = relative - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ @@ -772,12 +839,13 @@ def __call__(self, img: np.ndarray): img = scalar(img) if self.clip: - img = np.asarray(np.clip(img, self.b_min, self.b_max)) + mod = torch if isinstance(img, torch.Tensor) else np + img = mod.clip(img, self.b_min, self.b_max) # type: ignore return img -class MaskIntensity(Transform): +class MaskIntensity(TorchTransform, NumpyTransform): """ Mask the intensity values of input image with the specified mask data. Mask data must have the same spatial size as the input image, and all @@ -792,10 +860,10 @@ class MaskIntensity(Transform): """ - def __init__(self, mask_data: Optional[np.ndarray]) -> None: + def __init__(self, mask_data: Optional[DataObjects.Images]) -> None: self.mask_data = mask_data - def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> np.ndarray: + def __call__(self, img: DataObjects.Images, mask_data: Optional[DataObjects.Images] = None) -> DataObjects.Images: """ Args: mask_data: if mask data is single channel, apply to every channel @@ -808,24 +876,25 @@ def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> n - ValueError: When ``mask_data`` and ``img`` channels differ and ``mask_data`` is not single channel. """ - if self.mask_data is None and mask_data is None: - raise ValueError("Unknown mask_data.") - mask_data_ = np.array([[1]]) - if self.mask_data is not None and mask_data is None: - mask_data_ = self.mask_data > 0 + if mask_data is not None: mask_data_ = mask_data > 0 - mask_data_ = np.asarray(mask_data_) + elif self.mask_data is not None: + mask_data_ = self.mask_data > 0 + else: + raise ValueError("Unknown mask_data.") if mask_data_.shape[0] != 1 and mask_data_.shape[0] != img.shape[0]: raise ValueError( "When mask_data is not single channel, mask_data channels must match img, " f"got img={img.shape[0]} mask_data={mask_data_.shape[0]}." ) - - return np.asarray(img * mask_data_) + mask_data_, *_ = convert_data_type( + mask_data_, type(img), device=img.device if isinstance(img, torch.Tensor) else None + ) + return img * mask_data_ -class SavitzkyGolaySmooth(Transform): +class SavitzkyGolaySmooth(TorchTransform): """ Smooth the input data along the given axis using a Savitzky-Golay filter. @@ -847,23 +916,27 @@ def __init__(self, window_length: int, order: int, axis: int = 1, mode: str = "z self.axis = axis self.mode = mode - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: - img: numpy.ndarray containing input data. Must be real and in shape [channels, spatial1, spatial2, ...]. + img: array containing input data. Must be real and in shape [channels, spatial1, spatial2, ...]. Returns: - np.ndarray containing smoothed result. + array containing smoothed result. """ + img_t: torch.Tensor + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor) # type: ignore + # add one to transform axis because a batch axis will be added at dimension 0 savgol_filter = SavitzkyGolayFilter(self.window_length, self.order, self.axis + 1, self.mode) # convert to Tensor and add Batch axis expected by HilbertTransform - input_data = torch.as_tensor(np.ascontiguousarray(img)).unsqueeze(0) - return savgol_filter(input_data).squeeze(0).numpy() + out_t = savgol_filter(img_t.unsqueeze(0)).squeeze(0) + out, *_ = convert_data_type(out_t, orig_type, orig_device) + return out -class DetectEnvelope(Transform): +class DetectEnvelope(TorchTransform): """ Find the envelope of the input data along the requested axis using a Hilbert transform. Requires PyTorch 1.7.0+ and the PyTorch FFT module (which is not included in NVIDIA PyTorch Release 20.10). @@ -886,24 +959,28 @@ def __init__(self, axis: int = 1, n: Union[int, None] = None) -> None: self.axis = axis self.n = n - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: - img: numpy.ndarray containing input data. Must be real and in shape [channels, spatial1, spatial2, ...]. + img: array containing input data. Must be real and in shape [channels, spatial1, spatial2, ...]. Returns: - np.ndarray containing envelope of data in img along the specified axis. + array containing envelope of data in img along the specified axis. """ + img_t: torch.Tensor + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor) # type: ignore + # add one to transform axis because a batch axis will be added at dimension 0 hilbert_transform = HilbertTransform(self.axis + 1, self.n) # convert to Tensor and add Batch axis expected by HilbertTransform - input_data = torch.as_tensor(np.ascontiguousarray(img)).unsqueeze(0) - return np.abs(hilbert_transform(input_data).squeeze(0).numpy()) + out_t = torch.abs(hilbert_transform(img_t.unsqueeze(0))).squeeze(0) + out, *_ = convert_data_type(out_t, orig_type, orig_device) + return out -class GaussianSmooth(Transform): +class GaussianSmooth(TorchTransform): """ Apply Gaussian smooth to the input data based on specified `sigma` parameter. A default value `sigma=1.0` is provided for reference. @@ -921,13 +998,18 @@ def __init__(self, sigma: Union[Sequence[float], float] = 1.0, approx: str = "er self.sigma = sigma self.approx = approx - def __call__(self, img: np.ndarray): - gaussian_filter = GaussianFilter(img.ndim - 1, self.sigma, approx=self.approx) - input_data = torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0) - return gaussian_filter(input_data).squeeze(0).detach().numpy() + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: + img_t: torch.Tensor + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor) # type: ignore + img_t = img_t.to(torch.float) + + gaussian_filter = GaussianFilter(img.ndim - 1, self.sigma, approx=self.approx).to(img_t.device) + out_t = gaussian_filter(img_t.unsqueeze(0)).squeeze(0) + out, *_ = convert_data_type(out_t, orig_type, orig_device) + return out -class RandGaussianSmooth(RandomizableTransform): +class RandGaussianSmooth(TorchTransform, RandomizableTransform): """ Apply Gaussian smooth to the input data based on randomly selected `sigma` parameters. @@ -965,7 +1047,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.y = self.R.uniform(low=self.sigma_y[0], high=self.sigma_y[1]) self.z = self.R.uniform(low=self.sigma_z[0], high=self.sigma_z[1]) - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: self.randomize() if not self._do_transform: return img @@ -973,7 +1055,7 @@ def __call__(self, img: np.ndarray): return GaussianSmooth(sigma=sigma, approx=self.approx)(img) -class GaussianSharpen(Transform): +class GaussianSharpen(TorchTransform): """ Sharpen images using the Gaussian Blur filter. Referring to: http://scipy-lectures.org/advanced/image_processing/auto_examples/plot_sharpen.html. @@ -1012,16 +1094,22 @@ def __init__( self.alpha = alpha self.approx = approx - def __call__(self, img: np.ndarray): - gaussian_filter1 = GaussianFilter(img.ndim - 1, self.sigma1, approx=self.approx) - gaussian_filter2 = GaussianFilter(img.ndim - 1, self.sigma2, approx=self.approx) - input_data = torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0) - blurred_f = gaussian_filter1(input_data) - filter_blurred_f = gaussian_filter2(blurred_f) - return (blurred_f + self.alpha * (blurred_f - filter_blurred_f)).squeeze(0).detach().numpy() + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: + img_t: torch.Tensor + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor, dtype=torch.float32) # type: ignore + + gf1, gf2 = [ + GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx).to(img_t.device) + for sigma in (self.sigma1, self.sigma2) + ] + blurred_f = gf1(img_t.unsqueeze(0)) + filter_blurred_f = gf2(blurred_f) + out_t = (blurred_f + self.alpha * (blurred_f - filter_blurred_f)).squeeze(0) + out, *_ = convert_data_type(out_t, orig_type, orig_device) + return out -class RandGaussianSharpen(RandomizableTransform): +class RandGaussianSharpen(TorchTransform, RandomizableTransform): """ Sharpen images using the Gaussian Blur filter based on randomly selected `sigma1`, `sigma2` and `alpha`. The algorithm is :py:class:`monai.transforms.GaussianSharpen`. @@ -1087,7 +1175,7 @@ def __call__(self, img: np.ndarray): return GaussianSharpen(sigma1=sigma1, sigma2=sigma2, alpha=self.a, approx=self.approx)(img) -class RandHistogramShift(RandomizableTransform): +class RandHistogramShift(RandomizableTransform, NumpyTransform): """ Apply random nonlinear transform to the image's intensity histogram. @@ -1122,19 +1210,23 @@ def randomize(self, data: Optional[Any] = None) -> None: self.floating_control_points[i - 1], self.floating_control_points[i + 1] ) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: self.randomize() if not self._do_transform: return img - img_min, img_max = img.min(), img.max() + + img_np: np.ndarray + img_np, orig_type, orig_device = convert_data_type(img, np.ndarray) # type: ignore + img_min, img_max = img_np.min(), img_np.max() reference_control_points_scaled = self.reference_control_points * (img_max - img_min) + img_min floating_control_points_scaled = self.floating_control_points * (img_max - img_min) + img_min - return np.asarray( - np.interp(img, reference_control_points_scaled, floating_control_points_scaled), dtype=img.dtype - ) + + out_np = np.interp(img_np, reference_control_points_scaled, floating_control_points_scaled) + out, *_ = convert_data_type(out_np, orig_type, orig_device, dtype=img.dtype) + return out -class RandGibbsNoise(RandomizableTransform): +class RandGibbsNoise(TorchTransform, NumpyTransform, RandomizableTransform): """ Naturalistic image augmentation via Gibbs artifacts. The transform randomly applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts @@ -1153,10 +1245,9 @@ class RandGibbsNoise(RandomizableTransform): values in the interval [0,1] with alpha = 0 acting as the identity mapping. If a length-2 list is given as [a,b] then the value of alpha will be sampled uniformly from the interval [a,b]. 0 <= a <= b <= 1. - as_tensor_output: if true return torch.Tensor, else return np.array. default: True. """ - def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0), as_tensor_output: bool = True) -> None: + def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0)) -> None: if len(alpha) != 2: raise AssertionError("alpha length must be 2.") @@ -1167,24 +1258,18 @@ def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0), as_te self.alpha = alpha self.sampled_alpha = -1.0 # stores last alpha sampled by randomize() - self.as_tensor_output = as_tensor_output RandomizableTransform.__init__(self, prob=prob) - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: # randomize application and possibly alpha self._randomize(None) if self._do_transform: # apply transform - transform = GibbsNoise(self.sampled_alpha, self.as_tensor_output) + transform = GibbsNoise(self.sampled_alpha) img = transform(img) - else: - if isinstance(img, np.ndarray) and self.as_tensor_output: - img = torch.Tensor(img) - elif isinstance(img, torch.Tensor) and not self.as_tensor_output: - img = img.detach().cpu().numpy() return img def _randomize(self, _: Any) -> None: @@ -1196,7 +1281,7 @@ def _randomize(self, _: Any) -> None: self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1]) -class GibbsNoise(Transform): +class GibbsNoise(TorchTransform, NumpyTransform): """ The transform applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts are one of the common type of type artifacts appearing in MRI scans. @@ -1211,61 +1296,56 @@ class GibbsNoise(Transform): Args: alpha (float): Parametrizes the intensity of the Gibbs noise filter applied. Takes values in the interval [0,1] with alpha = 0 acting as the identity mapping. - as_tensor_output: if true return torch.Tensor, else return np.array. default: True. - """ - def __init__(self, alpha: float = 0.5, as_tensor_output: bool = True) -> None: + def __init__(self, alpha: float = 0.5) -> None: if alpha > 1 or alpha < 0: raise AssertionError("alpha must take values in the interval [0,1].") self.alpha = alpha - self.as_tensor_output = as_tensor_output - self._device = torch.device("cpu") - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: n_dims = len(img.shape[1:]) - - # convert to ndarray to work with np.fft - _device = None - if isinstance(img, torch.Tensor): - _device = img.device - img = img.cpu().detach().numpy() - + data_type = type(img) # FT - k = self._shift_fourier(img, n_dims) + k = self._shift_fourier(img, n_dims, data_type) # build and apply mask - k = self._apply_mask(k) + k = self._apply_mask(k, data_type) # map back - img = self._inv_shift_fourier(k, n_dims) - return torch.Tensor(img).to(_device or self._device) if self.as_tensor_output else img + return self._inv_shift_fourier(k, n_dims, data_type) - def _shift_fourier(self, x: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray: + def _shift_fourier(self, x: DataObjects.Images, n_dims: int, data_type: type) -> DataObjects.Images: """ Applies fourier transform and shifts its output. Only the spatial dimensions get transformed. Args: - x (np.ndarray): tensor to fourier transform. + x: tensor to fourier transform. """ - out: np.ndarray = np.fft.fftshift(np.fft.fftn(x, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) + # argument is dim if torch, else axes + mod, arg = (torch, "dim") if data_type is torch.Tensor else (np, "axes") + arg_dict = {arg: tuple(range(-n_dims, 0))} + out: DataObjects.Images = mod.fft.fftshift(mod.fft.fftn(x, **arg_dict), **arg_dict) # type: ignore return out - def _inv_shift_fourier(self, k: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray: + def _inv_shift_fourier(self, k: DataObjects.Images, n_dims: int, data_type: type) -> DataObjects.Images: """ Applies inverse shift and fourier transform. Only the spatial dimensions are transformed. """ - out: np.ndarray = np.fft.ifftn( - np.fft.ifftshift(k, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0)) - ).real - return out + dims = tuple(range(-n_dims, 0)) + out: DataObjects.Images + if data_type is torch.Tensor: + out = torch.fft.ifftn(torch.fft.ifftshift(k, dim=dims), dim=dims, norm="backward") + else: + out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims) + return out.real - def _apply_mask(self, k: np.ndarray) -> np.ndarray: + def _apply_mask(self, k: DataObjects.Images, data_type: type) -> DataObjects.Images: """Builds and applies a mask on the spatial dimensions. Args: - k (np.ndarray): k-space version of the image. + k: k-space version of the image. Returns: masked version of the k-space image. """ @@ -1286,12 +1366,15 @@ def _apply_mask(self, k: np.ndarray) -> np.ndarray: # add channel dimension into mask mask = np.repeat(mask[None], k.shape[0], axis=0) + if data_type is torch.Tensor: + mask = torch.Tensor(mask).to(k.device) # type: ignore + # apply binary mask - k_masked: np.ndarray = k * mask - return k_masked + out: DataObjects.Images = k * mask # type: ignore + return out -class KSpaceSpikeNoise(Transform): +class KSpaceSpikeNoise(TorchTransform, NumpyTransform): """ Apply localized spikes in `k`-space at the given locations and intensities. Spike (Herringbone) artifact is a type of data acquisition artifact which @@ -1318,8 +1401,6 @@ class KSpaceSpikeNoise(Transform): receive a sequence of intensities. This value should be tested as it is data-dependent. The default values are the 2.5 the mean of the log-intensity for each channel. - as_tensor_output: if ``True`` return torch.Tensor, else return np.array. - Default: ``True``. Example: When working with 4D data, ``KSpaceSpikeNoise(loc = ((3,60,64,32), (64,60,32)), k_intensity = (13,14))`` @@ -1332,11 +1413,9 @@ def __init__( self, loc: Union[Tuple, Sequence[Tuple]], k_intensity: Optional[Union[Sequence[float], float]] = None, - as_tensor_output: bool = True, ): self.loc = ensure_tuple(loc) - self.as_tensor_output = as_tensor_output self.k_intensity = k_intensity # assert one-to-one relationship between factors and locations @@ -1351,7 +1430,7 @@ def __init__( if not isinstance(self.k_intensity, Sequence): raise AssertionError("There must be one intensity_factor value for each tuple of indices in loc.") - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: img (np.array or torch.tensor): image with dimensions (C, H, W) or (C, H, W, D) @@ -1367,23 +1446,18 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, raise AssertionError("Input images of dimension 4 need location tuple to be length 3 or 4") n_dims = len(img.shape[1:]) - - # convert to ndarray to work with np.fft - if isinstance(img, torch.Tensor): - device = img.device - img = img.cpu().detach().numpy() - else: - device = torch.device("cpu") + data_type = type(img) + lib = np if isinstance(img, np.ndarray) else torch # FT - k = self._shift_fourier(img, n_dims) - log_abs = np.log(np.absolute(k) + 1e-10) - phase = np.angle(k) + k = self._shift_fourier(img, n_dims, data_type) + log_abs = lib.log(lib.absolute(k) + 1e-10) # type: ignore + phase = lib.angle(k) # type: ignore k_intensity = self.k_intensity # default log intensity if k_intensity is None: - k_intensity = tuple(np.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5) + k_intensity = tuple(lib.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5) # type: ignore # highlight if isinstance(self.loc[0], Sequence): @@ -1392,9 +1466,10 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, else: self._set_spike(log_abs, self.loc, k_intensity) # map back - k = np.exp(log_abs) * np.exp(1j * phase) - img = self._inv_shift_fourier(k, n_dims) - return torch.Tensor(img, device=device) if self.as_tensor_output else img + k = lib.exp(log_abs) * lib.exp(1j * phase) # type: ignore + img = self._inv_shift_fourier(k, n_dims, data_type) + + return img def _check_indices(self, img) -> None: """Helper method to check consistency of self.loc and input image. @@ -1414,7 +1489,7 @@ def _check_indices(self, img) -> None: f"The index value at position {i} of one of the tuples in loc = {self.loc} is out of bounds for current image." ) - def _set_spike(self, k: np.ndarray, idx: Tuple, val: Union[Sequence[float], float]): + def _set_spike(self, k: DataObjects.Images, idx: Tuple, val: Union[Sequence[float], float]): """ Helper function to introduce a given intensity at given location. @@ -1429,11 +1504,11 @@ def _set_spike(self, k: np.ndarray, idx: Tuple, val: Union[Sequence[float], floa else: k[idx] = val elif len(k.shape) == 4 and len(idx) == 3: - k[:, idx[0], idx[1], idx[2]] = val + k[:, idx[0], idx[1], idx[2]] = val # type: ignore elif len(k.shape) == 3 and len(idx) == 2: - k[:, idx[0], idx[1]] = val + k[:, idx[0], idx[1]] = val # type: ignore - def _shift_fourier(self, x: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray: + def _shift_fourier(self, x: DataObjects.Images, n_dims: int, data_type: type) -> DataObjects.Images: """ Applies fourier transform and shifts its output. Only the spatial dimensions get transformed. @@ -1441,21 +1516,27 @@ def _shift_fourier(self, x: Union[np.ndarray, torch.Tensor], n_dims: int) -> np. Args: x (np.ndarray): tensor to fourier transform. """ - out: np.ndarray = np.fft.fftshift(np.fft.fftn(x, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) + # argument is dim if torch, else axes + mod, arg = (torch, "dim") if data_type is torch.Tensor else (np, "axes") + arg_dict = {arg: tuple(range(-n_dims, 0))} + out: DataObjects.Images = mod.fft.fftshift(mod.fft.fftn(x, **arg_dict), **arg_dict) # type: ignore return out - def _inv_shift_fourier(self, k: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray: + def _inv_shift_fourier(self, k: DataObjects.Images, n_dims: int, data_type: type) -> DataObjects.Images: """ Applies inverse shift and fourier transform. Only the spatial dimensions are transformed. """ - out: np.ndarray = np.fft.ifftn( - np.fft.ifftshift(k, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0)) - ).real - return out + dims = tuple(range(-n_dims, 0)) + out: DataObjects.Images + if data_type is torch.Tensor: + out = torch.fft.ifftn(torch.fft.ifftshift(k, dim=dims), dim=dims, norm="backward") + else: + out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims) + return out.real -class RandKSpaceSpikeNoise(RandomizableTransform): +class RandKSpaceSpikeNoise(RandomizableTransform, TorchTransform, NumpyTransform): """ Naturalistic data augmentation via spike artifacts. The transform applies localized spikes in `k`-space, and it is the random version of @@ -1484,8 +1565,6 @@ class RandKSpaceSpikeNoise(RandomizableTransform): log-intensity for each channel. channel_wise: treat each channel independently. True by default. - as_tensor_output: if True return torch.Tensor, else - return np.array. default: True. Example: To apply `k`-space spikes randomly with probability 0.5, and @@ -1499,12 +1578,10 @@ def __init__( prob: float = 0.1, intensity_range: Optional[Sequence[Union[Sequence[float], float]]] = None, channel_wise=True, - as_tensor_output: bool = True, ): self.intensity_range = intensity_range self.channel_wise = channel_wise - self.as_tensor_output = as_tensor_output self.sampled_k_intensity: List[float] = [] self.sampled_locs: List[Tuple] = [] @@ -1516,7 +1593,7 @@ def __init__( super().__init__(prob) - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply transform to `img`. Assumes data is in channel-first form. @@ -1532,19 +1609,16 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, self.sampled_k_intensity = [] self.sampled_locs = [] - # convert to ndarray to work with np.fft - x, device = self._to_numpy(img) - intensity_range = self._make_sequence(x) - self._randomize(x, intensity_range) + intensity_range = self._make_sequence(img) + self._randomize(img, intensity_range) # build/apply transform only if there are spike locations if self.sampled_locs: - transform = KSpaceSpikeNoise(self.sampled_locs, self.sampled_k_intensity, self.as_tensor_output) - return transform(x) - - return torch.Tensor(x, device=device) if self.as_tensor_output else x + transform = KSpaceSpikeNoise(self.sampled_locs, self.sampled_k_intensity) + return transform(img) + return img - def _randomize(self, img: np.ndarray, intensity_range: Sequence[Sequence[float]]) -> None: + def _randomize(self, img: DataObjects.Images, intensity_range: Sequence[Sequence[float]]) -> None: """ Helper method to sample both the location and intensity of the spikes. When not working channel wise (channel_wise=False) it use the random @@ -1572,7 +1646,7 @@ def _randomize(self, img: np.ndarray, intensity_range: Sequence[Sequence[float]] else: self.sampled_k_intensity = [self.R.uniform(*self.intensity_range)] * len(img) # type: ignore - def _make_sequence(self, x: np.ndarray) -> Sequence[Sequence[float]]: + def _make_sequence(self, x: DataObjects.Images) -> Sequence[Sequence[float]]: """ Formats the sequence of intensities ranges to Sequence[Sequence[float]]. """ @@ -1586,18 +1660,21 @@ def _make_sequence(self, x: np.ndarray) -> Sequence[Sequence[float]]: # set default range if one not provided return self._set_default_range(x) - def _set_default_range(self, x: np.ndarray) -> Sequence[Sequence[float]]: + def _set_default_range(self, x: DataObjects.Images) -> Sequence[Sequence[float]]: """ Sets default intensity ranges to be sampled. Args: - x (np.ndarray): tensor to fourier transform. + x: tensor to fourier transform. """ n_dims = len(x.shape[1:]) - k = np.fft.fftshift(np.fft.fftn(x, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) - log_abs = np.log(np.absolute(k) + 1e-10) - shifted_means = np.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5 + mod, arg = (torch, "dim") if type(x) is torch.Tensor else (np, "axes") + arg_dict = {arg: tuple(range(-n_dims, 0))} + k: DataObjects.Images = mod.fft.fftshift(mod.fft.fftn(x, **arg_dict), **arg_dict) # type: ignore + + log_abs = mod.log(mod.absolute(k) + 1e-10) # type: ignore + shifted_means = mod.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5 # type: ignore intensity_sequence = tuple((i * 0.95, i * 1.1) for i in shifted_means) return intensity_sequence @@ -1608,7 +1685,7 @@ def _to_numpy(self, img: Union[np.ndarray, torch.Tensor]) -> Tuple[np.ndarray, t return img, torch.device("cpu") -class RandCoarseDropout(RandomizableTransform): +class RandCoarseDropout(RandomizableTransform, TorchTransform, NumpyTransform): """ Randomly coarse dropout regions in the image, then fill in the rectangular regions with specified value. Refer to: https://arxiv.org/abs/1708.04552 and: @@ -1664,7 +1741,7 @@ def randomize(self, img_size: Sequence[int]) -> None: valid_size = get_valid_patch_size(img_size, size) self.hole_coords.append((slice(None),) + get_random_patch(img_size, valid_size, self.R)) - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: self.randomize(img.shape[1:]) if self._do_transform: for h in self.hole_coords: diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 49f20ea419..7b90505705 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -16,7 +16,7 @@ """ from collections.abc import Iterable -from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Hashable, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -42,7 +42,9 @@ ThresholdIntensity, ) from monai.transforms.transform import MapTransform, RandomizableTransform -from monai.utils import dtype_torch_to_numpy, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple +from monai.utils import ensure_tuple_rep, ensure_tuple_size, fall_back_tuple +from monai.utils.enums import DataObjects +from monai.utils.misc import convert_data_type __all__ = [ "RandGaussianNoised", @@ -159,7 +161,7 @@ def randomize(self, im_shape: Sequence[int]) -> None: for m in self.mean: self._noise.append(self.R.normal(m, self.R.uniform(0, self.std), size=im_shape)) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) image_shape = d[self.keys[0]].shape # image shape from the first data key @@ -169,8 +171,13 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda if not self._do_transform: return d for key, noise in self.key_iterator(d, self._noise): - dtype = dtype_torch_to_numpy(d[key].dtype) if isinstance(d[key], torch.Tensor) else d[key].dtype - d[key] = d[key] + noise.astype(dtype) + noise, *_ = convert_data_type( + noise, + type(d[key]), + dtype=d[key].dtype, + device=d[key].device if isinstance(d[key], torch.Tensor) else None, + ) + d[key] = d[key] + noise return d @@ -215,9 +222,11 @@ def __init__( RandomizableTransform.__init__(self, global_prob) self.rand_rician_noise = RandRicianNoise(prob, mean, std, channel_wise, relative, sample_std) - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + def set_random_state(self, seed=None, state=None): + super().set_random_state(seed, state) + self.rand_rician_noise.set_random_state(seed, state) + + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) super().randomize(None) if not self._do_transform: @@ -243,7 +252,7 @@ def __init__(self, keys: KeysCollection, offset: float, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.shifter = ShiftIntensity(offset) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.shifter(d[key]) @@ -287,7 +296,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) super().randomize(None) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) self.randomize() if not self._do_transform: @@ -326,7 +335,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.shifter = StdShiftIntensity(factor, nonzero, channel_wise, dtype) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.shifter(d[key]) @@ -378,7 +387,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) super().randomize(None) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) self.randomize() if not self._do_transform: @@ -418,7 +427,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensity(minv, maxv, factor) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.scaler(d[key]) @@ -463,7 +472,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) super().randomize(None) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) self.randomize() if not self._do_transform: @@ -508,7 +517,7 @@ def __init__( def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) self.randomize() if not self._do_transform: @@ -549,7 +558,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.normalizer = NormalizeIntensity(subtrahend, divisor, nonzero, channel_wise, dtype) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.normalizer(d[key]) @@ -580,7 +589,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.filter = ThresholdIntensity(threshold, above, cval) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.filter(d[key]) @@ -615,7 +624,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensityRange(a_min, a_max, b_min, b_max, clip) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.scaler(d[key]) @@ -640,7 +649,7 @@ def __init__(self, keys: KeysCollection, gamma: float, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.adjuster = AdjustContrast(gamma) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.adjuster(d[key]) @@ -690,7 +699,7 @@ def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) self.gamma_value = self.R.uniform(low=self.gamma[0], high=self.gamma[1]) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) self.randomize() if self.gamma_value is None: @@ -733,7 +742,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.scaler(d[key]) @@ -769,7 +778,7 @@ def __init__( self.converter = MaskIntensity(mask_data) self.mask_key = mask_key if mask_data is None else None - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key], d[self.mask_key]) if self.mask_key is not None else self.converter(d[key]) @@ -802,7 +811,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = GaussianSmooth(sigma, approx=approx) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -849,7 +858,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.y = self.R.uniform(low=self.sigma_y[0], high=self.sigma_y[1]) self.z = self.R.uniform(low=self.sigma_z[0], high=self.sigma_z[1]) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) self.randomize() if not self._do_transform: @@ -892,7 +901,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = GaussianSharpen(sigma1, sigma2, alpha, approx=approx) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -1018,17 +1027,21 @@ def randomize(self, data: Optional[Any] = None) -> None: self.floating_control_points[i - 1], self.floating_control_points[i + 1] ) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) self.randomize() if not self._do_transform: return d for key in self.key_iterator(d): - img_min, img_max = d[key].min(), d[key].max() + img_np: np.ndarray + img_np, orig_type, orig_device = convert_data_type(d[key], np.ndarray) # type: ignore + img_min, img_max = img_np.min(), img_np.max() reference_control_points_scaled = self.reference_control_points * (img_max - img_min) + img_min floating_control_points_scaled = self.floating_control_points * (img_max - img_min) + img_min - dtype = d[key].dtype - d[key] = np.interp(d[key], reference_control_points_scaled, floating_control_points_scaled).astype(dtype) + + img_np = np.interp(img_np, reference_control_points_scaled, floating_control_points_scaled) + out, *_ = convert_data_type(img_np, orig_type, orig_device, dtype=d[key].dtype) + d[key] = out return d @@ -1054,7 +1067,6 @@ class RandGibbsNoised(RandomizableTransform, MapTransform): values in the interval [0,1] with alpha = 0 acting as the identity mapping. If a length-2 list is given as [a,b] then the value of alpha will be sampled uniformly from the interval [a,b]. - as_tensor_output: if true return torch.Tensor, else return np.array. default: True. allow_missing_keys: do not raise exception if key is missing. """ @@ -1063,7 +1075,6 @@ def __init__( keys: KeysCollection, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0), - as_tensor_output: bool = True, allow_missing_keys: bool = False, ) -> None: @@ -1071,11 +1082,8 @@ def __init__( RandomizableTransform.__init__(self, prob=prob) self.alpha = alpha self.sampled_alpha = -1.0 # stores last alpha sampled by randomize() - self.as_tensor_output = as_tensor_output - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) self._randomize(None) @@ -1083,13 +1091,8 @@ def __call__( for i, key in enumerate(self.key_iterator(d)): if self._do_transform: if i == 0: - transform = GibbsNoise(self.sampled_alpha, self.as_tensor_output) + transform = GibbsNoise(self.sampled_alpha) d[key] = transform(d[key]) - else: - if isinstance(d[key], np.ndarray) and self.as_tensor_output: - d[key] = torch.Tensor(d[key]) - elif isinstance(d[key], torch.Tensor) and not self.as_tensor_output: - d[key] = self._to_numpy(d[key]) return d def _randomize(self, _: Any) -> None: @@ -1100,7 +1103,7 @@ def _randomize(self, _: Any) -> None: super().randomize(None) self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1]) - def _to_numpy(self, d: Union[torch.Tensor, np.ndarray]) -> np.ndarray: + def _to_numpy(self, d: DataObjects.Images) -> np.ndarray: if isinstance(d, torch.Tensor): d_numpy: np.ndarray = d.cpu().detach().numpy() return d_numpy @@ -1122,20 +1125,15 @@ class GibbsNoised(MapTransform): you need to transform. alpha (float): Parametrizes the intensity of the Gibbs noise filter applied. Takes values in the interval [0,1] with alpha = 0 acting as the identity mapping. - as_tensor_output: if true return torch.Tensor, else return np.array. default: True. allow_missing_keys: do not raise exception if key is missing. """ - def __init__( - self, keys: KeysCollection, alpha: float = 0.5, as_tensor_output: bool = True, allow_missing_keys: bool = False - ) -> None: + def __init__(self, keys: KeysCollection, alpha: float = 0.5, allow_missing_keys: bool = False) -> None: MapTransform.__init__(self, keys, allow_missing_keys) - self.transform = GibbsNoise(alpha, as_tensor_output) + self.transform = GibbsNoise(alpha) - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) for key in self.key_iterator(d): @@ -1174,8 +1172,6 @@ class KSpaceSpikeNoised(MapTransform): receive a sequence of intensities. This value should be tested as it is data-dependent. The default values are the 2.5 the mean of the log-intensity for each channel. - as_tensor_output: if ``True`` return torch.Tensor, else return np.array. - Default: ``True``. allow_missing_keys: do not raise exception if key is missing. Example: @@ -1191,16 +1187,13 @@ def __init__( keys: KeysCollection, loc: Union[Tuple, Sequence[Tuple]], k_intensity: Optional[Union[Sequence[float], float]] = None, - as_tensor_output: bool = True, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.transform = KSpaceSpikeNoise(loc, k_intensity, as_tensor_output) + self.transform = KSpaceSpikeNoise(loc, k_intensity) - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: """ Args: data: Expects image/label to have dimensions (C, H, W) or @@ -1249,8 +1242,6 @@ class RandKSpaceSpikeNoised(RandomizableTransform, MapTransform): common_sampling: If ``True`` same values for location and log-intensity will be sampled for the image and label. common_seed: Seed to be used in case ``common_sampling = True``. - as_tensor_output: if ``True`` return torch.Tensor, else return - np.array. Default: ``True``. allow_missing_keys: do not raise exception if key is missing. Example: @@ -1270,7 +1261,6 @@ def __init__( channel_wise: bool = True, common_sampling: bool = False, common_seed: int = 42, - as_tensor_output: bool = True, allow_missing_keys: bool = False, ): @@ -1279,14 +1269,11 @@ def __init__( self.common_sampling = common_sampling self.common_seed = common_seed - self.as_tensor_output = as_tensor_output # the spikes artifact is amplitude dependent so we instantiate one per key - self.t_img = RandKSpaceSpikeNoise(prob, img_intensity_range, channel_wise, self.as_tensor_output) - self.t_label = RandKSpaceSpikeNoise(prob, label_intensity_range, channel_wise, self.as_tensor_output) + self.t_img = RandKSpaceSpikeNoise(prob, img_intensity_range, channel_wise) + self.t_label = RandKSpaceSpikeNoise(prob, label_intensity_range, channel_wise) - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: """ Args: data: Expects image/label to have dimensions (C, H, W) or @@ -1304,11 +1291,6 @@ def __call__( if self._do_transform: transform = self.t_img if key == "image" else self.t_label d[key] = transform(d[key]) - else: - if isinstance(d[key], np.ndarray) and self.as_tensor_output: - d[key] = torch.Tensor(d[key]) - elif isinstance(d[key], torch.Tensor) and not self.as_tensor_output: - d[key] = self._to_numpy(d[key]) return d def set_rand_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> None: diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 5d6b4d87fd..58f3526086 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -9,9 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Hashable, Optional, Tuple +from typing import Hashable, Optional, Tuple -import numpy as np import torch from monai.transforms.transform import RandomizableTransform, Transform @@ -113,7 +112,7 @@ def pop_transform(self, data: dict, key: Hashable) -> None: """Remove most recent transform.""" data[str(key) + InverseKeys.KEY_SUFFIX].pop() - def inverse(self, data: dict) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: dict) -> dict: """ Inverse of ``__call__``. diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 2c1a3c89ff..557baca470 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -23,10 +23,12 @@ from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader from monai.data.nifti_saver import NiftiSaver from monai.data.png_saver import PNGSaver -from monai.transforms.transform import Transform +from monai.transforms.transform import NumpyTransform, TorchTransform, Transform from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key from monai.utils import InterpolateMode, ensure_tuple, optional_import +from monai.utils.enums import DataObjects +from monai.utils.misc import convert_data_type from monai.utils.module import look_up_option nib, _ = optional_import("nibabel") @@ -80,6 +82,7 @@ def __init__( reader: Optional[Union[ImageReader, str]] = None, image_only: bool = False, dtype: DtypeLike = np.float32, + as_tensor: bool = True, *args, **kwargs, ) -> None: @@ -91,6 +94,7 @@ def __init__( "PILReader", "ITKReader", "NumpyReader". image_only: if True return only the image volume, otherwise return image data array and header dict. dtype: if not None convert the loaded image to this data type. + as_tensor: output as `torch.Tensor`, defaults to `True`. args: additional parameters for reader if providing a reader name. kwargs: additional parameters for reader if providing a reader name. @@ -116,6 +120,7 @@ def __init__( self.image_only = image_only self.dtype = dtype + self.as_tensor = as_tensor def register(self, reader: ImageReader) -> List[ImageReader]: """ @@ -160,9 +165,14 @@ def __call__( ) img = reader.read(filename) + img_array: DataObjects.Images img_array, meta_data = reader.get_data(img) img_array = img_array.astype(self.dtype) + # convert to desired output type + if self.as_tensor: + img_array, *_ = convert_data_type(img_array, torch.Tensor) + if self.image_only: return img_array meta_data[Key.FILENAME_OR_OBJ] = ensure_tuple(filename)[0] @@ -172,7 +182,7 @@ def __call__( return img_array, meta_data -class SaveImage(Transform): +class SaveImage(TorchTransform, NumpyTransform): """ Save transformed data into files, support NIfTI and PNG formats. It can work for both numpy array and PyTorch Tensor in both preprocessing transform @@ -286,7 +296,7 @@ def __init__( else: raise ValueError(f"unsupported output extension: {output_ext}.") - def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None): + def __call__(self, img: DataObjects.Images, meta_data: Optional[Dict] = None): """ Args: img: target data content that save into file. diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index db043848c7..0a95f065af 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -18,6 +18,7 @@ from typing import Optional, Union import numpy as np +import torch from monai.config import DtypeLike, KeysCollection from monai.data.image_reader import ImageReader @@ -61,6 +62,7 @@ def __init__( meta_key_postfix: str = "meta_dict", overwriting: bool = False, image_only: bool = False, + as_tensor: bool = True, allow_missing_keys: bool = False, *args, **kwargs, @@ -85,12 +87,13 @@ def __init__( default is False, which will raise exception if encountering existing key. image_only: if True return dictionary containing just only the image volumes, otherwise return dictionary containing image data array and header dict per input key. + as_tensor: output as `torch.Tensor`, defaults to `True`. allow_missing_keys: don't raise exception if key is missing. args: additional parameters for reader if providing a reader name. kwargs: additional parameters for reader if providing a reader name. """ super().__init__(keys, allow_missing_keys) - self._loader = LoadImage(reader, image_only, dtype, *args, **kwargs) + self._loader = LoadImage(reader, image_only, dtype, as_tensor, *args, **kwargs) if not isinstance(meta_key_postfix, str): raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.") self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) @@ -112,7 +115,7 @@ def __call__(self, data, reader: Optional[ImageReader] = None): for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): data = self._loader(d[key], reader) if self._loader.image_only: - if not isinstance(data, np.ndarray): + if not isinstance(data, (torch.Tensor, np.ndarray)): raise ValueError("loader must return a numpy array (because image_only=True was used).") d[key] = data else: diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 397b14e2e2..47c0fae741 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -14,7 +14,8 @@ """ import warnings -from typing import Callable, Optional, Sequence, Union +from copy import deepcopy +from typing import Callable, List, Optional, Sequence, Union import numpy as np import torch @@ -22,9 +23,11 @@ from monai.networks import one_hot from monai.networks.layers import GaussianFilter -from monai.transforms.transform import Transform +from monai.transforms.transform import NumpyTransform, TorchTransform from monai.transforms.utils import get_largest_connected_component_mask from monai.utils import ensure_tuple +from monai.utils.enums import DataObjects +from monai.utils.misc import convert_data_type __all__ = [ "Activations", @@ -37,7 +40,24 @@ ] -class Activations(Transform): +def _sigmoid(z): + if isinstance(z, torch.Tensor): + return torch.sigmoid(z) + return 1 / (1 + np.exp(-z)) + + +def _softmax(z, dim): + if isinstance(z, torch.Tensor): + return torch.softmax(z, dim=dim) + + max = np.max(z, axis=dim, keepdims=True) # returns max of each row and keeps same dims + e_x = np.exp(z - max) # subtracts each row with its max value + sum = np.sum(e_x, axis=dim, keepdims=True) # returns sum of each row and keeps same dims + f_x = e_x / sum + return f_x + + +class Activations(TorchTransform, NumpyTransform): """ Add activation operations to the model output, typically `Sigmoid` or `Softmax`. @@ -63,11 +83,11 @@ def __init__(self, sigmoid: bool = False, softmax: bool = False, other: Optional def __call__( self, - img: torch.Tensor, + img: DataObjects.Images, sigmoid: Optional[bool] = None, softmax: Optional[bool] = None, other: Optional[Callable] = None, - ) -> torch.Tensor: + ) -> DataObjects.Images: """ Args: sigmoid: whether to execute sigmoid function on model output before transform. @@ -89,20 +109,26 @@ def __call__( raise TypeError(f"other must be None or callable but is {type(other).__name__}.") # convert to float as activation must operate on float tensor - img = img.float() if sigmoid or self.sigmoid: - img = torch.sigmoid(img) + img = _sigmoid(img) if softmax or self.softmax: - img = torch.softmax(img, dim=0) + img = _softmax(img, dim=0) act_func = self.other if other is None else other if act_func is not None: - img = act_func(img) + try: + img = act_func(img) + except TypeError as te: + # callable only works on torch.Tensors + if "must be Tensor, not numpy.ndarray" in str(te): + img, *_ = convert_data_type(img, torch.Tensor) + img = act_func(img) + img, *_ = convert_data_type(img, np.ndarray) return img -class AsDiscrete(Transform): +class AsDiscrete(TorchTransform, NumpyTransform): """ Execute after model forward to transform model output to discrete values. It can complete below operations: @@ -141,13 +167,13 @@ def __init__( def __call__( self, - img: torch.Tensor, + img: DataObjects.Images, argmax: Optional[bool] = None, to_onehot: Optional[bool] = None, n_classes: Optional[int] = None, threshold_values: Optional[bool] = None, logit_thresh: Optional[float] = None, - ) -> torch.Tensor: + ) -> DataObjects.Images: """ Args: img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`, @@ -165,7 +191,10 @@ def __call__( """ if argmax or self.argmax: - img = torch.argmax(img, dim=0, keepdim=True) + if isinstance(img, torch.Tensor): + img = torch.argmax(img, dim=0, keepdim=True) + else: + img = np.argmax(img, axis=0)[None] if to_onehot or self.to_onehot: _nclasses = self.n_classes if n_classes is None else n_classes @@ -174,12 +203,13 @@ def __call__( img = one_hot(img, num_classes=_nclasses, dim=0) if threshold_values or self.threshold_values: - img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh) + img = img >= (logit_thresh or self.logit_thresh) - return img.float() + out, *_ = convert_data_type(img, dtype=torch.float32) + return out -class KeepLargestConnectedComponent(Transform): +class KeepLargestConnectedComponent(TorchTransform): """ Keeps only the largest connected component in the image. This transform can be used as a post-processing step to clean up over-segment areas in model output. @@ -245,7 +275,12 @@ def __init__( self.independent = independent self.connectivity = connectivity - def __call__(self, img: torch.Tensor) -> torch.Tensor: + @staticmethod + def _astype(x, dtype=torch.uint8): + x, *_ = convert_data_type(x, dtype=dtype) + return x + + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: img: shape must be (C, spatial_dim1[, spatial_dim2, ...]). @@ -254,42 +289,45 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: A PyTorch Tensor with shape (C, spatial_dim1[, spatial_dim2, ...]). """ if img.shape[0] == 1: - img = torch.squeeze(img, dim=0) + img = img.squeeze(0) if self.independent: for i in self.applied_labels: - foreground = (img == i).type(torch.uint8) + foreground = self._astype(img == i) mask = get_largest_connected_component_mask(foreground, self.connectivity) img[foreground != mask] = 0 else: - foreground = torch.zeros_like(img) + foreground = torch.zeros_like(img) if isinstance(img, torch.Tensor) else np.zeros_like(img) for i in self.applied_labels: - foreground += (img == i).type(torch.uint8) + foreground += self._astype(img == i) mask = get_largest_connected_component_mask(foreground, self.connectivity) img[foreground != mask] = 0 - output = torch.unsqueeze(img, dim=0) + output = img[None] else: # one-hot data is assumed to have binary value in each channel if self.independent: for i in self.applied_labels: - foreground = img[i, ...].type(torch.uint8) + foreground = self._astype(img[i, ...]) mask = get_largest_connected_component_mask(foreground, self.connectivity) img[i, ...][foreground != mask] = 0 else: - applied_img = img[self.applied_labels, ...].type(torch.uint8) - foreground = torch.any(applied_img, dim=0) + applied_img = self._astype(img[self.applied_labels, ...]) + foreground = applied_img.any(0) mask = get_largest_connected_component_mask(foreground, self.connectivity) - background_mask = torch.unsqueeze(foreground != mask, dim=0) - background_mask = torch.repeat_interleave(background_mask, len(self.applied_labels), dim=0) + background_mask = (foreground != mask)[None] + if isinstance(background_mask, torch.Tensor): + background_mask = torch.repeat_interleave(background_mask, len(self.applied_labels), dim=0) + else: + background_mask = np.repeat(background_mask, len(self.applied_labels), axis=0) applied_img[background_mask] = 0 - img[self.applied_labels, ...] = applied_img.type(img.type()) + img[self.applied_labels, ...] = self._astype(applied_img, img.dtype) output = img return output -class LabelToContour(Transform): +class LabelToContour(TorchTransform): """ Return the contour of binary input images that only compose of 0 and 1, with Laplace kernel set as default for edge detection. Typical usage is to plot the edge of label or segmentation output. @@ -307,7 +345,7 @@ def __init__(self, kernel_type: str = "Laplace") -> None: raise NotImplementedError('Currently only kernel_type="Laplace" is supported.') self.kernel_type = kernel_type - def __call__(self, img: torch.Tensor) -> torch.Tensor: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: img: torch tensor data to extract the contour, with shape: [channels, height, width[, depth]] @@ -323,25 +361,31 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: ideally the edge should be thin enough, but now it has a thickness. """ - channels = img.shape[0] - img_ = img.unsqueeze(0) - if img.ndimension() == 3: - kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32, device=img.device) + img_t: torch.Tensor + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor) # type: ignore + + channels = img_t.shape[0] + img_ = img_t.unsqueeze(0) + if img_t.ndimension() == 3: + kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32, device=img_t.device) kernel = kernel.repeat(channels, 1, 1, 1) contour_img = F.conv2d(img_, kernel, bias=None, stride=1, padding=1, dilation=1, groups=channels) - elif img.ndimension() == 4: - kernel = -1 * torch.ones(3, 3, 3, dtype=torch.float32, device=img.device) + elif img_t.ndimension() == 4: + kernel = -1 * torch.ones(3, 3, 3, dtype=torch.float32, device=img_t.device) kernel[1, 1, 1] = 26 kernel = kernel.repeat(channels, 1, 1, 1, 1) contour_img = F.conv3d(img_, kernel, bias=None, stride=1, padding=1, dilation=1, groups=channels) else: - raise ValueError(f"Unsupported img dimension: {img.ndimension()}, available options are [4, 5].") + raise ValueError(f"Unsupported img dimension: {img_t.ndimension()}, available options are [4, 5].") contour_img.clamp_(min=0.0, max=1.0) - return contour_img.squeeze(0) + contour_img = contour_img.squeeze(0) + + out, *_ = convert_data_type(contour_img, orig_type, orig_device) + return out -class MeanEnsemble(Transform): +class MeanEnsemble(TorchTransform, NumpyTransform): """ Execute mean ensemble on the input data. The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]], @@ -364,24 +408,41 @@ class MeanEnsemble(Transform): """ - def __init__(self, weights: Optional[Union[Sequence[float], torch.Tensor, np.ndarray]] = None) -> None: - self.weights = torch.as_tensor(weights, dtype=torch.float) if weights is not None else None + def __init__(self, weights: Optional[Union[Sequence[float], DataObjects.Images]] = None) -> None: + if weights is None: + self.weights = None + elif isinstance(weights, (torch.Tensor, np.ndarray)): + self.weights = weights + else: + self.weights = torch.as_tensor(weights, dtype=torch.float) + + def __call__(self, img: Union[Sequence[DataObjects.Images], DataObjects.Images]) -> DataObjects.Images: + if isinstance(img, (torch.Tensor, np.ndarray)): + img_ = img + elif isinstance(img[0], torch.Tensor): + img_ = torch.stack(img) # type: ignore + else: + img_ = np.stack(img) - def __call__(self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Tensor: - img_ = torch.stack(img) if isinstance(img, (tuple, list)) else torch.as_tensor(img) if self.weights is not None: - self.weights = self.weights.to(img_.device) + self.weights, *_ = convert_data_type( + self.weights, type(img_), device=img_.device if isinstance(img_, torch.Tensor) else None + ) shape = tuple(self.weights.shape) - for _ in range(img_.ndimension() - self.weights.ndimension()): + for _ in range(img_.ndim - self.weights.ndim): shape += (1,) weights = self.weights.reshape(*shape) - img_ = img_ * weights / weights.mean(dim=0, keepdim=True) + if isinstance(img_, torch.Tensor): + # torch can only do the mean on floats + img_ = img_ * weights / weights.float().mean(dim=0, keepdim=True) # type: ignore + else: + img_ = img_ * weights / weights.mean(axis=0, keepdims=True) # type: ignore - return torch.mean(img_, dim=0) + return img_.mean(0) # type: ignore -class VoteEnsemble(Transform): +class VoteEnsemble(TorchTransform, NumpyTransform): """ Execute vote ensemble on the input data. The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]], @@ -404,28 +465,39 @@ class VoteEnsemble(Transform): def __init__(self, num_classes: Optional[int] = None) -> None: self.num_classes = num_classes - def __call__(self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Tensor: - img_ = torch.stack(img) if isinstance(img, (tuple, list)) else torch.as_tensor(img) + def __call__(self, img: Union[Sequence[DataObjects.Images], DataObjects.Images]) -> DataObjects.Images: + if isinstance(img, (torch.Tensor, np.ndarray)): + img_ = img + elif isinstance(img[0], torch.Tensor): + img_ = torch.stack(img) # type: ignore + else: + img_ = np.stack(img) + if self.num_classes is not None: has_ch_dim = True - if img_.ndimension() > 1 and img_.shape[1] > 1: + if img_.ndim > 1 and img_.shape[1] > 1: warnings.warn("no need to specify num_classes for One-Hot format data.") else: - if img_.ndimension() == 1: + if img_.ndim == 1: # if no channel dim, need to remove channel dim after voting has_ch_dim = False img_ = one_hot(img_, self.num_classes, dim=1) - img_ = torch.mean(img_.float(), dim=0) + img_ = torch.mean(img_.float(), dim=0) if isinstance(img_, torch.Tensor) else np.mean(img_, axis=0) if self.num_classes is not None: # if not One-Hot, use "argmax" to vote the most common class - return torch.argmax(img_, dim=0, keepdim=has_ch_dim) + if isinstance(img_, torch.Tensor): + return torch.argmax(img_, dim=0, keepdim=has_ch_dim) + else: + img_ = np.argmax(img_, axis=0) + img_ = np.array(img_) if np.isscalar(img_) else img_ # numpy returns scalar if input was 1d + return img_[None] if has_ch_dim else img_ # for One-Hot data, round the float number to 0 or 1 - return torch.round(img_) + return torch.round(img_) if isinstance(img_, torch.Tensor) else np.round(img_) -class ProbNMS(Transform): +class ProbNMS(TorchTransform): """ Performs probability based non-maximum suppression (NMS) on the probabilities map via iteratively selecting the coordinate with highest probability and then move it as well @@ -484,29 +556,23 @@ def __init__( def __call__( self, - prob_map: Union[np.ndarray, torch.Tensor], - ): + prob_map: DataObjects.Images, + ) -> List[List]: """ prob_map: the input probabilities map, it must have shape (H[, W, ...]). """ + prob_map_t: torch.Tensor + prob_map_t, *_ = convert_data_type(deepcopy(prob_map), torch.Tensor, dtype=torch.float32) # type: ignore if self.sigma != 0: - if not isinstance(prob_map, torch.Tensor): - prob_map = torch.as_tensor(prob_map, dtype=torch.float) - self.filter.to(prob_map) - prob_map = self.filter(prob_map) - else: - if not isinstance(prob_map, torch.Tensor): - prob_map = prob_map.copy() - - if isinstance(prob_map, torch.Tensor): - prob_map = prob_map.detach().cpu().numpy() + self.filter.to(prob_map_t) + prob_map_t = self.filter(prob_map_t) - prob_map_shape = prob_map.shape + prob_map_shape = prob_map_t.shape outputs = [] - while np.max(prob_map) > self.prob_threshold: - max_idx = np.unravel_index(prob_map.argmax(), prob_map_shape) - prob_max = prob_map[max_idx] + while prob_map_t.max() > self.prob_threshold: + max_idx = np.unravel_index(prob_map_t.argmax().cpu(), prob_map_shape) + prob_max = prob_map_t[max_idx].item() max_idx_arr = np.asarray(max_idx) outputs.append([prob_max] + list(max_idx_arr)) @@ -514,6 +580,6 @@ def __call__( idx_max_range = (max_idx_arr + self.box_upper_bd).clip(None, prob_map_shape) # for each dimension, set values during index ranges to 0 slices = tuple(slice(idx_min_range[i], idx_max_range[i]) for i in range(self.spatial_dims)) - prob_map[slices] = 0 + prob_map_t[slices] = 0 return outputs diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 6cba08948b..0a8c30b5c8 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -17,9 +17,8 @@ import warnings from copy import deepcopy -from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Union +from typing import Callable, List, Optional, Sequence, Union -import numpy as np import torch from monai.config import KeysCollection @@ -38,7 +37,7 @@ from monai.transforms.utility.array import ToTensor from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode from monai.utils import ensure_tuple, ensure_tuple_rep -from monai.utils.enums import InverseKeys +from monai.utils.enums import DataObjects, InverseKeys __all__ = [ "Activationsd", @@ -106,7 +105,7 @@ def __init__( self.other = ensure_tuple_rep(other, len(self.keys)) self.converter = Activations() - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Mapping: d = dict(data) for key, sigmoid, softmax, other in self.key_iterator(d, self.sigmoid, self.softmax, self.other): d[key] = self.converter(d[key], sigmoid, softmax, other) @@ -153,7 +152,7 @@ def __init__( self.logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys)) self.converter = AsDiscrete() - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, argmax, to_onehot, n_classes, threshold_values, logit_thresh in self.key_iterator( d, self.argmax, self.to_onehot, self.n_classes, self.threshold_values, self.logit_thresh @@ -201,7 +200,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = KeepLargestConnectedComponent(applied_labels, independent, connectivity) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Mapping: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -225,7 +224,7 @@ def __init__(self, keys: KeysCollection, kernel_type: str = "Laplace", allow_mis super().__init__(keys, allow_missing_keys) self.converter = LabelToContour(kernel_type=kernel_type) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -241,7 +240,7 @@ class Ensembled(MapTransform): def __init__( self, keys: KeysCollection, - ensemble: Callable[[Union[Sequence[torch.Tensor], torch.Tensor]], torch.Tensor], + ensemble: Callable[[Union[Sequence[DataObjects.Images], DataObjects.Images]], DataObjects.Images], output_key: Optional[str] = None, allow_missing_keys: bool = False, ) -> None: @@ -267,9 +266,9 @@ def __init__( raise ValueError("Incompatible values: len(self.keys) > 1 and output_key=None.") self.output_key = output_key if output_key is not None else self.keys[0] - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) - items: Union[List[torch.Tensor], torch.Tensor] + items: Union[List[DataObjects.Images], DataObjects.Images] if len(self.keys) == 1: items = d[self.keys[0]] else: @@ -288,7 +287,7 @@ def __init__( self, keys: KeysCollection, output_key: Optional[str] = None, - weights: Optional[Union[Sequence[float], torch.Tensor, np.ndarray]] = None, + weights: Optional[Union[Sequence[float], DataObjects.Images]] = None, ) -> None: """ Args: @@ -382,7 +381,7 @@ def __init__( box_size=box_size, ) - def __call__(self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]): + def __call__(self, data: DataObjects.Mapping): d = dict(data) for key in self.key_iterator(d): d[key] = self.prob_nms(d[key]) @@ -483,7 +482,7 @@ def __init__( self.post_func = ensure_tuple_rep(post_func, len(self.keys)) self._totensor = ToTensor() - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for ( key, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index d9c10cf9c0..9eda4ced83 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -14,6 +14,7 @@ """ import warnings +from copy import deepcopy from math import ceil from typing import Any, List, Optional, Sequence, Tuple, Union @@ -23,8 +24,8 @@ from monai.config import USE_COMPILED, DtypeLike from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull -from monai.transforms.croppad.array import CenterSpatialCrop -from monai.transforms.transform import Randomizable, RandomizableTransform, ThreadUnsafe, Transform +from monai.transforms.croppad.array import CenterSpatialCrop, Pad +from monai.transforms.transform import NumpyTransform, Randomizable, RandomizableTransform, ThreadUnsafe, TorchTransform from monai.transforms.utils import ( create_control_grid, create_grid, @@ -46,6 +47,8 @@ issequenceiterable, optional_import, ) +from monai.utils.enums import DataObjects +from monai.utils.misc import convert_data_type from monai.utils.module import look_up_option nib, _ = optional_import("nibabel") @@ -77,7 +80,7 @@ RandRange = Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] -class Spacing(Transform): +class Spacing(TorchTransform): """ Resample input image into the specified `pixdim`. """ @@ -133,14 +136,14 @@ def __init__( def __call__( self, - data_array: np.ndarray, - affine: Optional[np.ndarray] = None, + data_array: DataObjects.Images, + affine: Optional[DataObjects.Images] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, dtype: DtypeLike = None, output_spatial_shape: Optional[np.ndarray] = None, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> Tuple[DataObjects.Images, DataObjects.Images, DataObjects.Images]: """ Args: data_array: in shape (num_channels, H[, W, ...]). @@ -169,7 +172,10 @@ def __call__( """ _dtype = dtype or self.dtype or data_array.dtype - sr = data_array.ndim - 1 + data_array_torch: torch.Tensor + data_array_torch, orig_type, orig_device = convert_data_type(data_array, torch.Tensor, dtype=_dtype) # type: ignore + + sr = data_array_torch.ndim - 1 if sr <= 0: raise ValueError("data_array must have at least one spatial dimension.") if affine is None: @@ -177,7 +183,8 @@ def __call__( affine = np.eye(sr + 1, dtype=np.float64) affine_ = np.eye(sr + 1, dtype=np.float64) else: - affine_ = to_affine_nd(sr, affine) + affine_ = to_affine_nd(sr, affine) # type: ignore + affine_, *_ = convert_data_type(affine_, np.ndarray) # type: ignore out_d = self.pixdim[:sr] if out_d.size < sr: @@ -185,39 +192,39 @@ def __call__( # compute output affine, shape and offset new_affine = zoom_affine(affine_, out_d, diagonal=self.diagonal) - output_shape, offset = compute_shape_offset(data_array.shape[1:], affine_, new_affine) + output_shape, offset = compute_shape_offset(data_array_torch.shape[1:], affine_, new_affine) new_affine[:sr, -1] = offset[:sr] - transform = np.linalg.inv(affine_) @ new_affine + affine_inv = np.linalg.inv(affine_) + transform = affine_inv @ new_affine # adapt to the actual rank transform = to_affine_nd(sr, transform) # no resampling if it's identity transform if np.allclose(transform, np.diag(np.ones(len(transform))), atol=1e-3): - output_data = data_array.copy().astype(np.float32) + output_data, *_ = convert_data_type(deepcopy(data_array), dtype=_dtype) + new_affine = to_affine_nd(affine, new_affine) + else: + # resample + affine_xform = AffineTransform( + normalized=False, + mode=look_up_option(mode or self.mode, GridSampleMode), + padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode), + align_corners=self.align_corners if align_corners is None else align_corners, + reverse_indexing=True, + ) + output_data = affine_xform( + # AffineTransform requires a batch dim + data_array_torch.unsqueeze(0), + convert_data_type(transform, torch.Tensor, data_array_torch.device, dtype=_dtype)[0], + spatial_size=output_shape if output_spatial_shape is None else output_spatial_shape, + ).squeeze(0) + output_data, *_ = convert_data_type(output_data, orig_type, dtype=np.float32) # type: ignore new_affine = to_affine_nd(affine, new_affine) - return output_data, affine, new_affine - - # resample - affine_xform = AffineTransform( - normalized=False, - mode=look_up_option(mode or self.mode, GridSampleMode), - padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode), - align_corners=self.align_corners if align_corners is None else align_corners, - reverse_indexing=True, - ) - output_data = affine_xform( - # AffineTransform requires a batch dim - torch.as_tensor(np.ascontiguousarray(data_array).astype(_dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(transform).astype(_dtype)), - spatial_size=output_shape if output_spatial_shape is None else output_spatial_shape, - ) - output_data = np.asarray(output_data.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # type: ignore - new_affine = to_affine_nd(affine, new_affine) return output_data, affine, new_affine -class Orientation(Transform): +class Orientation(NumpyTransform): """ Change the input image's orientation into the specified based on `axcodes`. """ @@ -255,8 +262,8 @@ def __init__( self.labels = labels def __call__( - self, data_array: np.ndarray, affine: Optional[np.ndarray] = None - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + self, data_array: DataObjects.Images, affine: Optional[DataObjects.Images] = None + ) -> Tuple[DataObjects.Images, DataObjects.Images, np.ndarray]: """ original orientation of `data_array` is defined by `affine`. @@ -272,14 +279,17 @@ def __call__( data_array (reoriented in `self.axcodes`), original axcodes, current axcodes. """ - sr = data_array.ndim - 1 + data_np: np.ndarray + data_np, orig_type, orig_device = convert_data_type(data_array, np.ndarray) # type: ignore + sr = data_np.ndim - 1 if sr <= 0: raise ValueError("data_array must have at least one spatial dimension.") if affine is None: affine = np.eye(sr + 1, dtype=np.float64) affine_ = np.eye(sr + 1, dtype=np.float64) else: - affine_ = to_affine_nd(sr, affine) + affine, *_ = convert_data_type(affine, np.ndarray) + affine_ = to_affine_nd(sr, affine) # type: ignore src = nib.io_orientation(affine_) if self.as_closest_canonical: spatial_ornt = src @@ -295,15 +305,17 @@ def __call__( ornt = spatial_ornt.copy() ornt[:, 0] += 1 # skip channel dim ornt = np.concatenate([np.array([[0, 1]]), ornt]) - shape = data_array.shape[1:] - data_array = np.ascontiguousarray(nib.orientations.apply_orientation(data_array, ornt)) + shape = data_np.shape[1:] + data_np = np.ascontiguousarray(nib.orientations.apply_orientation(data_np, ornt)) new_affine = affine_ @ nib.orientations.inv_ornt_aff(spatial_ornt, shape) new_affine = to_affine_nd(affine, new_affine) - return data_array, affine, new_affine + data_out, *_ = convert_data_type(data_np, orig_type, orig_device) + + return data_out, affine, new_affine -class Flip(Transform): +class Flip(TorchTransform): """ Reverses the order of elements along the given spatial axis. Preserves shape. Uses ``np.flip`` in practice. See numpy.flip for additional details: @@ -321,17 +333,21 @@ class Flip(Transform): def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: self.spatial_axis = spatial_axis - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ + img_t: torch.Tensor + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor) # type: ignore + + result_t = torch.flip(img_t, map_spatial_axes(img.ndim, self.spatial_axis)) - result: np.ndarray = np.flip(img, map_spatial_axes(img.ndim, self.spatial_axis)) - return result.astype(img.dtype) + result, *_ = convert_data_type(result_t, orig_type, orig_device) + return result -class Resize(Transform): +class Resize(TorchTransform): """ Resize the input image to given spatial size (with scaling, not cropping/padding). Implemented using :py:class:`torch.nn.functional.interpolate`. @@ -368,10 +384,10 @@ def __init__( def __call__( self, - img: np.ndarray, + img: DataObjects.Images, mode: Optional[Union[InterpolateMode, str]] = None, align_corners: Optional[bool] = None, - ) -> np.ndarray: + ) -> DataObjects.Images: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]). @@ -386,35 +402,38 @@ def __call__( ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. """ + img_t: torch.Tensor + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor, dtype=float) # type: ignore if self.size_mode == "all": - input_ndim = img.ndim - 1 # spatial ndim + input_ndim = img_t.ndim - 1 # spatial ndim output_ndim = len(ensure_tuple(self.spatial_size)) if output_ndim > input_ndim: - input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) - img = img.reshape(input_shape) + input_shape = ensure_tuple_size(img_t.shape, output_ndim + 1, 1) + img_t = img_t.reshape(input_shape) elif output_ndim < input_ndim: raise ValueError( "len(spatial_size) must be greater or equal to img spatial dimensions, " f"got spatial_size={output_ndim} img={input_ndim}." ) - spatial_size_ = fall_back_tuple(self.spatial_size, img.shape[1:]) - else: # for the "longest" mode + spatial_size_ = fall_back_tuple(self.spatial_size, img_t.shape[1:]) + else: img_size = img.shape[1:] if not isinstance(self.spatial_size, int): raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") scale = self.spatial_size / max(img_size) spatial_size_ = tuple(ceil(s * scale) for s in img_size) - resized = torch.nn.functional.interpolate( # type: ignore - input=torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0), + resized = torch.nn.functional.interpolate( + input=img_t.unsqueeze(0), size=spatial_size_, mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value, align_corners=self.align_corners if align_corners is None else align_corners, ) - resized = resized.squeeze(0).detach().cpu().numpy() - return np.asarray(resized) + resized = resized.squeeze(0) + out, *_ = convert_data_type(resized, orig_type, orig_device) + return out -class Rotate(Transform, ThreadUnsafe): +class Rotate(TorchTransform, ThreadUnsafe): """ Rotates an input image by given angle using :py:class:`monai.networks.layers.AffineTransform`. @@ -455,12 +474,12 @@ def __init__( def __call__( self, - img: np.ndarray, + img: DataObjects.Images, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, dtype: DtypeLike = None, - ) -> np.ndarray: + ) -> DataObjects.Images: """ Args: img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D]. @@ -483,7 +502,10 @@ def __call__( """ _dtype = dtype or self.dtype or img.dtype - im_shape = np.asarray(img.shape[1:]) # spatial dimensions + img_t: torch.Tensor + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor, dtype=_dtype) # type: ignore + + im_shape = np.asarray(img_t.shape[1:]) # spatial dimensions input_ndim = len(im_shape) if input_ndim not in (2, 3): raise ValueError(f"Unsupported img dimension: {input_ndim}, available options are [2, 3].") @@ -500,6 +522,8 @@ def __call__( output_shape = np.asarray(corners.ptp(axis=1) + 0.5, dtype=int) shift_1 = create_translate(input_ndim, (-(output_shape - 1) / 2).tolist()) transform = shift @ transform @ shift_1 + transform_t: torch.Tensor + transform_t, *_ = convert_data_type(transform, torch.Tensor, dtype=_dtype, device=img_t.device) # type: ignore xform = AffineTransform( normalized=False, @@ -509,12 +533,13 @@ def __call__( reverse_indexing=True, ) output = xform( - torch.as_tensor(np.ascontiguousarray(img).astype(_dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(transform).astype(_dtype)), + img_t.unsqueeze(0), + transform_t, spatial_size=output_shape, ) self._rotation_matrix = transform - return np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + out, *_ = convert_data_type(output.squeeze(0).float(), orig_type, orig_device) + return out def get_rotation_matrix(self) -> Optional[np.ndarray]: """ @@ -524,7 +549,7 @@ def get_rotation_matrix(self) -> Optional[np.ndarray]: return self._rotation_matrix -class Zoom(Transform): +class Zoom(TorchTransform): """ Zooms an ND image using :py:class:`torch.nn.functional.interpolate`. For details, please see https://pytorch.org/docs/stable/nn.functional.html#interpolate. @@ -565,11 +590,11 @@ def __init__( def __call__( self, - img: np.ndarray, + img: DataObjects.Images, mode: Optional[Union[InterpolateMode, str]] = None, padding_mode: Optional[Union[NumpyPadMode, str]] = None, align_corners: Optional[bool] = None, - ): + ) -> DataObjects.Images: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]). @@ -585,34 +610,41 @@ def __call__( See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate """ - _zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim + img_t: torch.Tensor + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor) # type: ignore + + _zoom = ensure_tuple_rep(self.zoom, img_t.ndim - 1) # match the spatial image dim zoomed = torch.nn.functional.interpolate( # type: ignore recompute_scale_factor=True, - input=torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0), + input=img_t.float().unsqueeze(0), scale_factor=list(_zoom), mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value, align_corners=self.align_corners if align_corners is None else align_corners, ) - zoomed = zoomed.squeeze(0).detach().cpu().numpy() - if not self.keep_size or np.allclose(img.shape, zoomed.shape): - return zoomed - - pad_vec = [[0, 0]] * len(img.shape) - slice_vec = [slice(None)] * len(img.shape) - for idx, (od, zd) in enumerate(zip(img.shape, zoomed.shape)): - diff = od - zd - half = abs(diff) // 2 - if diff > 0: # need padding - pad_vec[idx] = [half, diff - half] - elif diff < 0: # need slicing - slice_vec[idx] = slice(half, half + od) - - padding_mode = look_up_option(self.padding_mode if padding_mode is None else padding_mode, NumpyPadMode) - zoomed = np.pad(zoomed, pad_vec, mode=padding_mode.value) # type: ignore - return zoomed[tuple(slice_vec)] - - -class Rotate90(Transform): + zoomed = zoomed.squeeze(0) + + if self.keep_size and not np.allclose(img_t.shape, zoomed.shape): + + pad_vec = [(0, 0)] * len(img_t.shape) + slice_vec = [slice(None)] * len(img_t.shape) + for idx, (od, zd) in enumerate(zip(img_t.shape, zoomed.shape)): + diff = od - zd + half = abs(diff) // 2 + if diff > 0: # need padding + pad_vec[idx] = (half, diff - half) + elif diff < 0: # need slicing + slice_vec[idx] = slice(half, half + od) + + padding_mode = look_up_option(padding_mode or self.padding_mode, NumpyPadMode) + padder = Pad(pad_vec, padding_mode) + zoomed = padder(zoomed) + zoomed = zoomed[tuple(slice_vec)] + + out, *_ = convert_data_type(zoomed, orig_type, orig_device) + return out + + +class Rotate90(TorchTransform, NumpyTransform): """ Rotate an array by 90 degrees in the plane specified by `axes`. See np.rot90 for additional details: @@ -634,17 +666,17 @@ def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: raise ValueError("spatial_axes must be 2 int numbers to indicate the axes to rotate 90 degrees.") self.spatial_axes = spatial_axes_ - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ + if isinstance(img, torch.Tensor): + return torch.rot90(img, self.k, map_spatial_axes(img.ndim, self.spatial_axes)).to(img.dtype) + return np.rot90(img, self.k, map_spatial_axes(img.ndim, self.spatial_axes)).astype(img.dtype) # type: ignore - result: np.ndarray = np.rot90(img, self.k, map_spatial_axes(img.ndim, self.spatial_axes)) - return result.astype(img.dtype) - -class RandRotate90(RandomizableTransform): +class RandRotate90(TorchTransform, NumpyTransform, RandomizableTransform): """ With probability `prob`, input arrays are rotated by 90 degrees in the plane specified by `spatial_axes`. @@ -669,7 +701,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self._rand_k = self.R.randint(self.max_k) + 1 super().randomize(None) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), @@ -681,7 +713,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: return rotator(img) -class RandRotate(RandomizableTransform): +class RandRotate(TorchTransform, RandomizableTransform): """ Randomly rotate the input arrays. @@ -750,12 +782,12 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__( self, - img: np.ndarray, + img: DataObjects.Images, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, dtype: DtypeLike = None, - ) -> np.ndarray: + ) -> DataObjects.Images: """ Args: img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D). @@ -780,12 +812,12 @@ def __call__( mode=look_up_option(mode or self.mode, GridSampleMode), padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode), align_corners=self.align_corners if align_corners is None else align_corners, - dtype=dtype or self.dtype or img.dtype, + dtype=dtype or self.dtype or img.dtype, # type: ignore ) - return np.array(rotator(img)) + return rotator(img) -class RandFlip(RandomizableTransform): +class RandFlip(TorchTransform, RandomizableTransform): """ Randomly flips the image along axes. Preserves shape. See numpy.flip for additional details. @@ -800,7 +832,7 @@ def __init__(self, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int] RandomizableTransform.__init__(self, prob) self.flipper = Flip(spatial_axis=spatial_axis) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), @@ -811,7 +843,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: return self.flipper(img) -class RandAxisFlip(RandomizableTransform): +class RandAxisFlip(TorchTransform, RandomizableTransform): """ Randomly select a spatial axis and flip along it. See numpy.flip for additional details. @@ -826,11 +858,11 @@ def __init__(self, prob: float = 0.1) -> None: RandomizableTransform.__init__(self, prob) self._axis: Optional[int] = None - def randomize(self, data: np.ndarray) -> None: + def randomize(self, data: DataObjects.Images) -> None: super().randomize(None) self._axis = self.R.randint(data.ndim - 1) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), @@ -842,7 +874,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: return flipper(img) -class RandZoom(RandomizableTransform): +class RandZoom(TorchTransform, RandomizableTransform): """ Randomly zooms input arrays with given probability within given zoom range. @@ -899,11 +931,11 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__( self, - img: np.ndarray, + img: DataObjects.Images, mode: Optional[Union[InterpolateMode, str]] = None, padding_mode: Optional[Union[NumpyPadMode, str]] = None, align_corners: Optional[bool] = None, - ) -> np.ndarray: + ) -> DataObjects.Images: """ Args: img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D). @@ -920,28 +952,27 @@ def __call__( """ # match the spatial image dim self.randomize() - _dtype = np.float32 + if not self._do_transform: - return img.astype(_dtype) + return img + if len(self._zoom) == 1: # to keep the spatial shape ratio, use same random zoom factor for all dims self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 1) elif len(self._zoom) == 2 and img.ndim > 3: # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 2) + ensure_tuple(self._zoom[-1]) - zoomer = Zoom(self._zoom, keep_size=self.keep_size) - return np.asarray( - zoomer( - img, - mode=look_up_option(mode or self.mode, InterpolateMode), - padding_mode=look_up_option(padding_mode or self.padding_mode, NumpyPadMode), - align_corners=self.align_corners if align_corners is None else align_corners, - ), - dtype=_dtype, + zoomer = Zoom( + self._zoom, + keep_size=self.keep_size, + mode=look_up_option(mode or self.mode, InterpolateMode), + padding_mode=look_up_option(padding_mode or self.padding_mode, NumpyPadMode), + align_corners=align_corners or self.align_corners, ) + return zoomer(img) -class AffineGrid(Transform): +class AffineGrid(TorchTransform): """ Affine transforms on the coordinates. @@ -976,25 +1007,21 @@ def __init__( shear_params: Optional[Union[Sequence[float], float]] = None, translate_params: Optional[Union[Sequence[float], float]] = None, scale_params: Optional[Union[Sequence[float], float]] = None, - as_tensor_output: bool = True, device: Optional[torch.device] = None, - affine: Optional[Union[np.ndarray, torch.Tensor]] = None, + affine: Optional[DataObjects.Images] = None, ) -> None: self.rotate_params = rotate_params self.shear_params = shear_params self.translate_params = translate_params self.scale_params = scale_params - - self.as_tensor_output = as_tensor_output self.device = device - self.affine = affine def __call__( self, spatial_size: Optional[Sequence[int]] = None, - grid: Optional[Union[np.ndarray, torch.Tensor]] = None, - ) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]: + grid: Optional[DataObjects.Images] = None, + ) -> Tuple[DataObjects.Images, DataObjects.Images]: """ Args: spatial_size: output grid size. @@ -1010,7 +1037,7 @@ def __call__( else: raise ValueError("Incompatible values: grid=None and spatial_size=None.") - affine: Union[torch.Tensor, np.ndarray] + affine: DataObjects.Images if self.affine is None: spatial_dims = len(grid.shape) - 1 affine = np.eye(spatial_dims + 1) @@ -1025,20 +1052,17 @@ def __call__( else: affine = self.affine - if isinstance(affine, np.ndarray): - affine = torch.as_tensor(np.ascontiguousarray(affine)) + grid, orig_type, orig_device = convert_data_type(grid, torch.Tensor, dtype=float, device=self.device) + affine, *_ = convert_data_type(affine, torch.Tensor, dtype=float, device=grid.device) # type: ignore - grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone() - if self.device: - affine = affine.to(self.device) - grid = grid.to(self.device) - grid = (affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:])) + grid = (affine @ grid.reshape((grid.shape[0], -1))).reshape([-1] + list(grid.shape[1:])) if grid is None or not isinstance(grid, torch.Tensor): raise ValueError("Unknown grid.") - return grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()), affine + grid, *_ = convert_data_type(grid, orig_type, orig_device) + return grid, affine -class RandAffineGrid(Randomizable, Transform): +class RandAffineGrid(Randomizable, TorchTransform): """ Generate randomised affine grid. """ @@ -1049,7 +1073,6 @@ def __init__( shear_range: RandRange = None, translate_range: RandRange = None, scale_range: RandRange = None, - as_tensor_output: bool = True, device: Optional[torch.device] = None, ) -> None: """ @@ -1064,8 +1087,6 @@ def __init__( translate_range: translate_range with format matching `rotate_range`. scale_range: scaling_range with format matching `rotate_range`. A value of 1.0 is added to the result. This allows 0 to correspond to no change (i.e., a scaling of 1). - as_tensor_output: whether to output tensor instead of numpy array. - defaults to True. device: device to store the output grid data. See also: @@ -1084,9 +1105,8 @@ def __init__( self.translate_params: Optional[List[float]] = None self.scale_params: Optional[List[float]] = None - self.as_tensor_output = as_tensor_output self.device = device - self.affine: Optional[Union[np.ndarray, torch.Tensor]] = None + self.affine: Optional[DataObjects.Images] = None def _get_rand_param(self, param_range, add_scalar: float = 0.0): out_param = [] @@ -1108,8 +1128,8 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__( self, spatial_size: Optional[Sequence[int]] = None, - grid: Optional[Union[np.ndarray, torch.Tensor]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + grid: Optional[DataObjects.Images] = None, + ) -> DataObjects.Images: """ Args: spatial_size: output grid size. @@ -1124,18 +1144,17 @@ def __call__( shear_params=self.shear_params, translate_params=self.translate_params, scale_params=self.scale_params, - as_tensor_output=self.as_tensor_output, device=self.device, ) grid, self.affine = affine_grid(spatial_size, grid) return grid - def get_transformation_matrix(self) -> Optional[Union[np.ndarray, torch.Tensor]]: + def get_transformation_matrix(self) -> Optional[DataObjects.Images]: """Get the most recently applied transformation matrix""" return self.affine -class RandDeformGrid(Randomizable, Transform): +class RandDeformGrid(Randomizable, TorchTransform, NumpyTransform): """ Generate random deformation grid. """ @@ -1185,12 +1204,11 @@ def __call__(self, spatial_size: Sequence[int]): return control_grid -class Resample(Transform): +class Resample(TorchTransform): def __init__( self, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, - as_tensor_output: bool = False, device: Optional[torch.device] = None, ) -> None: """ @@ -1204,21 +1222,19 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: whether to return a torch tensor. Defaults to False. device: device on which the tensor will be allocated. """ self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) - self.as_tensor_output = as_tensor_output self.device = device def __call__( self, - img: Union[np.ndarray, torch.Tensor], - grid: Optional[Union[np.ndarray, torch.Tensor]] = None, + img: DataObjects.Images, + grid: Optional[DataObjects.Images] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> DataObjects.Images: """ Args: img: shape must be (num_channels, H, W[, D]). @@ -1230,18 +1246,17 @@ def __call__( Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ - - if not isinstance(img, torch.Tensor): - img = torch.as_tensor(np.ascontiguousarray(img)) if grid is None: raise AssertionError("Error, grid argument must be supplied as an ndarray or tensor ") - grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone() - if self.device: - img = img.to(self.device) - grid = grid.to(self.device) + + img_t: torch.Tensor + img_t, orig_type, orig_device = convert_data_type( # type: ignore + img, torch.Tensor, device=self.device, dtype=torch.float32 + ) + grid, *_ = convert_data_type(deepcopy(grid), torch.Tensor, device=img_t.device, dtype=float) if USE_COMPILED: - for i, dim in enumerate(img.shape[1:]): + for i, dim in enumerate(img_t.shape[1:]): grid[i] += (dim - 1.0) / 2.0 grid = grid[:-1] / grid[-1:] grid = grid.permute(list(range(grid.ndimension()))[1:] + [0]) @@ -1256,32 +1271,32 @@ def __call__( bound = 1 _interp_mode = look_up_option(self.mode if mode is None else mode, GridSampleMode).value out = grid_pull( - img.unsqueeze(0).float(), + img_t.unsqueeze(0).float(), grid.unsqueeze(0).float(), bound=bound, extrapolate=True, interpolation=1 if _interp_mode == "bilinear" else _interp_mode, )[0] else: - for i, dim in enumerate(img.shape[1:]): + for i, dim in enumerate(img_t.shape[1:]): grid[i] = 2.0 * grid[i] / (dim - 1.0) grid = grid[:-1] / grid[-1:] - index_ordering: List[int] = list(range(img.ndimension() - 2, -1, -1)) + index_ordering: List[int] = list(range(img_t.ndimension() - 2, -1, -1)) grid = grid[index_ordering] grid = grid.permute(list(range(grid.ndimension()))[1:] + [0]) out = torch.nn.functional.grid_sample( - img.unsqueeze(0).float(), + img_t.unsqueeze(0).float(), grid.unsqueeze(0).float(), mode=self.mode.value if mode is None else GridSampleMode(mode).value, padding_mode=self.padding_mode.value if padding_mode is None else GridSamplePadMode(padding_mode).value, align_corners=True, )[0] - if self.as_tensor_output: - return torch.as_tensor(out) - return np.asarray(out.cpu().numpy()) + + out, *_ = convert_data_type(out, orig_type, orig_device) + return out # type: ignore -class Affine(Transform): +class Affine(TorchTransform): """ Transform ``img`` given the affine parameters. """ @@ -1295,7 +1310,6 @@ def __init__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, - as_tensor_output: bool = False, device: Optional[torch.device] = None, image_only: bool = False, ) -> None: @@ -1321,8 +1335,6 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. image_only: if True return only the image volume, otherwise return (image, affine). """ @@ -1331,22 +1343,21 @@ def __init__( shear_params=shear_params, translate_params=translate_params, scale_params=scale_params, - as_tensor_output=True, device=device, ) self.image_only = image_only - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.resampler = Resample(device=device) self.spatial_size = spatial_size self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: DataObjects.Images, spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ): + ) -> Union[DataObjects.Images, Tuple[DataObjects.Images, DataObjects.Images]]: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1369,7 +1380,7 @@ def __call__( return ret if self.image_only else (ret, affine) -class RandAffine(RandomizableTransform): +class RandAffine(RandomizableTransform, TorchTransform): """ Random affine transform. """ @@ -1385,7 +1396,6 @@ def __init__( mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, cache_grid: bool = False, - as_tensor_output: bool = True, device: Optional[torch.device] = None, ) -> None: """ @@ -1417,8 +1427,6 @@ def __init__( cache_grid: whether to cache the identity sampling grid. If the spatial size is not dynamically defined by input image, enabling this option could accelerate the transform. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. See also: @@ -1432,10 +1440,9 @@ def __init__( shear_range=shear_range, translate_range=translate_range, scale_range=scale_range, - as_tensor_output=True, device=device, ) - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.resampler = Resample(device=device) self.spatial_size = spatial_size self.cache_grid = cache_grid @@ -1492,11 +1499,11 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: DataObjects.Images, spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> DataObjects.Images: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1518,8 +1525,8 @@ def __call__( sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) do_resampling = self._do_transform or (sp_size != ensure_tuple(img.shape[1:])) if not do_resampling: - img = img.float() if isinstance(img, torch.Tensor) else img.astype("float32") - return torch.Tensor(img) if self.resampler.as_tensor_output else np.array(img) + img, *_ = convert_data_type(img, dtype=np.float32) + return img grid = self.get_identity_grid(sp_size) if self._do_transform: grid = self.rand_affine_grid(grid=grid) @@ -1528,7 +1535,7 @@ def __call__( ) -class Rand2DElastic(RandomizableTransform): +class Rand2DElastic(TorchTransform, RandomizableTransform): """ Random elastic deformation and affine in 2D """ @@ -1545,7 +1552,6 @@ def __init__( spatial_size: Optional[Union[Tuple[int, int], int]] = None, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, - as_tensor_output: bool = False, device: Optional[torch.device] = None, ) -> None: """ @@ -1577,8 +1583,6 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. See also: @@ -1594,10 +1598,9 @@ def __init__( shear_range=shear_range, translate_range=translate_range, scale_range=scale_range, - as_tensor_output=True, device=device, ) - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.resampler = Resample(device=device) self.spatial_size = spatial_size self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) @@ -1618,11 +1621,11 @@ def randomize(self, spatial_size: Sequence[int]) -> None: def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: DataObjects.Images, spatial_size: Optional[Union[Tuple[int, int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> DataObjects.Images: """ Args: img: shape must be (num_channels, H, W), @@ -1654,7 +1657,7 @@ def __call__( return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) -class Rand3DElastic(RandomizableTransform): +class Rand3DElastic(TorchTransform, RandomizableTransform): """ Random elastic deformation and affine in 3D """ @@ -1671,7 +1674,6 @@ def __init__( spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, - as_tensor_output: bool = False, device: Optional[torch.device] = None, ) -> None: """ @@ -1705,8 +1707,6 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. See also: @@ -1714,8 +1714,8 @@ def __init__( - :py:class:`Affine` for the affine transformation parameters configurations. """ RandomizableTransform.__init__(self, prob) - self.rand_affine_grid = RandAffineGrid(rotate_range, shear_range, translate_range, scale_range, True, device) - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.rand_affine_grid = RandAffineGrid(rotate_range, shear_range, translate_range, scale_range, device) + self.resampler = Resample(device=device) self.sigma_range = sigma_range self.magnitude_range = magnitude_range @@ -1745,11 +1745,11 @@ def randomize(self, grid_size: Sequence[int]) -> None: def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: DataObjects.Images, spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> DataObjects.Images: """ Args: img: shape must be (num_channels, H, W, D), @@ -1769,15 +1769,16 @@ def __call__( if self._do_transform: if self.rand_offset is None: raise AssertionError - grid = torch.as_tensor(np.ascontiguousarray(grid), device=self.device) + grid, *_ = convert_data_type(grid, torch.Tensor, device=self.device) + offset, *_ = convert_data_type(self.rand_offset, torch.Tensor, device=self.device) + offset = offset.unsqueeze(0) # type: ignore gaussian = GaussianFilter(3, self.sigma, 3.0).to(device=self.device) - offset = torch.as_tensor(self.rand_offset, device=self.device).unsqueeze(0) grid[:3] += gaussian(offset)[0] * self.magnitude grid = self.rand_affine_grid(grid=grid) return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) -class AddCoordinateChannels(Transform): +class AddCoordinateChannels(TorchTransform, NumpyTransform): """ Appends additional channels encoding coordinates of the input. Useful when e.g. training using patch-based sampling, to allow feeding of the patch's location into the network. @@ -1799,7 +1800,7 @@ def __init__( """ self.spatial_channels = spatial_channels - def __call__(self, img: Union[np.ndarray, torch.Tensor]): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: img: data to be transformed, assuming `img` is channel first. @@ -1813,8 +1814,15 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]): raise ValueError("cannot add AddCoordinateChannels channel for dimension 0, as 0 is channel dim.") spatial_dims = img.shape[1:] - coord_channels = np.array(np.meshgrid(*tuple(np.linspace(-0.5, 0.5, s) for s in spatial_dims), indexing="ij")) + coord_channels = np.meshgrid(*tuple(np.linspace(-0.5, 0.5, s) for s in spatial_dims), indexing="ij") + coord_channels, *_ = convert_data_type( + coord_channels, type(img), device=img.device if isinstance(img, torch.Tensor) else None + ) # only keep required dimensions. need to subtract 1 since im will be 0-based # but user input is 1-based (because channel dim is 0) coord_channels = coord_channels[[s - 1 for s in self.spatial_channels]] - return np.concatenate((img, coord_channels), axis=0) + if isinstance(img, torch.Tensor): + out = torch.cat((img, coord_channels), dim=0) + else: + out = np.concatenate((img, coord_channels), axis=0) + return out diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 0d65fdfa29..dc372ff10d 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -53,7 +53,7 @@ ensure_tuple_rep, fall_back_tuple, ) -from monai.utils.enums import InverseKeys +from monai.utils.enums import DataObjects, InverseKeys from monai.utils.module import optional_import nib, _ = optional_import("nibabel") @@ -206,9 +206,7 @@ def __init__( raise ValueError("meta_keys should have the same length as keys.") self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - def __call__( - self, data: Mapping[Union[Hashable, str], Dict[str, np.ndarray]] - ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: + def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: d: Dict = dict(data) for key, mode, padding_mode, align_corners, dtype, meta_key, meta_key_postfix in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.meta_keys, self.meta_key_postfix @@ -222,7 +220,7 @@ def __call__( # using affine fetched from d[affine_key] original_spatial_shape = d[key].shape[1:] d[key], old_affine, new_affine = self.spacing_transform( - data_array=np.asarray(d[key]), + data_array=d[key], affine=meta_data["affine"], mode=mode, padding_mode=padding_mode, @@ -245,7 +243,7 @@ def __call__( meta_data["affine"] = new_affine return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) @@ -354,7 +352,7 @@ def __call__( d[meta_key]["affine"] = new_affine return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -394,14 +392,14 @@ def __init__( super().__init__(keys, allow_missing_keys) self.rotator = Rotate90(k, spatial_axes) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) d[key] = self.rotator(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): _ = self.get_most_recent_transform(d, key) @@ -460,7 +458,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self._rand_k = self.R.randint(self.max_k) + 1 super().randomize(None) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Mapping: self.randomize() d = dict(data) @@ -471,7 +469,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. self.push_transform(d, key, extra_info={"rand_k": self._rand_k}) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -533,7 +531,7 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.resizer = Resize(spatial_size=spatial_size, size_mode=size_mode) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): self.push_transform( @@ -547,7 +545,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.resizer(d[key], mode=mode, align_corners=align_corners) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -583,7 +581,6 @@ def __init__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, - as_tensor_output: bool = False, device: Optional[torch.device] = None, allow_missing_keys: bool = False, ) -> None: @@ -610,8 +607,6 @@ def __init__( Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample It also can be a sequence of string, each element corresponds to a key in ``keys``. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. @@ -626,15 +621,12 @@ def __init__( translate_params=translate_params, scale_params=scale_params, spatial_size=spatial_size, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): orig_size = d[key].shape[1:] @@ -651,7 +643,7 @@ def __call__( ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -695,7 +687,6 @@ def __init__( mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, cache_grid: bool = False, - as_tensor_output: bool = True, device: Optional[torch.device] = None, allow_missing_keys: bool = False, ) -> None: @@ -731,8 +722,6 @@ def __init__( cache_grid: whether to cache the identity sampling grid. If the spatial size is not dynamically defined by input image, enabling this option could accelerate the transform. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. @@ -750,7 +739,6 @@ def __init__( scale_range=scale_range, spatial_size=spatial_size, cache_grid=cache_grid, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -767,9 +755,7 @@ def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) self.rand_affine.randomize() - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) self.randomize() @@ -783,7 +769,7 @@ def __call__( if do_resampling: # need to prepare grid grid = self.rand_affine.get_identity_grid(sp_size) if self._do_transform: # add some random factors - grid = self.rand_affine.rand_affine_grid(grid=grid) + grid = deepcopy(self.rand_affine.rand_affine_grid(grid=grid)) affine = self.rand_affine.rand_affine_grid.get_transformation_matrix() # type: ignore[assignment] for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): @@ -799,23 +785,17 @@ def __call__( # do the transform if do_resampling: d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) - # if not doing transform and and spatial size is unchanged, only need to do numpy/torch conversion - else: - if self.rand_affine.resampler.as_tensor_output and not isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]) - elif not self.rand_affine.resampler.as_tensor_output and isinstance(d[key], torch.Tensor): - d[key] = d[key].detach().cpu().numpy() # type: ignore[union-attr] return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # if transform was not performed and spatial size is None, nothing to do. if not transform[InverseKeys.DO_TRANSFORM] and self.rand_affine.spatial_size is None: - out: Union[np.ndarray, torch.Tensor] = d[key] + out: DataObjects.Images = d[key] else: orig_size = transform[InverseKeys.ORIG_SIZE] # Create inverse transform @@ -857,7 +837,6 @@ def __init__( scale_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, - as_tensor_output: bool = False, device: Optional[torch.device] = None, allow_missing_keys: bool = False, ) -> None: @@ -894,8 +873,6 @@ def __init__( Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample It also can be a sequence of string, each element corresponds to a key in ``keys``. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. @@ -914,7 +891,6 @@ def __init__( translate_range=translate_range, scale_range=scale_range, spatial_size=spatial_size, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -931,9 +907,7 @@ def randomize(self, spatial_size: Sequence[int]) -> None: super().randomize(None) self.rand_2d_elastic.randomize(spatial_size) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, data[self.keys[0]].shape[1:]) @@ -976,7 +950,6 @@ def __init__( scale_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, - as_tensor_output: bool = False, device: Optional[torch.device] = None, allow_missing_keys: bool = False, ) -> None: @@ -1014,8 +987,6 @@ def __init__( Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample It also can be a sequence of string, each element corresponds to a key in ``keys``. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. @@ -1034,7 +1005,6 @@ def __init__( translate_range=translate_range, scale_range=scale_range, spatial_size=spatial_size, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -1051,9 +1021,7 @@ def randomize(self, grid_size: Sequence[int]) -> None: super().randomize(None) self.rand_3d_elastic.randomize(grid_size) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, data[self.keys[0]].shape[1:]) @@ -1094,20 +1062,17 @@ def __init__( super().__init__(keys, allow_missing_keys) self.flipper = Flip(spatial_axis=spatial_axis) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Dict) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) d[key] = self.flipper(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Dict) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): _ = self.get_most_recent_transform(d, key) - # Might need to convert to numpy - if isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]).cpu().numpy() # Inverse is same as forward d[key] = self.flipper(d[key]) # Remove the applied transform @@ -1143,7 +1108,7 @@ def __init__( self.flipper = Flip(spatial_axis=spatial_axis) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Dict) -> DataObjects.Dict: self.randomize(None) d = dict(data) for key in self.key_iterator(d): @@ -1152,15 +1117,12 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.push_transform(d, key) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Dict) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform[InverseKeys.DO_TRANSFORM]: - # Might need to convert to numpy - if isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]).cpu().numpy() # Inverse is same as forward d[key] = self.flipper(d[key]) # Remove the applied transform @@ -1187,11 +1149,11 @@ def __init__(self, keys: KeysCollection, prob: float = 0.1, allow_missing_keys: RandomizableTransform.__init__(self, prob) self._axis: Optional[int] = None - def randomize(self, data: np.ndarray) -> None: + def randomize(self, data: DataObjects.Images) -> None: super().randomize(None) self._axis = self.R.randint(data.ndim - 1) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Dict) -> DataObjects.Dict: self.randomize(data=data[self.keys[0]]) flipper = Flip(spatial_axis=self._axis) @@ -1202,16 +1164,13 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.push_transform(d, key, extra_info={"axis": self._axis}) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Dict) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform[InverseKeys.DO_TRANSFORM]: flipper = Flip(spatial_axis=transform[InverseKeys.EXTRA_INFO]["axis"]) - # Might need to convert to numpy - if isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]).cpu().numpy() # Inverse is same as forward d[key] = flipper(d[key]) # Remove the applied transform @@ -1266,7 +1225,7 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype @@ -1293,7 +1252,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) @@ -1400,7 +1359,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: self.randomize() d = dict(data) angle: Union[Sequence[float], float] = self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z) @@ -1436,7 +1395,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) @@ -1509,7 +1468,7 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.zoomer = Zoom(zoom=zoom, keep_size=keep_size) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners @@ -1531,7 +1490,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -1619,7 +1578,7 @@ def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) self._zoom = [self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)] - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: # match the spatial dim of first item self.randomize() d = dict(data) @@ -1654,7 +1613,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -1700,9 +1659,7 @@ def __init__(self, keys: KeysCollection, spatial_channels: Sequence[int], allow_ super().__init__(keys, allow_missing_keys) self.add_coordinate_channels = AddCoordinateChannels(spatial_channels) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) for key in self.key_iterator(d): d[key] = self.add_coordinate_channels(d[key]) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 681c0ba9ec..ee3b1c092c 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -14,7 +14,7 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Generator, Hashable, Iterable, List, Optional, Tuple, TypeVar, Union import numpy as np import torch @@ -22,8 +22,17 @@ from monai import transforms from monai.config import KeysCollection from monai.utils import MAX_SEED, ensure_tuple - -__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"] +from monai.utils.enums import DataObjects + +__all__ = [ + "ThreadUnsafe", + "apply_transform", + "Randomizable", + "RandomizableTransform", + "Transform", + "MapTransform", + "NumpyTransform", +] ReturnType = TypeVar("ReturnType") @@ -233,6 +242,21 @@ def __call__(self, data: Any): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") +class TorchTransform(Transform): + """Most transforms use torch. If the transforms inherits from this class, then that is the case. + If the input is not torch, convert to torch and then convert back at the end.""" + + pass + + +class NumpyTransform(Transform): + """Most transforms use torch. Transforms that inherit from this class, however, use numpy under the hood. + This means that if the input image is `torch.Tensor`, it will be converted to numpy and then reverted + at the end.""" + + pass + + class RandomizableTransform(Randomizable, Transform): """ An interface for handling random state locally, currently based on a class variable `R`, @@ -343,7 +367,7 @@ def __call__(self, data): def key_iterator( self, - data: Dict[Hashable, Any], + data: DataObjects.Dict, *extra_iterables: Optional[Iterable], ) -> Generator: """ diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 4e0141652f..eeb04997af 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -17,14 +17,15 @@ import sys import time import warnings -from typing import Callable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch -from monai.config import DtypeLike, NdarrayTensor -from monai.transforms.transform import Randomizable, RandomizableTransform, Transform +from monai.config import DtypeLike +from monai.transforms.transform import NumpyTransform, Randomizable, RandomizableTransform, TorchTransform, Transform from monai.transforms.utils import ( + _unravel_index, convert_to_numpy, convert_to_tensor, extreme_points_to_image, @@ -32,7 +33,9 @@ map_binary_to_indices, map_classes_to_indices, ) -from monai.utils import ensure_tuple, issequenceiterable, min_version, optional_import +from monai.utils import ensure_tuple, min_version, optional_import +from monai.utils.enums import DataObjects +from monai.utils.misc import convert_data_type, dtype_convert, is_module_ver_at_least PILImageImage, has_pil = optional_import("PIL.Image", name="Image") pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray") @@ -69,7 +72,7 @@ ] -class Identity(Transform): +class Identity(TorchTransform, NumpyTransform): """ Convert the input to an np.ndarray, if input data is np.ndarray or subclasses, return unchanged data. As the output value is same as input, it can be used as a testing tool to verify the transform chain, @@ -77,14 +80,14 @@ class Identity(Transform): """ - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ - return np.asanyarray(img) + return img -class AsChannelFirst(Transform): +class AsChannelFirst(TorchTransform, NumpyTransform): """ Change the channel dimension of the image to the first dimension. @@ -105,14 +108,19 @@ def __init__(self, channel_dim: int = -1) -> None: raise AssertionError("invalid channel dimension.") self.channel_dim = channel_dim - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ - return np.moveaxis(img, self.channel_dim, 0) + if isinstance(img, torch.Tensor) and hasattr(torch, "moveaxis"): + return torch.moveaxis(img, self.channel_dim, 0) + + img, orig_type, orig_device = convert_data_type(img, np.ndarray) + img = np.moveaxis(img, self.channel_dim, 0) # type: ignore + return convert_data_type(img, orig_type, orig_device)[0] -class AsChannelLast(Transform): +class AsChannelLast(TorchTransform, NumpyTransform): """ Change the channel dimension of the image to the last dimension. @@ -132,14 +140,19 @@ def __init__(self, channel_dim: int = 0) -> None: raise AssertionError("invalid channel dimension.") self.channel_dim = channel_dim - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ - return np.moveaxis(img, self.channel_dim, -1) + if isinstance(img, torch.Tensor) and hasattr(torch, "moveaxis"): + return torch.moveaxis(img, self.channel_dim, -1) + + img, orig_type, orig_device = convert_data_type(img, np.ndarray) + img = np.moveaxis(img, self.channel_dim, -1) # type: ignore + return convert_data_type(img, orig_type, orig_device)[0] -class AddChannel(Transform): +class AddChannel(TorchTransform, NumpyTransform): """ Adds a 1-length channel dimension to the input image. @@ -153,14 +166,14 @@ class AddChannel(Transform): transforms. """ - def __call__(self, img: NdarrayTensor): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ return img[None] -class EnsureChannelFirst(Transform): +class EnsureChannelFirst(TorchTransform): """ Automatically adjust or add the channel dimension of input data to ensure `channel_first` shape. It extracts the `original_channel_dim` info from provided meta_data dictionary. @@ -175,7 +188,7 @@ def __init__(self, strict_check: bool = True): """ self.strict_check = strict_check - def __call__(self, img: np.ndarray, meta_dict: Optional[Mapping] = None): + def __call__(self, img: DataObjects.Images, meta_dict: Optional[Dict] = None) -> DataObjects.Images: """ Apply the transform to `img`. """ @@ -199,7 +212,7 @@ def __call__(self, img: np.ndarray, meta_dict: Optional[Mapping] = None): return AsChannelFirst(channel_dim=channel_dim)(img) -class RepeatChannel(Transform): +class RepeatChannel(TorchTransform, NumpyTransform): """ Repeat channel data to construct expected input shape for models. The `repeats` count includes the origin data, for example: @@ -214,14 +227,15 @@ def __init__(self, repeats: int) -> None: raise AssertionError("repeats count must be greater than 0.") self.repeats = repeats - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`, assuming `img` is a "channel-first" array. """ - return np.repeat(img, self.repeats, 0) + repeeat_fn = torch.repeat_interleave if isinstance(img, torch.Tensor) else np.repeat + return repeeat_fn(img, self.repeats, 0) # type: ignore -class RemoveRepeatedChannel(Transform): +class RemoveRepeatedChannel(TorchTransform, NumpyTransform): """ RemoveRepeatedChannel data to undo RepeatChannel The `repeats` count specifies the deletion of the origin data, for example: @@ -237,17 +251,17 @@ def __init__(self, repeats: int) -> None: self.repeats = repeats - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`, assuming `img` is a "channel-first" array. """ - if np.shape(img)[0] < 2: + if img.shape[0] < 2: raise AssertionError("Image must have more than one channel") - return np.array(img[:: self.repeats, :]) + return img[:: self.repeats, :] -class SplitChannel(Transform): +class SplitChannel(TorchTransform, NumpyTransform): """ Split Numpy array or PyTorch Tensor data according to the channel dim. It can help applying different following transforms to different channels. @@ -260,7 +274,7 @@ class SplitChannel(Transform): def __init__(self, channel_dim: int = 0) -> None: self.channel_dim = channel_dim - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> List[Union[np.ndarray, torch.Tensor]]: + def __call__(self, img: DataObjects.Images) -> List[DataObjects.Images]: n_classes = img.shape[self.channel_dim] if n_classes <= 1: raise RuntimeError("input image does not contain multiple channels.") @@ -274,7 +288,7 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> List[Union[np.ndarra return outputs -class CastToType(Transform): +class CastToType(TorchTransform, NumpyTransform): """ Cast the Numpy data to specified numpy data type, or cast the PyTorch Tensor to specified PyTorch data type. @@ -288,8 +302,8 @@ def __init__(self, dtype=np.float32) -> None: self.dtype = dtype def __call__( - self, img: Union[np.ndarray, torch.Tensor], dtype: Optional[Union[DtypeLike, torch.dtype]] = None - ) -> Union[np.ndarray, torch.Tensor]: + self, img: DataObjects.Images, dtype: Optional[Union[DtypeLike, torch.dtype]] = None + ) -> DataObjects.Images: """ Apply the transform to `img`, assuming `img` is a numpy array or PyTorch Tensor. @@ -300,30 +314,29 @@ def __call__( TypeError: When ``img`` type is not in ``Union[numpy.ndarray, torch.Tensor]``. """ + dtype = dtype_convert(dtype or self.dtype, type(img)) if isinstance(img, np.ndarray): - return img.astype(self.dtype if dtype is None else dtype) # type: ignore + return img.astype(dtype) # type: ignore if isinstance(img, torch.Tensor): - return torch.as_tensor(img, dtype=self.dtype if dtype is None else dtype) + return img.to(dtype=dtype) # type: ignore raise TypeError(f"img must be one of (numpy.ndarray, torch.Tensor) but is {type(img).__name__}.") -class ToTensor(Transform): +class ToTensor(TorchTransform): """ Converts the input image to a tensor without applying any other transformations. """ - def __call__(self, img) -> torch.Tensor: + def __call__(self, img: Union[DataObjects.Images, Sequence]) -> torch.Tensor: """ Apply the transform to `img` and make it contiguous. """ - if isinstance(img, torch.Tensor): - return img.contiguous() - if issequenceiterable(img): - # numpy array with 0 dims is also sequence iterable - if not (isinstance(img, np.ndarray) and img.ndim == 0): - # `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims - img = np.ascontiguousarray(img) - return torch.as_tensor(img) + if isinstance(img, Sequence): + img = torch.Tensor(img) + else: + img, *_ = convert_data_type(img, torch.Tensor) + img = img.contiguous() # type: ignore + return img class EnsureType(Transform): @@ -357,21 +370,21 @@ def __call__(self, data): return convert_to_tensor(data) if self.data_type == "tensor" else convert_to_numpy(data) -class ToNumpy(Transform): +class ToNumpy(TorchTransform, NumpyTransform): """ Converts the input data to numpy array, can support list or tuple of numbers and PyTorch Tensor. """ - def __call__(self, img) -> np.ndarray: + def __call__(self, img: Union[DataObjects.Images, Sequence]) -> np.ndarray: """ Apply the transform to `img` and make it contiguous. """ - if isinstance(img, torch.Tensor): - img = img.detach().cpu().numpy() - elif has_cp and isinstance(img, cp_ndarray): - img = cp.asnumpy(img) + array: np.ndarray + if not isinstance(img, Sequence): + array, *_ = convert_data_type(img, np.ndarray) # type: ignore + else: + array = np.asarray(img) - array: np.ndarray = np.asarray(img) return np.ascontiguousarray(array) if array.ndim > 0 else array @@ -380,7 +393,7 @@ class ToCupy(Transform): Converts the input data to CuPy array, can support list or tuple of numbers, NumPy and PyTorch Tensor. """ - def __call__(self, img): + def __call__(self, img: Union[DataObjects.Images, Sequence]): """ Apply the transform to `img` and make it contiguous. """ @@ -400,12 +413,10 @@ def __call__(self, img): """ if isinstance(img, PILImageImage): return img - if isinstance(img, torch.Tensor): - img = img.detach().cpu().numpy() - return pil_image_fromarray(img) + return pil_image_fromarray(ToNumpy()(img)) -class Transpose(Transform): +class Transpose(TorchTransform, NumpyTransform): """ Transposes the input image based on the given `indices` dimension ordering. """ @@ -413,14 +424,16 @@ class Transpose(Transform): def __init__(self, indices: Optional[Sequence[int]]) -> None: self.indices = None if indices is None else tuple(indices) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ + if isinstance(img, torch.Tensor): + return img.permute(self.indices or tuple(range(img.ndim)[::-1])) return img.transpose(self.indices) # type: ignore -class SqueezeDim(Transform): +class SqueezeDim(TorchTransform, NumpyTransform): """ Squeeze a unitary dimension. """ @@ -439,15 +452,20 @@ def __init__(self, dim: Optional[int] = 0) -> None: raise TypeError(f"dim must be None or a int but is {type(dim).__name__}.") self.dim = dim - def __call__(self, img: NdarrayTensor) -> NdarrayTensor: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: img: numpy arrays with required dimension `dim` removed """ - return img.squeeze(self.dim) # type: ignore + if self.dim is None: + return img.squeeze() + # for pytorch/numpy unification + if img.shape[self.dim] != 1: + raise ValueError("Can only squeeze singleton dimension") + return img.squeeze(self.dim) -class DataStats(Transform): +class DataStats(TorchTransform, NumpyTransform): """ Utility transform to show the statistics of data for debug or analysis. It can be inserted into any place of a transform chain and check results of previous transforms. @@ -503,14 +521,14 @@ def __init__( def __call__( self, - img: NdarrayTensor, + img: DataObjects.Images, prefix: Optional[str] = None, data_type: Optional[bool] = None, data_shape: Optional[bool] = None, value_range: Optional[bool] = None, data_value: Optional[bool] = None, additional_info: Optional[Callable] = None, - ) -> NdarrayTensor: + ) -> DataObjects.Images: """ Apply the transform to `img`, optionally take arguments similar to the class constructor. """ @@ -538,7 +556,7 @@ def __call__( return img -class SimulateDelay(Transform): +class SimulateDelay(TorchTransform, NumpyTransform): """ This is a pass through transform to be used for testing purposes. It allows adding fake behaviors that are useful for testing purposes to simulate @@ -559,7 +577,7 @@ def __init__(self, delay_time: float = 0.0) -> None: super().__init__() self.delay_time: float = delay_time - def __call__(self, img: NdarrayTensor, delay_time: Optional[float] = None) -> NdarrayTensor: + def __call__(self, img: DataObjects.Images, delay_time: Optional[float] = None) -> DataObjects.Images: """ Args: img: data remain unchanged throughout this transform. @@ -597,7 +615,7 @@ def __init__(self, func: Optional[Callable] = None) -> None: raise TypeError(f"func must be None or callable but is {type(func).__name__}.") self.func = func - def __call__(self, img: Union[np.ndarray, torch.Tensor], func: Optional[Callable] = None): + def __call__(self, img: DataObjects.Images, func: Optional[Callable] = None): """ Apply `self.func` to `img`. @@ -640,7 +658,7 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor], func: Optional[Callable return super().__call__(img=img, func=func) if self._do_transform else img -class LabelToMask(Transform): +class LabelToMask(TorchTransform, NumpyTransform): """ Convert labels to mask for other tasks. A typical usage is to convert segmentation labels to mask data to pre-process images and then feed the images into classification network. @@ -667,9 +685,19 @@ def __init__( # pytype: disable=annotation-type-mismatch self.select_labels = ensure_tuple(select_labels) self.merge_channels = merge_channels + @staticmethod + def _in1d(x, y): + if isinstance(x, np.ndarray): + return np.in1d(x, y) + else: + return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1) + def __call__( - self, img: np.ndarray, select_labels: Optional[Union[Sequence[int], int]] = None, merge_channels: bool = False - ): + self, + img: DataObjects.Images, + select_labels: Optional[Union[Sequence[int], int]] = None, + merge_channels: bool = False, + ) -> DataObjects.Images: """ Args: select_labels: labels to generate mask from. for 1 channel label, the `select_labels` @@ -686,12 +714,21 @@ def __call__( if img.shape[0] > 1: data = img[[*select_labels]] else: - data = np.where(np.in1d(img, select_labels), True, False).reshape(img.shape) + where = np.where if isinstance(img, np.ndarray) else torch.where + if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): + data = where(self._in1d(img, select_labels), True, False).reshape(img.shape) + else: + data = where(self._in1d(img, select_labels), 1, 0).reshape(img.shape) - return np.any(data, axis=0, keepdims=True) if (merge_channels or self.merge_channels) else data + if merge_channels or self.merge_channels: + if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): + return data.any(0)[None] + else: + return data.to(torch.uint8).any(0)[None].to(bool) # type: ignore + return data -class FgBgToIndices(Transform): +class FgBgToIndices(TorchTransform, NumpyTransform): def __init__(self, image_threshold: float = 0.0, output_shape: Optional[Sequence[int]] = None) -> None: """ Compute foreground and background of the input label data, return the indices. @@ -711,10 +748,10 @@ def __init__(self, image_threshold: float = 0.0, output_shape: Optional[Sequence def __call__( self, - label: np.ndarray, - image: Optional[np.ndarray] = None, + label: DataObjects.Images, + image: Optional[DataObjects.Images] = None, output_shape: Optional[Sequence[int]] = None, - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> Tuple[DataObjects.Images, DataObjects.Images]: """ Args: label: input data to compute foreground and background indices. @@ -727,13 +764,14 @@ def __call__( output_shape = self.output_shape fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold) if output_shape is not None: - fg_indices = np.stack([np.unravel_index(i, output_shape) for i in fg_indices]) - bg_indices = np.stack([np.unravel_index(i, output_shape) for i in bg_indices]) + stack = np.stack if isinstance(label, np.ndarray) else torch.stack + fg_indices = stack([_unravel_index(i, output_shape) for i in fg_indices]) # type: ignore + bg_indices = stack([_unravel_index(i, output_shape) for i in bg_indices]) # type: ignore return fg_indices, bg_indices -class ClassesToIndices(Transform): +class ClassesToIndices(TorchTransform, NumpyTransform): def __init__( self, num_classes: Optional[int] = None, @@ -760,10 +798,10 @@ def __init__( def __call__( self, - label: np.ndarray, - image: Optional[np.ndarray] = None, + label: DataObjects.Images, + image: Optional[DataObjects.Images] = None, output_shape: Optional[Sequence[int]] = None, - ) -> List[np.ndarray]: + ) -> List[DataObjects.Images]: """ Args: label: input data to compute the indices of every class. @@ -776,12 +814,15 @@ def __call__( output_shape = self.output_shape indices = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) if output_shape is not None: - indices = [np.stack([np.unravel_index(i, output_shape) for i in array]) for array in indices] + indices = [ + np.stack([np.unravel_index(i.cpu() if isinstance(i, torch.Tensor) else i, output_shape) for i in array]) + for array in indices + ] return indices -class ConvertToMultiChannelBasedOnBratsClasses(Transform): +class ConvertToMultiChannelBasedOnBratsClasses(TorchTransform, NumpyTransform): """ Convert labels to multi channels based on brats18 classes: label 1 is the necrotic and non-enhancing tumor core @@ -791,22 +832,24 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform): and ET (Enhancing tumor). """ - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: # if img has channel dim, squeeze it if img.ndim == 4 and img.shape[0] == 1: - img = np.squeeze(img, axis=0) + img = img.squeeze(0) result = [] # merge labels 1 (tumor non-enh) and 4 (tumor enh) to TC - result.append(np.logical_or(img == 1, img == 4)) + result.append((img == 1) | (img == 4)) # | is np.logical_or # merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT - result.append(np.logical_or(np.logical_or(img == 1, img == 4), img == 2)) + result.append((img == 1) | (img == 4) | (img == 2)) # label 4 is ET result.append(img == 4) - return np.stack(result, axis=0) + if isinstance(img, np.ndarray): + return np.stack(result, axis=0) + return torch.stack(result, dim=0) -class AddExtremePointsChannel(Randomizable, Transform): +class AddExtremePointsChannel(Randomizable, TorchTransform): """ Add extreme points of label to the image as a new channel. This transform generates extreme point from label and applies a gaussian filter. The pixel values in points image are rescaled @@ -831,17 +874,17 @@ def __init__(self, background: int = 0, pert: float = 0.0) -> None: self._pert = pert self._points: List[Tuple[int, ...]] = [] - def randomize(self, label: np.ndarray) -> None: + def randomize(self, label: DataObjects.Images) -> None: self._points = get_extreme_points(label, rand_state=self.R, background=self._background, pert=self._pert) def __call__( self, - img: np.ndarray, - label: Optional[np.ndarray] = None, + img: DataObjects.Images, + label: Optional[DataObjects.Images] = None, sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 3.0, rescale_min: float = -1.0, rescale_max: float = 1.0, - ): + ) -> DataObjects.Images: """ Args: img: the image that we want to add new channel to. @@ -858,14 +901,23 @@ def __call__( if label.shape[0] != 1: raise ValueError("Only supports single channel labels!") + img_t: torch.Tensor + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor) # type: ignore + # Generate extreme points self.randomize(label[0, :]) - points_image = extreme_points_to_image( - points=self._points, label=label, sigma=sigma, rescale_min=rescale_min, rescale_max=rescale_max + points_image_t = extreme_points_to_image( + points=self._points, + label=label, + sigma=sigma, + rescale_min=rescale_min, + rescale_max=rescale_max, + device=img_t.device, ) - - return np.concatenate([img, points_image], axis=0) + out_t = torch.cat((img_t, points_image_t), dim=0) + out, *_ = convert_data_type(out_t, orig_type, orig_device) + return out class TorchVision: diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 75be9685c4..9d3214d790 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -23,7 +23,7 @@ import numpy as np import torch -from monai.config import DtypeLike, KeysCollection, NdarrayTensor +from monai.config import DtypeLike, KeysCollection from monai.data.utils import no_collation from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform @@ -56,7 +56,7 @@ ) from monai.transforms.utils import extreme_points_to_image, get_extreme_points, tensor_to_numpy from monai.utils import ensure_tuple, ensure_tuple_rep -from monai.utils.enums import InverseKeys +from monai.utils.enums import DataObjects, InverseKeys __all__ = [ "AddChannelD", @@ -171,9 +171,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.identity = Identity() - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) for key in self.key_iterator(d): d[key] = self.identity(d[key]) @@ -196,7 +194,7 @@ def __init__(self, keys: KeysCollection, channel_dim: int = -1, allow_missing_ke super().__init__(keys, allow_missing_keys) self.converter = AsChannelFirst(channel_dim=channel_dim) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -219,7 +217,7 @@ def __init__(self, keys: KeysCollection, channel_dim: int = 0, allow_missing_key super().__init__(keys, allow_missing_keys) self.converter = AsChannelLast(channel_dim=channel_dim) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -241,7 +239,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.adder = AddChannel() - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.adder(d[key]) @@ -280,7 +278,7 @@ def __init__( self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - def __call__(self, data) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, meta_key, meta_key_postfix in zip(self.keys, self.meta_keys, self.meta_key_postfix): d[key] = self.adjuster(d[key], d[meta_key or f"{key}_{meta_key_postfix}"]) @@ -303,7 +301,7 @@ def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.repeater = RepeatChannel(repeats) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.repeater(d[key]) @@ -326,7 +324,7 @@ def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.repeater = RemoveRepeatedChannel(repeats) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.repeater(d[key]) @@ -363,9 +361,7 @@ def __init__( self.output_postfixes = output_postfixes self.splitter = SplitChannel(channel_dim=channel_dim) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) for key in self.key_iterator(d): rets = self.splitter(d[key]) @@ -405,9 +401,7 @@ def __init__( self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.converter = CastToType() - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) for key, dtype in self.key_iterator(d, self.dtype): d[key] = self.converter(d[key], dtype=dtype) @@ -430,14 +424,14 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.converter = ToTensor() - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) d[key] = self.converter(d[key]) return d - def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): # Create inverse transform @@ -506,7 +500,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.converter = ToNumpy() - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -528,7 +522,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.converter = ToCupy() - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -550,7 +544,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.converter = ToPIL() - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -568,7 +562,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.transform = Transpose(indices) - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.transform(d[key]) @@ -577,7 +571,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: self.push_transform(d, key, extra_info={"indices": indices}) return d - def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -629,7 +623,7 @@ def __init__(self, keys: KeysCollection, dim: int = 0, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.converter = SqueezeDim(dim=dim) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -687,7 +681,7 @@ def __init__( self.logger_handler = logger_handler self.printer = DataStats(logger_handler=logger_handler) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, prefix, data_type, data_shape, value_range, data_value, additional_info in self.key_iterator( d, self.prefix, self.data_type, self.data_shape, self.value_range, self.data_value, self.additional_info @@ -725,7 +719,7 @@ def __init__( self.delay_time = ensure_tuple_rep(delay_time, len(self.keys)) self.delayer = SimulateDelay() - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, delay_time in self.key_iterator(d, self.delay_time): d[key] = self.delayer(d[key], delay_time=delay_time) @@ -985,7 +979,7 @@ def __init__( # pytype: disable=annotation-type-mismatch super().__init__(keys, allow_missing_keys) self.converter = LabelToMask(select_labels=select_labels, merge_channels=merge_channels) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -1029,7 +1023,7 @@ def __init__( self.image_key = image_key self.converter = FgBgToIndices(image_threshold, output_shape) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) image = d[self.image_key] if self.image_key else None for key in self.key_iterator(d): @@ -1096,7 +1090,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False): super().__init__(keys, allow_missing_keys) self.converter = ConvertToMultiChannelBasedOnBratsClasses() - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -1142,7 +1136,7 @@ def __init__( self.rescale_min = rescale_min self.rescale_max = rescale_max - def randomize(self, label: np.ndarray) -> None: + def randomize(self, label: DataObjects.Images) -> None: self.points = get_extreme_points(label, rand_state=self.R, background=self.background, pert=self.pert) def __call__(self, data): @@ -1162,8 +1156,12 @@ def __call__(self, data): sigma=self.sigma, rescale_min=self.rescale_min, rescale_max=self.rescale_max, + device=img.device if isinstance(img, torch.Tensor) else None, ) - d[key] = np.concatenate([img, points_image], axis=0) + if isinstance(img, np.ndarray): + d[key] = np.concatenate([img, points_image], axis=0) + else: + d[key] = torch.cat((img, points_image), dim=0) return d @@ -1275,7 +1273,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.mapper = MapLabelValue(orig_labels=orig_labels, target_labels=target_labels, dtype=dtype) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.mapper(d[key]) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 2da7b688cb..f66181a3cb 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -35,6 +35,8 @@ min_version, optional_import, ) +from monai.utils.enums import DataObjects +from monai.utils.misc import convert_data_type, is_module_ver_at_least measure, _ = optional_import("skimage.measure", "0.14.2", min_version) cp, has_cp = optional_import("cupy") @@ -100,7 +102,7 @@ def in_bounds(x: float, y: float, margin: float, maxx: float, maxy: float) -> bo return bool(margin <= x < (maxx - margin) and margin <= y < (maxy - margin)) -def is_empty(img: Union[np.ndarray, torch.Tensor]) -> bool: +def is_empty(img: DataObjects.Images) -> bool: """ Returns True if `img` is empty, that is its maximum value is not greater than its minimum. """ @@ -124,15 +126,20 @@ def zero_margins(img: np.ndarray, margin: int) -> bool: return not np.any(img[:, :margin, :]) and not np.any(img[:, -margin:, :]) -def rescale_array(arr: np.ndarray, minv: float = 0.0, maxv: float = 1.0, dtype: DtypeLike = np.float32): +def rescale_array( + arr: DataObjects.Images, + minv: float = 0.0, + maxv: float = 1.0, + dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32, +) -> DataObjects.Images: """ Rescale the values of numpy array `arr` to be from `minv` to `maxv`. """ if dtype is not None: - arr = arr.astype(dtype) + arr, *_ = convert_data_type(arr, dtype=dtype) - mina = np.min(arr) - maxa = np.max(arr) + mina = arr.min() + maxa = arr.max() if mina == maxa: return arr * minv @@ -244,10 +251,10 @@ def resize_center(img: np.ndarray, *resize_dims: Optional[int], fill_value: floa def map_binary_to_indices( - label: np.ndarray, - image: Optional[np.ndarray] = None, + label: DataObjects.Images, + image: Optional[DataObjects.Images] = None, image_threshold: float = 0.0, -) -> Tuple[np.ndarray, np.ndarray]: +) -> Tuple[DataObjects.Images, DataObjects.Images]: """ Compute the foreground and background of input label data, return the indices after fattening. For example: @@ -262,26 +269,35 @@ def map_binary_to_indices( determine the valid image content area and select background only in this area. """ + + def _nonzero(x): + if isinstance(x, np.ndarray): + return np.nonzero(x)[0] + return torch.nonzero(x).flatten() + # Prepare fg/bg indices if label.shape[0] > 1: label = label[1:] # for One-Hot format data, remove the background channel - label_flat = np.any(label, axis=0).ravel() # in case label has multiple dimensions - fg_indices = np.nonzero(label_flat)[0] + label_flat = label.any(0).ravel() # in case label has multiple dimensions + fg_indices = _nonzero(label_flat) if image is not None: - img_flat = np.any(image > image_threshold, axis=0).ravel() - bg_indices = np.nonzero(np.logical_and(img_flat, ~label_flat))[0] + img_flat = (image > image_threshold).any(0).ravel() + img_flat, *_ = convert_data_type( + img_flat, type(label), device=label.device if isinstance(label, torch.Tensor) else None + ) + bg_indices = _nonzero(img_flat & ~label_flat) else: - bg_indices = np.nonzero(~label_flat)[0] + bg_indices = _nonzero(~label_flat) return fg_indices, bg_indices def map_classes_to_indices( - label: np.ndarray, + label: DataObjects.Images, num_classes: Optional[int] = None, - image: Optional[np.ndarray] = None, + image: Optional[DataObjects.Images] = None, image_threshold: float = 0.0, -) -> List[np.ndarray]: +) -> List[DataObjects.Images]: """ Filter out indices of every class of the input label data, return the indices after fattening. It can handle both One-Hot format label and Argmax format label, must provide `num_classes` for @@ -301,11 +317,17 @@ def map_classes_to_indices( determine the valid image content area and select class indices only in this area. """ - img_flat: Optional[np.ndarray] = None + img_flat: Optional[DataObjects.Images] = None + # if older pytorch, switch to numpy + if not hasattr(label, "ravel"): + label, *_ = convert_data_type(label, np.ndarray) + if image is not None: + image, *_ = convert_data_type(image, np.ndarray) + if image is not None: - img_flat = np.any(image > image_threshold, axis=0).ravel() + img_flat = (image > image_threshold).any(axis=0).ravel() # type: ignore - indices: List[np.ndarray] = [] + indices: List[DataObjects.Images] = [] # assuming the first dimension is channel channels = len(label) @@ -316,16 +338,27 @@ def map_classes_to_indices( num_classes_ = num_classes for c in range(num_classes_): - label_flat = np.any(label[c : c + 1] if channels > 1 else label == c, axis=0).ravel() - label_flat = np.logical_and(img_flat, label_flat) if img_flat is not None else label_flat - indices.append(np.nonzero(label_flat)[0]) + label_flat = (label[c : c + 1] if channels > 1 else label == c).any(axis=0).ravel() # type: ignore + label_flat = img_flat & label_flat if img_flat is not None else label_flat + idx = label_flat.nonzero() + indices.append(idx[0] if isinstance(idx, tuple) else idx.squeeze(1)) return indices +def _unravel_index(idx, shape): + if isinstance(idx, torch.Tensor): + coord = [] + for dim in reversed(shape): + coord.insert(0, idx % dim) + idx = torch.div(idx, dim, rounding_mode="floor") + return torch.stack(coord) + return np.unravel_index(np.asarray(idx, dtype=int), shape) + + def weighted_patch_samples( spatial_size: Union[int, Sequence[int]], - w: np.ndarray, + w: DataObjects.Images, n_samples: int = 1, r_state: Optional[np.random.RandomState] = None, ) -> List: @@ -355,16 +388,28 @@ def weighted_patch_samples( v = w[s] # weight map in the 'valid' mode v_size = v.shape v = v.ravel() - if np.any(v < 0): + if (v < 0).any(): v -= np.min(v) # shifting to non-negative - v = v.cumsum() - if not v[-1] or not np.isfinite(v[-1]) or v[-1] < 0: # uniform sampling + v = v.cumsum(0) + idx: DataObjects.Images + if not v[-1] or not torch.as_tensor(v[-1]).isfinite() or v[-1] < 0: # uniform sampling idx = r_state.randint(0, len(v), size=n_samples) + if isinstance(v, torch.Tensor): + idx = torch.as_tensor(idx, device=v.device) else: - idx = v.searchsorted(r_state.random(n_samples) * v[-1], side="right") + r: DataObjects.Images + r = r_state.random(n_samples) + if isinstance(v, np.ndarray): + idx = v.searchsorted(r * v[-1], side="right") + else: + r = torch.as_tensor(r, device=v.device) + idx = torch.searchsorted(v, r * v[-1], right=True) # compensate 'valid' mode + diff: DataObjects.Images diff = np.minimum(win_size, img_size) // 2 - return [np.unravel_index(i, v_size) + diff for i in np.asarray(idx, dtype=int)] + if isinstance(v, torch.Tensor): + diff = torch.as_tensor(diff, device=v.device) + return [_unravel_index(i, v_size) + diff for i in idx] def correct_crop_centers( @@ -410,8 +455,8 @@ def generate_pos_neg_label_crop_centers( num_samples: int, pos_ratio: float, label_spatial_shape: Sequence[int], - fg_indices: np.ndarray, - bg_indices: np.ndarray, + fg_indices: DataObjects.Images, + bg_indices: DataObjects.Images, rand_state: Optional[np.random.RandomState] = None, ) -> List[List[np.ndarray]]: """ @@ -436,11 +481,12 @@ def generate_pos_neg_label_crop_centers( rand_state = np.random.random.__self__ # type: ignore centers = [] - fg_indices, bg_indices = np.asarray(fg_indices), np.asarray(bg_indices) - if fg_indices.size == 0 and bg_indices.size == 0: + fg_indices = np.asarray(fg_indices) if isinstance(fg_indices, Sequence) else fg_indices + bg_indices = np.asarray(bg_indices) if isinstance(bg_indices, Sequence) else bg_indices + if len(fg_indices) == 0 and len(bg_indices) == 0: raise ValueError("No sampling location available.") - if fg_indices.size == 0 or bg_indices.size == 0: + if len(fg_indices) == 0 or len(bg_indices) == 0: warnings.warn( f"N foreground {len(fg_indices)}, N background {len(bg_indices)}," "unable to generate class balanced samples." @@ -450,7 +496,8 @@ def generate_pos_neg_label_crop_centers( for _ in range(num_samples): indices_to_use = fg_indices if rand_state.rand() < pos_ratio else bg_indices random_int = rand_state.randint(len(indices_to_use)) - center = np.unravel_index(indices_to_use[random_int], label_spatial_shape) + idx = indices_to_use[random_int] + center = _unravel_index(idx, label_spatial_shape) # shift center to range of valid centers center_ori = list(center) centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape)) @@ -462,7 +509,7 @@ def generate_label_classes_crop_centers( spatial_size: Union[Sequence[int], int], num_samples: int, label_spatial_shape: Sequence[int], - indices: List[np.ndarray], + indices: List[DataObjects.Images], ratios: Optional[List[Union[float, int]]] = None, rand_state: Optional[np.random.RandomState] = None, ) -> List[List[np.ndarray]]: @@ -491,8 +538,6 @@ def generate_label_classes_crop_centers( if any([i < 0 for i in ratios_]): raise ValueError("ratios should not contain negative number.") - # ensure indices are numpy array - indices = [np.asarray(i) for i in indices] for i, array in enumerate(indices): if len(array) == 0: warnings.warn(f"no available indices of class {i} to crop, set the crop ratio of this class to zero.") @@ -504,7 +549,8 @@ def generate_label_classes_crop_centers( # randomly select the indices of a class based on the ratios indices_to_use = indices[i] random_int = rand_state.randint(len(indices_to_use)) - center = np.unravel_index(indices_to_use[random_int], label_spatial_shape) + idx, *_ = convert_data_type(indices_to_use[random_int], np.ndarray) + center = np.unravel_index(idx, label_spatial_shape) # shift center to range of valid centers center_ori = list(center) centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape)) @@ -657,7 +703,7 @@ def create_translate(spatial_dims: int, shift: Union[Sequence[float], float]) -> def generate_spatial_bounding_box( - img: np.ndarray, + img: DataObjects.Images, select_fn: Callable = is_positive, channel_indices: Optional[IndexSelection] = None, margin: Union[Sequence[int], int] = 0, @@ -681,8 +727,12 @@ def generate_spatial_bounding_box( of image. if None, select foreground on the whole image. margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. """ + # argmax changed for more recent pytorch versions + if isinstance(img, torch.Tensor) and not is_module_ver_at_least(torch, (1, 7, 0)): + img = img.cpu().numpy() + data = img[list(ensure_tuple(channel_indices))] if channel_indices is not None else img - data = np.any(select_fn(data), axis=0) + data = select_fn(data).any(0) ndim = len(data.shape) margin = ensure_tuple_rep(margin, ndim) for m in margin: @@ -693,19 +743,33 @@ def generate_spatial_bounding_box( box_end = [0] * ndim for di, ax in enumerate(itertools.combinations(reversed(range(ndim)), ndim - 1)): - dt = data.any(axis=ax) - if not np.any(dt): + dt = data + if len(ax) != 0: + if isinstance(dt, np.ndarray): + dt = dt.any(ax) + # pytorch can't handle multiple dimensions to `any` so loop across them + # this works because the dimensions will be reverse sorted. + else: + for i in ax: + dt = dt.any(i) + + if not dt.any(): # if no foreground, return all zero bounding box coords return [0] * ndim, [0] * ndim - min_d = max(np.argmax(dt) - margin[di], 0) - max_d = max(data.shape[di] - max(np.argmax(dt[::-1]) - margin[di], 0), min_d + 1) + dt = dt if isinstance(dt, np.ndarray) else dt.int() + rev_dt = dt[::-1] if isinstance(dt, np.ndarray) else dt.flip(0) + + min_d = max(dt.argmax() - margin[di], 0) + max_d = max(data.shape[di] - max(rev_dt.argmax() - margin[di], 0), min_d + 1) box_start[di], box_end[di] = min_d, max_d return box_start, box_end -def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Optional[int] = None) -> torch.Tensor: +def get_largest_connected_component_mask( + img: DataObjects.Images, connectivity: Optional[int] = None +) -> DataObjects.Images: """ Gets the largest connected component mask of an image. @@ -715,17 +779,20 @@ def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Option Accepted values are ranging from 1 to input.ndim. If ``None``, a full connectivity of ``input.ndim`` is used. """ - img_arr = img.detach().cpu().numpy() - largest_cc = np.zeros(shape=img_arr.shape, dtype=img_arr.dtype) - img_arr = measure.label(img_arr, connectivity=connectivity) - if img_arr.max() != 0: - largest_cc[...] = img_arr == (np.argmax(np.bincount(img_arr.flat)[1:]) + 1) + img_np: np.ndarray + img_np, orig_type, orig_device = convert_data_type(img, np.ndarray) # type: ignore + largest_cc = np.zeros_like(img_np) + img_np = measure.label(img_np, connectivity=connectivity) - return torch.as_tensor(largest_cc, device=img.device) + if img_np.max() != 0: + largest_cc[...] = img_np == (np.argmax(np.bincount(img_np.flat)[1:]) + 1) + + out, *_ = convert_data_type(largest_cc, orig_type, orig_device) + return out def get_extreme_points( - img: np.ndarray, rand_state: Optional[np.random.RandomState] = None, background: int = 0, pert: float = 0.0 + img: DataObjects.Images, rand_state: Optional[np.random.RandomState] = None, background: int = 0, pert: float = 0.0 ) -> List[Tuple[int, ...]]: """ Generate extreme points from an image. These are used to generate initial segmentation @@ -749,9 +816,12 @@ def get_extreme_points( """ if rand_state is None: rand_state = np.random.random.__self__ # type: ignore - indices = np.where(img != background) + where = np.where if isinstance(img, np.ndarray) else torch.where + indices = where(img != background) if np.size(indices[0]) == 0: raise ValueError("get_extreme_points: no foreground object in mask!") + if isinstance(img, torch.Tensor): + indices = tuple(i.cpu() for i in indices) def _get_point(val, dim): """ @@ -773,19 +843,20 @@ def _get_point(val, dim): points = [] for i in range(img.ndim): - points.append(tuple(_get_point(np.min(indices[i][...]), i))) - points.append(tuple(_get_point(np.max(indices[i][...]), i))) + points.append(tuple(_get_point(indices[i].min(), i))) + points.append(tuple(_get_point(indices[i].max(), i))) return points def extreme_points_to_image( points: List[Tuple[int, ...]], - label: np.ndarray, + label: DataObjects.Images, sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 0.0, rescale_min: float = -1.0, rescale_max: float = 1.0, -): + device: Optional[Union[torch.device, str]] = None, +) -> torch.Tensor: """ Please refer to :py:class:`monai.transforms.AddExtremePointsChannel` for the usage. @@ -801,20 +872,21 @@ def extreme_points_to_image( use it for all spatial dimensions. rescale_min: minimum value of output data. rescale_max: maximum value of output data. + device: device on which to return result """ # points to image - points_image = torch.zeros(label.shape[1:], dtype=torch.float) + points_image = torch.zeros(label.shape[1:], dtype=torch.float, device=device) for p in points: points_image[p] = 1.0 # add channel and add batch points_image = points_image.unsqueeze(0).unsqueeze(0) - gaussian_filter = GaussianFilter(label.ndim - 1, sigma=sigma) - points_image = gaussian_filter(points_image).squeeze(0).detach().numpy() + gaussian_filter = GaussianFilter(label.ndim - 1, sigma=sigma).to(device) + points_image = gaussian_filter(points_image).squeeze(0).detach() # rescale the points image to [rescale_min, rescale_max] - min_intensity = np.min(points_image) - max_intensity = np.max(points_image) + min_intensity = points_image.min() + max_intensity = points_image.max() points_image = (points_image - min_intensity) / (max_intensity - min_intensity) points_image = points_image * (rescale_max - rescale_min) + rescale_min return points_image diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index af3cd87652..a1be196874 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -19,6 +19,7 @@ BlendMode, ChannelMatching, CommonKeys, + DataObjects, ForwardMode, GridSampleMode, GridSamplePadMode, @@ -37,7 +38,9 @@ from .misc import ( MAX_SEED, ImageMetaKey, + convert_data_type, copy_to_device, + dtype_convert, dtype_numpy_to_torch, dtype_torch_to_numpy, ensure_tuple, @@ -47,6 +50,7 @@ first, get_seed, has_option, + is_module_ver_at_least, is_scalar, is_scalar_tensor, issequenceiterable, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 014363e14f..398da78741 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -10,6 +10,10 @@ # limitations under the License. from enum import Enum +from typing import Any, Dict, Hashable, Mapping, Union + +import numpy as np +import torch __all__ = [ "NumpyPadMode", @@ -29,6 +33,7 @@ "InverseKeys", "CommonKeys", "ForwardMode", + "DataObjects", ] @@ -233,3 +238,11 @@ class CommonKeys: LABEL = "label" PRED = "pred" LOSS = "loss" + + +class DataObjects: + """Common classes used for arrays/tensors and then their usage in dict/mappings.""" + + Images = Union[torch.Tensor, np.ndarray] + Dict = Dict[Hashable, Any] + Mapping = Mapping[Hashable, Any] diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 86dc55aa9e..c8041e653c 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -22,7 +22,9 @@ import numpy as np import torch -from monai.utils.module import get_torch_version_tuple +from monai.config.type_definitions import DtypeLike +from monai.utils.enums import DataObjects +from monai.utils.module import get_torch_version_tuple, version_leq __all__ = [ "zip_with", @@ -41,9 +43,12 @@ "list_to_dict", "dtype_torch_to_numpy", "dtype_numpy_to_torch", + "dtype_convert", "MAX_SEED", "copy_to_device", "ImageMetaKey", + "convert_data_type", + "is_module_ver_at_least", ] _seed = None @@ -125,6 +130,10 @@ def ensure_tuple_rep(tup: Any, dim: int) -> Tuple[Any, ...]: ValueError: Sequence must have length 3, got length 2. """ + if isinstance(tup, torch.Tensor): + tup = tup.cpu().numpy() + if isinstance(tup, np.ndarray): + tup = tup.tolist() if not issequenceiterable(tup): return (tup,) * dim if len(tup) == dim: @@ -300,17 +309,17 @@ def _parse_var(s): _torch_to_np_dtype = { - torch.bool: bool, - torch.uint8: np.uint8, - torch.int8: np.int8, - torch.int16: np.int16, - torch.int32: np.int32, - torch.int64: np.int64, - torch.float16: np.float16, - torch.float32: np.float32, - torch.float64: np.float64, - torch.complex64: np.complex64, - torch.complex128: np.complex128, + torch.bool: np.dtype(bool), + torch.uint8: np.dtype(np.uint8), + torch.int8: np.dtype(np.int8), + torch.int16: np.dtype(np.int16), + torch.int32: np.dtype(np.int32), + torch.int64: np.dtype(np.int64), + torch.float16: np.dtype(np.float16), + torch.float32: np.dtype(np.float32), + torch.float64: np.dtype(np.float64), + torch.complex64: np.dtype(np.complex64), + torch.complex128: np.dtype(np.complex128), } _np_to_torch_dtype = {value: key for key, value in _torch_to_np_dtype.items()} @@ -322,9 +331,86 @@ def dtype_torch_to_numpy(dtype): def dtype_numpy_to_torch(dtype): """Convert a numpy dtype to its torch equivalent.""" + # np dtypes can be given as np.float32 and np.dtype(np.float32) so unify them + dtype = np.dtype(dtype) if type(dtype) == type else dtype return _np_to_torch_dtype[dtype] +def dtype_convert(dtype, data_type): + """Convert to the `dtype` that corresponds to `data_type`. + + Example: + im = torch.Tensor((1)) + dtype = dtype_convert(np.float32, type(im)) + """ + if data_type is torch.Tensor: + if type(dtype) is torch.dtype: + return dtype + return dtype_numpy_to_torch(dtype) + else: + if type(dtype) is not torch.dtype: + return dtype + return dtype_torch_to_numpy(dtype) + + +def convert_data_type( + data: Any, + output_type: Optional[type] = None, + device: Optional[torch.device] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, +) -> Tuple[DataObjects.Images, type, Optional[torch.device]]: + """Convert to `torch.Tensor`/`np.ndarray` from `torch.Tensor`/`np.ndarray`/`float`/`int` etc. + + Args: + data: data to be converted + output_type: `torch.Tensor` or `np.ndarray` (if blank, unchanged) + device: if output is `torch.Tensor`, select device (if blank, unchanged) + dtype: dtype of output data. Converted to correct library type (e.g., + `np.float32` is converted to `torch.float32` if output type is `torch.Tensor`). + If left blank, it remains unchanged. + + Returns: + modified data, orig_type, orig_device + """ + orig_type = type(data) + orig_device = data.device if isinstance(data, torch.Tensor) else None + + output_type = output_type or orig_type + + def get_dtype(data: Any): + if hasattr(data, "dtype"): + return data.dtype + # need recursion + if isinstance(data, Sequence): + return get_dtype(data[0]) + # objects like float don't have dtype, so return their type + return type(data) + + dtype = dtype_convert(dtype or get_dtype(data), output_type) + + if output_type is torch.Tensor: + if orig_type is np.ndarray: + if (np.array(data.strides) < 0).any(): # copy if -ve stride + data = data.copy() + data = torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data)) + else: + data = torch.as_tensor(data) + if dtype != data.dtype: + data = data.to(dtype) # type: ignore + elif output_type is np.ndarray: + if orig_type is torch.Tensor: + data = data.detach().cpu().numpy() # type: ignore + else: + data = np.array(data) + if dtype != data.dtype: + data = data.astype(dtype) # type: ignore + + if isinstance(data, torch.Tensor) and device is not None: + data = data.to(device) + + return data, orig_type, orig_device + + def copy_to_device( obj: Any, device: Optional[Union[str, torch.device]], @@ -379,3 +465,8 @@ def has_option(obj, keywords: Union[str, Sequence[str]]) -> bool: return False sig = inspect.signature(obj) return all(key in sig.parameters for key in ensure_tuple(keywords)) + + +def is_module_ver_at_least(module, required_version_tuple): + test_ver = ".".join(map(str, required_version_tuple)) + return module.__version__ != test_ver and version_leq(test_ver, module.__version__) diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py index 4a17607320..85a19b4c0f 100644 --- a/monai/visualize/img2tensorboard.py +++ b/monai/visualize/img2tensorboard.py @@ -9,14 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Dict, Optional, Sequence import numpy as np import torch -from monai.config import NdarrayTensor from monai.transforms import rescale_array from monai.utils import optional_import +from monai.utils.enums import DataObjects PIL, _ = optional_import("PIL") GifImage, _ = optional_import("PIL.GifImagePlugin", name="Image") @@ -32,7 +32,7 @@ __all__ = ["make_animated_gif_summary", "add_animated_gif", "add_animated_gif_no_channels", "plot_2d_or_3d_image"] -def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], scale_factor: float = 1.0) -> Summary: +def _image3_animated_gif(tag: str, image: DataObjects.Images, scale_factor: float = 1.0) -> Summary: """Function to actually create the animated gif. Args: @@ -61,7 +61,7 @@ def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], scale def make_animated_gif_summary( tag: str, - image: Union[np.ndarray, torch.Tensor], + image: DataObjects.Images, max_out: int = 3, animation_axes: Sequence[int] = (3,), image_axes: Sequence[int] = (1, 2), @@ -96,7 +96,7 @@ def make_animated_gif_summary( image = image[tuple(slicing)] for it_i in range(min(max_out, list(image.shape)[0])): - one_channel_img: Union[torch.Tensor, np.ndarray] = ( + one_channel_img: DataObjects.Images = ( image[it_i, :, :, :].squeeze(dim=0) if isinstance(image, torch.Tensor) else image[it_i, :, :, :] ) summary_op = _image3_animated_gif(tag + suffix.format(it_i), one_channel_img, scale_factor) @@ -106,7 +106,7 @@ def make_animated_gif_summary( def add_animated_gif( writer: SummaryWriter, tag: str, - image_tensor: Union[np.ndarray, torch.Tensor], + image_tensor: DataObjects.Images, max_out: int, scale_factor: float, global_step: Optional[int] = None, @@ -133,7 +133,7 @@ def add_animated_gif( def add_animated_gif_no_channels( writer: SummaryWriter, tag: str, - image_tensor: Union[np.ndarray, torch.Tensor], + image_tensor: DataObjects.Images, max_out: int, scale_factor: float, global_step: Optional[int] = None, @@ -160,7 +160,7 @@ def add_animated_gif_no_channels( def plot_2d_or_3d_image( - data: Union[NdarrayTensor, List[NdarrayTensor]], + data: DataObjects.Images, step: int, writer: SummaryWriter, index: int = 0, @@ -188,7 +188,7 @@ def plot_2d_or_3d_image( d: np.ndarray = data_index.detach().cpu().numpy() if isinstance(data_index, torch.Tensor) else data_index if d.ndim == 2: - d = rescale_array(d, 0, 1) + d = rescale_array(d, 0, 1) # type: ignore dataformats = "HW" writer.add_image(f"{tag}_{dataformats}", d, step, dataformats=dataformats) return diff --git a/tests/rb_test_transforms.py b/tests/rb_test_transforms.py new file mode 100644 index 0000000000..83eeb73756 --- /dev/null +++ b/tests/rb_test_transforms.py @@ -0,0 +1,61 @@ +from inspect import getmembers, isclass + +from monai import transforms +from monai.transforms import MapTransform, Transform +from monai.transforms.transform import NumpyTransform, TorchTransform + + +class Colours: + red = "91" + green = "92" + yellow = "93" + light_purple = "94" + purple = "95" + cyan = "96" + light_gray = "97" + black = "98" + + +def print_colour(t, colour): + print(f"\033[{colour}m{t}\033[00m") + + +tr_total = 0 +tr_t_or_np = 0 +tr_t = 0 +tr_np = 0 +tr_todo = 0 +tr_uncategorised = 0 +for n, obj in getmembers(transforms): + if isclass(obj) and issubclass(obj, Transform) and not issubclass(obj, MapTransform): + if n in [ + "Transform", + "InvertibleTransform", + "Lambda", + "LoadImage", + "Compose", + "RandomizableTransform", + "NumpyTransform", + "TorchTransform", + "ToPIL", + "ToCupy", + ]: + continue + tr_total += 1 + if issubclass(obj, TorchTransform) and issubclass(obj, NumpyTransform): + tr_t_or_np += 1 + print_colour(f"TorchOrNumpy: {n}", Colours.green) + elif issubclass(obj, TorchTransform): + tr_t += 1 + print_colour(f"Torch: {n}", Colours.green) + elif issubclass(obj, NumpyTransform): + tr_np += 1 + print_colour(f"Numpy: {n}", Colours.yellow) + else: + tr_uncategorised += 1 + print_colour(f"Uncategorised: {n}", Colours.red) +print("Total number of transforms:", tr_total) +print_colour(f"Number transforms allowing both torch and numpy: {tr_t_or_np}", Colours.green) +print_colour(f"Number of TorchTransform: {tr_t}", Colours.green) +print_colour(f"Number of NumpyTransform: {tr_np}", Colours.yellow) +print_colour(f"Number of uncategorised: {tr_uncategorised}", Colours.red) diff --git a/tests/test_activations.py b/tests/test_activations.py index 7d8b3e4c38..66867daf3a 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -11,63 +11,100 @@ import unittest +import numpy as np import torch from parameterized import parameterized from monai.networks.layers.factories import Act from monai.transforms import Activations - -TEST_CASE_1 = [ - {"sigmoid": True, "softmax": False, "other": None}, - torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), - torch.tensor([[[0.5000, 0.7311], [0.8808, 0.9526]]]), - (1, 2, 2), -] - -TEST_CASE_2 = [ - {"sigmoid": False, "softmax": True, "other": None}, - torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), - torch.tensor([[[0.1192, 0.1192]], [[0.8808, 0.8808]]]), - (2, 1, 2), -] - -TEST_CASE_3 = [ - {"sigmoid": False, "softmax": False, "other": torch.tanh}, - torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), - torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]]), - (1, 2, 2), -] - -TEST_CASE_4 = [ - "swish", - torch.tensor([[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]], dtype=torch.float32), - torch.tensor( - [[[-4.54e-04, -2.68e-03, -1.48e-02, -7.19e-02, -2.38e-01], [0.00e00, 1.76e00, 3.93e00, 5.99e00, 8.00e00]]] - ), - (1, 2, 5), -] - -TEST_CASE_5 = [ - "memswish", - torch.tensor([[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]], dtype=torch.float32), - torch.tensor( - [[[-4.54e-04, -2.68e-03, -1.48e-02, -7.19e-02, -2.38e-01], [0.00e00, 1.76e00, 3.93e00, 5.99e00, 8.00e00]]] - ), - (1, 2, 5), -] - -TEST_CASE_6 = [ - "mish", - torch.tensor([[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]], dtype=torch.float32), - torch.tensor( - [[[-4.54e-04, -2.68e-03, -1.49e-02, -7.26e-02, -2.53e-01], [0.00e00, 1.94e00, 4.00e00, 6.00e00, 8.00e00]]] - ), - (1, 2, 5), -] +from tests.utils import TEST_NDARRAYS + +TESTS1, TESTS2 = [], [] +for p in TEST_NDARRAYS: + TESTS1.append( + [ + {"sigmoid": True, "softmax": False, "other": None}, + p(torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])), + p(torch.tensor([[[0.5000, 0.7311], [0.8808, 0.9526]]])), + (1, 2, 2), + ] + ) + TESTS1.append( + [ + {"sigmoid": False, "softmax": True, "other": None}, + p(torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]])), + p(torch.tensor([[[0.1192, 0.1192]], [[0.8808, 0.8808]]])), + (2, 1, 2), + ] + ) + TESTS1.append( + [ + {"sigmoid": False, "softmax": False, "other": torch.tanh}, + p(torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])), + p(torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]])), + (1, 2, 2), + ] + ) + +for p in TEST_NDARRAYS: + # these tests require torch.Tensor + if p == np.array: + continue + TESTS2.append( + [ + "swish", + p(torch.tensor([[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]], dtype=torch.float32)), + p( + torch.tensor( + [ + [ + [-4.54e-04, -2.68e-03, -1.48e-02, -7.19e-02, -2.38e-01], + [0.00e00, 1.76e00, 3.93e00, 5.99e00, 8.00e00], + ] + ] + ) + ), + (1, 2, 5), + ] + ) + TESTS2.append( + [ + "memswish", + p(torch.tensor([[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]], dtype=torch.float32)), + p( + torch.tensor( + [ + [ + [-4.54e-04, -2.68e-03, -1.48e-02, -7.19e-02, -2.38e-01], + [0.00e00, 1.76e00, 3.93e00, 5.99e00, 8.00e00], + ] + ] + ) + ), + (1, 2, 5), + ] + ) + TESTS2.append( + [ + "mish", + p(torch.tensor([[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]], dtype=torch.float32)), + p( + torch.tensor( + [ + [ + [-4.54e-04, -2.68e-03, -1.49e-02, -7.26e-02, -2.53e-01], + [0.00e00, 1.94e00, 4.00e00, 6.00e00, 8.00e00], + ] + ] + ) + ), + (1, 2, 5), + ] + ) class TestActivations(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS1) def test_value_shape(self, input_param, img, out, expected_shape): result = Activations(**input_param)(img) @@ -81,7 +118,7 @@ def _compare(ret, out, shape): else: _compare(result, out, expected_shape) - @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS2) def test_monai_activations_value_shape(self, input_param, img, out, expected_shape): act = Act[input_param]() result = act(img) diff --git a/tests/test_activationsd.py b/tests/test_activationsd.py index 355c50f389..4ed550fabf 100644 --- a/tests/test_activationsd.py +++ b/tests/test_activationsd.py @@ -11,48 +11,64 @@ import unittest +import numpy as np import torch from parameterized import parameterized from monai.transforms import Activationsd +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"keys": ["pred", "label"], "sigmoid": False, "softmax": [True, False], "other": None}, - {"pred": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), "label": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]])}, - { - "pred": torch.tensor([[[0.1192, 0.1192]], [[0.8808, 0.8808]]]), - "label": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), - }, - (2, 1, 2), -] - -TEST_CASE_2 = [ - {"keys": ["pred", "label"], "sigmoid": False, "softmax": False, "other": [torch.tanh, None]}, - {"pred": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), "label": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])}, - { - "pred": torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]]), - "label": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), - }, - (1, 2, 2), -] - -TEST_CASE_3 = [ - {"keys": "pred", "sigmoid": False, "softmax": False, "other": torch.tanh}, - {"pred": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])}, - {"pred": torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]])}, - (1, 2, 2), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["pred", "label"], "sigmoid": False, "softmax": [True, False], "other": None}, + { + "pred": p(torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]])), + "label": p(torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]])), + }, + { + "pred": torch.tensor([[[0.1192, 0.1192]], [[0.8808, 0.8808]]]), + "label": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), + }, + (2, 1, 2), + ] + ) + TESTS.append( + [ + {"keys": ["pred", "label"], "sigmoid": False, "softmax": False, "other": [torch.tanh, None]}, + {"pred": p(torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])), "label": p(torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]))}, + { + "pred": torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]]), + "label": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), + }, + (1, 2, 2), + ] + ) + TESTS.append( + [ + {"keys": "pred", "sigmoid": False, "softmax": False, "other": torch.tanh}, + {"pred": p(torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]))}, + {"pred": torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]])}, + (1, 2, 2), + ] + ) class TestActivationsd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value_shape(self, input_param, test_input, output, expected_shape): result = Activationsd(**input_param)(test_input) - torch.testing.assert_allclose(result["pred"], output["pred"]) - self.assertTupleEqual(result["pred"].shape, expected_shape) - if "label" in result: - torch.testing.assert_allclose(result["label"], output["label"]) - self.assertTupleEqual(result["label"].shape, expected_shape) + for k in ("label", "pred"): + if k not in result: + continue + i, r, o = test_input[k], result[k], output[k] + self.assertEqual(type(i), type(r)) + if isinstance(r, torch.Tensor): + self.assertEqual(r.device, i.device) + r = r.cpu() + np.testing.assert_allclose(r, o, rtol=1e-4, atol=1e-5) + self.assertTupleEqual(r.shape, expected_shape) if __name__ == "__main__": diff --git a/tests/test_add_coordinate_channels.py b/tests/test_add_coordinate_channels.py index 3399008e02..e90dca4688 100644 --- a/tests/test_add_coordinate_channels.py +++ b/tests/test_add_coordinate_channels.py @@ -12,32 +12,38 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import AddCoordinateChannels +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"spatial_channels": (1, 2, 3)}, np.random.randint(0, 2, size=(1, 3, 3, 3)), (4, 3, 3, 3)] - -TEST_CASE_2 = [{"spatial_channels": (1,)}, np.random.randint(0, 2, size=(1, 3, 3, 3)), (2, 3, 3, 3)] - -TEST_CASE_ERROR_3 = [{"spatial_channels": (3,)}, np.random.randint(0, 2, size=(1, 3, 3))] - -TEST_CASE_ERROR_4 = [{"spatial_channels": (0, 1, 2)}, np.random.randint(0, 2, size=(1, 3, 3))] +TESTS, TEST_CASES_ERROR_1, TEST_CASES_ERROR_2 = [], [], [] +for p in TEST_NDARRAYS: + TESTS.append([{"spatial_channels": (1, 2, 3)}, p(np.random.randint(0, 2, size=(1, 3, 3, 3))), (4, 3, 3, 3)]) + TESTS.append([{"spatial_channels": (1,)}, p(np.random.randint(0, 2, size=(1, 3, 3, 3))), (2, 3, 3, 3)]) + TEST_CASES_ERROR_1.append([{"spatial_channels": (3,)}, p(np.random.randint(0, 2, size=(1, 3, 3)))]) + TEST_CASES_ERROR_2.append([{"spatial_channels": (0, 1, 2)}, p(np.random.randint(0, 2, size=(1, 3, 3)))]) class TestAddCoordinateChannels(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input, expected_shape): result = AddCoordinateChannels(**input_param)(input) + self.assertEqual(type(result), type(input)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, input.device) + input = input.cpu() + result = result.cpu() self.assertEqual(list(result.shape), list(expected_shape)) np.testing.assert_array_equal(input[0, ...], result[0, ...]) - @parameterized.expand([TEST_CASE_ERROR_3]) + @parameterized.expand(TEST_CASES_ERROR_1) def test_max_channel(self, input_param, input): with self.assertRaises(ValueError): AddCoordinateChannels(**input_param)(input) - @parameterized.expand([TEST_CASE_ERROR_4]) + @parameterized.expand(TEST_CASES_ERROR_2) def test_channel_dim(self, input_param, input): with self.assertRaises(ValueError): AddCoordinateChannels(**input_param)(input) diff --git a/tests/test_add_coordinate_channelsd.py b/tests/test_add_coordinate_channelsd.py index 0fa6aae1c9..cb504983e6 100644 --- a/tests/test_add_coordinate_channelsd.py +++ b/tests/test_add_coordinate_channelsd.py @@ -12,40 +12,56 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import AddCoordinateChannelsd +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"spatial_channels": (1, 2, 3), "keys": ["img"]}, - {"img": np.random.randint(0, 2, size=(1, 3, 3, 3))}, - (4, 3, 3, 3), -] +TESTS, TEST_CASES_ERROR_1, TEST_CASES_ERROR_2 = [], [], [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"spatial_channels": (1, 2, 3), "keys": ["img"]}, + {"img": p(np.random.randint(0, 2, size=(1, 3, 3, 3)))}, + (4, 3, 3, 3), + ] + ) + TESTS.append( + [ + {"spatial_channels": (1,), "keys": ["img"]}, + {"img": p(np.random.randint(0, 2, size=(1, 3, 3, 3)))}, + (2, 3, 3, 3), + ] + ) -TEST_CASE_2 = [ - {"spatial_channels": (1,), "keys": ["img"]}, - {"img": np.random.randint(0, 2, size=(1, 3, 3, 3))}, - (2, 3, 3, 3), -] - -TEST_CASE_ERROR_3 = [{"spatial_channels": (3,), "keys": ["img"]}, {"img": np.random.randint(0, 2, size=(1, 3, 3))}] - -TEST_CASE_ERROR_4 = [{"spatial_channels": (0, 1, 2), "keys": ["img"]}, {"img": np.random.randint(0, 2, size=(1, 3, 3))}] + TEST_CASES_ERROR_1.append( + [{"spatial_channels": (3,), "keys": ["img"]}, {"img": p(np.random.randint(0, 2, size=(1, 3, 3)))}] + ) + TEST_CASES_ERROR_2.append( + [{"spatial_channels": (0, 1, 2), "keys": ["img"]}, {"img": p(np.random.randint(0, 2, size=(1, 3, 3)))}] + ) class TestAddCoordinateChannels(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input, expected_shape): - result = AddCoordinateChannelsd(**input_param)(input) - self.assertEqual(list(result["img"].shape), list(expected_shape)) - np.testing.assert_array_equal(input["img"][0, ...], result["img"][0, ...]) + result = AddCoordinateChannelsd(**input_param)(input)["img"] + input = input["img"] + self.assertEqual(type(result), type(input)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, input.device) + input = input.cpu() + result = result.cpu() + self.assertEqual(result.shape, expected_shape) + np.testing.assert_array_equal(input[0, ...], result[0, ...]) - @parameterized.expand([TEST_CASE_ERROR_3]) + @parameterized.expand(TEST_CASES_ERROR_1) def test_max_channel(self, input_param, input): with self.assertRaises(ValueError): AddCoordinateChannelsd(**input_param)(input) - @parameterized.expand([TEST_CASE_ERROR_4]) + @parameterized.expand(TEST_CASES_ERROR_2) def test_channel_dim(self, input_param, input): with self.assertRaises(ValueError): AddCoordinateChannelsd(**input_param)(input) diff --git a/tests/test_add_extreme_points_channel.py b/tests/test_add_extreme_points_channel.py index ecf2c83d3c..c062f809ba 100644 --- a/tests/test_add_extreme_points_channel.py +++ b/tests/test_add_extreme_points_channel.py @@ -12,54 +12,67 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import AddExtremePointsChannel +from tests.utils import TEST_NDARRAYS IMG_CHANNEL = 3 -TEST_CASE_1 = [ - { - "img": np.zeros((IMG_CHANNEL, 4, 3)), - "label": np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]]), - "sigma": 1.0, - "rescale_min": 0.0, - "rescale_max": 1.0, - }, - np.array( - [ - [0.38318458, 0.98615628, 0.85551184], - [0.35422316, 0.94430935, 1.0], - [0.46000731, 0.57319659, 0.46000722], - [0.64577687, 0.38318464, 0.0], - ] - ), -] +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: -TEST_CASE_2 = [ - { - "img": np.zeros((IMG_CHANNEL, 4, 3)), - "label": np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]]), - "sigma": 1.0, - "rescale_min": 0.0, - "rescale_max": 1.0, - }, - np.array( - [ - [0.44628328, 0.80495411, 0.44628328], - [0.6779086, 1.0, 0.67790854], - [0.33002687, 0.62079221, 0.33002687], - [0.0, 0.31848389, 0.0], - ] - ), -] + TESTS.append( + [ + { + "img": p(np.zeros((IMG_CHANNEL, 4, 3))), + "label": q(np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]])), + "sigma": 1.0, + "rescale_min": 0.0, + "rescale_max": 1.0, + }, + np.array( + [ + [0.38318458, 0.98615628, 0.85551184], + [0.35422316, 0.94430935, 1.0], + [0.46000731, 0.57319659, 0.46000722], + [0.64577687, 0.38318464, 0.0], + ] + ), + ] + ) + TESTS.append( + [ + { + "img": p(np.zeros((IMG_CHANNEL, 4, 3))), + "label": q(np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]])), + "sigma": 1.0, + "rescale_min": 0.0, + "rescale_max": 1.0, + }, + np.array( + [ + [0.44628328, 0.80495411, 0.44628328], + [0.6779086, 1.0, 0.67790854], + [0.33002687, 0.62079221, 0.33002687], + [0.0, 0.31848389, 0.0], + ] + ), + ] + ) class TestAddExtremePointsChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_correct_results(self, input_data, expected): add_extreme_points_channel = AddExtremePointsChannel() result = add_extreme_points_channel(**input_data) + self.assertEqual(type(input_data["img"]), type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, input_data["img"].device) + result = result.cpu() np.testing.assert_allclose(result[IMG_CHANNEL], expected, rtol=1e-4) diff --git a/tests/test_add_extreme_points_channeld.py b/tests/test_add_extreme_points_channeld.py index e33bb0838c..30cdb5ee82 100644 --- a/tests/test_add_extreme_points_channeld.py +++ b/tests/test_add_extreme_points_channeld.py @@ -12,45 +12,63 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import AddExtremePointsChanneld +from tests.utils import TEST_NDARRAYS IMG_CHANNEL = 3 -TEST_CASE_1 = [ - {"img": np.zeros((IMG_CHANNEL, 4, 3)), "label": np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]])}, - np.array( - [ - [0.38318458, 0.98615628, 0.85551184], - [0.35422316, 0.94430935, 1.0], - [0.46000731, 0.57319659, 0.46000722], - [0.64577687, 0.38318464, 0.0], - ] - ), -] - -TEST_CASE_2 = [ - {"img": np.zeros((IMG_CHANNEL, 4, 3)), "label": np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]])}, - np.array( - [ - [0.44628328, 0.80495411, 0.44628328], - [0.6779086, 1.0, 0.67790854], - [0.33002687, 0.62079221, 0.33002687], - [0.0, 0.31848389, 0.0], - ] - ), -] +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TESTS.append( + [ + { + "img": p(np.zeros((IMG_CHANNEL, 4, 3))), + "label": q(np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]])), + }, + np.array( + [ + [0.38318458, 0.98615628, 0.85551184], + [0.35422316, 0.94430935, 1.0], + [0.46000731, 0.57319659, 0.46000722], + [0.64577687, 0.38318464, 0.0], + ] + ), + ] + ) + TESTS.append( + [ + { + "img": p(np.zeros((IMG_CHANNEL, 4, 3))), + "label": q(np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]])), + }, + np.array( + [ + [0.44628328, 0.80495411, 0.44628328], + [0.6779086, 1.0, 0.67790854], + [0.33002687, 0.62079221, 0.33002687], + [0.0, 0.31848389, 0.0], + ] + ), + ] + ) class TestAddExtremePointsChanneld(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_correct_results(self, input_data, expected): add_extreme_points_channel = AddExtremePointsChanneld( keys="img", label_key="label", sigma=1.0, rescale_min=0.0, rescale_max=1.0 ) - result = add_extreme_points_channel(input_data) - np.testing.assert_allclose(result["img"][IMG_CHANNEL], expected, rtol=1e-4) + result = add_extreme_points_channel(input_data)["img"] + + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, input_data["img"].device) + result = result.cpu() + np.testing.assert_allclose(result[IMG_CHANNEL], expected, rtol=1e-4) if __name__ == "__main__": diff --git a/tests/test_affine.py b/tests/test_affine.py index dd82d72e23..b33fce8104 100644 --- a/tests/test_affine.py +++ b/tests/test_affine.py @@ -16,77 +16,143 @@ from parameterized import parameterized from monai.transforms import Affine +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(9).reshape((1, 3, 3)), "spatial_size": (-1, 0)}, - np.arange(9).reshape(1, 3, 3), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None, image_only=True), - {"img": np.arange(9).reshape((1, 3, 3)), "spatial_size": (-1, 0)}, - np.arange(9).reshape(1, 3, 3), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.arange(4).reshape(1, 2, 2), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2)), "spatial_size": (4, 4)}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(rotate_params=[np.pi / 2], padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2)), "spatial_size": (4, 4)}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(27).reshape((1, 3, 3, 3)), "spatial_size": (-1, 0, 0)}, - np.arange(27).reshape(1, 3, 3, 3), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(8).reshape((1, 2, 2, 2)), "spatial_size": (4, 4, 4)}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 4.0, 5.0, 0.0], [0.0, 6.0, 7.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(9).reshape((1, 3, 3))), "spatial_size": (-1, 0)}, + p(np.arange(9).reshape(1, 3, 3)), ] - ), - ], - [ - dict(rotate_params=[np.pi / 2], padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(8).reshape((1, 2, 2, 2)), "spatial_size": (4, 4, 4)}, - np.array( + ) + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 6.0, 4.0, 0.0], [0.0, 7.0, 5.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(padding_mode="zeros", device=device, image_only=True), + {"img": p(np.arange(9).reshape((1, 3, 3))), "spatial_size": (-1, 0)}, + p(np.arange(9).reshape(1, 3, 3)), ] - ), - ], -] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.arange(4).reshape(1, 2, 2)), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2))), "spatial_size": (4, 4)}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(rotate_params=[np.pi / 2], padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2))), "spatial_size": (4, 4)}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(27).reshape((1, 3, 3, 3))), "spatial_size": (-1, 0, 0)}, + p(np.arange(27).reshape(1, 3, 3, 3)), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(8).reshape((1, 2, 2, 2))), "spatial_size": (4, 4, 4)}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 2.0, 3.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 4.0, 5.0, 0.0], + [0.0, 6.0, 7.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + dict(rotate_params=[np.pi / 2], padding_mode="zeros", device=device), + {"img": p(np.arange(8).reshape((1, 2, 2, 2))), "spatial_size": (4, 4, 4)}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0], + [0.0, 3.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 6.0, 4.0, 0.0], + [0.0, 7.0, 5.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) class TestAffine(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_affine(self, input_param, input_data, expected_val): g = Affine(**input_param) result = g(**input_data) if isinstance(result, tuple): result = result[0] - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) + self.assertEqual(type(result), type(expected_val)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, expected_val.device) + result = result.cpu() + expected_val = expected_val.cpu() np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) diff --git a/tests/test_affine_grid.py b/tests/test_affine_grid.py index 24772b9a21..23e12ba6f0 100644 --- a/tests/test_affine_grid.py +++ b/tests/test_affine_grid.py @@ -16,88 +16,110 @@ from parameterized import parameterized from monai.transforms import AffineGrid +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - {"as_tensor_output": False, "device": torch.device("cpu:0")}, - {"spatial_size": (2, 2)}, - np.array([[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]], [[1.0, 1.0], [1.0, 1.0]]]), - ], - [ - {"as_tensor_output": True, "device": None}, - {"spatial_size": (2, 2)}, - torch.tensor([[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]], [[1.0, 1.0], [1.0, 1.0]]]), - ], - [{"as_tensor_output": False, "device": None}, {"grid": np.ones((3, 3, 3))}, np.ones((3, 3, 3))], - [{"as_tensor_output": True, "device": torch.device("cpu:0")}, {"grid": np.ones((3, 3, 3))}, torch.ones((3, 3, 3))], - [{"as_tensor_output": False, "device": None}, {"grid": torch.ones((3, 3, 3))}, np.ones((3, 3, 3))], - [ - {"as_tensor_output": True, "device": torch.device("cpu:0")}, - {"grid": torch.ones((3, 3, 3))}, - torch.ones((3, 3, 3)), - ], - [ - { - "rotate_params": (1.0, 1.0), - "scale_params": (-20, 10), - "as_tensor_output": True, - "device": torch.device("cpu:0"), - }, - {"grid": torch.ones((3, 3, 3))}, - torch.tensor( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [[-19.2208, -19.2208, -19.2208], [-19.2208, -19.2208, -19.2208], [-19.2208, -19.2208, -19.2208]], - [[-11.4264, -11.4264, -11.4264], [-11.4264, -11.4264, -11.4264], [-11.4264, -11.4264, -11.4264]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + {"device": device}, + {"spatial_size": (2, 2)}, + np.array([[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]], [[1.0, 1.0], [1.0, 1.0]]]), ] - ), - ], - [ - { - "rotate_params": (1.0, 1.0, 1.0), - "scale_params": (-20, 10), - "as_tensor_output": True, - "device": torch.device("cpu:0"), - }, - {"grid": torch.ones((4, 3, 3, 3))}, - torch.tensor( + ) + + TESTS.append([{"device": device}, {"grid": p(np.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))]) + TESTS.append([{"device": device}, {"grid": p(torch.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))]) + TESTS.append( + [ + { + "rotate_params": (1.0, 1.0), + "scale_params": (-20, 10), + "device": device, + }, + {"grid": p(torch.ones((3, 3, 3)))}, + p( + torch.tensor( + [ + [ + [-19.2208, -19.2208, -19.2208], + [-19.2208, -19.2208, -19.2208], + [-19.2208, -19.2208, -19.2208], + ], + [ + [-11.4264, -11.4264, -11.4264], + [-11.4264, -11.4264, -11.4264], + [-11.4264, -11.4264, -11.4264], + ], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ) + ), + ] + ) + TESTS.append( [ - [ - [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], - [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], - [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], - ], - [ - [[-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381]], - [[-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381]], - [[-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381]], - ], - [ - [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], - [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], - [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], - ], - [ - [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], - [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], - [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], - ], + { + "rotate_params": (1.0, 1.0, 1.0), + "scale_params": (-20, 10), + "device": device, + }, + {"grid": p(torch.ones((4, 3, 3, 3)))}, + p( + torch.tensor( + [ + [ + [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], + [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], + [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], + ], + [ + [ + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + ], + [ + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + ], + [ + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + ], + ], + [ + [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], + [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], + [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], + ], + [ + [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], + [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], + [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], + ], + ] + ) + ), ] - ), - ], -] + ) class TestAffineGrid(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_affine_grid(self, input_param, input_data, expected_val): g = AffineGrid(**input_param) result, _ = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + if "grid" in input_data: + self.assertEqual(type(result), type(expected_val)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, expected_val.device) + result = result.cpu() + expected_val = expected_val.cpu() + np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_affined.py b/tests/test_affined.py index 850f12905d..5bbb384537 100644 --- a/tests/test_affined.py +++ b/tests/test_affined.py @@ -16,84 +16,145 @@ from parameterized import parameterized from monai.transforms import Affined +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - dict(keys="img", padding_mode="zeros", as_tensor_output=False, spatial_size=(-1, 0), device=None), - {"img": np.arange(9).reshape((1, 3, 3))}, - np.arange(9).reshape(1, 3, 3), - ], - [ - dict(keys="img", padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.arange(4).reshape(1, 2, 2), - ], - [ - dict(keys="img", padding_mode="zeros", spatial_size=(4, 4), as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict( - keys="img", - rotate_params=[np.pi / 2], - padding_mode="zeros", - spatial_size=(4, 4), - as_tensor_output=False, - device=None, - ), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0, 0), as_tensor_output=False, device=None), - {"img": np.arange(27).reshape((1, 3, 3, 3))}, - np.arange(27).reshape(1, 3, 3, 3), - ], - [ - dict(keys="img", padding_mode="zeros", spatial_size=(4, 4, 4), as_tensor_output=False, device=None), - {"img": np.arange(8).reshape((1, 2, 2, 2))}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 4.0, 5.0, 0.0], [0.0, 6.0, 7.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0), device=device), + {"img": p(np.arange(9).reshape((1, 3, 3)))}, + p(np.arange(9).reshape(1, 3, 3)), ] - ), - ], - [ - dict( - keys="img", - rotate_params=[np.pi / 2], - padding_mode="zeros", - spatial_size=(4, 4, 4), - as_tensor_output=False, - device=None, - ), - {"img": np.arange(8).reshape((1, 2, 2, 2))}, - np.array( + ) + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 6.0, 4.0, 0.0], [0.0, 7.0, 5.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(keys="img", padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.arange(4).reshape(1, 2, 2)), ] - ), - ], -] + ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", spatial_size=(4, 4), device=device), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict( + keys="img", + rotate_params=[np.pi / 2], + padding_mode="zeros", + spatial_size=(4, 4), + device=device, + ), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0, 0), device=device), + {"img": p(np.arange(27).reshape((1, 3, 3, 3)))}, + p(np.arange(27).reshape(1, 3, 3, 3)), + ] + ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", spatial_size=(4, 4, 4), device=device), + {"img": p(np.arange(8).reshape((1, 2, 2, 2)))}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 2.0, 3.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 4.0, 5.0, 0.0], + [0.0, 6.0, 7.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + dict( + keys="img", + rotate_params=[np.pi / 2], + padding_mode="zeros", + spatial_size=(4, 4, 4), + device=device, + ), + {"img": p(np.arange(8).reshape((1, 2, 2, 2)))}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0], + [0.0, 3.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 6.0, 4.0, 0.0], + [0.0, 7.0, 5.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) class TestAffined(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_affine(self, input_param, input_data, expected_val): g = Affined(**input_param) result = g(input_data)["img"] - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) + self.assertEqual(type(result), type(expected_val)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, expected_val.device) + result, expected_val = result.cpu(), expected_val.cpu() np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) diff --git a/tests/test_arraydataset.py b/tests/test_arraydataset.py index f6459cc88c..bf16aacb26 100644 --- a/tests/test_arraydataset.py +++ b/tests/test_arraydataset.py @@ -125,7 +125,7 @@ def test_dataloading_img(self, img_transform, expected_shape): dataset = ArrayDataset(test_images, img_transform) self.assertEqual(len(dataset), 2) dataset.set_random_state(1234) - n_workers = 0 if sys.platform == "win32" else 2 + n_workers = 2 if sys.platform == "linux" else 0 loader = DataLoader(dataset, batch_size=10, num_workers=n_workers) imgs = next(iter(loader)) # test batching np.testing.assert_allclose(imgs.shape, [2] + list(expected_shape)) @@ -151,7 +151,7 @@ def test_dataloading_img_label(self, img_transform, expected_shape): dataset = ArrayDataset(test_images, img_transform, test_labels, img_transform) self.assertEqual(len(dataset), 2) dataset.set_random_state(1234) - n_workers = 0 if sys.platform == "win32" else 2 + n_workers = 2 if sys.platform == "linux" else 0 loader = DataLoader(dataset, batch_size=10, num_workers=n_workers) data = next(iter(loader)) # test batching np.testing.assert_allclose(data[0].shape, [2] + list(expected_shape)) diff --git a/tests/test_as_channel_first.py b/tests/test_as_channel_first.py index e7d9866ae1..bc9158f277 100644 --- a/tests/test_as_channel_first.py +++ b/tests/test_as_channel_first.py @@ -15,18 +15,19 @@ from parameterized import parameterized from monai.transforms import AsChannelFirst +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"channel_dim": -1}, (4, 1, 2, 3)] - -TEST_CASE_2 = [{"channel_dim": 3}, (4, 1, 2, 3)] - -TEST_CASE_3 = [{"channel_dim": 2}, (3, 1, 2, 4)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"channel_dim": -1}, (4, 1, 2, 3)]) + TESTS.append([p, {"channel_dim": 3}, (4, 1, 2, 3)]) + TESTS.append([p, {"channel_dim": 2}, (3, 1, 2, 4)]) class TestAsChannelFirst(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_param, expected_shape): - test_data = np.random.randint(0, 2, size=[1, 2, 3, 4]) + @parameterized.expand(TESTS) + def test_shape(self, in_type, input_param, expected_shape): + test_data = in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])) result = AsChannelFirst(**input_param)(test_data) self.assertTupleEqual(result.shape, expected_shape) diff --git a/tests/test_as_channel_firstd.py b/tests/test_as_channel_firstd.py index e70c2e1b47..68d33434c1 100644 --- a/tests/test_as_channel_firstd.py +++ b/tests/test_as_channel_firstd.py @@ -15,21 +15,22 @@ from parameterized import parameterized from monai.transforms import AsChannelFirstd +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"keys": ["image", "label", "extra"], "channel_dim": -1}, (4, 1, 2, 3)] - -TEST_CASE_2 = [{"keys": ["image", "label", "extra"], "channel_dim": 3}, (4, 1, 2, 3)] - -TEST_CASE_3 = [{"keys": ["image", "label", "extra"], "channel_dim": 2}, (3, 1, 2, 4)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": -1}, (4, 1, 2, 3)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 3}, (4, 1, 2, 3)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 2}, (3, 1, 2, 4)]) class TestAsChannelFirstd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_param, expected_shape): + @parameterized.expand(TESTS) + def test_shape(self, in_type, input_param, expected_shape): test_data = { - "image": np.random.randint(0, 2, size=[1, 2, 3, 4]), - "label": np.random.randint(0, 2, size=[1, 2, 3, 4]), - "extra": np.random.randint(0, 2, size=[1, 2, 3, 4]), + "image": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), + "label": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), + "extra": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), } result = AsChannelFirstd(**input_param)(test_data) self.assertTupleEqual(result["image"].shape, expected_shape) diff --git a/tests/test_as_channel_last.py b/tests/test_as_channel_last.py index 6ec6c8d6e6..55a7a08676 100644 --- a/tests/test_as_channel_last.py +++ b/tests/test_as_channel_last.py @@ -15,18 +15,19 @@ from parameterized import parameterized from monai.transforms import AsChannelLast +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"channel_dim": 0}, (2, 3, 4, 1)] - -TEST_CASE_2 = [{"channel_dim": 1}, (1, 3, 4, 2)] - -TEST_CASE_3 = [{"channel_dim": 3}, (1, 2, 3, 4)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"channel_dim": 0}, (2, 3, 4, 1)]) + TESTS.append([p, {"channel_dim": 1}, (1, 3, 4, 2)]) + TESTS.append([p, {"channel_dim": 3}, (1, 2, 3, 4)]) class TestAsChannelLast(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_param, expected_shape): - test_data = np.random.randint(0, 2, size=[1, 2, 3, 4]) + @parameterized.expand(TESTS) + def test_shape(self, in_type, input_param, expected_shape): + test_data = in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])) result = AsChannelLast(**input_param)(test_data) self.assertTupleEqual(result.shape, expected_shape) diff --git a/tests/test_as_channel_lastd.py b/tests/test_as_channel_lastd.py index 2ef4dd4da1..350f639f3f 100644 --- a/tests/test_as_channel_lastd.py +++ b/tests/test_as_channel_lastd.py @@ -15,21 +15,22 @@ from parameterized import parameterized from monai.transforms import AsChannelLastd +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"keys": ["image", "label", "extra"], "channel_dim": 0}, (2, 3, 4, 1)] - -TEST_CASE_2 = [{"keys": ["image", "label", "extra"], "channel_dim": 1}, (1, 3, 4, 2)] - -TEST_CASE_3 = [{"keys": ["image", "label", "extra"], "channel_dim": 3}, (1, 2, 3, 4)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 0}, (2, 3, 4, 1)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 1}, (1, 3, 4, 2)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 3}, (1, 2, 3, 4)]) class TestAsChannelLastd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_param, expected_shape): + @parameterized.expand(TESTS) + def test_shape(self, in_type, input_param, expected_shape): test_data = { - "image": np.random.randint(0, 2, size=[1, 2, 3, 4]), - "label": np.random.randint(0, 2, size=[1, 2, 3, 4]), - "extra": np.random.randint(0, 2, size=[1, 2, 3, 4]), + "image": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), + "label": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), + "extra": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), } result = AsChannelLastd(**input_param)(test_data) self.assertTupleEqual(result["image"].shape, expected_shape) diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index ea806be139..3bb9702f86 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -11,45 +11,61 @@ import unittest +import numpy as np import torch from parameterized import parameterized from monai.transforms import AsDiscrete +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"argmax": True, "to_onehot": False, "n_classes": None, "threshold_values": False, "logit_thresh": 0.5}, - torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), - torch.tensor([[[1.0, 1.0]]]), - (1, 1, 2), -] - -TEST_CASE_2 = [ - {"argmax": True, "to_onehot": True, "n_classes": 2, "threshold_values": False, "logit_thresh": 0.5}, - torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), - torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]]), - (2, 1, 2), -] - -TEST_CASE_3 = [ - {"argmax": False, "to_onehot": False, "n_classes": None, "threshold_values": True, "logit_thresh": 0.6}, - torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), - torch.tensor([[[0.0, 1.0], [1.0, 1.0]]]), - (1, 2, 2), -] - -TEST_CASE_4 = [ - {"argmax": False, "to_onehot": True, "n_classes": 3}, - torch.tensor(1), - torch.tensor([0.0, 1.0, 0.0]), - (3,), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"argmax": True, "to_onehot": False, "n_classes": None, "threshold_values": False, "logit_thresh": 0.5}, + p([[[0.0, 1.0]], [[2.0, 3.0]]]), + ([[[1.0, 1.0]]]), + (1, 1, 2), + ] + ) + + TESTS.append( + [ + {"argmax": True, "to_onehot": True, "n_classes": 2, "threshold_values": False, "logit_thresh": 0.5}, + p([[[0.0, 1.0]], [[2.0, 3.0]]]), + ([[[0.0, 0.0]], [[1.0, 1.0]]]), + (2, 1, 2), + ] + ) + + TESTS.append( + [ + {"argmax": False, "to_onehot": False, "n_classes": None, "threshold_values": True, "logit_thresh": 0.6}, + p([[[0.0, 1.0], [2.0, 3.0]]]), + ([[[0.0, 1.0], [1.0, 1.0]]]), + (1, 2, 2), + ] + ) + + TESTS.append( + [ + {"argmax": False, "to_onehot": True, "n_classes": 3}, + p(1), + ([0.0, 1.0, 0.0]), + (3,), + ] + ) class TestAsDiscrete(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand(TESTS) def test_value_shape(self, input_param, img, out, expected_shape): result = AsDiscrete(**input_param)(img) - torch.testing.assert_allclose(result, out) + self.assertEqual(type(img), type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(img.device, result.device) + result = result.cpu() + np.testing.assert_allclose(result, out) self.assertTupleEqual(result.shape, expected_shape) diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py index d6a6f3c2a4..d00e6dad24 100644 --- a/tests/test_as_discreted.py +++ b/tests/test_as_discreted.py @@ -11,63 +11,75 @@ import unittest +import numpy as np import torch from parameterized import parameterized from monai.transforms import AsDiscreted +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - { - "keys": ["pred", "label"], - "argmax": [True, False], - "to_onehot": True, - "n_classes": 2, - "threshold_values": False, - "logit_thresh": 0.5, - }, - {"pred": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), "label": torch.tensor([[[0, 1]]])}, - {"pred": torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]]), "label": torch.tensor([[[1.0, 0.0]], [[0.0, 1.0]]])}, - (2, 1, 2), -] - -TEST_CASE_2 = [ - { - "keys": ["pred", "label"], - "argmax": False, - "to_onehot": False, - "n_classes": None, - "threshold_values": [True, False], - "logit_thresh": 0.6, - }, - {"pred": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), "label": torch.tensor([[[0, 1], [1, 1]]])}, - {"pred": torch.tensor([[[0.0, 1.0], [1.0, 1.0]]]), "label": torch.tensor([[[0.0, 1.0], [1.0, 1.0]]])}, - (1, 2, 2), -] - -TEST_CASE_3 = [ - { - "keys": ["pred"], - "argmax": True, - "to_onehot": True, - "n_classes": 2, - "threshold_values": False, - "logit_thresh": 0.5, - }, - {"pred": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]])}, - {"pred": torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]])}, - (2, 1, 2), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + { + "keys": ["pred", "label"], + "argmax": [True, False], + "to_onehot": True, + "n_classes": 2, + "threshold_values": False, + "logit_thresh": 0.5, + }, + {"pred": p(torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]])), "label": p(torch.tensor([[[0, 1]]]))}, + {"pred": torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]]), "label": torch.tensor([[[1.0, 0.0]], [[0.0, 1.0]]])}, + (2, 1, 2), + ] + ) + TESTS.append( + [ + { + "keys": ["pred", "label"], + "argmax": False, + "to_onehot": False, + "n_classes": None, + "threshold_values": [True, False], + "logit_thresh": 0.6, + }, + {"pred": p(torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])), "label": p(torch.tensor([[[0, 1], [1, 1]]]))}, + {"pred": torch.tensor([[[0.0, 1.0], [1.0, 1.0]]]), "label": torch.tensor([[[0.0, 1.0], [1.0, 1.0]]])}, + (1, 2, 2), + ] + ) + TESTS.append( + [ + { + "keys": ["pred"], + "argmax": True, + "to_onehot": True, + "n_classes": 2, + "threshold_values": False, + "logit_thresh": 0.5, + }, + {"pred": p(torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]))}, + {"pred": torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]])}, + (2, 1, 2), + ] + ) class TestAsDiscreted(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value_shape(self, input_param, test_input, output, expected_shape): result = AsDiscreted(**input_param)(test_input) - torch.testing.assert_allclose(result["pred"], output["pred"]) - self.assertTupleEqual(result["pred"].shape, expected_shape) - if "label" in result: - torch.testing.assert_allclose(result["label"], output["label"]) - self.assertTupleEqual(result["label"].shape, expected_shape) + for k in ("label", "pred"): + if k not in result: + continue + i, r, o = test_input[k], result[k], output[k] + self.assertEqual(type(i), type(r)) + if isinstance(r, torch.Tensor): + self.assertEqual(i.device, r.device) + r = r.cpu() + np.testing.assert_allclose(r, o) if __name__ == "__main__": diff --git a/tests/test_border_pad.py b/tests/test_border_pad.py index b011601694..cacb60112f 100644 --- a/tests/test_border_pad.py +++ b/tests/test_border_pad.py @@ -16,6 +16,7 @@ from monai.transforms import BorderPad from monai.utils import NumpyPadMode +from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [ {"spatial_border": 2, "mode": "constant"}, @@ -45,11 +46,12 @@ class TestBorderPad(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_pad_shape(self, input_param, input_data, expected_val): - padder = BorderPad(**input_param) - result = padder(input_data) - self.assertAlmostEqual(result.shape, expected_val.shape) - result = padder(input_data, mode=input_param["mode"]) - self.assertAlmostEqual(result.shape, expected_val.shape) + for p in TEST_NDARRAYS: + padder = BorderPad(**input_param) + result = padder(p(input_data)) + self.assertAlmostEqual(result.shape, expected_val.shape) + result = padder(p(input_data), mode=input_param["mode"]) + self.assertAlmostEqual(result.shape, expected_val.shape) def test_pad_kwargs(self): padder = BorderPad(spatial_border=2, mode="constant", constant_values=((0, 0), (1, 1), (2, 2))) diff --git a/tests/test_bounding_rect.py b/tests/test_bounding_rect.py index 38585cba18..bef97c47b5 100644 --- a/tests/test_bounding_rect.py +++ b/tests/test_bounding_rect.py @@ -12,38 +12,40 @@ import unittest import numpy as np +import torch from parameterized import parameterized import monai from monai.transforms import BoundingRect +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [(2, 3), [[0, 0], [1, 2]]] +SEED = 1 -TEST_CASE_2 = [(1, 8, 10), [[0, 7, 1, 9]]] - -TEST_CASE_3 = [(2, 16, 20, 18), [[0, 16, 0, 20, 0, 18], [0, 16, 0, 20, 0, 18]]] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {}, (2, 3), [[0, 0], [1, 2]]]) + TESTS.append([p, {}, (1, 8, 10), [[0, 7, 1, 9]]]) + TESTS.append([p, {}, (2, 16, 20, 18), [[0, 16, 0, 20, 0, 18], [0, 16, 0, 20, 0, 18]]]) + TESTS.append([p, {"select_fn": lambda x: x < 1}, (2, 3), [[0, 3], [0, 3]]]) class TestBoundingRect(unittest.TestCase): def setUp(self): - monai.utils.set_determinism(1) + monai.utils.set_determinism(SEED) def tearDown(self): monai.utils.set_determinism(None) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_shape, expected): - test_data = np.random.randint(0, 8, size=input_shape) + @parameterized.expand(TESTS) + def test_result(self, im_type, input_args, input_shape, expected): + np.random.seed(SEED) + test_data = im_type(np.random.randint(0, 8, size=input_shape)) test_data = test_data == 7 - result = BoundingRect()(test_data) + result = BoundingRect(**input_args)(test_data) + if isinstance(result, torch.Tensor): + result = result.cpu() np.testing.assert_allclose(result, expected) - def test_select_fn(self): - test_data = np.random.randint(0, 8, size=(2, 3)) - test_data = test_data == 7 - bbox = BoundingRect(select_fn=lambda x: x < 1)(test_data) - np.testing.assert_allclose(bbox, [[0, 3], [0, 3]]) - if __name__ == "__main__": unittest.main() diff --git a/tests/test_bounding_rectd.py b/tests/test_bounding_rectd.py index 6e725ff583..0bd07ee550 100644 --- a/tests/test_bounding_rectd.py +++ b/tests/test_bounding_rectd.py @@ -16,24 +16,28 @@ import monai from monai.transforms import BoundingRectD +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [(2, 3), [[0, 0], [1, 2]]] +SEED = 1 -TEST_CASE_2 = [(1, 8, 10), [[0, 7, 1, 9]]] - -TEST_CASE_3 = [(2, 16, 20, 18), [[0, 16, 0, 20, 0, 18], [0, 16, 0, 20, 0, 18]]] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, (2, 3), [[0, 0], [1, 2]]]) + TESTS.append([p, (1, 8, 10), [[0, 7, 1, 9]]]) + TESTS.append([p, (2, 16, 20, 18), [[0, 16, 0, 20, 0, 18], [0, 16, 0, 20, 0, 18]]]) class TestBoundingRectD(unittest.TestCase): def setUp(self): - monai.utils.set_determinism(1) + monai.utils.set_determinism(SEED) def tearDown(self): monai.utils.set_determinism(None) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_shape, expected): - test_data = np.random.randint(0, 8, size=input_shape) + @parameterized.expand(TESTS) + def test_value(self, im_type, input_shape, expected): + np.random.seed(SEED) + test_data = im_type(np.random.randint(0, 8, size=input_shape)) test_data = test_data == 7 result = BoundingRectD("image")({"image": test_data}) np.testing.assert_allclose(result["image_bbox"], expected) diff --git a/tests/test_center_scale_crop.py b/tests/test_center_scale_crop.py index e28849ce90..cada8e273c 100644 --- a/tests/test_center_scale_crop.py +++ b/tests/test_center_scale_crop.py @@ -16,33 +16,44 @@ from parameterized import parameterized from monai.transforms import CenterScaleCrop +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [{"roi_scale": [0.6, 0.3, -1]}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 1, 3)] +TESTS, TEST_VALUES = [], [] +for p in TEST_NDARRAYS: + TESTS.append([{"roi_scale": [0.6, 0.3, -1]}, p(np.random.randint(0, 2, size=[3, 3, 3, 3])), (3, 2, 1, 3)]) -TEST_CASE_1 = [{"roi_scale": 0.6}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 2, 2)] + TESTS.append([{"roi_scale": 0.6}, p(np.random.randint(0, 2, size=[3, 3, 3, 3])), (3, 2, 2, 2)]) -TEST_CASE_2 = [ - {"roi_scale": [0.4, 0.4]}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2], [2, 3]]]), -] + TESTS.append( + [ + {"roi_scale": 0.5}, + p(torch.randint(0, 2, size=[3, 3, 3, 3])), + (3, 2, 2, 2), + ] + ) -TEST_CASE_3 = [ - {"roi_scale": 0.5}, - torch.randint(0, 2, size=[3, 3, 3, 3], device="cuda" if torch.cuda.is_available() else "cpu"), - (3, 2, 2, 2), -] + TEST_VALUES.append( + [ + {"roi_scale": [0.4, 0.4]}, + p(np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])), + np.array([[[1, 2], [2, 3]]]), + ] + ) class TestCenterScaleCrop(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = CenterScaleCrop(**input_param)(input_data) np.testing.assert_allclose(result.shape, expected_shape) - @parameterized.expand([TEST_CASE_2]) + @parameterized.expand(TEST_VALUES) def test_value(self, input_param, input_data, expected_value): result = CenterScaleCrop(**input_param)(input_data) + self.assertEqual(type(result), type(input_data)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, input_data.device) + result = result.cpu() np.testing.assert_allclose(result, expected_value) diff --git a/tests/test_center_scale_cropd.py b/tests/test_center_scale_cropd.py index 313e8e7f7e..c5b680b26c 100644 --- a/tests/test_center_scale_cropd.py +++ b/tests/test_center_scale_cropd.py @@ -16,34 +16,47 @@ from parameterized import parameterized from monai.transforms import CenterScaleCropd +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [{"keys": "img", "roi_scale": [0.6, 0.3, -1]}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 1, 3)] +TESTS, TEST_VALUES = [], [] +for p in TEST_NDARRAYS: + TESTS.append( + [{"keys": "img", "roi_scale": [0.6, 0.3, -1]}, p(np.random.randint(0, 2, size=[3, 3, 3, 3])), (3, 2, 1, 3)] + ) -TEST_CASE_1 = [{"keys": "img", "roi_scale": 0.6}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 2, 2)] + TESTS.append([{"keys": "img", "roi_scale": 0.6}, p(np.random.randint(0, 2, size=[3, 3, 3, 3])), (3, 2, 2, 2)]) -TEST_CASE_2 = [ - {"keys": "img", "roi_scale": [0.4, 0.4]}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2], [2, 3]]]), -] + TESTS.append( + [ + {"keys": "img", "roi_scale": 0.5}, + p(torch.randint(0, 2, size=[3, 3, 3, 3])), + (3, 2, 2, 2), + ] + ) -TEST_CASE_3 = [ - {"keys": "img", "roi_scale": 0.5}, - torch.randint(0, 2, size=[3, 3, 3, 3], device="cuda" if torch.cuda.is_available() else "cpu"), - (3, 2, 2, 2), -] + TEST_VALUES.append( + [ + {"keys": "img", "roi_scale": [0.4, 0.4]}, + p(np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])), + np.array([[[1, 2], [2, 3]]]), + ] + ) class TestCenterScaleCropd(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = CenterScaleCropd(**input_param)({"img": input_data}) np.testing.assert_allclose(result["img"].shape, expected_shape) - @parameterized.expand([TEST_CASE_2]) + @parameterized.expand(TEST_VALUES) def test_value(self, input_param, input_data, expected_value): - result = CenterScaleCropd(**input_param)({"img": input_data}) - np.testing.assert_allclose(result["img"], expected_value) + result = CenterScaleCropd(**input_param)({"img": input_data})["img"] + self.assertEqual(type(result), type(input_data)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, input_data.device) + result = result.cpu() + np.testing.assert_allclose(result, expected_value) if __name__ == "__main__": diff --git a/tests/test_center_spatial_crop.py b/tests/test_center_spatial_crop.py index 3e828176a5..51abe7b07f 100644 --- a/tests/test_center_spatial_crop.py +++ b/tests/test_center_spatial_crop.py @@ -16,34 +16,33 @@ from parameterized import parameterized from monai.transforms import CenterSpatialCrop +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [{"roi_size": [2, 2, -1]}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 2, 3)] +TEST_SHAPES, TEST_VALUES = [], [] +for p in TEST_NDARRAYS: + TEST_SHAPES.append([{"roi_size": [2, 2, -1]}, p(np.random.randint(0, 2, size=[3, 3, 3, 3])), (3, 2, 2, 3)]) -TEST_CASE_1 = [{"roi_size": [2, 2, 2]}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 2, 2)] + TEST_SHAPES.append([{"roi_size": [2, 2, 2]}, p(np.random.randint(0, 2, size=[3, 3, 3, 3])), (3, 2, 2, 2)]) -TEST_CASE_2 = [ - {"roi_size": [2, 2]}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2], [2, 3]]]), -] - -TEST_CASE_3 = [ - {"roi_size": [2, 2, 2]}, - torch.randint(0, 2, size=[3, 3, 3, 3], device="cuda" if torch.cuda.is_available() else "cpu"), - (3, 2, 2, 2), -] + TEST_VALUES.append( + [ + {"roi_size": [2, 2]}, + p(np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])), + p(np.array([[[1, 2], [2, 3]]])), + ] + ) class TestCenterSpatialCrop(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) + @parameterized.expand(TEST_SHAPES) def test_shape(self, input_param, input_data, expected_shape): result = CenterSpatialCrop(**input_param)(input_data) np.testing.assert_allclose(result.shape, expected_shape) - @parameterized.expand([TEST_CASE_2]) + @parameterized.expand(TEST_VALUES) def test_value(self, input_param, input_data, expected_value): result = CenterSpatialCrop(**input_param)(input_data) - np.testing.assert_allclose(result, expected_value) + torch.testing.assert_allclose(result, expected_value, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_classes_to_indices.py b/tests/test_classes_to_indices.py index 0ba3dd094a..a38f21a93d 100644 --- a/tests/test_classes_to_indices.py +++ b/tests/test_classes_to_indices.py @@ -12,66 +12,89 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import ClassesToIndices +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - # test Argmax data - {"num_classes": 3, "image_threshold": 0.0}, - np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), - None, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + # test Argmax data + {"num_classes": 3, "image_threshold": 0.0}, + p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + None, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], + ] + ) -TEST_CASE_2 = [ - {"num_classes": 3, "image_threshold": 60}, - np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), - np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + TESTS.append( + [ + {"num_classes": 3, "image_threshold": 60}, + p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + p(np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]])), + [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], + ] + ) -TEST_CASE_3 = [ - # test One-Hot data - {"image_threshold": 0.0}, - np.array( + TESTS.append( [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + # test One-Hot data + {"image_threshold": 0.0}, + p( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ) + ), + None, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], ] - ), - None, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] + ) -TEST_CASE_4 = [ - {"num_classes": None, "image_threshold": 60}, - np.array( + TESTS.append( [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + {"num_classes": None, "image_threshold": 60}, + p( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ) + ), + p(np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]])), + [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], ] - ), - np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + ) -TEST_CASE_5 = [ - # test output_shape - {"num_classes": 3, "image_threshold": 0.0, "output_shape": [3, 3]}, - np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), - None, - [np.array([[0, 0], [1, 1], [2, 2]]), np.array([[0, 1], [1, 2], [2, 0]]), np.array([[0, 2], [1, 0], [2, 1]])], -] + TESTS.append( + [ + # test output_shape + {"num_classes": 3, "image_threshold": 0.0, "output_shape": [3, 3]}, + p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + None, + [ + np.array([[0, 0], [1, 1], [2, 2]]), + np.array([[0, 1], [1, 2], [2, 0]]), + np.array([[0, 2], [1, 0], [2, 1]]), + ], + ] + ) class TestClassesToIndices(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_value(self, input_args, label, image, expected_indices): indices = ClassesToIndices(**input_args)(label, image) for i, e in zip(indices, expected_indices): + i = i.cpu() if isinstance(i, torch.Tensor) else i np.testing.assert_allclose(i, e) diff --git a/tests/test_compose.py b/tests/test_compose.py index 28783cad23..a3c5f655e6 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -12,11 +12,38 @@ import sys import unittest +import numpy as np + from monai.data import DataLoader, Dataset from monai.transforms import AddChannel, Compose -from monai.transforms.transform import Randomizable +from monai.transforms.transform import NumpyTransform, Randomizable, TorchTransform, Transform from monai.utils import set_determinism +from parameterized import parameterized +import math + + +class Tr(Transform): + def __call__(self, x): + return x +class N(Tr, NumpyTransform): + pass +class T(Tr, TorchTransform): + pass +class NT(Tr, NumpyTransform, TorchTransform): + pass + + +TEST_NUMBER_CONVERSIONS = [ + ((N(), N()), 0), + ((T(), T()), 0), + ((NT(), NT()), 0), + ((N(), T()), 1), + ((N(), T(), N()), 2), + ((N(), NT(), T(), T(), NT(), NT(), N()), 2), + ((N(), NT(), T(), Compose([T(), NT(), NT(), N()])), 2), + ((Tr(), N(), T()), math.nan), +] class _RandXform(Randomizable): def randomize(self): @@ -217,6 +244,15 @@ def test_flatten_and_len(self): def test_backwards_compatible_imports(self): from monai.transforms.compose import MapTransform, RandomizableTransform, Transform # noqa: F401 + @parameterized.expand(TEST_NUMBER_CONVERSIONS) + def test_get_number_of_conversions(self, transforms, expected): + tr = Compose(transforms) + n = tr.get_number_np_torch_conversions() + if math.isnan(expected): + self.assertTrue(math.isnan(n)) + else: + self.assertEqual(n, expected) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_compute_roc_auc.py b/tests/test_compute_roc_auc.py index 79d62b6436..c4ffc09b47 100644 --- a/tests/test_compute_roc_auc.py +++ b/tests/test_compute_roc_auc.py @@ -18,73 +18,85 @@ from monai.data import decollate_batch from monai.metrics import ROCAUCMetric, compute_roc_auc from monai.transforms import Activations, AsDiscrete, Compose, ToTensor +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), - torch.tensor([[0], [1], [0], [1]]), - True, - True, - "macro", - 0.75, -] - -TEST_CASE_2 = [ - torch.tensor([[0.5], [0.5], [0.2], [8.3]]), - torch.tensor([[0], [1], [0], [1]]), - False, - False, - "macro", - 0.875, -] - -TEST_CASE_3 = [ - torch.tensor([[0.5], [0.5], [0.2], [8.3]]), - torch.tensor([0, 1, 0, 1]), - False, - False, - "macro", - 0.875, -] - -TEST_CASE_4 = [ - torch.tensor([0.5, 0.5, 0.2, 8.3]), - torch.tensor([0, 1, 0, 1]), - False, - False, - "macro", - 0.875, -] - -TEST_CASE_5 = [ - torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), - torch.tensor([[0], [1], [0], [1]]), - True, - True, - "none", - [0.75, 0.75], -] - -TEST_CASE_6 = [ - torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), - torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), - True, - False, - "weighted", - 0.56667, -] - -TEST_CASE_7 = [ - torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), - torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), - True, - False, - "micro", - 0.62, -] +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TESTS.append( + [ + p(torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]])), + q(torch.tensor([[0], [1], [0], [1]])), + True, + True, + "macro", + 0.75, + ] + ) + TESTS.append( + [ + p(torch.tensor([[0.5], [0.5], [0.2], [8.3]])), + q(torch.tensor([[0], [1], [0], [1]])), + False, + False, + "macro", + 0.875, + ] + ) + TESTS.append( + [ + p(torch.tensor([[0.5], [0.5], [0.2], [8.3]])), + q(torch.tensor([0, 1, 0, 1])), + False, + False, + "macro", + 0.875, + ] + ) + TESTS.append( + [ + p(torch.tensor([0.5, 0.5, 0.2, 8.3])), + q(torch.tensor([0, 1, 0, 1])), + False, + False, + "macro", + 0.875, + ] + ) + TESTS.append( + [ + p(torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]])), + q(torch.tensor([[0], [1], [0], [1]])), + True, + True, + "none", + [0.75, 0.75], + ] + ) + TESTS.append( + [ + p(torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]])), + q(torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]])), + True, + False, + "weighted", + 0.56667, + ] + ) + TESTS.append( + [ + p(torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]])), + q(torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]])), + True, + False, + "micro", + 0.62, + ] + ) class TestComputeROCAUC(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + @parameterized.expand(TESTS) def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value): y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot, n_classes=2)]) @@ -93,7 +105,7 @@ def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value): result = compute_roc_auc(y_pred=y_pred, y=y, average=average) np.testing.assert_allclose(expected_value, result, rtol=1e-5) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + @parameterized.expand(TESTS) def test_class_value(self, y_pred, y, softmax, to_onehot, average, expected_value): y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot, n_classes=2)]) diff --git a/tests/test_convert_data_type.py b/tests/test_convert_data_type.py new file mode 100644 index 0000000000..bf5179394a --- /dev/null +++ b/tests/test_convert_data_type.py @@ -0,0 +1,44 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import List, Tuple + +import numpy as np +import torch +from parameterized import parameterized + +from monai.utils.misc import convert_data_type, dtype_convert + +TESTS: List[Tuple] = [] +TESTS.append((np.array, torch.Tensor, np.float32, torch.float32)) +TESTS.append((torch.Tensor, np.ndarray, np.float32, torch.float32)) +TESTS.append((np.array, torch.Tensor, torch.float32, np.float32)) +TESTS.append((torch.Tensor, np.ndarray, torch.float32, np.float32)) + + +class TestConvertDataType(unittest.TestCase): + @staticmethod + def get_im(im_type, dtype): + dtype = dtype_convert(dtype, im_type) + lib = torch if im_type is torch.Tensor else np + return lib.zeros((1, 2, 3), dtype=dtype) + + @parameterized.expand(TESTS) + def test_convert_data_type(self, in_type, out_type, in_dtype, out_dtype): + orig_im = self.get_im(in_type, in_dtype) + converted_im, orig_type, _ = convert_data_type(orig_im, out_type) + self.assertEqual(type(orig_im), orig_type) + self.assertEqual(type(converted_im), out_type) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_convert_to_multi_channel.py b/tests/test_convert_to_multi_channel.py index 2f7a38e6e4..a350f8fdd8 100644 --- a/tests/test_convert_to_multi_channel.py +++ b/tests/test_convert_to_multi_channel.py @@ -12,31 +12,49 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import ConvertToMultiChannelBasedOnBratsClasses +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - np.array([[0, 1, 2], [1, 2, 4], [0, 1, 4]]), - np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + p(np.array([[0, 1, 2], [1, 2, 4], [0, 1, 4]])), + np.array( + [ + [[0, 1, 0], [1, 0, 1], [0, 1, 1]], + [[0, 1, 1], [1, 1, 1], [0, 1, 1]], + [[0, 0, 0], [0, 0, 1], [0, 0, 1]], + ] + ), + ] + ) -TEST_CASE_2 = [ - np.array([[[[0, 1], [1, 2]], [[2, 4], [4, 4]]]]), - np.array( + TESTS.append( [ - [[[0, 1], [1, 0]], [[0, 1], [1, 1]]], - [[[0, 1], [1, 1]], [[1, 1], [1, 1]]], - [[[0, 0], [0, 0]], [[0, 1], [1, 1]]], + p(np.array([[[[0, 1], [1, 2]], [[2, 4], [4, 4]]]])), + np.array( + [ + [[[0, 1], [1, 0]], [[0, 1], [1, 1]]], + [[[0, 1], [1, 1]], [[1, 1], [1, 1]]], + [[[0, 0], [0, 0]], [[0, 1], [1, 1]]], + ] + ), ] - ), -] + ) class TestConvertToMultiChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_type_shape(self, data, expected_result): result = ConvertToMultiChannelBasedOnBratsClasses()(data) + self.assertEqual(type(result), type(data)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, data.device) + result = result.cpu().numpy() np.testing.assert_equal(result, expected_result) self.assertEqual(f"{result.dtype}", "bool") diff --git a/tests/test_convert_to_multi_channeld.py b/tests/test_convert_to_multi_channeld.py index 945e07e1cd..0fb48dde34 100644 --- a/tests/test_convert_to_multi_channeld.py +++ b/tests/test_convert_to_multi_channeld.py @@ -12,22 +12,39 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import ConvertToMultiChannelBasedOnBratsClassesd - -TEST_CASE = [ - {"keys": "label"}, - {"label": np.array([[0, 1, 2], [1, 2, 4], [0, 1, 4]])}, - np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]), -] +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": "label"}, + {"label": p(np.array([[0, 1, 2], [1, 2, 4], [0, 1, 4]]))}, + np.array( + [ + [[0, 1, 0], [1, 0, 1], [0, 1, 1]], + [[0, 1, 1], [1, 1, 1], [0, 1, 1]], + [[0, 0, 0], [0, 0, 1], [0, 0, 1]], + ] + ), + ] + ) class TestConvertToMultiChanneld(unittest.TestCase): - @parameterized.expand([TEST_CASE]) + @parameterized.expand(TESTS) def test_type_shape(self, keys, data, expected_result): - result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data) - np.testing.assert_equal(result["label"], expected_result) + result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data)["label"] + self.assertEqual(type(result), type(data["label"])) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, data["label"].device) + result = result.cpu().numpy() + np.testing.assert_equal(result, expected_result) + self.assertEqual(f"{result.dtype}", "bool") if __name__ == "__main__": diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py index 8eae8f484e..0bae1f90f3 100644 --- a/tests/test_crop_foreground.py +++ b/tests/test_crop_foreground.py @@ -12,60 +12,79 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import CropForeground +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), -] - -TEST_CASE_2 = [ - {"select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[3]]]), -] - -TEST_CASE_3 = [ - {"select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), -] - -TEST_CASE_4 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), -] - -TEST_CASE_5 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1]}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), -] - -TEST_CASE_6 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 4}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2, 1, 0], [2, 3, 2, 0], [1, 2, 1, 0], [0, 0, 0, 0]]]), -] - -TEST_CASE_7 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 10}, - np.array([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - np.zeros((1, 0, 0)), -] +TEST_COORDS, TESTS = [], [] + +for p in TEST_NDARRAYS: + TEST_COORDS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + p([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, + p([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]), + p([[[3]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + p([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1]}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 4}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + p([[[1, 2, 1, 0], [2, 3, 2, 0], [1, 2, 1, 0], [0, 0, 0, 0]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 10}, + p([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + p(np.zeros((1, 0, 0), dtype=np.int64)), + ] + ) class TestCropForeground(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + @parameterized.expand(TEST_COORDS + TESTS) def test_value(self, argments, image, expected_data): result = CropForeground(**argments)(image) - np.testing.assert_allclose(result, expected_data) + torch.testing.assert_allclose(result, expected_data, rtol=1e-7, atol=0) - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TEST_COORDS) def test_return_coords(self, argments, image, _): argments["return_coords"] = True _, start_coord, end_coord = CropForeground(**argments)(image) diff --git a/tests/test_crop_foregroundd.py b/tests/test_crop_foregroundd.py index 37abfb8c55..14367ebfd7 100644 --- a/tests/test_crop_foregroundd.py +++ b/tests/test_crop_foregroundd.py @@ -12,79 +12,126 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import CropForegroundd +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - { - "keys": ["img", "label"], - "source_key": "label", - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": 0, - }, - { - "img": np.array([[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]), - "label": np.array([[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]), - }, - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), -] +TEST_POSITION, TESTS = [], [] +for p in TEST_NDARRAYS: -TEST_CASE_2 = [ - {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[3]]]), -] - -TEST_CASE_3 = [ - {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), -] - -TEST_CASE_4 = [ - {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), -] - -TEST_CASE_5 = [ - {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1]}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), -] - -TEST_CASE_6 = [ - { - "keys": ["img"], - "source_key": "img", - "select_fn": lambda x: x > 0, - "channel_indices": 0, - "margin": 0, - "k_divisible": [4, 6], - "mode": "edge", - }, - {"img": np.array([[[0, 2, 1, 2, 0], [1, 1, 2, 1, 1], [2, 2, 3, 2, 2], [1, 1, 2, 1, 1], [0, 0, 0, 0, 0]]])}, - np.array([[[0, 2, 1, 2, 0, 0], [1, 1, 2, 1, 1, 1], [2, 2, 3, 2, 2, 2], [1, 1, 2, 1, 1, 1]]]), -] + TEST_POSITION.append( + [ + { + "keys": ["img", "label"], + "source_key": "label", + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": 0, + }, + { + "img": p( + np.array([[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]) + ), + "label": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]) + ), + }, + np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]) + ) + }, + np.array([[[3]]]), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ) + }, + np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]) + ) + }, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), + ] + ) + TESTS.append( + [ + { + "keys": ["img"], + "source_key": "img", + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": [2, 1], + }, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]) + ) + }, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + ] + ) + TESTS.append( + [ + { + "keys": ["img"], + "source_key": "img", + "select_fn": lambda x: x > 0, + "channel_indices": 0, + "margin": 0, + "k_divisible": [4, 6], + "mode": "edge", + }, + { + "img": p( + np.array([[[0, 2, 1, 2, 0], [1, 1, 2, 1, 1], [2, 2, 3, 2, 2], [1, 1, 2, 1, 1], [0, 0, 0, 0, 0]]]) + ) + }, + np.array([[[0, 2, 1, 2, 0, 0], [1, 1, 2, 1, 1, 1], [2, 2, 3, 2, 2, 2], [1, 1, 2, 1, 1, 1]]]), + ] + ) class TestCropForegroundd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) - def test_value(self, argments, image, expected_data): - result = CropForegroundd(**argments)(image) - np.testing.assert_allclose(result["img"], expected_data) + @parameterized.expand(TEST_POSITION + TESTS) + def test_value(self, argments, input_data, expected_data): + result = CropForegroundd(**argments)(input_data) + r, i = result["img"], input_data["img"] + self.assertEqual(type(r), type(i)) + if isinstance(r, torch.Tensor): + self.assertEqual(r.device, i.device) + r = r.cpu() + np.testing.assert_allclose(r, expected_data) - @parameterized.expand([TEST_CASE_1]) - def test_foreground_position(self, argments, image, _): - result = CropForegroundd(**argments)(image) + @parameterized.expand(TEST_POSITION) + def test_foreground_position(self, argments, input_data, _): + result = CropForegroundd(**argments)(input_data) np.testing.assert_allclose(result["foreground_start_coord"], np.array([1, 1])) np.testing.assert_allclose(result["foreground_end_coord"], np.array([4, 4])) argments["start_coord_key"] = "test_start_coord" argments["end_coord_key"] = "test_end_coord" - result = CropForegroundd(**argments)(image) + result = CropForegroundd(**argments)(input_data) np.testing.assert_allclose(result["test_start_coord"], np.array([1, 1])) np.testing.assert_allclose(result["test_end_coord"], np.array([4, 4])) diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 3b159fb5b8..2c9f85fec3 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -48,7 +48,7 @@ def test_values(self): ] ) dataset = CacheDataset(data=datalist, transform=transform, cache_rate=0.5, cache_num=1) - n_workers = 0 if sys.platform == "win32" else 2 + n_workers = 2 if sys.platform == "linux" else 0 dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=n_workers) for d in dataloader: self.assertEqual(d["image"][0], "spleen_19.nii.gz") diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 521d263663..296eaa8f75 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -120,7 +120,8 @@ def check_match(self, in1, in2): def check_decollate(self, dataset): batch_size = 2 - num_workers = 2 + # num workers = 0 for mac or gpu transforms + num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2 loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) diff --git a/tests/test_deepgrow_dataset.py b/tests/test_deepgrow_dataset.py index 147d8e7099..d35c4c7e91 100644 --- a/tests/test_deepgrow_dataset.py +++ b/tests/test_deepgrow_dataset.py @@ -21,30 +21,25 @@ from monai.apps.deepgrow.dataset import create_dataset from monai.utils import set_determinism -TEST_CASE_1 = [{"dimension": 2, "pixdim": (1, 1)}, {"length": 3}, 9, 1] - -TEST_CASE_2 = [{"dimension": 2, "pixdim": (1, 1), "limit": 1}, {"length": 3}, 3, 1] - -TEST_CASE_3 = [{"dimension": 2, "pixdim": (1, 1)}, {"length": 1}, 3, 1] - -TEST_CASE_4 = [{"dimension": 3, "pixdim": (1, 1, 1)}, {"length": 1}, 1, 1] - -TEST_CASE_5 = [{"dimension": 3, "pixdim": (1, 1, 1)}, {"length": 1, "image_channel": 4}, 1, 1] - -TEST_CASE_6 = [{"dimension": 2, "pixdim": (1, 1)}, {"length": 1, "image_channel": 4}, 3, 1] - -TEST_CASE_7 = [ - {"dimension": 2, "pixdim": (1, 1), "label_key": None}, - {"length": 1, "image_channel": 4, "with_label": False}, - 40, - None, -] - -TEST_CASE_8 = [ - {"dimension": 3, "pixdim": (1, 1, 1), "label_key": None}, - {"length": 1, "image_channel": 4, "with_label": False}, - 1, - None, +TESTS = [ + [{"dimension": 2, "pixdim": (1, 1)}, {"length": 3}, 9, 1], + [{"dimension": 2, "pixdim": (1, 1), "limit": 1}, {"length": 3}, 3, 1], + [{"dimension": 2, "pixdim": (1, 1)}, {"length": 1}, 3, 1], + [{"dimension": 3, "pixdim": (1, 1, 1)}, {"length": 1}, 1, 1], + [{"dimension": 3, "pixdim": (1, 1, 1)}, {"length": 1, "image_channel": 4}, 1, 1], + [{"dimension": 2, "pixdim": (1, 1)}, {"length": 1, "image_channel": 4}, 3, 1], + [ + {"dimension": 2, "pixdim": (1, 1), "label_key": None}, + {"length": 1, "image_channel": 4, "with_label": False}, + 40, + None, + ], + [ + {"dimension": 3, "pixdim": (1, 1, 1), "label_key": None}, + {"length": 1, "image_channel": 4, "with_label": False}, + 1, + None, + ], ] @@ -79,9 +74,7 @@ def _create_data(self, length=1, image_channel=1, with_label=True): return datalist - @parameterized.expand( - [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8] - ) + @parameterized.expand(TESTS) def test_create_dataset(self, args, data_args, expected_length, expected_region): datalist = self._create_data(**data_args) deepgrow_datalist = create_dataset(datalist=datalist, output_dir=self.tempdir, **args) diff --git a/tests/test_detect_envelope.py b/tests/test_detect_envelope.py index ded0290de2..b82d993c10 100644 --- a/tests/test_detect_envelope.py +++ b/tests/test_detect_envelope.py @@ -12,11 +12,18 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import DetectEnvelope from monai.utils import InvalidPyTorchVersionError, OptionalImportError -from tests.utils import SkipIfAtLeastPyTorchVersion, SkipIfBeforePyTorchVersion, SkipIfModule, SkipIfNoModule +from tests.utils import ( + TEST_NDARRAYS, + SkipIfAtLeastPyTorchVersion, + SkipIfBeforePyTorchVersion, + SkipIfModule, + SkipIfNoModule, +) n_samples = 500 hann_windowed_sine = np.sin(2 * np.pi * 10 * np.linspace(0, 1, n_samples)) * np.hanning(n_samples) @@ -125,8 +132,9 @@ class TestDetectEnvelope(unittest.TestCase): ] ) def test_value(self, arguments, image, expected_data, atol): - result = DetectEnvelope(**arguments)(image) - np.testing.assert_allclose(result, expected_data, atol=atol) + for p in TEST_NDARRAYS: + result = DetectEnvelope(**arguments)(p(image)) + torch.testing.assert_allclose(result, p(expected_data.astype(np.float32)), atol=atol, rtol=1e-7) @parameterized.expand( [ diff --git a/tests/test_divisible_pad.py b/tests/test_divisible_pad.py index e4415a2f22..923370caca 100644 --- a/tests/test_divisible_pad.py +++ b/tests/test_divisible_pad.py @@ -12,27 +12,36 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import DivisiblePad - -# pad first dim to be divisible by 7, the second unchanged. -TEST_CASE_1 = [ - {"k": (7, -1), "mode": "constant"}, - np.zeros((3, 8, 7)), - np.zeros((3, 14, 7)), -] - -# pad all dimensions to be divisible by 5 -TEST_CASE_2 = [ - {"k": 5, "mode": "constant", "method": "end"}, - np.zeros((3, 10, 5, 17)), - np.zeros((3, 10, 5, 20)), -] +from tests.utils import TEST_NDARRAYS + +TESTS = [] + +for p in TEST_NDARRAYS: + # pad first dim to be divisible by 7, the second unchanged. + TESTS.append( + [ + {"k": (7, -1), "mode": "constant"}, + p(np.zeros((3, 8, 7))), + p(np.zeros((3, 14, 7))), + ] + ) + + # pad all dimensions to be divisible by 5 + TESTS.append( + [ + {"k": 5, "mode": "constant", "method": "end"}, + p(np.zeros((3, 10, 5, 17))), + p(np.zeros((3, 10, 5, 20))), + ] + ) class TestDivisiblePad(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_pad_shape(self, input_param, input_data, expected_val): padder = DivisiblePad(**input_param) result = padder(input_data) @@ -42,9 +51,10 @@ def test_pad_shape(self, input_param, input_data, expected_val): def test_pad_kwargs(self): padder = DivisiblePad(k=5, mode="constant", constant_values=((0, 0), (1, 1), (2, 2))) - result = padder(np.zeros((3, 8, 4))) - np.testing.assert_allclose(result[:, :1, :4], np.ones((3, 1, 4))) - np.testing.assert_allclose(result[:, :, 4:5], np.ones((3, 10, 1)) + 1) + for p in TEST_NDARRAYS: + result = padder(p(np.zeros((3, 8, 4)))) + torch.testing.assert_allclose(result[:, :1, :4], p(np.ones((3, 1, 4))), rtol=1e-7, atol=0) + torch.testing.assert_allclose(result[:, :, 4:5], p(np.ones((3, 10, 1)) + 1), rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_dtype_convert.py b/tests/test_dtype_convert.py new file mode 100644 index 0000000000..90ae371a68 --- /dev/null +++ b/tests/test_dtype_convert.py @@ -0,0 +1,36 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.utils.misc import dtype_convert +from tests.utils import TEST_NDARRAYS + +DTYPES = [torch.float32, np.float32, np.dtype(np.float32)] + +TESTS = [] +for im_type in TEST_NDARRAYS: + for im_dtype in DTYPES: + TESTS.append((im_type, im_dtype)) + + +class TestDtypeConvert(unittest.TestCase): + @parameterized.expand(TESTS) + def test_dtype_convert(self, im_type, desired_dtype): + out = dtype_convert(desired_dtype, im_type) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_fg_bg_to_indices.py b/tests/test_fg_bg_to_indices.py index 98626c7028..e8451ba334 100644 --- a/tests/test_fg_bg_to_indices.py +++ b/tests/test_fg_bg_to_indices.py @@ -12,57 +12,73 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import FgBgToIndices +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"image_threshold": 0.0, "output_shape": None}, - np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - None, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 4, 8]), -] +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: -TEST_CASE_2 = [ - {"image_threshold": 0.0, "output_shape": None}, - np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]]), - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] - -TEST_CASE_3 = [ - {"image_threshold": 1.0, "output_shape": None}, - np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] - -TEST_CASE_4 = [ - {"image_threshold": 1.0, "output_shape": None}, - np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), - np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] - -TEST_CASE_5 = [ - {"image_threshold": 1.0, "output_shape": [3, 3]}, - np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), - np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - np.array([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]), - np.array([[0, 0], [2, 2]]), -] + TESTS.append( + [ + {"image_threshold": 0.0, "output_shape": None}, + p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])), + None, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 4, 8]), + ] + ) + TESTS.append( + [ + {"image_threshold": 0.0, "output_shape": None}, + p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])), + q(np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]])), + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) + TESTS.append( + [ + {"image_threshold": 1.0, "output_shape": None}, + p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])), + q(np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])), + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) + TESTS.append( + [ + {"image_threshold": 1.0, "output_shape": None}, + p(np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]])), + q(np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])), + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) + TESTS.append( + [ + {"image_threshold": 1.0, "output_shape": [3, 3]}, + p(np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]])), + q(np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])), + np.array([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]), + np.array([[0, 0], [2, 2]]), + ] + ) class TestFgBgToIndices(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_type_shape(self, input_data, label, image, expected_fg, expected_bg): fg_indices, bg_indices = FgBgToIndices(**input_data)(label, image) - np.testing.assert_allclose(fg_indices, expected_fg) - np.testing.assert_allclose(bg_indices, expected_bg) + for indices, expected in zip((fg_indices, bg_indices), (expected_fg, expected_bg)): + self.assertEqual(type(indices), type(label)) + if isinstance(indices, torch.Tensor): + self.assertEqual(indices.device, label.device) + indices = indices.cpu() + np.testing.assert_allclose(indices, expected) if __name__ == "__main__": diff --git a/tests/test_fg_bg_to_indicesd.py b/tests/test_fg_bg_to_indicesd.py index ce6ca30f1b..0a47bda6a0 100644 --- a/tests/test_fg_bg_to_indicesd.py +++ b/tests/test_fg_bg_to_indicesd.py @@ -12,52 +12,85 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import FgBgToIndicesd +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"keys": "label", "image_key": None, "image_threshold": 0.0, "output_shape": None}, - {"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 4, 8]), -] +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: -TEST_CASE_2 = [ - {"keys": "label", "image_key": "image", "image_threshold": 0.0, "output_shape": None}, - {"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]])}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TESTS.append( + [ + {"keys": "label", "image_key": None, "image_threshold": 0.0, "output_shape": None}, + {"label": p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]))}, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 4, 8]), + ] + ) -TEST_CASE_3 = [ - {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None}, - {"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TESTS.append( + [ + {"keys": "label", "image_key": "image", "image_threshold": 0.0, "output_shape": None}, + { + "label": p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])), + "image": q(np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]])), + }, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) -TEST_CASE_4 = [ - {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None}, - {"label": np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TESTS.append( + [ + {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None}, + { + "label": p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])), + "image": q(np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])), + }, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) -TEST_CASE_5 = [ - {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": [3, 3]}, - {"label": np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, - np.array([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]), - np.array([[0, 0], [2, 2]]), -] + TESTS.append( + [ + {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None}, + { + "label": p(np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]])), + "image": q(np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])), + }, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) + + TESTS.append( + [ + {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": [3, 3]}, + { + "label": p(np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]])), + "image": q(np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])), + }, + np.array([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]), + np.array([[0, 0], [2, 2]]), + ] + ) class TestFgBgToIndicesd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_type_shape(self, input_data, data, expected_fg, expected_bg): result = FgBgToIndicesd(**input_data)(data) - np.testing.assert_allclose(result["label_fg_indices"], expected_fg) - np.testing.assert_allclose(result["label_bg_indices"], expected_bg) + for key, expected in zip(("fg", "bg"), (expected_fg, expected_bg)): + r = result[f"label_{key}_indices"] + self.assertEqual(type(r), type(data["label"])) + if isinstance(r, torch.Tensor): + self.assertEqual(r.device, data["label"].device) + r = r.cpu() + np.testing.assert_allclose(r, expected) if __name__ == "__main__": diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index 1314fe3841..5c102989f3 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -16,7 +16,7 @@ import torch.nn.functional as F from monai.losses import FocalLoss -from monai.networks import one_hot +from monai.networks.utils import one_hot_torch from tests.utils import SkipIfBeforePyTorchVersion, test_script_save @@ -60,7 +60,7 @@ def test_consistency_with_cross_entropy_2d_onehot_label(self): x = x.cuda() l = l.cuda() output0 = focal_loss(x, l) - output1 = ce(x, one_hot(l, num_classes=class_num)) + output1 = ce(x, one_hot_torch(l, num_classes=class_num)) a = float(output0.cpu().detach()) b = float(output1.cpu().detach()) if abs(a - b) > max_error: @@ -84,7 +84,7 @@ def test_consistency_with_cross_entropy_classification(self): x = x.cuda() l = l.cuda() output0 = focal_loss(x, l) - output1 = ce(x, one_hot(l, num_classes=class_num)) + output1 = ce(x, one_hot_torch(l, num_classes=class_num)) a = float(output0.cpu().detach()) b = float(output1.cpu().detach()) if abs(a - b) > max_error: @@ -108,7 +108,7 @@ def test_consistency_with_cross_entropy_classification_01(self): x = x.cuda() l = l.cuda() output0 = focal_loss(x, l) - output1 = ce(x, one_hot(l, num_classes=class_num)) + output1 = ce(x, one_hot_torch(l, num_classes=class_num)) a = float(output0.cpu().detach()) b = float(output1.cpu().detach()) if abs(a - b) > max_error: diff --git a/tests/test_gaussian_sharpen.py b/tests/test_gaussian_sharpen.py index 9d078e65e5..66fca04311 100644 --- a/tests/test_gaussian_sharpen.py +++ b/tests/test_gaussian_sharpen.py @@ -11,50 +11,80 @@ import unittest -import numpy as np +import torch from parameterized import parameterized from monai.transforms import GaussianSharpen +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( +TESTS = [] + +for p in TEST_NDARRAYS: + TESTS.append( [ - [[4.1081963, 3.4950666, 4.1081963], [3.7239995, 2.8491793, 3.7239995], [4.569839, 3.9529324, 4.569839]], - [[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]], + {}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [4.1081963, 3.4950666, 4.1081963], + [3.7239995, 2.8491793, 3.7239995], + [4.569839, 3.9529324, 4.569839], + ], + [[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - {"sigma1": 1.0, "sigma2": 0.75, "alpha": 20}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[4.513644, 4.869134, 4.513644], [8.467242, 9.4004135, 8.467242], [10.416813, 12.0653515, 10.416813]], - [[15.711488, 17.569994, 15.711488], [21.16811, 23.501041, 21.16811], [21.614658, 24.766209, 21.614658]], + {"sigma1": 1.0, "sigma2": 0.75, "alpha": 20}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [4.513644, 4.869134, 4.513644], + [8.467242, 9.4004135, 8.467242], + [10.416813, 12.0653515, 10.416813], + ], + [ + [15.711488, 17.569994, 15.711488], + [21.16811, 23.501041, 21.16811], + [21.614658, 24.766209, 21.614658], + ], + ] + ), ] - ), -] + ) -TEST_CASE_3 = [ - {"sigma1": (0.5, 1.0), "sigma2": (0.5, 0.75), "alpha": 20}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[3.3324685, 3.335536, 3.3324673], [7.7666636, 8.16056, 7.7666636], [12.662973, 14.317837, 12.6629715]], - [[15.329051, 16.57557, 15.329051], [19.41665, 20.40139, 19.416655], [24.659554, 27.557873, 24.659554]], + {"sigma1": (0.5, 1.0), "sigma2": (0.5, 0.75), "alpha": 20}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [3.3324685, 3.335536, 3.3324673], + [7.7666636, 8.16056, 7.7666636], + [12.662973, 14.317837, 12.6629715], + ], + [ + [15.329051, 16.57557, 15.329051], + [19.41665, 20.40139, 19.416655], + [24.659554, 27.557873, 24.659554], + ], + ] + ), ] - ), -] + ) class TestGaussianSharpen(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = GaussianSharpen(**argments)(image) - np.testing.assert_allclose(result, expected_data, rtol=1e-4) + torch.testing.assert_allclose(result, expected_data, atol=0, rtol=1e-4) if __name__ == "__main__": diff --git a/tests/test_gaussian_smooth.py b/tests/test_gaussian_smooth.py index e51977fbee..4dc7e35a9f 100644 --- a/tests/test_gaussian_smooth.py +++ b/tests/test_gaussian_smooth.py @@ -11,54 +11,84 @@ import unittest -import numpy as np +import torch from parameterized import parameterized from monai.transforms import GaussianSmooth +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"sigma": 1.5}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( +TESTS = [] + +for p in TEST_NDARRAYS: + TESTS.append( [ - [ - [0.59167546, 0.69312394, 0.59167546], - [0.7956997, 0.93213004, 0.7956997], - [0.7668002, 0.8982755, 0.7668002], - ], - [[1.6105323, 1.8866735, 1.6105323], [1.9892492, 2.3303251, 1.9892492], [1.7856569, 2.091825, 1.7856569]], + {"sigma": 1.5}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore + p( + [ + [ + [0.59167546, 0.69312394, 0.59167546], + [0.7956997, 0.93213004, 0.7956997], + [0.7668002, 0.8982755, 0.7668002], + ], + [ + [1.6105323, 1.8866735, 1.6105323], + [1.9892492, 2.3303251, 1.9892492], + [1.7856569, 2.091825, 1.7856569], + ], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - {"sigma": 0.5}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[0.8424794, 0.99864554, 0.8424794], [1.678146, 1.9892154, 1.678146], [1.9889624, 2.3576462, 1.9889624]], - [[2.966061, 3.5158648, 2.966061], [4.1953645, 4.973038, 4.1953645], [4.112544, 4.8748655, 4.1125436]], + {"sigma": 0.5}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore + p( + [ + [ + [0.8424794, 0.99864554, 0.8424794], + [1.678146, 1.9892154, 1.678146], + [1.9889624, 2.3576462, 1.9889624], + ], + [ + [2.966061, 3.5158648, 2.966061], + [4.1953645, 4.973038, 4.1953645], + [4.112544, 4.8748655, 4.1125436], + ], + ] + ), ] - ), -] + ) -TEST_CASE_3 = [ - {"sigma": [1.5, 0.5]}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[0.8542037, 1.0125432, 0.8542037], [1.1487541, 1.3616928, 1.1487541], [1.1070318, 1.3122368, 1.1070318]], - [[2.3251305, 2.756128, 2.3251305], [2.8718853, 3.4042323, 2.8718853], [2.5779586, 3.0558217, 2.5779586]], + {"sigma": [1.5, 0.5]}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore + p( + [ + [ + [0.8542037, 1.0125432, 0.8542037], + [1.1487541, 1.3616928, 1.1487541], + [1.1070318, 1.3122368, 1.1070318], + ], + [ + [2.3251305, 2.756128, 2.3251305], + [2.8718853, 3.4042323, 2.8718853], + [2.5779586, 3.0558217, 2.5779586], + ], + ] + ), ] - ), -] + ) class TestGaussianSmooth(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = GaussianSmooth(**argments)(image) - np.testing.assert_allclose(result, expected_data, rtol=1e-4) + torch.testing.assert_allclose(result, expected_data, atol=0, rtol=1e-4) if __name__ == "__main__": diff --git a/tests/test_generate_pos_neg_label_crop_centers.py b/tests/test_generate_pos_neg_label_crop_centers.py index 40181aa9ea..0335a94d97 100644 --- a/tests/test_generate_pos_neg_label_crop_centers.py +++ b/tests/test_generate_pos_neg_label_crop_centers.py @@ -15,25 +15,30 @@ from parameterized import parameterized from monai.transforms import generate_pos_neg_label_crop_centers - -TEST_CASE_1 = [ - { - "spatial_size": [2, 2, 2], - "num_samples": 2, - "pos_ratio": 1.0, - "label_spatial_shape": [3, 3, 3], - "fg_indices": [1, 9, 18], - "bg_indices": [3, 12, 21], - "rand_state": np.random.RandomState(), - }, - list, - 2, - 3, -] +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS + (None,): + TESTS.append( + [ + { + "spatial_size": [2, 2, 2], + "num_samples": 2, + "pos_ratio": 1.0, + "label_spatial_shape": [3, 3, 3], + "fg_indices": [1, 9, 18] if p is None else p([1, 9, 18]), + "bg_indices": [3, 12, 21] if p is None else p([3, 12, 21]), + "rand_state": np.random.RandomState(), + }, + list, + 2, + 3, + ] + ) class TestGeneratePosNegLabelCropCenters(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TESTS) def test_type_shape(self, input_data, expected_type, expected_count, expected_shape): result = generate_pos_neg_label_crop_centers(**input_data) self.assertIsInstance(result, expected_type) diff --git a/tests/test_generate_spatial_bounding_box.py b/tests/test_generate_spatial_bounding_box.py index 32a45d8d1c..d73b9fafcc 100644 --- a/tests/test_generate_spatial_bounding_box.py +++ b/tests/test_generate_spatial_bounding_box.py @@ -15,60 +15,79 @@ from parameterized import parameterized from monai.transforms import generate_spatial_bounding_box +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": 0, - }, - ([1, 1], [4, 4]), -] - -TEST_CASE_2 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 1, - "channel_indices": None, - "margin": 0, - }, - ([2, 2], [3, 3]), -] - -TEST_CASE_3 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 0, - "channel_indices": 0, - "margin": 0, - }, - ([1, 1], [4, 4]), -] - -TEST_CASE_4 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": 1, - }, - ([0, 0], [4, 5]), -] - -TEST_CASE_5 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": [2, 1], - }, - ([0, 0], [5, 5]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": 0, + }, + ([1, 1], [4, 4]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 1, + "channel_indices": None, + "margin": 0, + }, + ([2, 2], [3, 3]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": 0, + "margin": 0, + }, + ([1, 1], [4, 4]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": 1, + }, + ([0, 0], [4, 5]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": [2, 1], + }, + ([0, 0], [5, 5]), + ] + ) class TestGenerateSpatialBoundingBox(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_value(self, input_data, expected_box): result = generate_spatial_bounding_box(**input_data) self.assertTupleEqual(result, expected_box) diff --git a/tests/test_get_extreme_points.py b/tests/test_get_extreme_points.py index a334c12415..af655bfb86 100644 --- a/tests/test_get_extreme_points.py +++ b/tests/test_get_extreme_points.py @@ -15,30 +15,36 @@ from parameterized import parameterized from monai.transforms import get_extreme_points - -TEST_CASE_1 = [ - { - "img": np.array([[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]), - "rand_state": np.random, - "background": 0, - "pert": 0.0, - }, - [(0, 1), (3, 0), (3, 0), (1, 2)], -] - -TEST_CASE_2 = [ - { - "img": np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]), - "rand_state": np.random, - "background": 0, - "pert": 0.0, - }, - [(0, 1), (3, 1), (1, 0), (1, 2)], -] +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + { + "img": p(np.array([[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]])), + "rand_state": np.random, + "background": 0, + "pert": 0.0, + }, + [(0, 1), (3, 0), (3, 0), (1, 2)], + ] + ) + TESTS.append( + [ + { + "img": p(np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]])), + "rand_state": np.random, + "background": 0, + "pert": 0.0, + }, + [(0, 1), (3, 1), (1, 0), (1, 2)], + ] + ) class TestGetExtremePoints(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_type_shape(self, input_data, expected): result = get_extreme_points(**input_data) self.assertEqual(result, expected) diff --git a/tests/test_gibbs_noise.py b/tests/test_gibbs_noise.py index 83cba56938..2c5e117eaf 100644 --- a/tests/test_gibbs_noise.py +++ b/tests/test_gibbs_noise.py @@ -19,12 +19,15 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import GibbsNoise from monai.utils.misc import set_determinism +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS + +_, has_torch_fft = optional_import("torch.fft", name="fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, input_type)) class TestGibbsNoise(unittest.TestCase): @@ -36,36 +39,39 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d im = create_test_image(*im_shape, num_objs=4, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] - return torch.Tensor(im) if as_tensor_input else im + return input_type(im) @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_same_result(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = 0.8 - t = GibbsNoise(alpha, as_tensor_output) + t = GibbsNoise(alpha) out1 = t(deepcopy(im)) out2 = t(deepcopy(im)) - np.testing.assert_allclose(out1, out2) - self.assertIsInstance(out1, torch.Tensor if as_tensor_output else np.ndarray) + self.assertEqual(type(out1), type(im)) + if isinstance(out1, torch.Tensor): + self.assertEqual(out1.device, im.device) + torch.testing.assert_allclose(out1, out2, rtol=1e-7, atol=0) + self.assertIsInstance(out1, type(im)) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_identity(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = 0.0 t = GibbsNoise(alpha) out = t(deepcopy(im)) - np.testing.assert_allclose(im, out, atol=1e-2) + torch.testing.assert_allclose(im, out, atol=1e-2, rtol=1e-7) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_alpha_1(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = 1.0 t = GibbsNoise(alpha) out = t(deepcopy(im)) - np.testing.assert_allclose(0 * im, out) + torch.testing.assert_allclose(0 * im, out, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_gibbs_noised.py b/tests/test_gibbs_noised.py index 0e02feb341..f02052818f 100644 --- a/tests/test_gibbs_noised.py +++ b/tests/test_gibbs_noised.py @@ -19,13 +19,15 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import GibbsNoised from monai.utils.misc import set_determinism +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS + +_, has_torch_fft = optional_import("torch.fft", name="fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) - + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, input_type)) KEYS = ["im", "label"] @@ -38,49 +40,56 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims - return {k: v for k, v in zip(KEYS, ims)} + return {k: input_type(deepcopy(v)) for k, v in zip(KEYS, ims)} @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_same_result(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = 0.8 - t = GibbsNoised(KEYS, alpha, as_tensor_output) + t = GibbsNoised(KEYS, alpha) out1 = t(deepcopy(data)) out2 = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(out1[k], out2[k]) - self.assertIsInstance(out1[k], torch.Tensor if as_tensor_output else np.ndarray) + torch.testing.assert_allclose(out1[k], out2[k], rtol=1e-7, atol=0) + self.assertIsInstance(out1[k], type(data[k])) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_identity(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = 0.0 t = GibbsNoised(KEYS, alpha) out = t(deepcopy(data)) for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k], data[k] = out[k].cpu(), data[k].cpu() np.testing.assert_allclose(data[k], out[k], atol=1e-2) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_alpha_1(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = 1.0 t = GibbsNoised(KEYS, alpha) out = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(0 * data[k], out[k]) + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k], data[k] = out[k].cpu(), data[k].cpu() + np.testing.assert_allclose(0.0 * data[k], out[k], atol=1e-2) @parameterized.expand(TEST_CASES) - def test_dict_matches(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_dict_matches(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} alpha = 1.0 t = GibbsNoised(KEYS, alpha) out = t(deepcopy(data)) - np.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]]) + torch.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]], rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py index 6e0aa4023e..045a102955 100644 --- a/tests/test_grid_dataset.py +++ b/tests/test_grid_dataset.py @@ -36,7 +36,7 @@ def test_shape(self): test_dataset = ["vwxyz", "helloworld", "worldfoobar"] result = GridPatchDataset(dataset=test_dataset, patch_iter=identity_generator, with_coordinates=False) output = [] - n_workers = 0 if sys.platform == "win32" else 2 + n_workers = 2 if sys.platform == "linux" else 0 for item in DataLoader(result, batch_size=3, num_workers=n_workers): output.append("".join(item)) expected = ["vwx", "wor", "yzh", "ldf", "ell", "oob", "owo", "ar", "rld"] diff --git a/tests/test_handler_transform_inverter.py b/tests/test_handler_transform_inverter.py index 385311eba7..e5c54888e7 100644 --- a/tests/test_handler_transform_inverter.py +++ b/tests/test_handler_transform_inverter.py @@ -70,7 +70,7 @@ def test_invert(self): data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms - num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2 + num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) diff --git a/tests/test_image_dataset.py b/tests/test_image_dataset.py index 3b3c06c87c..8e718461ae 100644 --- a/tests/test_image_dataset.py +++ b/tests/test_image_dataset.py @@ -18,6 +18,7 @@ from monai.data import ImageDataset from monai.transforms import Compose, EnsureChannelFirst, RandAdjustContrast, RandomizableTransform, Spacing +from monai.utils.misc import dtype_convert FILENAMES = ["test1.nii.gz", "test2.nii", "test3.nii.gz"] @@ -85,7 +86,7 @@ def test_dataset(self): # loading no meta, int dataset = ImageDataset(full_names, dtype=np.float16) for d, _ in zip(dataset, ref_data): - self.assertEqual(d.dtype, np.float16) + self.assertEqual(d.dtype, dtype_convert(np.float16, type(d))) # loading with meta, no transform dataset = ImageDataset(full_names, image_only=False) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index f2470d47fd..042c0b5e11 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -64,6 +64,7 @@ allow_missing_keys_mode, convert_inverse_interp_mode, ) +from monai.transforms.utility.dictionary import ToTensord from monai.utils import first, get_seed, optional_import, set_determinism from monai.utils.enums import InverseKeys from tests.utils import make_nifti_image, make_rand_affine @@ -489,7 +490,7 @@ ) ) -TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] +TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose([ToTensord(KEYS), Compose(t[3:])])) for t in TESTS] TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore @@ -651,7 +652,7 @@ def test_inverse_inferred_seg(self, extra_transform): batch_size = 10 # num workers = 0 for mac - num_workers = 2 if sys.platform != "darwin" else 0 + num_workers = 2 if sys.platform == "linux" else 0 transforms = Compose( [ AddChanneld(KEYS), diff --git a/tests/test_inverse_collation.py b/tests/test_inverse_collation.py index c302e04017..334ea5f4b3 100644 --- a/tests/test_inverse_collation.py +++ b/tests/test_inverse_collation.py @@ -56,7 +56,6 @@ prob=0.5, rotate_range=np.pi, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), - as_tensor_output=False, ), ] ] @@ -75,7 +74,6 @@ prob=0.5, rotate_range=np.pi, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), - as_tensor_output=False, ), ] ] @@ -117,7 +115,7 @@ def test_collation(self, _, transform, collate_fn, ndim): modified_transform = Compose([transform, ResizeWithPadOrCropd(KEYS, 100), ToTensord(KEYS)]) # num workers = 0 for mac or gpu transforms - num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2 + num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2 dataset = CacheDataset(data, transform=modified_transform, progress=False) loader = DataLoader(dataset, num_workers, batch_size=self.batch_size, collate_fn=collate_fn) diff --git a/tests/test_invertd.py b/tests/test_invertd.py index 5b98653f0a..5a74e90c99 100644 --- a/tests/test_invertd.py +++ b/tests/test_invertd.py @@ -70,7 +70,7 @@ def test_invert(self): data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms - num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2 + num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) diff --git a/tests/test_k_space_spike_noise.py b/tests/test_k_space_spike_noise.py index 53661d5fcb..66763f286f 100644 --- a/tests/test_k_space_spike_noise.py +++ b/tests/test_k_space_spike_noise.py @@ -20,12 +20,12 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoise from monai.utils.misc import set_determinism +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] +TESTS = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for p in TEST_NDARRAYS: + TESTS.append((shape, p)) class TestKSpaceSpikeNoise(unittest.TestCase): @@ -37,34 +37,44 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, im_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d - im = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] - return torch.Tensor(im) if as_tensor_input else im + im, _ = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) + return im_type(im[None]) - @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_same_result(self, im_shape, im_type): - im = self.get_data(im_shape, as_tensor_input) + im = self.get_data(im_shape, im_type) loc = [0, int(im.shape[1] / 2), 0] if len(im_shape) == 2 else [0, int(im.shape[1] / 2), 0, 0] k_intensity = 10 - t = KSpaceSpikeNoise(loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoise(loc, k_intensity) out1 = t(deepcopy(im)) out2 = t(deepcopy(im)) + self.assertEqual(type(im), type(out1)) + if isinstance(out1, torch.Tensor): + self.assertEqual(im.device, out1.device) + out1 = out1.cpu() + out2 = out2.cpu() + np.testing.assert_allclose(out1, out2) - self.assertIsInstance(out1, torch.Tensor if as_tensor_output else np.ndarray) - @parameterized.expand(TEST_CASES) - def test_highlighted_kspace_pixel(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_highlighted_kspace_pixel(self, im_shape, as_tensor_input): im = self.get_data(im_shape, as_tensor_input) loc = [0, int(im.shape[1] / 2), 0] if len(im_shape) == 2 else [0, int(im.shape[1] / 2), 0, 0] k_intensity = 10 - t = KSpaceSpikeNoise(loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoise(loc, k_intensity) out = t(im) + self.assertEqual(type(im), type(out)) + if isinstance(out, torch.Tensor): + self.assertEqual(im.device, out.device) + out = out.cpu() + n_dims = len(im_shape) out_k = fftshift(fftn(out, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) log_mag = np.log(np.absolute(out_k)) diff --git a/tests/test_k_space_spike_noised.py b/tests/test_k_space_spike_noised.py index e5d2dfb6f8..3fa6a394f3 100644 --- a/tests/test_k_space_spike_noised.py +++ b/tests/test_k_space_spike_noised.py @@ -20,12 +20,12 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoised from monai.utils.misc import set_determinism +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] +TESTS = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for p in TEST_NDARRAYS: + TESTS.append((shape, p)) KEYS = ["image", "label"] @@ -39,55 +39,69 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, im_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) - ims = [im[None] for im in ims] - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims + ims = [im_type(im[None]) for im in ims] return {k: v for k, v in zip(KEYS, ims)} - @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_same_result(self, im_shape, im_type): - data = self.get_data(im_shape, as_tensor_input) + data = self.get_data(im_shape, im_type) loc = [0] + [int(im_shape[i] / 2) for i in range(len(im_shape))] k_intensity = 10 - t = KSpaceSpikeNoised(KEYS, loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoised(KEYS, loc, k_intensity) out1 = t(deepcopy(data)) out2 = t(deepcopy(data)) for k in KEYS: + self.assertEqual(type(out1[k]), type(data[k])) + if isinstance(out1[k], torch.Tensor): + self.assertEqual(out1[k].device, data[k].device) + out1[k] = out1[k].cpu() + out2[k] = out2[k].cpu() np.testing.assert_allclose(out1[k], out2[k]) - self.assertIsInstance(out1[k], torch.Tensor if as_tensor_output else np.ndarray) - @parameterized.expand(TEST_CASES) - def test_highlighted_kspace_pixel(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_highlighted_kspace_pixel(self, im_shape, im_type): - data = self.get_data(im_shape, as_tensor_input) + data = self.get_data(im_shape, im_type) loc = [0] + [int(im_shape[i] / 2) for i in range(len(im_shape))] k_intensity = 10 - t = KSpaceSpikeNoised(KEYS, loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoised(KEYS, loc, k_intensity) out = t(data) for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k] = out[k].cpu() + n_dims = len(im_shape) out_k = fftshift(fftn(out[k], axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) log_mag = np.log(np.absolute(out_k)) np.testing.assert_allclose(k_intensity, log_mag[tuple(loc)], 1e-1) - @parameterized.expand(TEST_CASES) - def test_dict_matches(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_dict_matches(self, im_shape, im_type): + data = self.get_data(im_shape, im_type) # use same image for both dictionary entries to check same trans is applied to them data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} loc = [0] + [int(im_shape[i] / 2) for i in range(len(im_shape))] k_intensity = 10 - t = KSpaceSpikeNoised(KEYS, loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoised(KEYS, loc, k_intensity) out = t(deepcopy(data)) + for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k] = out[k].cpu() + np.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]]) diff --git a/tests/test_keep_largest_connected_component.py b/tests/test_keep_largest_connected_component.py index a8835329ba..6eaa61897d 100644 --- a/tests/test_keep_largest_connected_component.py +++ b/tests/test_keep_largest_connected_component.py @@ -11,10 +11,12 @@ import unittest +import numpy as np import torch from parameterized import parameterized from monai.transforms import KeepLargestConnectedComponent +from tests.utils import TEST_NDARRAYS grid_1 = torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]) grid_2 = torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]) @@ -322,23 +324,24 @@ class TestKeepLargestConnectedComponent(unittest.TestCase): @parameterized.expand(VALID_CASES) - def test_correct_results(self, _, args, tensor, expected): - converter = KeepLargestConnectedComponent(**args) - if torch.cuda.is_available(): - result = converter(tensor.clone().cuda()) - assert torch.allclose(result, expected.cuda()) - else: - result = converter(tensor.clone()) - assert torch.allclose(result, expected) + def test_correct_results(self, _, args, img, expected): + for p in TEST_NDARRAYS: + converter = KeepLargestConnectedComponent(**args) + im = p(img.clone()) + result = converter(im) + self.assertEqual(type(im), type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, im.device) + result = result.cpu() + np.testing.assert_allclose(result, expected) @parameterized.expand(INVALID_CASES) - def test_raise_exception(self, _, args, tensor, expected_error): + def test_raise_exception(self, _, args, img, expected_error): with self.assertRaises(expected_error): - converter = KeepLargestConnectedComponent(**args) - if torch.cuda.is_available(): - _ = converter(tensor.clone().cuda()) - else: - _ = converter(tensor.clone()) + for p in TEST_NDARRAYS: + converter = KeepLargestConnectedComponent(**args) + im = p(img) + _ = converter(im) if __name__ == "__main__": diff --git a/tests/test_label_to_contour.py b/tests/test_label_to_contour.py index 8f8f3cc054..c34ff2e7e0 100644 --- a/tests/test_label_to_contour.py +++ b/tests/test_label_to_contour.py @@ -15,6 +15,7 @@ import torch from monai.transforms import LabelToContour +from tests.utils import TEST_NDARRAYS expected_output_for_cube = np.array( [ @@ -144,34 +145,37 @@ def gen_fixed_img(): class TestContour(unittest.TestCase): def test_contour(self): input_param = {"kernel_type": "Laplace"} - - # check 5-dim input data - test_cube, expected_output = gen_fixed_cube() - for cube in test_cube: - test_result_cube = LabelToContour(**input_param)(cube) - self.assertEqual(test_result_cube.shape, cube.shape) - - test_result_np = test_result_cube.cpu().numpy() - channels = cube.shape[0] - for channel in range(channels): - np.testing.assert_allclose(test_result_np[channel, ...], expected_output) - - # check 4-dim input data - test_img, expected_output = gen_fixed_img() - for img in test_img: - channels = img.shape[0] - test_result_img = LabelToContour(**input_param)(img) - self.assertEqual(test_result_img.shape, img.shape) - - test_result_np = test_result_img.cpu().numpy() - for channel in range(channels): - np.testing.assert_allclose(test_result_np[channel, ...], expected_output) - - # check invalid input data - error_input = torch.rand(1, 2) - self.assertRaises(ValueError, LabelToContour(**input_param), error_input) - error_input = torch.rand(1, 2, 3, 4, 5) - self.assertRaises(ValueError, LabelToContour(**input_param), error_input) + for p in TEST_NDARRAYS: + + # check 5-dim input data + test_cube, expected_output = gen_fixed_cube() + for cube in p(test_cube): + test_result_cube = LabelToContour(**input_param)(cube) + self.assertEqual(test_result_cube.shape, cube.shape) + self.assertEqual(type(cube), type(test_result_cube)) + if isinstance(test_result_cube, torch.Tensor): + test_result_cube = test_result_cube.cpu() + channels = cube.shape[0] + for channel in range(channels): + np.testing.assert_allclose(test_result_cube[channel, ...], expected_output) + + # check 4-dim input data + test_img, expected_output = gen_fixed_img() + for img in p(test_img): + channels = img.shape[0] + test_result_img = LabelToContour(**input_param)(img) + self.assertEqual(test_result_img.shape, img.shape) + self.assertEqual(type(test_result_img), type(img)) + if isinstance(test_result_img, torch.Tensor): + test_result_img = test_result_img.cpu() + for channel in range(channels): + np.testing.assert_allclose(test_result_img[channel, ...], expected_output) + + # check invalid input data + error_input = p(torch.rand(1, 2)) + self.assertRaises(ValueError, LabelToContour(**input_param), error_input) + error_input = p(torch.rand(1, 2, 3, 4, 5)) + self.assertRaises(ValueError, LabelToContour(**input_param), error_input) if __name__ == "__main__": diff --git a/tests/test_label_to_contourd.py b/tests/test_label_to_contourd.py index d3795755c7..5b5b8260d1 100644 --- a/tests/test_label_to_contourd.py +++ b/tests/test_label_to_contourd.py @@ -15,6 +15,7 @@ import torch from monai.transforms import LabelToContourd +from tests.utils import TEST_NDARRAYS expected_output_for_cube = np.array( [ @@ -144,34 +145,39 @@ def gen_fixed_img(): class TestContourd(unittest.TestCase): def test_contour(self): input_param = {"keys": "img", "kernel_type": "Laplace"} - - # check 5-dim input data - test_cube, expected_output = gen_fixed_cube() - for cube in test_cube: - test_result_cube = LabelToContourd(**input_param)({"img": cube}) - self.assertEqual(test_result_cube["img"].shape, cube.shape) - - test_result_np = test_result_cube["img"].cpu().numpy() - channels = cube.shape[0] - for channel in range(channels): - np.testing.assert_allclose(test_result_np[channel, ...], expected_output) - - # check 4-dim input data - test_img, expected_output = gen_fixed_img() - for img in test_img: - channels = img.shape[0] - test_result_img = LabelToContourd(**input_param)({"img": img}) - self.assertEqual(test_result_img["img"].shape, img.shape) - - test_result_np = test_result_img["img"].cpu().numpy() - for channel in range(channels): - np.testing.assert_allclose(test_result_np[channel, ...], expected_output) - - # check invalid input data - error_input = {"img": torch.rand(1, 2)} - self.assertRaises(ValueError, LabelToContourd(**input_param), error_input) - error_input = {"img": torch.rand(1, 2, 3, 4, 5)} - self.assertRaises(ValueError, LabelToContourd(**input_param), error_input) + for p in TEST_NDARRAYS: + # check 5-dim input data + test_cube, expected_output = gen_fixed_cube() + for cube in p(test_cube): + test_result_cube = LabelToContourd(**input_param)({"img": cube})["img"] + self.assertEqual(test_result_cube.shape, cube.shape) + self.assertEqual(type(test_result_cube), type(cube)) + if isinstance(test_result_cube, torch.Tensor): + self.assertEqual(cube.device, test_result_cube.device) + test_result_cube = test_result_cube.cpu() + + channels = cube.shape[0] + for channel in range(channels): + np.testing.assert_allclose(test_result_cube[channel, ...], expected_output) + + # check 4-dim input data + test_img, expected_output = gen_fixed_img() + for img in p(test_img): + channels = img.shape[0] + test_result_img = LabelToContourd(**input_param)({"img": img})["img"] + self.assertEqual(test_result_img.shape, img.shape) + self.assertEqual(type(test_result_img), type(img)) + if isinstance(test_result_img, torch.Tensor): + test_result_img = test_result_img.cpu() + + for channel in range(channels): + np.testing.assert_allclose(test_result_img[channel, ...], expected_output) + + # check invalid input data + error_input = {"img": p(torch.rand(1, 2))} + self.assertRaises(ValueError, LabelToContourd(**input_param), error_input) + error_input = {"img": p(torch.rand(1, 2, 3, 4, 5))} + self.assertRaises(ValueError, LabelToContourd(**input_param), error_input) if __name__ == "__main__": diff --git a/tests/test_label_to_mask.py b/tests/test_label_to_mask.py index 2a84c7bea6..dc785ba24e 100644 --- a/tests/test_label_to_mask.py +++ b/tests/test_label_to_mask.py @@ -12,45 +12,59 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import LabelToMask +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"select_labels": [2, 3], "merge_channels": False}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array([[[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), -] - -TEST_CASE_2 = [ - {"select_labels": 2, "merge_channels": False}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array([[[0, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), -] - -TEST_CASE_3 = [ - {"select_labels": [1, 2], "merge_channels": False}, - np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), - np.array([[[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), -] - -TEST_CASE_4 = [ - {"select_labels": 2, "merge_channels": False}, - np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), - np.array([[[1, 0, 1], [1, 1, 0]]]), -] - -TEST_CASE_5 = [ - {"select_labels": [1, 2], "merge_channels": True}, - np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), - np.array([[[1, 0, 1], [1, 1, 1]]]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"select_labels": [2, 3], "merge_channels": False}, + p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]])), + np.array([[[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), + ] + ) + TESTS.append( + [ + {"select_labels": 2, "merge_channels": False}, + p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]])), + np.array([[[0, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), + ] + ) + TESTS.append( + [ + {"select_labels": [1, 2], "merge_channels": False}, + p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])), + np.array([[[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), + ] + ) + TESTS.append( + [ + {"select_labels": 2, "merge_channels": False}, + p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])), + np.array([[[1, 0, 1], [1, 1, 0]]]), + ] + ) + TESTS.append( + [ + {"select_labels": [1, 2], "merge_channels": True}, + p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])), + np.array([[[1, 0, 1], [1, 1, 1]]]), + ] + ) class TestLabelToMask(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = LabelToMask(**argments)(image) + self.assertEqual(type(result), type(image)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, image.device) + result = result.cpu() np.testing.assert_allclose(result, expected_data) diff --git a/tests/test_label_to_maskd.py b/tests/test_label_to_maskd.py index f046390c19..59e4ddb047 100644 --- a/tests/test_label_to_maskd.py +++ b/tests/test_label_to_maskd.py @@ -12,46 +12,61 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import LabelToMaskd +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"keys": "img", "select_labels": [2, 3], "merge_channels": False}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array([[[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), -] - -TEST_CASE_2 = [ - {"keys": "img", "select_labels": 2, "merge_channels": False}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array([[[0, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), -] - -TEST_CASE_3 = [ - {"keys": "img", "select_labels": [1, 2], "merge_channels": False}, - {"img": np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])}, - np.array([[[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), -] - -TEST_CASE_4 = [ - {"keys": "img", "select_labels": 2, "merge_channels": False}, - {"img": np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])}, - np.array([[[1, 0, 1], [1, 1, 0]]]), -] - -TEST_CASE_5 = [ - {"keys": "img", "select_labels": [1, 2], "merge_channels": True}, - {"img": np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])}, - np.array([[[1, 0, 1], [1, 1, 1]]]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": "img", "select_labels": [2, 3], "merge_channels": False}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array([[[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), + ] + ) + TESTS.append( + [ + {"keys": "img", "select_labels": 2, "merge_channels": False}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array([[[0, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), + ] + ) + TESTS.append( + [ + {"keys": "img", "select_labels": [1, 2], "merge_channels": False}, + {"img": p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]))}, + np.array([[[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), + ] + ) + TESTS.append( + [ + {"keys": "img", "select_labels": 2, "merge_channels": False}, + {"img": p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]))}, + np.array([[[1, 0, 1], [1, 1, 0]]]), + ] + ) + TESTS.append( + [ + {"keys": "img", "select_labels": [1, 2], "merge_channels": True}, + {"img": p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]))}, + np.array([[[1, 0, 1], [1, 1, 1]]]), + ] + ) class TestLabelToMaskd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) - def test_value(self, argments, image, expected_data): - result = LabelToMaskd(**argments)(image) - np.testing.assert_allclose(result["img"], expected_data) + @parameterized.expand(TESTS) + def test_value(self, argments, input_data, expected_data): + result = LabelToMaskd(**argments)(input_data) + r, i = result["img"], input_data["img"] + self.assertEqual(type(r), type(i)) + if isinstance(r, torch.Tensor): + self.assertEqual(r.device, i.device) + r = r.cpu() + np.testing.assert_allclose(r, expected_data) if __name__ == "__main__": diff --git a/tests/test_load_spacing_orientation.py b/tests/test_load_spacing_orientation.py index 48aac7ec56..0317b48cba 100644 --- a/tests/test_load_spacing_orientation.py +++ b/tests/test_load_spacing_orientation.py @@ -19,6 +19,7 @@ from parameterized import parameterized from monai.transforms import AddChanneld, LoadImaged, Orientationd, Spacingd +from monai.utils.misc import convert_data_type FILES = tuple( os.path.join(os.path.dirname(__file__), "testing_data", filename) @@ -36,7 +37,8 @@ def test_load_spacingd(self, filename): res_dict = Spacingd(keys="image", pixdim=(1, 0.2, 1), diagonal=True, padding_mode="zeros")(data_dict) t1 = time.time() print(f"time monai: {t1 - t}") - anat = nibabel.Nifti1Image(data_dict["image"][0], data_dict["image_meta_dict"]["original_affine"]) + im_np, *_ = convert_data_type(data_dict["image"][0], np.ndarray) + anat = nibabel.Nifti1Image(im_np, data_dict["image_meta_dict"]["original_affine"]) ref = resample_to_output(anat, (1, 0.2, 1), order=1) t2 = time.time() print(f"time scipy: {t2 - t1}") @@ -58,7 +60,8 @@ def test_load_spacingd_rotate(self, filename): res_dict = Spacingd(keys="image", pixdim=(1, 2, 3), diagonal=True, padding_mode="zeros")(data_dict) t1 = time.time() print(f"time monai: {t1 - t}") - anat = nibabel.Nifti1Image(data_dict["image"][0], data_dict["image_meta_dict"]["original_affine"]) + im_np, *_ = convert_data_type(data_dict["image"][0], np.ndarray) + anat = nibabel.Nifti1Image(im_np, data_dict["image_meta_dict"]["original_affine"]) ref = resample_to_output(anat, (1, 2, 3), order=1) t2 = time.time() print(f"time scipy: {t2 - t1}") diff --git a/tests/test_map_classes_to_indices.py b/tests/test_map_classes_to_indices.py index 2320954520..e27a7e540d 100644 --- a/tests/test_map_classes_to_indices.py +++ b/tests/test_map_classes_to_indices.py @@ -12,88 +12,149 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import map_classes_to_indices +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - # test Argmax data - {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), "num_classes": 3, "image": None, "image_threshold": 0.0}, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + # test Argmax data + { + "label": p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + "num_classes": 3, + "image": None, + "image_threshold": 0.0, + }, + [ + np.array([0, 4, 8]), + np.array([1, 5, 6]), + np.array([2, 3, 7]), + ], + ] + ) -TEST_CASE_2 = [ - { - "label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), - "num_classes": 3, - "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - "image_threshold": 60, - }, - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + TESTS.append( + [ + { + "label": p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + "num_classes": 3, + "image": p(np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]])), + "image_threshold": 60, + }, + [ + np.array([0, 8]), + np.array([1, 5, 6]), + np.array([3]), + ], + ] + ) -TEST_CASE_3 = [ - # test One-Hot data - { - "label": np.array( + TESTS.append( + [ + # test One-Hot data + { + "label": p( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ) + ), + "image": None, + "image_threshold": 0.0, + }, [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], - ] - ), - "image": None, - "image_threshold": 0.0, - }, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] + np.array([0, 4, 8]), + np.array([1, 5, 6]), + np.array([2, 3, 7]), + ], + ] + ) -TEST_CASE_4 = [ - { - "label": np.array( + TESTS.append( + [ + { + "label": p( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ) + ), + "num_classes": None, + "image": p(np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]])), + "image_threshold": 60, + }, [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], - ] - ), - "num_classes": None, - "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - "image_threshold": 60, - }, - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + np.array([0, 8]), + np.array([1, 5, 6]), + np.array([3]), + ], + ] + ) -TEST_CASE_5 = [ - # test empty class - {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), "num_classes": 5, "image": None, "image_threshold": 0.0}, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7]), np.array([]), np.array([])], -] + TESTS.append( + [ + # test empty class + { + "label": p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + "num_classes": 5, + "image": None, + "image_threshold": 0.0, + }, + [ + np.array([0, 4, 8]), + np.array([1, 5, 6]), + np.array([2, 3, 7]), + np.array([]), + np.array([]), + ], + ] + ) -TEST_CASE_6 = [ - # test empty class - { - "label": np.array( + TESTS.append( + [ + # test empty class + { + "label": p( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + ] + ) + ), + "image": None, + "image_threshold": 0.0, + }, [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], - [[0, 0, 0], [0, 0, 0], [0, 0, 0]], - [[0, 0, 0], [0, 0, 0], [0, 0, 0]], - ] - ), - "image": None, - "image_threshold": 0.0, - }, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7]), np.array([]), np.array([])], -] + np.array([0, 4, 8]), + np.array([1, 5, 6]), + np.array([2, 3, 7]), + np.array([]), + np.array([]), + ], + ] + ) class TestMapClassesToIndices(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS) def test_value(self, input_data, expected_indices): indices = map_classes_to_indices(**input_data) for i, e in zip(indices, expected_indices): + i = i.cpu() if isinstance(i, torch.Tensor) else i np.testing.assert_allclose(i, e) diff --git a/tests/test_mask_intensity.py b/tests/test_mask_intensity.py index 3131abe8bf..0086409c16 100644 --- a/tests/test_mask_intensity.py +++ b/tests/test_mask_intensity.py @@ -11,35 +11,45 @@ import unittest -import numpy as np +import torch from parameterized import parameterized from monai.transforms import MaskIntensity - -TEST_CASE_1 = [ - {"mask_data": np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]])}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]]), -] - -TEST_CASE_2 = [ - {"mask_data": np.array([[[0, 0, 0], [0, 5, 0], [0, 0, 0]]])}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]]), -] - -TEST_CASE_3 = [ - {"mask_data": np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [0, 1, 0], [0, 1, 0]]])}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]), -] +from tests.utils import TEST_NDARRAYS + +TEST_CASES = [] + +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + + TEST_CASES.append( + [ + {"mask_data": p([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]])}, + q([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + q([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]]), + ] + ) + TEST_CASES.append( + [ + {"mask_data": p([[[0, 0, 0], [0, 5, 0], [0, 0, 0]]])}, + q([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + q([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]]), + ] + ) + TEST_CASES.append( + [ + {"mask_data": p([[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [0, 1, 0], [0, 1, 0]]])}, + q([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + q([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]), + ] + ) class TestMaskIntensity(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TEST_CASES) def test_value(self, argments, image, expected_data): result = MaskIntensity(**argments)(image) - np.testing.assert_allclose(result, expected_data) + torch.testing.assert_allclose(result, expected_data, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_mean_ensemble.py b/tests/test_mean_ensemble.py index 7e08846beb..a4595fe1ac 100644 --- a/tests/test_mean_ensemble.py +++ b/tests/test_mean_ensemble.py @@ -16,6 +16,7 @@ from parameterized import parameterized from monai.transforms import MeanEnsemble +from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [ {"weights": None}, @@ -57,8 +58,22 @@ class TestMeanEnsemble(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_value(self, input_param, img, expected_value): - result = MeanEnsemble(**input_param)(img) - torch.testing.assert_allclose(result, expected_value) + for p in TEST_NDARRAYS: + if isinstance(img, list): + im = [p(i) for i in img] + im_type = type(im[0]) + im_device = im[0].device if isinstance(im[0], torch.Tensor) else None + else: + im = p(img) + im_type = type(im) + im_device = im.device if isinstance(im, torch.Tensor) else None + + result = MeanEnsemble(**input_param)(im) + self.assertEqual(im_type, type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, im_device) + result = result.cpu() + np.testing.assert_allclose(result, expected_value) def test_cuda_value(self): img = torch.stack([torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2) + 2]) diff --git a/tests/test_mean_ensembled.py b/tests/test_mean_ensembled.py index ea77ef18a0..4505d34944 100644 --- a/tests/test_mean_ensembled.py +++ b/tests/test_mean_ensembled.py @@ -16,49 +16,66 @@ from parameterized import parameterized from monai.transforms import MeanEnsembled +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": None}, - {"pred0": torch.ones(2, 2, 2), "pred1": torch.ones(2, 2, 2) + 2}, - torch.ones(2, 2, 2) + 1, -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": None}, + {"pred0": p(torch.ones(2, 2, 2)), "pred1": p(torch.ones(2, 2, 2) + 2)}, + torch.ones(2, 2, 2) + 1, + ] + ) -TEST_CASE_2 = [ - {"keys": "output", "weights": None}, - {"output": torch.stack([torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2])}, - torch.ones(2, 2, 2) + 1, -] + TESTS.append( + [ + {"keys": "output", "weights": None}, + {"output": p(torch.stack([torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2]))}, + torch.ones(2, 2, 2) + 1, + ] + ) -TEST_CASE_3 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": [1, 3]}, - {"pred0": torch.ones(2, 2, 2, 2), "pred1": torch.ones(2, 2, 2, 2) + 2}, - torch.ones(2, 2, 2, 2) * 2.5, -] + TESTS.append( + [ + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": [1, 3]}, + {"pred0": p(torch.ones(2, 2, 2, 2)), "pred1": p(torch.ones(2, 2, 2, 2) + 2)}, + torch.ones(2, 2, 2, 2) * 2.5, + ] + ) -TEST_CASE_4 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": [[1, 3], [3, 1]]}, - {"pred0": torch.ones(2, 2, 2), "pred1": torch.ones(2, 2, 2) + 2}, - torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1), -] + TESTS.append( + [ + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": [[1, 3], [3, 1]]}, + {"pred0": p(torch.ones(2, 2, 2)), "pred1": p(torch.ones(2, 2, 2) + 2)}, + torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1), + ] + ) -TEST_CASE_5 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": np.array([[[1, 3]], [[3, 1]]])}, - {"pred0": torch.ones(2, 2, 2, 2), "pred1": torch.ones(2, 2, 2, 2) + 2}, - torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1), -] + TESTS.append( + [ + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": np.array([[[1, 3]], [[3, 1]]])}, + {"pred0": p(torch.ones(2, 2, 2, 2)), "pred1": p(torch.ones(2, 2, 2, 2) + 2)}, + torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1), + ] + ) -TEST_CASE_6 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": torch.tensor([[[1, 3]], [[3, 1]]])}, - {"pred0": torch.ones(2, 2, 2, 2), "pred1": torch.ones(2, 2, 2, 2) + 2}, - torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1), -] + TESTS.append( + [ + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": torch.tensor([[[1, 3]], [[3, 1]]])}, + {"pred0": p(torch.ones(2, 2, 2, 2)), "pred1": p(torch.ones(2, 2, 2, 2) + 2)}, + torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1), + ] + ) class TestMeanEnsembled(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS) def test_value(self, input_param, data, expected_value): result = MeanEnsembled(**input_param)(data) - torch.testing.assert_allclose(result["output"], expected_value) + if isinstance(result["output"], torch.Tensor): + result["output"] = result["output"].cpu() + np.testing.assert_allclose(result["output"], expected_value) def test_cuda_value(self): img = torch.stack([torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2) + 2]) diff --git a/tests/test_normalize_intensity.py b/tests/test_normalize_intensity.py index 9d474faea7..2d26eb6dd3 100644 --- a/tests/test_normalize_intensity.py +++ b/tests/test_normalize_intensity.py @@ -12,70 +12,108 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import NormalizeIntensity -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D -TEST_CASES = [ - [{"nonzero": True}, np.array([0.0, 3.0, 0.0, 4.0]), np.array([0.0, -1.0, 0.0, 1.0])], - [ - {"subtrahend": np.array([3.5, 3.5, 3.5, 3.5]), "divisor": np.array([0.5, 0.5, 0.5, 0.5]), "nonzero": True}, - np.array([0.0, 3.0, 0.0, 4.0]), - np.array([0.0, -1.0, 0.0, 1.0]), - ], - [{"nonzero": True}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])], - [{"nonzero": False}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])], - [{"nonzero": False}, np.array([1, 1, 1, 1]), np.array([0.0, 0.0, 0.0, 0.0])], - [ - {"nonzero": False, "channel_wise": True, "subtrahend": [1, 2, 3]}, - np.ones((3, 2, 2)), - np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-2.0, -2.0], [-2.0, -2.0]]]), - ], - [ - {"nonzero": True, "channel_wise": True, "subtrahend": [1, 2, 3], "divisor": [0, 0, 2]}, - np.ones((3, 2, 2)), - np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-1.0, -1.0], [-1.0, -1.0]]]), - ], - [ - {"nonzero": True, "channel_wise": False, "subtrahend": 2, "divisor": 0}, - np.ones((3, 2, 2)), - np.ones((3, 2, 2)) * -1.0, - ], - [ - {"nonzero": True, "channel_wise": False, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": 0}, - np.ones((3, 2, 2)), - np.ones((3, 2, 2)) * 0.5, - ], - [ - {"nonzero": True, "channel_wise": True, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": [0, 1, 0]}, - np.ones((3, 2, 2)), - np.ones((3, 2, 2)) * 0.5, - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"nonzero": True}, np.array([0.0, 3.0, 0.0, 4.0]), np.array([0.0, -1.0, 0.0, 1.0])]) + TESTS.append( + [ + p, + {"subtrahend": np.array([3.5, 3.5, 3.5, 3.5]), "divisor": np.array([0.5, 0.5, 0.5, 0.5]), "nonzero": True}, + np.array([0.0, 3.0, 0.0, 4.0]), + np.array([0.0, -1.0, 0.0, 1.0]), + ] + ) + TESTS.append([p, {"nonzero": True}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])]) + TESTS.append([p, {"nonzero": False}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])]) + TESTS.append([p, {"nonzero": False}, np.array([1, 1, 1, 1]), np.array([0.0, 0.0, 0.0, 0.0])]) + TESTS.append( + [ + p, + {"nonzero": False, "channel_wise": True, "subtrahend": [1, 2, 3]}, + np.ones((3, 2, 2)), + np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-2.0, -2.0], [-2.0, -2.0]]]), + ] + ) + TESTS.append( + [ + p, + {"nonzero": True, "channel_wise": True, "subtrahend": [1, 2, 3], "divisor": [0, 0, 2]}, + np.ones((3, 2, 2)), + np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-1.0, -1.0], [-1.0, -1.0]]]), + ] + ) + TESTS.append( + [ + p, + {"nonzero": True, "channel_wise": False, "subtrahend": 2, "divisor": 0}, + np.ones((3, 2, 2)), + np.ones((3, 2, 2)) * -1.0, + ] + ) + TESTS.append( + [ + p, + {"nonzero": True, "channel_wise": False, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": 0}, + np.ones((3, 2, 2)), + np.ones((3, 2, 2)) * 0.5, + ] + ) + TESTS.append( + [ + p, + {"nonzero": True, "channel_wise": True, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": [0, 1, 0]}, + np.ones((3, 2, 2)), + np.ones((3, 2, 2)) * 0.5, + ] + ) class TestNormalizeIntensity(NumpyImageTestCase2D): - def test_default(self): + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_default(self, im_type): + im = im_type(self.imt.copy()) normalizer = NormalizeIntensity() - normalized = normalizer(self.imt.copy()) - self.assertTrue(normalized.dtype == np.float32) + normalized = normalizer(im) + self.assertEqual(type(im), type(normalized)) + if isinstance(normalized, torch.Tensor): + self.assertEqual(im.device, normalized.device) + normalized = normalized.cpu() + self.assertTrue(normalized.dtype in (np.float32, torch.float32)) expected = (self.imt - np.mean(self.imt)) / np.std(self.imt) np.testing.assert_allclose(normalized, expected, rtol=1e-3) - @parameterized.expand(TEST_CASES) - def test_nonzero(self, input_param, input_data, expected_data): + @parameterized.expand(TESTS) + def test_nonzero(self, in_type, input_param, input_data, expected_data): normalizer = NormalizeIntensity(**input_param) - np.testing.assert_allclose(expected_data, normalizer(input_data)) + im = in_type(input_data) + normalized = normalizer(im) + self.assertEqual(type(im), type(normalized)) + if isinstance(normalized, torch.Tensor): + self.assertEqual(im.device, normalized.device) + normalized = normalized.cpu() + np.testing.assert_allclose(expected_data, normalized) - def test_channel_wise(self): + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_channel_wise(self, im_type): normalizer = NormalizeIntensity(nonzero=True, channel_wise=True) - input_data = np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]) + input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]])) expected = np.array([[0.0, -1.0, 0.0, 1.0], [0.0, -1.0, 0.0, 1.0]]) - np.testing.assert_allclose(expected, normalizer(input_data)) + normalized = normalizer(input_data) + self.assertEqual(type(input_data), type(normalized)) + if isinstance(normalized, torch.Tensor): + self.assertEqual(input_data.device, normalized.device) + normalized = normalized.cpu() + np.testing.assert_allclose(expected, normalized) - def test_value_errors(self): - input_data = np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]) + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_value_errors(self, im_type): + input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]])) normalizer = NormalizeIntensity(nonzero=True, channel_wise=True, subtrahend=[1]) with self.assertRaises(ValueError): normalizer(input_data) diff --git a/tests/test_normalize_intensityd.py b/tests/test_normalize_intensityd.py index 482d1a3f5b..2a45260b65 100644 --- a/tests/test_normalize_intensityd.py +++ b/tests/test_normalize_intensityd.py @@ -12,54 +12,80 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import NormalizeIntensityd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D -TEST_CASE_1 = [ - {"keys": ["img"], "nonzero": True}, - {"img": np.array([0.0, 3.0, 0.0, 4.0])}, - np.array([0.0, -1.0, 0.0, 1.0]), -] - -TEST_CASE_2 = [ - { - "keys": ["img"], - "subtrahend": np.array([3.5, 3.5, 3.5, 3.5]), - "divisor": np.array([0.5, 0.5, 0.5, 0.5]), - "nonzero": True, - }, - {"img": np.array([0.0, 3.0, 0.0, 4.0])}, - np.array([0.0, -1.0, 0.0, 1.0]), -] - -TEST_CASE_3 = [ - {"keys": ["img"], "nonzero": True}, - {"img": np.array([0.0, 0.0, 0.0, 0.0])}, - np.array([0.0, 0.0, 0.0, 0.0]), -] +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["img"], "nonzero": True}, + {"img": p(np.array([0.0, 3.0, 0.0, 4.0]))}, + np.array([0.0, -1.0, 0.0, 1.0]), + ] + ) + TESTS.append( + [ + { + "keys": ["img"], + "subtrahend": q(np.array([3.5, 3.5, 3.5, 3.5])), + "divisor": q(np.array([0.5, 0.5, 0.5, 0.5])), + "nonzero": True, + }, + {"img": p(np.array([0.0, 3.0, 0.0, 4.0]))}, + np.array([0.0, -1.0, 0.0, 1.0]), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "nonzero": True}, + {"img": p(np.array([0.0, 0.0, 0.0, 0.0]))}, + np.array([0.0, 0.0, 0.0, 0.0]), + ] + ) class TestNormalizeIntensityd(NumpyImageTestCase2D): - def test_image_normalize_intensityd(self): + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_image_normalize_intensityd(self, im_type): key = "img" + im = im_type(self.imt) normalizer = NormalizeIntensityd(keys=[key]) - normalized = normalizer({key: self.imt}) + normalized = normalizer({key: im})[key] expected = (self.imt - np.mean(self.imt)) / np.std(self.imt) - np.testing.assert_allclose(normalized[key], expected, rtol=1e-3) + self.assertEqual(type(im), type(normalized)) + if isinstance(normalized, torch.Tensor): + self.assertEqual(im.device, normalized.device) + normalized = normalized.cpu() + np.testing.assert_allclose(normalized, expected, rtol=1e-3) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_nonzero(self, input_param, input_data, expected_data): + key = "img" normalizer = NormalizeIntensityd(**input_param) - np.testing.assert_allclose(expected_data, normalizer(input_data)["img"]) + normalized = normalizer(input_data)[key] + self.assertEqual(type(input_data[key]), type(normalized)) + if isinstance(normalized, torch.Tensor): + self.assertEqual(input_data[key].device, normalized.device) + normalized = normalized.cpu() + np.testing.assert_allclose(normalized, expected_data) - def test_channel_wise(self): + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_channel_wise(self, im_type): key = "img" normalizer = NormalizeIntensityd(keys=key, nonzero=True, channel_wise=True) - input_data = {key: np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]])} + input_data = {key: im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]))} + normalized = normalizer(input_data)[key] + self.assertEqual(type(input_data[key]), type(normalized)) + if isinstance(normalized, torch.Tensor): + self.assertEqual(input_data[key].device, normalized.device) + normalized = normalized.cpu() expected = np.array([[0.0, -1.0, 0.0, 1.0], [0.0, -1.0, 0.0, 1.0]]) - np.testing.assert_allclose(expected, normalizer(input_data)[key]) + np.testing.assert_allclose(expected, normalized) if __name__ == "__main__": diff --git a/tests/test_orientation.py b/tests/test_orientation.py index aa7f33a469..875ad078bc 100644 --- a/tests/test_orientation.py +++ b/tests/test_orientation.py @@ -13,112 +13,164 @@ import nibabel as nib import numpy as np +import torch from parameterized import parameterized from monai.transforms import Orientation, create_rotate, create_translate +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - {"axcodes": "RAS"}, - np.arange(12).reshape((2, 1, 2, 3)), - {"affine": np.eye(4)}, - np.arange(12).reshape((2, 1, 2, 3)), - "RAS", - ], - [ - {"axcodes": "ALS"}, - np.arange(12).reshape((2, 1, 2, 3)), - {"affine": np.diag([-1, -1, 1, 1])}, - np.array([[[[3, 4, 5]], [[0, 1, 2]]], [[[9, 10, 11]], [[6, 7, 8]]]]), - "ALS", - ], - [ - {"axcodes": "RAS"}, - np.arange(12).reshape((2, 1, 2, 3)), - {"affine": np.diag([-1, -1, 1, 1])}, - np.array([[[[3, 4, 5], [0, 1, 2]]], [[[9, 10, 11], [6, 7, 8]]]]), - "RAS", - ], - [ - {"axcodes": "AL"}, - np.arange(6).reshape((2, 1, 3)), - {"affine": np.eye(3)}, - np.array([[[0], [1], [2]], [[3], [4], [5]]]), - "AL", - ], - [{"axcodes": "L"}, np.arange(6).reshape((2, 3)), {"affine": np.eye(2)}, np.array([[2, 1, 0], [5, 4, 3]]), "L"], - [{"axcodes": "L"}, np.arange(6).reshape((2, 3)), {"affine": np.eye(2)}, np.array([[2, 1, 0], [5, 4, 3]]), "L"], - [{"axcodes": "L"}, np.arange(6).reshape((2, 3)), {"affine": np.diag([-1, 1])}, np.arange(6).reshape((2, 3)), "L"], - [ - {"axcodes": "LPS"}, - np.arange(12).reshape((2, 1, 2, 3)), - { - "affine": create_translate(3, (10, 20, 30)) - @ create_rotate(3, (np.pi / 2, np.pi / 2, np.pi / 4)) - @ np.diag([-1, 1, 1, 1]) - }, - np.array([[[[2, 5]], [[1, 4]], [[0, 3]]], [[[8, 11]], [[7, 10]], [[6, 9]]]]), - "LPS", - ], - [ - {"as_closest_canonical": True}, - np.arange(12).reshape((2, 1, 2, 3)), - { - "affine": create_translate(3, (10, 20, 30)) - @ create_rotate(3, (np.pi / 2, np.pi / 2, np.pi / 4)) - @ np.diag([-1, 1, 1, 1]) - }, - np.array([[[[0, 3]], [[1, 4]], [[2, 5]]], [[[6, 9]], [[7, 10]], [[8, 11]]]]), - "RAS", - ], - [ - {"as_closest_canonical": True}, - np.arange(6).reshape((1, 2, 3)), - {"affine": create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])}, - np.array([[[3, 0], [4, 1], [5, 2]]]), - "RA", - ], - [ - {"axcodes": "LP"}, - np.arange(6).reshape((1, 2, 3)), - {"affine": create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])}, - np.array([[[2, 5], [1, 4], [0, 3]]]), - "LP", - ], - [ - {"axcodes": "LPID", "labels": tuple(zip("LPIC", "RASD"))}, - np.zeros((1, 2, 3, 4, 5)), - {"affine": np.diag([-1, -0.2, -1, 1, 1])}, - np.zeros((1, 2, 3, 4, 5)), - "LPID", - ], - [ - {"as_closest_canonical": True, "labels": tuple(zip("LPIC", "RASD"))}, - np.zeros((1, 2, 3, 4, 5)), - {"affine": np.diag([-1, -0.2, -1, 1, 1])}, - np.zeros((1, 2, 3, 4, 5)), - "RASD", - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TESTS.append( + [ + {"axcodes": "RAS"}, + p(np.arange(12).reshape((2, 1, 2, 3))), + {"affine": q(np.eye(4))}, + np.arange(12).reshape((2, 1, 2, 3)), + "RAS", + ] + ) + TESTS.append( + [ + {"axcodes": "ALS"}, + p(np.arange(12).reshape((2, 1, 2, 3))), + {"affine": q(np.diag([-1, -1, 1, 1]))}, + np.array([[[[3, 4, 5]], [[0, 1, 2]]], [[[9, 10, 11]], [[6, 7, 8]]]]), + "ALS", + ] + ) + TESTS.append( + [ + {"axcodes": "RAS"}, + p(np.arange(12).reshape((2, 1, 2, 3))), + {"affine": q(np.diag([-1, -1, 1, 1]))}, + np.array([[[[3, 4, 5], [0, 1, 2]]], [[[9, 10, 11], [6, 7, 8]]]]), + "RAS", + ] + ) + TESTS.append( + [ + {"axcodes": "AL"}, + p(np.arange(6).reshape((2, 1, 3))), + {"affine": q(np.eye(3))}, + np.array([[[0], [1], [2]], [[3], [4], [5]]]), + "AL", + ] + ) + TESTS.append( + [ + {"axcodes": "L"}, + p(np.arange(6).reshape((2, 3))), + {"affine": q(np.eye(2))}, + np.array([[2, 1, 0], [5, 4, 3]]), + "L", + ] + ) + TESTS.append( + [ + {"axcodes": "L"}, + p(np.arange(6).reshape((2, 3))), + {"affine": q(np.eye(2))}, + np.array([[2, 1, 0], [5, 4, 3]]), + "L", + ] + ) + TESTS.append( + [ + {"axcodes": "L"}, + p(np.arange(6).reshape((2, 3))), + {"affine": q(np.diag([-1, 1]))}, + np.arange(6).reshape((2, 3)), + "L", + ] + ) + TESTS.append( + [ + {"axcodes": "LPS"}, + p(np.arange(12).reshape((2, 1, 2, 3))), + { + "affine": q( + create_translate(3, (10, 20, 30)) + @ create_rotate(3, (np.pi / 2, np.pi / 2, np.pi / 4)) + @ np.diag([-1, 1, 1, 1]) + ) + }, + np.array([[[[2, 5]], [[1, 4]], [[0, 3]]], [[[8, 11]], [[7, 10]], [[6, 9]]]]), + "LPS", + ] + ) + TESTS.append( + [ + {"as_closest_canonical": True}, + np.arange(12).reshape((2, 1, 2, 3)), + { + "affine": q( + create_translate(3, (10, 20, 30)) + @ create_rotate(3, (np.pi / 2, np.pi / 2, np.pi / 4)) + @ np.diag([-1, 1, 1, 1]) + ) + }, + np.array([[[[0, 3]], [[1, 4]], [[2, 5]]], [[[6, 9]], [[7, 10]], [[8, 11]]]]), + "RAS", + ] + ) + TESTS.append( + [ + {"as_closest_canonical": True}, + p(np.arange(6).reshape((1, 2, 3))), + {"affine": q(create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1]))}, + np.array([[[3, 0], [4, 1], [5, 2]]]), + "RA", + ] + ) + TESTS.append( + [ + {"axcodes": "LP"}, + p(np.arange(6).reshape((1, 2, 3))), + {"affine": q(create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1]))}, + np.array([[[2, 5], [1, 4], [0, 3]]]), + "LP", + ] + ) + TESTS.append( + [ + {"axcodes": "LPID", "labels": tuple(zip("LPIC", "RASD"))}, + p(np.zeros((1, 2, 3, 4, 5))), + {"affine": q(np.diag([-1, -0.2, -1, 1, 1]))}, + np.zeros((1, 2, 3, 4, 5)), + "LPID", + ] + ) + TESTS.append( + [ + {"as_closest_canonical": True, "labels": tuple(zip("LPIC", "RASD"))}, + p(np.zeros((1, 2, 3, 4, 5))), + {"affine": q(np.diag([-1, -0.2, -1, 1, 1]))}, + np.zeros((1, 2, 3, 4, 5)), + "RASD", + ] + ) -ILL_CASES = [ - # no axcodes or as_cloest_canonical - [{}, np.arange(6).reshape((2, 3)), "L"], - # too short axcodes - [{"axcodes": "RA"}, np.arange(12).reshape((2, 1, 2, 3)), {"affine": np.eye(4)}], -] +ILL_CASES = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + # no axcodes or as_cloest_canonical + ILL_CASES.append([{}, p(np.arange(6).reshape((2, 3))), "L"]) + # too short axcodes + ILL_CASES.append([{"axcodes": "RA"}, p(np.arange(12).reshape((2, 1, 2, 3))), {"affine": q(np.eye(4))}]) class TestOrientationCase(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_ornt(self, init_param, img, data_param, expected_data, expected_code): ornt = Orientation(**init_param) res = ornt(img, **data_param) - if not isinstance(res, tuple): - np.testing.assert_allclose(res, expected_data) - return + res = [i.cpu() if isinstance(i, torch.Tensor) else i for i in res] np.testing.assert_allclose(res[0], expected_data) original_affine = data_param["affine"] + if isinstance(original_affine, torch.Tensor): + original_affine = original_affine.cpu() np.testing.assert_allclose(original_affine, res[1]) new_code = nib.orientations.aff2axcodes(res[2], labels=ornt.labels) self.assertEqual("".join(new_code), expected_code) diff --git a/tests/test_orientationd.py b/tests/test_orientationd.py index 452172ce9b..01e647decb 100644 --- a/tests/test_orientationd.py +++ b/tests/test_orientationd.py @@ -13,25 +13,34 @@ import nibabel as nib import numpy as np +from parameterized import parameterized from monai.transforms import Orientationd +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TESTS.append((p, q)) class TestOrientationdCase(unittest.TestCase): - def test_orntd(self): - data = {"seg": np.ones((2, 1, 2, 3)), "seg_meta_dict": {"affine": np.eye(4)}} + @parameterized.expand(TESTS) + def test_orntd(self, im_type, affine_type): + data = {"seg": im_type(np.ones((2, 1, 2, 3))), "seg_meta_dict": {"affine": affine_type(np.eye(4))}} ornt = Orientationd(keys="seg", axcodes="RAS") res = ornt(data) np.testing.assert_allclose(res["seg"].shape, (2, 1, 2, 3)) code = nib.aff2axcodes(res["seg_meta_dict"]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("R", "A", "S")) - def test_orntd_3d(self): + @parameterized.expand(TESTS) + def test_orntd_3d(self, im_type, affine_type): data = { - "seg": np.ones((2, 1, 2, 3)), - "img": np.ones((2, 1, 2, 3)), - "seg_meta_dict": {"affine": np.eye(4)}, - "img_meta_dict": {"affine": np.eye(4)}, + "seg": im_type(np.ones((2, 1, 2, 3))), + "img": im_type(np.ones((2, 1, 2, 3))), + "seg_meta_dict": {"affine": affine_type(np.eye(4))}, + "img_meta_dict": {"affine": affine_type(np.eye(4))}, } ornt = Orientationd(keys=("img", "seg"), axcodes="PLI") res = ornt(data) @@ -42,12 +51,13 @@ def test_orntd_3d(self): code = nib.aff2axcodes(res["img_meta_dict"]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("P", "L", "I")) - def test_orntd_2d(self): + @parameterized.expand(TESTS) + def test_orntd_2d(self, im_type, affine_type): data = { - "seg": np.ones((2, 1, 3)), - "img": np.ones((2, 1, 3)), - "seg_meta_dict": {"affine": np.eye(4)}, - "img_meta_dict": {"affine": np.eye(4)}, + "seg": im_type(np.ones((2, 1, 3))), + "img": im_type(np.ones((2, 1, 3))), + "seg_meta_dict": {"affine": affine_type(np.eye(4))}, + "img_meta_dict": {"affine": affine_type(np.eye(4))}, } ornt = Orientationd(keys=("img", "seg"), axcodes="PLI") res = ornt(data) @@ -57,12 +67,13 @@ def test_orntd_2d(self): code = nib.aff2axcodes(res["img_meta_dict"]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("P", "L", "S")) - def test_orntd_1d(self): + @parameterized.expand(TESTS) + def test_orntd_1d(self, im_type, affine_type): data = { - "seg": np.ones((2, 3)), - "img": np.ones((2, 3)), - "seg_meta_dict": {"affine": np.eye(4)}, - "img_meta_dict": {"affine": np.eye(4)}, + "seg": im_type(np.ones((2, 3))), + "img": im_type(np.ones((2, 3))), + "seg_meta_dict": {"affine": affine_type(np.eye(4))}, + "img_meta_dict": {"affine": affine_type(np.eye(4))}, } ornt = Orientationd(keys=("img", "seg"), axcodes="L") res = ornt(data) @@ -72,12 +83,13 @@ def test_orntd_1d(self): code = nib.aff2axcodes(res["img_meta_dict"]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("L", "A", "S")) - def test_orntd_canonical(self): + @parameterized.expand(TESTS) + def test_orntd_canonical(self, im_type, affine_type): data = { - "seg": np.ones((2, 1, 2, 3)), - "img": np.ones((2, 1, 2, 3)), - "seg_meta_dict": {"affine": np.eye(4)}, - "img_meta_dict": {"affine": np.eye(4)}, + "seg": im_type(np.ones((2, 1, 2, 3))), + "img": im_type(np.ones((2, 1, 2, 3))), + "seg_meta_dict": {"affine": affine_type(np.eye(4))}, + "img_meta_dict": {"affine": affine_type(np.eye(4))}, } ornt = Orientationd(keys=("img", "seg"), as_closest_canonical=True) res = ornt(data) @@ -88,8 +100,9 @@ def test_orntd_canonical(self): code = nib.aff2axcodes(res["img_meta_dict"]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("R", "A", "S")) - def test_orntd_no_metadata(self): - data = {"seg": np.ones((2, 1, 2, 3))} + @parameterized.expand(TESTS) + def test_orntd_no_metadata(self, im_type, _): + data = {"seg": im_type(np.ones((2, 1, 2, 3)))} ornt = Orientationd(keys="seg", axcodes="RAS") res = ornt(data) np.testing.assert_allclose(res["seg"].shape, (2, 1, 2, 3)) diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py index a8c544558f..d60b280317 100644 --- a/tests/test_pad_collation.py +++ b/tests/test_pad_collation.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import random import unittest from typing import List, Tuple @@ -18,6 +17,7 @@ from parameterized import parameterized from monai.data import CacheDataset, DataLoader +from monai.data.dataset import Dataset from monai.data.utils import decollate_batch, pad_list_data_collate from monai.transforms import ( PadListDataCollate, @@ -30,58 +30,72 @@ RandZoom, RandZoomd, ) -from monai.utils import set_determinism +from tests.utils import TEST_NDARRAYS TESTS: List[Tuple] = [] -for pad_collate in [ - lambda x: pad_list_data_collate(batch=x, method="end", mode="constant", constant_values=1), - PadListDataCollate(method="end", mode="constant", constant_values=1), -]: - TESTS.append((dict, pad_collate, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) - TESTS.append((dict, pad_collate, RandRotated("image", prob=1, range_x=np.pi, keep_size=False))) - TESTS.append((dict, pad_collate, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) - TESTS.append((dict, pad_collate, RandRotate90d("image", prob=1, max_k=2))) - - TESTS.append((list, pad_collate, RandSpatialCrop(roi_size=[8, 7], random_size=True))) - TESTS.append((list, pad_collate, RandRotate(prob=1, range_x=np.pi, keep_size=False))) - TESTS.append((list, pad_collate, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) - TESTS.append((list, pad_collate, RandRotate90(prob=1, max_k=2))) - - -class _Dataset(torch.utils.data.Dataset): - def __init__(self, images, labels, transforms): - self.images = images - self.labels = labels - self.transforms = transforms - - def __len__(self): - return len(self.images) - +for p in TEST_NDARRAYS: + for include_label in (True, False): + for pad_collate in [ + lambda x: pad_list_data_collate(batch=x, method="end", mode="constant", constant_values=1), + PadListDataCollate(method="end", mode="constant", constant_values=1), + ]: + TESTS.append( + (dict, p, include_label, pad_collate, RandSpatialCropd("im", roi_size=[8, 7], random_size=True)) + ) + TESTS.append( + (dict, p, include_label, pad_collate, RandRotated("im", prob=1, range_x=np.pi, keep_size=False)) + ) + TESTS.append( + ( + dict, + p, + include_label, + pad_collate, + RandZoomd("im", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False), + ) + ) + TESTS.append((dict, p, include_label, pad_collate, RandRotate90d("im", prob=1, max_k=2))) + + TESTS.append((list, p, include_label, pad_collate, RandSpatialCrop(roi_size=[8, 7], random_size=True))) + TESTS.append((list, p, include_label, pad_collate, RandRotate(prob=1, range_x=np.pi, keep_size=False))) + TESTS.append( + (list, p, include_label, pad_collate, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False)) + ) + TESTS.append((list, p, include_label, pad_collate, RandRotate90(prob=1, max_k=2))) + + +class TupleDataset(Dataset): def __getitem__(self, index): - return self.transforms(self.images[index]), self.labels[index] + return self.transform(self.data[index][0]), self.data[index][1] class TestPadCollation(unittest.TestCase): - def setUp(self) -> None: - set_determinism(seed=0) + @staticmethod + def get_data(t_type, im_type, include_label): # image is non square to throw rotation errors - im = np.arange(0, 10 * 9).reshape(1, 10, 9) + im = im_type(np.arange(0, 10 * 9).reshape(1, 10, 9)) num_elements = 20 - self.dict_data = [{"image": im} for _ in range(num_elements)] - self.list_data = [im for _ in range(num_elements)] - self.list_labels = [random.randint(0, 1) for _ in range(num_elements)] - - def tearDown(self) -> None: - set_determinism(None) + out = [] + for _ in range(num_elements): + label = np.random.randint(0, 1) + if t_type is dict: + out.append({"im": im, "label": label} if include_label else {"im": im}) + else: + out.append((im, label) if include_label else im) + return out @parameterized.expand(TESTS) - def test_pad_collation(self, t_type, collate_method, transform): + def test_pad_collation(self, t_type, im_type, include_label, collate_method, transform): + + input_data = self.get_data(t_type, im_type, include_label) - if t_type == dict: - dataset = CacheDataset(self.dict_data, transform, progress=False) + if t_type is dict: + dataset = CacheDataset(input_data, transform, progress=False) + elif isinstance(input_data[0], tuple): + dataset = TupleDataset(input_data, transform) else: - dataset = _Dataset(self.list_data, self.list_labels, transform) + dataset = Dataset(input_data, transform) # Default collation should raise an error loader_fail = DataLoader(dataset, batch_size=10) @@ -93,8 +107,14 @@ def test_pad_collation(self, t_type, collate_method, transform): loader = DataLoader(dataset, batch_size=10, collate_fn=collate_method) # check collation in forward direction for data in loader: - if t_type == dict: - decollated_data = decollate_batch(data) + d = data["im"] if isinstance(data, dict) else data + i = input_data[0]["im"] if isinstance(data, dict) else input_data[0] + if isinstance(i, torch.Tensor): + self.assertEqual(d.device, i.device) + + decollated_data = decollate_batch(data) + # if a dictionary, do the inverse + if t_type is dict: for d in decollated_data: PadListDataCollate.inverse(d) diff --git a/tests/test_patch_dataset.py b/tests/test_patch_dataset.py index 4f6e9a25fd..119da0bd2c 100644 --- a/tests/test_patch_dataset.py +++ b/tests/test_patch_dataset.py @@ -32,7 +32,7 @@ def test_shape(self): result = PatchDataset(dataset=test_dataset, patch_func=identity, samples_per_image=n_per_image) output = [] - n_workers = 0 if sys.platform == "win32" else 2 + n_workers = 2 if sys.platform == "linux" else 0 for item in DataLoader(result, batch_size=3, num_workers=n_workers): output.append("".join(item)) expected = ["vwx", "yzh", "ell", "owo", "rld"] diff --git a/tests/test_png_rw.py b/tests/test_png_rw.py index 265b31b83b..3557d28a84 100644 --- a/tests/test_png_rw.py +++ b/tests/test_png_rw.py @@ -14,67 +14,80 @@ import unittest import numpy as np +from parameterized import parameterized from PIL import Image from monai.data import write_png +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append((p,)) +# TESTS = [[np.array]] class TestPngWrite(unittest.TestCase): - def test_write_gray(self): + @parameterized.expand(TESTS) + def test_write_gray(self, in_type): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.png") img = np.random.rand(2, 3) img_save_val = (255 * img).astype(np.uint8) - write_png(img, image_name, scale=255) + write_png(in_type(img), image_name, scale=255) out = np.asarray(Image.open(image_name)) out = np.moveaxis(out, 0, 1) np.testing.assert_allclose(out, img_save_val) - def test_write_gray_1height(self): + @parameterized.expand(TESTS) + def test_write_gray_1height(self, in_type): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.png") img = np.random.rand(1, 3) img_save_val = (65535 * img).astype(np.uint16) - write_png(img, image_name, scale=65535) + write_png(in_type(img), image_name, scale=65535) out = np.asarray(Image.open(image_name)) out = np.moveaxis(out, 0, 1) np.testing.assert_allclose(out, img_save_val) - def test_write_gray_1channel(self): + @parameterized.expand(TESTS) + def test_write_gray_1channel(self, in_type): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.png") img = np.random.rand(2, 3, 1) img_save_val = (255 * img).astype(np.uint8).squeeze(2) - write_png(img, image_name, scale=255) + write_png(in_type(img), image_name, scale=255) out = np.asarray(Image.open(image_name)) out = np.moveaxis(out, 0, 1) np.testing.assert_allclose(out, img_save_val) - def test_write_rgb(self): + @parameterized.expand(TESTS) + def test_write_rgb(self, in_type): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.png") img = np.random.rand(2, 3, 3) img_save_val = (255 * img).astype(np.uint8) - write_png(img, image_name, scale=255) + write_png(in_type(img), image_name, scale=255) out = np.asarray(Image.open(image_name)) out = np.moveaxis(out, 0, 1) np.testing.assert_allclose(out, img_save_val) - def test_write_2channels(self): + @parameterized.expand(TESTS) + def test_write_2channels(self, in_type): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.png") img = np.random.rand(2, 3, 2) img_save_val = (255 * img).astype(np.uint8) - write_png(img, image_name, scale=255) + write_png(in_type(img), image_name, scale=255) out = np.asarray(Image.open(image_name)) out = np.moveaxis(out, 0, 1) np.testing.assert_allclose(out, img_save_val) - def test_write_output_shape(self): + @parameterized.expand(TESTS) + def test_write_output_shape(self, in_type): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.png") img = np.random.rand(2, 2, 3) - write_png(img, image_name, (4, 4), scale=255) + write_png(in_type(img), image_name, (4, 4), scale=255) out = np.asarray(Image.open(image_name)) np.testing.assert_allclose(out.shape, (4, 4, 3)) diff --git a/tests/test_probnms.py b/tests/test_probnms.py index e51d1017d8..e43b7ddcfb 100644 --- a/tests/test_probnms.py +++ b/tests/test_probnms.py @@ -16,83 +16,80 @@ from parameterized import parameterized from monai.transforms.post.array import ProbNMS +from tests.utils import TEST_NDARRAYS -probs_map_1 = np.random.rand(100, 100).clip(0, 0.5) -TEST_CASES_2D_1 = [{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": 10}, probs_map_1, []] +TESTS = [] +for p in TEST_NDARRAYS: + probs_map_1 = p(np.random.rand(100, 100).clip(0, 0.5)) + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": 10}, probs_map_1, []]) -probs_map_2 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_2[33, 33] = 0.7 -probs_map_2[66, 66] = 0.9 -expected_2 = [[0.9, 66, 66], [0.7, 33, 33]] -TEST_CASES_2D_2 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": [10, 10]}, - probs_map_2, - expected_2, -] - -probs_map_3 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_3[56, 58] = 0.7 -probs_map_3[60, 66] = 0.8 -probs_map_3[66, 66] = 0.9 -expected_3 = [[0.9, 66, 66], [0.8, 60, 66]] -TEST_CASES_2D_3 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": (10, 20)}, - probs_map_3, - expected_3, -] - -probs_map_4 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_4[33, 33] = 0.7 -probs_map_4[66, 66] = 0.9 -expected_4 = [[0.9, 66, 66]] -TEST_CASES_2D_4 = [ - {"spatial_dims": 2, "prob_threshold": 0.8, "box_size": 10}, - probs_map_4, - expected_4, -] - -probs_map_5 = np.random.rand(100, 100).clip(0, 0.5) -TEST_CASES_2D_5 = [{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, probs_map_5, []] + probs_map_2 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_2[33, 33] = 0.7 + probs_map_2[66, 66] = 0.9 + expected_2 = [[0.9, 66, 66], [0.7, 33, 33]] + TESTS.append( + [ + {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": [10, 10]}, + probs_map_2, + expected_2, + ] + ) -probs_map_6 = torch.as_tensor(np.random.rand(100, 100).clip(0, 0.5)) -TEST_CASES_2D_6 = [{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, probs_map_6, []] + probs_map_3 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_3[56, 58] = 0.7 + probs_map_3[60, 66] = 0.8 + probs_map_3[66, 66] = 0.9 + expected_3 = [[0.9, 66, 66], [0.8, 60, 66]] + TESTS.append( + [ + {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": (10, 20)}, + probs_map_3, + expected_3, + ] + ) -probs_map_7 = torch.as_tensor(np.random.rand(100, 100).clip(0, 0.5)) -probs_map_7[33, 33] = 0.7 -probs_map_7[66, 66] = 0.9 -if torch.cuda.is_available(): - probs_map_7 = probs_map_7.cuda() -expected_7 = [[0.9, 66, 66], [0.7, 33, 33]] -TEST_CASES_2D_7 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, - probs_map_7, - expected_7, -] + probs_map_4 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_4[33, 33] = 0.7 + probs_map_4[66, 66] = 0.9 + expected_4 = [[0.9, 66, 66]] + TESTS.append( + [ + {"spatial_dims": 2, "prob_threshold": 0.8, "box_size": 10}, + probs_map_4, + expected_4, + ] + ) -probs_map_3d = torch.rand([50, 50, 50]).uniform_(0, 0.5) -probs_map_3d[25, 25, 25] = 0.7 -probs_map_3d[45, 45, 45] = 0.9 -expected_3d = [[0.9, 45, 45, 45], [0.7, 25, 25, 25]] -TEST_CASES_3D = [ - {"spatial_dims": 3, "prob_threshold": 0.5, "box_size": (10, 10, 10)}, - probs_map_3d, - expected_3d, -] + probs_map_5 = p(np.random.rand(100, 100).clip(0, 0.5)) + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, probs_map_5, []]) + probs_map_6 = p((np.random.rand(100, 100).clip(0, 0.5))) + probs_map_6[33, 33] = 0.7 + probs_map_6[66, 66] = 0.9 + expected_6 = [[0.9, 66, 66], [0.7, 33, 33]] + TESTS.append( + [ + {"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, + probs_map_6, + expected_6, + ] + ) -class TestProbNMS(unittest.TestCase): - @parameterized.expand( + probs_map_3d = p(torch.rand([50, 50, 50]).uniform_(0, 0.5)) + probs_map_3d[25, 25, 25] = 0.7 + probs_map_3d[45, 45, 45] = 0.9 + expected_3d = [[0.9, 45, 45, 45], [0.7, 25, 25, 25]] + TESTS.append( [ - TEST_CASES_2D_1, - TEST_CASES_2D_2, - TEST_CASES_2D_3, - TEST_CASES_2D_4, - TEST_CASES_2D_5, - TEST_CASES_2D_6, - TEST_CASES_2D_7, - TEST_CASES_3D, + {"spatial_dims": 3, "prob_threshold": 0.5, "box_size": (10, 10, 10)}, + probs_map_3d, + expected_3d, ] ) + + +class TestProbNMS(unittest.TestCase): + @parameterized.expand(TESTS) def test_output(self, class_args, probs_map, expected): nms = ProbNMS(**class_args) output = nms(probs_map) diff --git a/tests/test_probnmsd.py b/tests/test_probnmsd.py index 5b75d4310f..c9ae7c6a46 100644 --- a/tests/test_probnmsd.py +++ b/tests/test_probnmsd.py @@ -10,91 +10,89 @@ # limitations under the License. import unittest +from typing import List import numpy as np import torch from parameterized import parameterized -from monai.transforms.post.dictionary import ProbNMSD +from monai.transforms.post.dictionary import ProbNMSd +from tests.utils import TEST_NDARRAYS -probs_map_1 = np.random.rand(100, 100).clip(0, 0.5) -TEST_CASES_2D_1 = [{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": 10}, {"prob_map": probs_map_1}, []] +TESTS: List[List] = [] +for p in TEST_NDARRAYS: + probs_map_1 = p(np.random.rand(100, 100).clip(0, 0.5)) + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": 10}, {"prob_map": probs_map_1}, []]) -probs_map_2 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_2[33, 33] = 0.7 -probs_map_2[66, 66] = 0.9 -expected_2 = [[0.9, 66, 66], [0.7, 33, 33]] -TEST_CASES_2D_2 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": [10, 10]}, - {"prob_map": probs_map_2}, - expected_2, -] - -probs_map_3 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_3[56, 58] = 0.7 -probs_map_3[60, 66] = 0.8 -probs_map_3[66, 66] = 0.9 -expected_3 = [[0.9, 66, 66], [0.8, 60, 66]] -TEST_CASES_2D_3 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": (10, 20)}, - {"prob_map": probs_map_3}, - expected_3, -] - -probs_map_4 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_4[33, 33] = 0.7 -probs_map_4[66, 66] = 0.9 -expected_4 = [[0.9, 66, 66]] -TEST_CASES_2D_4 = [ - {"spatial_dims": 2, "prob_threshold": 0.8, "box_size": 10}, - {"prob_map": probs_map_4}, - expected_4, -] - -probs_map_5 = np.random.rand(100, 100).clip(0, 0.5) -TEST_CASES_2D_5 = [{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, {"prob_map": probs_map_5}, []] + probs_map_2 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_2[33, 33] = 0.7 + probs_map_2[66, 66] = 0.9 + expected_2 = [[0.9, 66, 66], [0.7, 33, 33]] + TESTS.append( + [ + {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": [10, 10]}, + {"prob_map": probs_map_2}, + expected_2, + ] + ) -probs_map_6 = torch.as_tensor(np.random.rand(100, 100).clip(0, 0.5)) -TEST_CASES_2D_6 = [{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, {"prob_map": probs_map_6}, []] + probs_map_3 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_3[56, 58] = 0.7 + probs_map_3[60, 66] = 0.8 + probs_map_3[66, 66] = 0.9 + expected_3 = [[0.9, 66, 66], [0.8, 60, 66]] + TESTS.append( + [ + {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": (10, 20)}, + {"prob_map": probs_map_3}, + expected_3, + ] + ) -probs_map_7 = torch.as_tensor(np.random.rand(100, 100).clip(0, 0.5)) -probs_map_7[33, 33] = 0.7 -probs_map_7[66, 66] = 0.9 -if torch.cuda.is_available(): - probs_map_7 = probs_map_7.cuda() -expected_7 = [[0.9, 66, 66], [0.7, 33, 33]] -TEST_CASES_2D_7 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, - {"prob_map": probs_map_7}, - expected_7, -] + probs_map_4 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_4[33, 33] = 0.7 + probs_map_4[66, 66] = 0.9 + expected_4 = [[0.9, 66, 66]] + TESTS.append( + [ + {"spatial_dims": 2, "prob_threshold": 0.8, "box_size": 10}, + {"prob_map": probs_map_4}, + expected_4, + ] + ) -probs_map_3d = torch.rand([50, 50, 50]).uniform_(0, 0.5) -probs_map_3d[25, 25, 25] = 0.7 -probs_map_3d[45, 45, 45] = 0.9 -expected_3d = [[0.9, 45, 45, 45], [0.7, 25, 25, 25]] -TEST_CASES_3D = [ - {"spatial_dims": 3, "prob_threshold": 0.5, "box_size": (10, 10, 10)}, - {"prob_map": probs_map_3d}, - expected_3d, -] + probs_map_5 = p(np.random.rand(100, 100).clip(0, 0.5)) + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, {"prob_map": probs_map_5}, []]) + probs_map_6 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_6[33, 33] = 0.7 + probs_map_6[66, 66] = 0.9 + expected_6 = [[0.9, 66, 66], [0.7, 33, 33]] + TESTS.append( + [ + {"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, + {"prob_map": probs_map_6}, + expected_6, + ] + ) -class TestProbNMS(unittest.TestCase): - @parameterized.expand( + probs_map_3d = p(torch.rand([50, 50, 50]).uniform_(0, 0.5)) + probs_map_3d[25, 25, 25] = 0.7 + probs_map_3d[45, 45, 45] = 0.9 + expected_3d = [[0.9, 45, 45, 45], [0.7, 25, 25, 25]] + TESTS.append( [ - TEST_CASES_2D_1, - TEST_CASES_2D_2, - TEST_CASES_2D_3, - TEST_CASES_2D_4, - TEST_CASES_2D_5, - TEST_CASES_2D_6, - TEST_CASES_2D_7, - TEST_CASES_3D, + {"spatial_dims": 3, "prob_threshold": 0.5, "box_size": (10, 10, 10)}, + {"prob_map": probs_map_3d}, + expected_3d, ] ) + + +class TestProbNMS(unittest.TestCase): + @parameterized.expand(TESTS) def test_output(self, class_args, probs_map, expected): - nms = ProbNMSD(keys="prob_map", **class_args) + nms = ProbNMSd(keys="prob_map", **class_args) output = nms(probs_map) np.testing.assert_allclose(output["prob_map"], expected) diff --git a/tests/test_rand_affine.py b/tests/test_rand_affine.py index 1e1a23bc09..1a0502df3f 100644 --- a/tests/test_rand_affine.py +++ b/tests/test_rand_affine.py @@ -16,114 +16,136 @@ from parameterized import parameterized from monai.transforms import RandAffine +from monai.utils.misc import convert_data_type +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - dict(as_tensor_output=False, device=None), - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - dict(as_tensor_output=False, device=None, spatial_size=-1), - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - dict(as_tensor_output=False, device=None), - {"img": torch.arange(27).reshape((3, 3, 3)), "spatial_size": (2, 2)}, - np.array([[[2.0, 3.0], [5.0, 6.0]], [[11.0, 12.0], [14.0, 15.0]], [[20.0, 21.0], [23.0, 24.0]]]), - ], - [ - dict(as_tensor_output=True, device=None), - {"img": torch.ones((1, 3, 3, 3)), "spatial_size": (2, 2, 2)}, - torch.ones((1, 2, 2, 2)), - ], - [ - dict(as_tensor_output=True, device=None, spatial_size=(2, 2, 2), cache_grid=True), - {"img": torch.ones((1, 3, 3, 3))}, - torch.ones((1, 2, 2, 2)), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - padding_mode="zeros", - spatial_size=(2, 2, 2), - device=None, - ), - {"img": torch.ones((1, 3, 3, 3)), "mode": "bilinear"}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - padding_mode="zeros", - spatial_size=(2, 2, 2), - cache_grid=True, - device=None, - ), - {"img": torch.ones((1, 3, 3, 3)), "mode": "bilinear"}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=True, - device=None, - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "spatial_size": (3, 3)}, - torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - spatial_size=(3, 3), - cache_grid=True, - as_tensor_output=True, - device=None, - ), - {"img": torch.arange(64).reshape((1, 8, 8))}, - torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]), - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + dict(device=device), + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=-1), + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), + ] + ) + TESTS.append( + [ + dict(device=device), + {"img": p(torch.arange(27).reshape((3, 3, 3))), "spatial_size": (2, 2)}, + p(np.array([[[2.0, 3.0], [5.0, 6.0]], [[11.0, 12.0], [14.0, 15.0]], [[20.0, 21.0], [23.0, 24.0]]])), + ] + ) + TESTS.append( + [ + dict(device=device), + {"img": p(torch.ones((1, 3, 3, 3))), "spatial_size": (2, 2, 2)}, + p(torch.ones((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2, 2), cache_grid=True), + {"img": p(torch.ones((1, 3, 3, 3)))}, + p(torch.ones((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + padding_mode="zeros", + spatial_size=(2, 2, 2), + device=device, + ), + {"img": p(torch.ones((1, 3, 3, 3))), "mode": "bilinear"}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + padding_mode="zeros", + spatial_size=(2, 2, 2), + cache_grid=True, + device=device, + ), + {"img": p(torch.ones((1, 3, 3, 3))), "mode": "bilinear"}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "spatial_size": (3, 3)}, + p( + torch.tensor( + [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]] + ) + ), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + cache_grid=True, + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8)))}, + p( + torch.tensor( + [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]] + ) + ), + ] + ) -ARR_NUMPY = np.arange(9 * 10).reshape(1, 9, 10) -ARR_TORCH = torch.Tensor(ARR_NUMPY) TEST_CASES_SKIPPED_CONSISTENCY = [] -for im in (ARR_NUMPY, ARR_TORCH): - for as_tensor_output in (True, False): - for in_dtype_is_int in (True, False): - TEST_CASES_SKIPPED_CONSISTENCY.append((im, as_tensor_output, in_dtype_is_int)) +for p in TEST_NDARRAYS: + for in_dtype in (np.int32, np.float32): + TEST_CASES_SKIPPED_CONSISTENCY.append((p(np.arange(9 * 10).reshape(1, 9, 10)), in_dtype)) class TestRandAffine(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_affine(self, input_param, input_data, expected_val): g = RandAffine(**input_param) g.set_random_state(123) result = g(**input_data) if input_param.get("cache_grid", False): self.assertTrue(g._cached_grid is not None) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) + self.assertEqual(type(result), type(expected_val)) if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + self.assertEqual(result.device, expected_val.device) + result, expected_val = result.cpu(), expected_val.cpu() + np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) def test_ill_cache(self): with self.assertWarns(UserWarning): @@ -132,15 +154,11 @@ def test_ill_cache(self): RandAffine(cache_grid=True, spatial_size=(1, 1, -1)) @parameterized.expand(TEST_CASES_SKIPPED_CONSISTENCY) - def test_skipped_transform_consistency(self, im, as_tensor_output, in_dtype_is_int): - t1 = RandAffine(prob=0, as_tensor_output=as_tensor_output) - t2 = RandAffine(prob=1, spatial_size=(10, 11), as_tensor_output=as_tensor_output) + def test_skipped_transform_consistency(self, im, in_dtype): + t1 = RandAffine(prob=0) + t2 = RandAffine(prob=1, spatial_size=(10, 11)) - # change dtype to int32 or float32 - if in_dtype_is_int: - im = im.astype("int32") if isinstance(im, np.ndarray) else im.int() - else: - im = im.astype("float32") if isinstance(im, np.ndarray) else im.float() + im, *_ = convert_data_type(im, dtype=in_dtype) out1 = t1(im) out2 = t2(im) diff --git a/tests/test_rand_affine_grid.py b/tests/test_rand_affine_grid.py index 605d0a30ba..3b81a924ba 100644 --- a/tests/test_rand_affine_grid.py +++ b/tests/test_rand_affine_grid.py @@ -16,182 +16,196 @@ from parameterized import parameterized from monai.transforms import RandAffineGrid +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [{"as_tensor_output": False, "device": None}, {"grid": torch.ones((3, 3, 3))}, np.ones((3, 3, 3))], - [ - {"rotate_range": (1, 2), "translate_range": (3, 3, 3)}, - {"grid": torch.arange(0, 27).reshape((3, 3, 3))}, - torch.tensor( - np.array( - [ - [ - [-32.81998, -33.910976, -35.001972], - [-36.092968, -37.183964, -38.27496], - [-39.36596, -40.456955, -41.54795], - ], - [[2.1380205, 3.1015975, 4.0651755], [5.028752, 5.9923296, 6.955907], [7.919484, 8.883063, 9.84664]], - [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0], [24.0, 25.0, 26.0]], - ] - ) - ), - ], - [ - {"translate_range": (3, 3, 3), "as_tensor_output": False, "device": torch.device("cpu:0")}, - {"spatial_size": (3, 3, 3)}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append([{"device": device}, {"grid": p(torch.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))]) + TESTS.append( [ - [ - [ - [0.17881513, 0.17881513, 0.17881513], - [0.17881513, 0.17881513, 0.17881513], - [0.17881513, 0.17881513, 0.17881513], - ], - [ - [1.1788151, 1.1788151, 1.1788151], - [1.1788151, 1.1788151, 1.1788151], - [1.1788151, 1.1788151, 1.1788151], - ], - [ - [2.1788151, 2.1788151, 2.1788151], - [2.1788151, 2.1788151, 2.1788151], - [2.1788151, 2.1788151, 2.1788151], - ], - ], - [ - [ - [-2.283164, -2.283164, -2.283164], - [-1.283164, -1.283164, -1.283164], - [-0.28316402, -0.28316402, -0.28316402], - ], - [ - [-2.283164, -2.283164, -2.283164], - [-1.283164, -1.283164, -1.283164], - [-0.28316402, -0.28316402, -0.28316402], - ], - [ - [-2.283164, -2.283164, -2.283164], - [-1.283164, -1.283164, -1.283164], - [-0.28316402, -0.28316402, -0.28316402], - ], - ], - [ - [ - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - ], - [ - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - ], - [ - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - ], - ], - [ - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - ], - ] - ), - ], - [ - {"rotate_range": (1.0, 1.0, 1.0), "shear_range": (0.1,), "scale_range": (1.2,)}, - {"grid": torch.arange(0, 108).reshape((4, 3, 3, 3))}, - torch.tensor( - np.array( - [ - [ - [ - [-9.4201e00, -8.1672e00, -6.9143e00], - [-5.6614e00, -4.4085e00, -3.1556e00], - [-1.9027e00, -6.4980e-01, 6.0310e-01], - ], - [ - [1.8560e00, 3.1089e00, 4.3618e00], - [5.6147e00, 6.8676e00, 8.1205e00], - [9.3734e00, 1.0626e01, 1.1879e01], - ], + {"rotate_range": (1, 2), "translate_range": (3, 3, 3)}, + {"grid": p(torch.arange(0, 27).reshape((3, 3, 3)))}, + p( + np.array( [ - [1.3132e01, 1.4385e01, 1.5638e01], - [1.6891e01, 1.8144e01, 1.9397e01], - [2.0650e01, 2.1902e01, 2.3155e01], - ], - ], - [ - [ - [9.9383e-02, -4.8845e-01, -1.0763e00], - [-1.6641e00, -2.2519e00, -2.8398e00], - [-3.4276e00, -4.0154e00, -4.6032e00], - ], - [ - [-5.1911e00, -5.7789e00, -6.3667e00], - [-6.9546e00, -7.5424e00, -8.1302e00], - [-8.7180e00, -9.3059e00, -9.8937e00], - ], - [ - [-1.0482e01, -1.1069e01, -1.1657e01], - [-1.2245e01, -1.2833e01, -1.3421e01], - [-1.4009e01, -1.4596e01, -1.5184e01], - ], - ], + [ + [-32.81998, -33.910976, -35.001972], + [-36.092968, -37.183964, -38.27496], + [-39.36596, -40.456955, -41.54795], + ], + [ + [2.1380205, 3.1015975, 4.0651755], + [5.028752, 5.9923296, 6.955907], + [7.919484, 8.883063, 9.84664], + ], + [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0], [24.0, 25.0, 26.0]], + ] + ) + ), + ] + ) + TESTS.append( + [ + {"translate_range": (3, 3, 3), "device": device}, + {"spatial_size": (3, 3, 3)}, + np.array( [ [ - [5.9635e01, 6.1199e01, 6.2764e01], - [6.4328e01, 6.5892e01, 6.7456e01], - [6.9021e01, 7.0585e01, 7.2149e01], - ], - [ - [7.3714e01, 7.5278e01, 7.6842e01], - [7.8407e01, 7.9971e01, 8.1535e01], - [8.3099e01, 8.4664e01, 8.6228e01], + [ + [0.17881513, 0.17881513, 0.17881513], + [0.17881513, 0.17881513, 0.17881513], + [0.17881513, 0.17881513, 0.17881513], + ], + [ + [1.1788151, 1.1788151, 1.1788151], + [1.1788151, 1.1788151, 1.1788151], + [1.1788151, 1.1788151, 1.1788151], + ], + [ + [2.1788151, 2.1788151, 2.1788151], + [2.1788151, 2.1788151, 2.1788151], + [2.1788151, 2.1788151, 2.1788151], + ], ], [ - [8.7792e01, 8.9357e01, 9.0921e01], - [9.2485e01, 9.4049e01, 9.5614e01], - [9.7178e01, 9.8742e01, 1.0031e02], + [ + [-2.283164, -2.283164, -2.283164], + [-1.283164, -1.283164, -1.283164], + [-0.28316402, -0.28316402, -0.28316402], + ], + [ + [-2.283164, -2.283164, -2.283164], + [-1.283164, -1.283164, -1.283164], + [-0.28316402, -0.28316402, -0.28316402], + ], + [ + [-2.283164, -2.283164, -2.283164], + [-1.283164, -1.283164, -1.283164], + [-0.28316402, -0.28316402, -0.28316402], + ], ], - ], - [ [ - [8.1000e01, 8.2000e01, 8.3000e01], - [8.4000e01, 8.5000e01, 8.6000e01], - [8.7000e01, 8.8000e01, 8.9000e01], + [ + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + ], + [ + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + ], + [ + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + ], ], [ - [9.0000e01, 9.1000e01, 9.2000e01], - [9.3000e01, 9.4000e01, 9.5000e01], - [9.6000e01, 9.7000e01, 9.8000e01], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], ], + ] + ), + ] + ) + TESTS.append( + [ + {"device": device, "rotate_range": (1.0, 1.0, 1.0), "shear_range": (0.1,), "scale_range": (1.2,)}, + {"grid": p(torch.arange(0, 108).reshape((4, 3, 3, 3)))}, + p( + np.array( [ - [9.9000e01, 1.0000e02, 1.0100e02], - [1.0200e02, 1.0300e02, 1.0400e02], - [1.0500e02, 1.0600e02, 1.0700e02], - ], - ], - ] - ) - ), - ], -] + [ + [ + [-9.4201e00, -8.1672e00, -6.9143e00], + [-5.6614e00, -4.4085e00, -3.1556e00], + [-1.9027e00, -6.4980e-01, 6.0310e-01], + ], + [ + [1.8560e00, 3.1089e00, 4.3618e00], + [5.6147e00, 6.8676e00, 8.1205e00], + [9.3734e00, 1.0626e01, 1.1879e01], + ], + [ + [1.3132e01, 1.4385e01, 1.5638e01], + [1.6891e01, 1.8144e01, 1.9397e01], + [2.0650e01, 2.1902e01, 2.3155e01], + ], + ], + [ + [ + [9.9383e-02, -4.8845e-01, -1.0763e00], + [-1.6641e00, -2.2519e00, -2.8398e00], + [-3.4276e00, -4.0154e00, -4.6032e00], + ], + [ + [-5.1911e00, -5.7789e00, -6.3667e00], + [-6.9546e00, -7.5424e00, -8.1302e00], + [-8.7180e00, -9.3059e00, -9.8937e00], + ], + [ + [-1.0482e01, -1.1069e01, -1.1657e01], + [-1.2245e01, -1.2833e01, -1.3421e01], + [-1.4009e01, -1.4596e01, -1.5184e01], + ], + ], + [ + [ + [5.9635e01, 6.1199e01, 6.2764e01], + [6.4328e01, 6.5892e01, 6.7456e01], + [6.9021e01, 7.0585e01, 7.2149e01], + ], + [ + [7.3714e01, 7.5278e01, 7.6842e01], + [7.8407e01, 7.9971e01, 8.1535e01], + [8.3099e01, 8.4664e01, 8.6228e01], + ], + [ + [8.7792e01, 8.9357e01, 9.0921e01], + [9.2485e01, 9.4049e01, 9.5614e01], + [9.7178e01, 9.8742e01, 1.0031e02], + ], + ], + [ + [ + [8.1000e01, 8.2000e01, 8.3000e01], + [8.4000e01, 8.5000e01, 8.6000e01], + [8.7000e01, 8.8000e01, 8.9000e01], + ], + [ + [9.0000e01, 9.1000e01, 9.2000e01], + [9.3000e01, 9.4000e01, 9.5000e01], + [9.6000e01, 9.7000e01, 9.8000e01], + ], + [ + [9.9000e01, 1.0000e02, 1.0100e02], + [1.0200e02, 1.0300e02, 1.0400e02], + [1.0500e02, 1.0600e02, 1.0700e02], + ], + ], + ] + ) + ), + ] + ) class TestRandAffineGrid(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_affine_grid(self, input_param, input_data, expected_val): g = RandAffineGrid(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + if "grid" in input_data: + self.assertEqual(type(result), type(expected_val)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, expected_val.device) + result = result.cpu() + expected_val = expected_val.cpu() + np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py index d2f8a60665..a0078aae1e 100644 --- a/tests/test_rand_affined.py +++ b/tests/test_rand_affined.py @@ -17,179 +17,188 @@ from monai.transforms import RandAffined from monai.utils import GridSampleMode +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - dict(as_tensor_output=False, device=None, spatial_size=None, keys=("img", "seg")), - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - dict(as_tensor_output=False, device=None, spatial_size=(2, 2), keys=("img", "seg")), - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.ones((3, 2, 2)), - ], - [ - dict(as_tensor_output=False, device=None, spatial_size=(2, 2), cache_grid=True, keys=("img", "seg")), - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.ones((3, 2, 2)), - ], - [ - dict(as_tensor_output=True, device=None, spatial_size=(2, 2, 2), keys=("img", "seg")), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - torch.ones((1, 2, 2, 2)), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - spatial_size=(2, 2, 2), - padding_mode="zeros", - device=None, - keys=("img", "seg"), - mode="bilinear", - ), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=False, - spatial_size=(2, 2, 2), - padding_mode="zeros", - device=None, - cache_grid=True, - keys=("img", "seg"), - mode="bilinear", - ), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - np.array([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=True, - spatial_size=(3, 3), - keys=("img", "seg"), - device=None, - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]), - ], - [ - dict( - prob=0.9, - mode=("bilinear", "nearest"), - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=False, - spatial_size=(3, 3), - keys=("img", "seg"), - device=torch.device("cpu:0"), - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - { - "img": np.array( - [ - [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], - ] - ] - ), - "seg": np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]]), - }, - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - spatial_size=(2, 2, 2), - padding_mode="zeros", - device=None, - keys=("img", "seg"), - mode=GridSampleMode.BILINEAR, - ), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=False, - spatial_size=(3, 3), - keys=("img", "seg"), - device=torch.device("cpu:0"), - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - { - "img": np.array( - [ - [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], - ] - ] - ), - "seg": np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]]), - }, - ], - [ - dict( - prob=0.9, - mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=False, - spatial_size=(3, 3), - cache_grid=True, - keys=("img", "seg"), - device=torch.device("cpu:0"), - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - { - "img": np.array( - [ - [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], - ] - ] - ), - "seg": np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]]), - }, - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + dict(device=device, spatial_size=None, keys=("img", "seg")), + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2), keys=("img", "seg")), + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p(np.ones((3, 2, 2))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2), cache_grid=True, keys=("img", "seg")), + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p(np.ones((3, 2, 2))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2, 2), keys=("img", "seg")), + {"img": p(torch.ones((1, 3, 3, 3))), "seg": p(torch.ones((1, 3, 3, 3)))}, + p(torch.ones((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + spatial_size=(2, 2, 2), + padding_mode="zeros", + device=device, + keys=("img", "seg"), + mode="bilinear", + ), + {"img": p(torch.ones((1, 3, 3, 3))), "seg": p(torch.ones((1, 3, 3, 3)))}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + p( + torch.tensor( + [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]] + ) + ), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + mode=("bilinear", "nearest"), + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + { + "img": p( + np.array( + [ + [ + [18.736153, 15.581954, 12.4277525], + [27.398798, 24.244598, 21.090399], + [36.061443, 32.90724, 29.753046], + ] + ] + ) + ), + "seg": p(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + }, + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + spatial_size=(2, 2, 2), + padding_mode="zeros", + device=device, + keys=("img", "seg"), + mode=GridSampleMode.BILINEAR, + ), + {"img": p(torch.ones((1, 3, 3, 3))), "seg": p(torch.ones((1, 3, 3, 3)))}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + { + "img": p( + np.array( + [ + [ + [18.736153, 15.581954, 12.4277525], + [27.398798, 24.244598, 21.090399], + [36.061443, 32.90724, 29.753046], + ] + ] + ) + ), + "seg": p(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + }, + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + cache_grid=True, + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + { + "img": p( + np.array( + [ + [ + [18.736153, 15.581954, 12.4277525], + [27.398798, 24.244598, 21.090399], + [36.061443, 32.90724, 29.753046], + ] + ] + ) + ), + "seg": p(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + }, + ] + ) class TestRandAffined(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_affined(self, input_param, input_data, expected_val): g = RandAffined(**input_param).set_random_state(123) res = g(input_data) @@ -200,23 +209,20 @@ def test_rand_affined(self, input_param, input_data, expected_val): if "_transforms" in key: continue expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) + self.assertEqual(type(result), type(expected)) if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + self.assertEqual(result.device, expected.device) + result, expected = result.cpu(), expected.cpu() + np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) def test_ill_cache(self): with self.assertWarns(UserWarning): # spatial size is None - RandAffined( - as_tensor_output=False, device=None, spatial_size=None, prob=1.0, cache_grid=True, keys=("img", "seg") - ) + RandAffined(device=device, spatial_size=None, prob=1.0, cache_grid=True, keys=("img", "seg")) with self.assertWarns(UserWarning): # spatial size is dynamic RandAffined( - as_tensor_output=False, - device=None, + device=device, spatial_size=(2, -1), prob=1.0, cache_grid=True, diff --git a/tests/test_rand_coarse_dropout.py b/tests/test_rand_coarse_dropout.py index 235a391567..a7c38a6c66 100644 --- a/tests/test_rand_coarse_dropout.py +++ b/tests/test_rand_coarse_dropout.py @@ -12,41 +12,54 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandCoarseDropout from monai.utils import fall_back_tuple +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - {"holes": 2, "spatial_size": [2, 2, 2], "fill_value": 5, "prob": 1.0}, - np.random.randint(0, 2, size=[3, 3, 3, 4]), - (3, 3, 3, 4), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"holes": 2, "spatial_size": [2, 2, 2], "fill_value": 5, "prob": 1.0}, + p(np.random.randint(0, 2, size=[3, 3, 3, 4])), + (3, 3, 3, 4), + ] + ) -TEST_CASE_1 = [ - {"holes": 1, "spatial_size": [1, 2, 3], "fill_value": 5, "max_holes": 5, "prob": 1.0}, - np.random.randint(0, 2, size=[3, 3, 3, 4]), - (3, 3, 3, 4), -] + TESTS.append( + [ + {"holes": 1, "spatial_size": [1, 2, 3], "fill_value": 5, "max_holes": 5, "prob": 1.0}, + p(np.random.randint(0, 2, size=[3, 3, 3, 4])), + (3, 3, 3, 4), + ] + ) -TEST_CASE_2 = [ - {"holes": 2, "spatial_size": [2, 2, 2], "fill_value": 5, "max_spatial_size": [4, 4, 3], "prob": 1.0}, - np.random.randint(0, 2, size=[3, 3, 3, 4]), - (3, 3, 3, 4), -] + TESTS.append( + [ + {"holes": 2, "spatial_size": [2, 2, 2], "fill_value": 5, "max_spatial_size": [4, 4, 3], "prob": 1.0}, + p(np.random.randint(0, 2, size=[3, 3, 3, 4])), + (3, 3, 3, 4), + ] + ) -TEST_CASE_3 = [ - {"holes": 2, "spatial_size": [2, -1, 2], "fill_value": 5, "max_spatial_size": [4, 4, -1], "prob": 1.0}, - np.random.randint(0, 2, size=[3, 3, 3, 4]), - (3, 3, 3, 4), -] + TESTS.append( + [ + {"holes": 2, "spatial_size": [2, -1, 2], "fill_value": 5, "max_spatial_size": [4, 4, -1], "prob": 1.0}, + p(np.random.randint(0, 2, size=[3, 3, 3, 4])), + (3, 3, 3, 4), + ] + ) class TestRandCoarseDropout(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, input_param, input_data, expected_shape): dropout = RandCoarseDropout(**input_param) result = dropout(input_data) + np.testing.assert_equal(result.shape, expected_shape) holes = input_param.get("holes") max_holes = input_param.get("max_holes") spatial_size = fall_back_tuple(input_param.get("spatial_size"), input_data.shape[1:]) @@ -60,6 +73,10 @@ def test_value(self, input_param, input_data, expected_shape): for h in dropout.hole_coords: data = result[h] + self.assertEqual(type(data), type(input_data)) + if isinstance(data, torch.Tensor): + self.assertEqual(data.device, input_data.device) + data = data.cpu() np.testing.assert_allclose(data, input_param.get("fill_value", 0)) if max_spatial_size is None: self.assertTupleEqual(data.shape[1:], tuple(spatial_size)) diff --git a/tests/test_rand_coarse_dropoutd.py b/tests/test_rand_coarse_dropoutd.py index d189a80f56..1511590ed6 100644 --- a/tests/test_rand_coarse_dropoutd.py +++ b/tests/test_rand_coarse_dropoutd.py @@ -12,55 +12,69 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandCoarseDropoutd from monai.utils import fall_back_tuple +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - {"keys": "img", "holes": 2, "spatial_size": [2, 2, 2], "fill_value": 5, "prob": 1.0}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 4])}, - (3, 3, 3, 4), -] +TESTS = [] +for p in TEST_NDARRAYS: -TEST_CASE_1 = [ - {"keys": "img", "holes": 1, "spatial_size": [1, 2, 3], "fill_value": 5, "max_holes": 5, "prob": 1.0}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 4])}, - (3, 3, 3, 4), -] + TESTS.append( + [ + {"keys": "img", "holes": 2, "spatial_size": [2, 2, 2], "fill_value": 5, "prob": 1.0}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 4]))}, + (3, 3, 3, 4), + ] + ) -TEST_CASE_2 = [ - { - "keys": "img", - "holes": 2, - "spatial_size": [2, 2, 2], - "fill_value": 5, - "max_spatial_size": [4, 4, 3], - "prob": 1.0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 4])}, - (3, 3, 3, 4), -] + TESTS.append( + [ + {"keys": "img", "holes": 1, "spatial_size": [1, 2, 3], "fill_value": 5, "max_holes": 5, "prob": 1.0}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 4]))}, + (3, 3, 3, 4), + ] + ) -TEST_CASE_3 = [ - { - "keys": "img", - "holes": 2, - "spatial_size": [2, -1, 2], - "fill_value": 5, - "max_spatial_size": [4, 4, -1], - "prob": 1.0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 4])}, - (3, 3, 3, 4), -] + TESTS.append( + [ + { + "keys": "img", + "holes": 2, + "spatial_size": [2, 2, 2], + "fill_value": 5, + "max_spatial_size": [4, 4, 3], + "prob": 1.0, + }, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 4]))}, + (3, 3, 3, 4), + ] + ) + + TESTS.append( + [ + { + "keys": "img", + "holes": 2, + "spatial_size": [2, -1, 2], + "fill_value": 5, + "max_spatial_size": [4, 4, -1], + "prob": 1.0, + }, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 4]))}, + (3, 3, 3, 4), + ] + ) class TestRandCoarseDropoutd(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, input_param, input_data, expected_shape): dropout = RandCoarseDropoutd(**input_param) result = dropout(input_data)["img"] + np.testing.assert_equal(result.shape, expected_shape) holes = input_param.get("holes") max_holes = input_param.get("max_holes") spatial_size = fall_back_tuple(input_param.get("spatial_size"), input_data["img"].shape[1:]) @@ -74,6 +88,10 @@ def test_value(self, input_param, input_data, expected_shape): for h in dropout.hole_coords: data = result[h] + self.assertEqual(type(data), type(input_data["img"])) + if isinstance(data, torch.Tensor): + self.assertEqual(data.device, input_data["img"].device) + data = data.cpu() np.testing.assert_allclose(data, input_param.get("fill_value", 0)) if max_spatial_size is None: self.assertTupleEqual(data.shape[1:], tuple(spatial_size)) diff --git a/tests/test_rand_crop_by_label_classes.py b/tests/test_rand_crop_by_label_classes.py index b21f971042..d562a44a6d 100644 --- a/tests/test_rand_crop_by_label_classes.py +++ b/tests/test_rand_crop_by_label_classes.py @@ -15,68 +15,77 @@ from parameterized import parameterized from monai.transforms import ClassesToIndices, RandCropByLabelClasses +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ +TESTS_INDICES, TESTS_SHAPE = [], [] +for p in TEST_NDARRAYS: # One-Hot label - { - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "num_classes": None, - "spatial_size": [2, 2, -1], - "ratios": [1, 1, 1], - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - list, - (3, 2, 2, 3), -] + TESTS_INDICES.append( + [ + { + "label": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "num_classes": None, + "spatial_size": [2, 2, -1], + "ratios": [1, 1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + }, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + list, + (3, 2, 2, 3), + ] + ) -TEST_CASE_1 = [ - # Argmax label - { - "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), - "num_classes": 2, - "spatial_size": [2, 2, 2], - "ratios": [1, 1], - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - list, - (3, 2, 2, 2), -] + TESTS_INDICES.append( + [ + # Argmax label + { + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + }, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + list, + (3, 2, 2, 2), + ] + ) -TEST_CASE_2 = [ - # provide label at runtime - { - "label": None, - "num_classes": 2, - "spatial_size": [2, 2, 2], - "ratios": [1, 1], - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - { - "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - }, - list, - (3, 2, 2, 2), -] + TESTS_SHAPE.append( + [ + # provide label at runtime + { + "label": None, + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + }, + list, + (3, 2, 2, 2), + ] + ) class TestRandCropByLabelClasses(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS_INDICES + TESTS_SHAPE) def test_type_shape(self, input_param, input_data, expected_type, expected_shape): result = RandCropByLabelClasses(**input_param)(**input_data) self.assertIsInstance(result, expected_type) self.assertTupleEqual(result[0].shape, expected_shape) - @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) + @parameterized.expand(TESTS_INDICES) def test_indices(self, input_param, input_data, expected_type, expected_shape): input_param["indices"] = ClassesToIndices(num_classes=input_param["num_classes"])(input_param["label"]) result = RandCropByLabelClasses(**input_param)(**input_data) diff --git a/tests/test_rand_crop_by_label_classesd.py b/tests/test_rand_crop_by_label_classesd.py index 829096953b..27fe3425dd 100644 --- a/tests/test_rand_crop_by_label_classesd.py +++ b/tests/test_rand_crop_by_label_classesd.py @@ -15,52 +15,59 @@ from parameterized import parameterized from monai.transforms import ClassesToIndicesd, RandCropByLabelClassesd +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - # One-Hot label - { - "keys": "img", - "label_key": "label", - "num_classes": None, - "spatial_size": [2, 2, -1], - "ratios": [1, 1, 1], - "num_samples": 2, - "image_key": "image", - "image_threshold": 0, - }, - { - "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - }, - list, - (3, 2, 2, 3), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + # One-Hot label + { + "keys": "img", + "label_key": "label", + "num_classes": None, + "spatial_size": [2, 2, -1], + "ratios": [1, 1, 1], + "num_samples": 2, + "image_key": "image", + "image_threshold": 0, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + }, + list, + (3, 2, 2, 3), + ] + ) -TEST_CASE_1 = [ - # Argmax label - { - "keys": "img", - "label_key": "label", - "num_classes": 2, - "spatial_size": [2, 2, 2], - "ratios": [1, 1], - "num_samples": 2, - "image_key": "image", - "image_threshold": 0, - }, - { - "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), - }, - list, - (3, 2, 2, 2), -] + TESTS.append( + [ + # Argmax label + { + "keys": "img", + "label_key": "label", + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image_key": "image", + "image_threshold": 0, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + }, + list, + (3, 2, 2, 2), + ] + ) class TestRandCropByLabelClassesd(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) + @parameterized.expand(TESTS) def test_type_shape(self, input_param, input_data, expected_type, expected_shape): result = RandCropByLabelClassesd(**input_param)(input_data) self.assertIsInstance(result, expected_type) diff --git a/tests/test_rand_crop_by_pos_neg_label.py b/tests/test_rand_crop_by_pos_neg_label.py index e0f669ab3f..a81976dea1 100644 --- a/tests/test_rand_crop_by_pos_neg_label.py +++ b/tests/test_rand_crop_by_pos_neg_label.py @@ -10,68 +10,93 @@ # limitations under the License. import unittest +from copy import deepcopy import numpy as np from parameterized import parameterized from monai.transforms import RandCropByPosNegLabel +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - { - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "spatial_size": [2, 2, -1], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - list, - (3, 2, 2, 3), -] +TESTS = [] +TESTS.append( + [ + { + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "spatial_size": [2, 2, -1], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, + (3, 2, 2, 3), + ] +) +TESTS.append( + [ + { + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "spatial_size": [2, 2, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, + (3, 2, 2, 2), + ] +) +TESTS.append( + [ + { + "label": None, + "spatial_size": [2, 2, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + { + "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + }, + (3, 2, 2, 2), + ] +) -TEST_CASE_1 = [ - { - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "spatial_size": [2, 2, 2], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - list, - (3, 2, 2, 2), -] -TEST_CASE_2 = [ - { - "label": None, - "spatial_size": [2, 2, 2], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - { - "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - }, - list, - (3, 2, 2, 2), -] +class TestRandCropByPosNegLabel(unittest.TestCase): + @staticmethod + def convert_data_type(im_type, d, keys=("img", "image", "label")): + out = deepcopy(d) + for k, v in out.items(): + if k in keys and isinstance(v, np.ndarray): + out[k] = im_type(v) + return out + @parameterized.expand(TESTS) + def test_type_shape(self, input_param, input_data, expected_shape): + results = [] + for p in TEST_NDARRAYS: + input_param_mod = self.convert_data_type(p, input_param) + input_data_mod = self.convert_data_type(p, input_data) + cropper = RandCropByPosNegLabel(**input_param_mod) + cropper.set_random_state(0) + result = cropper(**input_data_mod) -class TestRandCropByPosNegLabel(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) - def test_type_shape(self, input_param, input_data, expected_type, expected_shape): - result = RandCropByPosNegLabel(**input_param)(**input_data) - self.assertIsInstance(result, expected_type) - self.assertTupleEqual(result[0].shape, expected_shape) + self.assertIsInstance(result, list) + self.assertTupleEqual(result[0].shape, expected_shape) + + # check for same results across numpy, torch.Tensor and torch.cuda.Tensor + result = np.asarray([i if isinstance(i, np.ndarray) else i.cpu().numpy() for i in result]) + results.append(np.asarray(result)) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) if __name__ == "__main__": diff --git a/tests/test_rand_crop_by_pos_neg_labeld.py b/tests/test_rand_crop_by_pos_neg_labeld.py index 17a3e117bb..03ec87bb4d 100644 --- a/tests/test_rand_crop_by_pos_neg_labeld.py +++ b/tests/test_rand_crop_by_pos_neg_labeld.py @@ -10,11 +10,13 @@ # limitations under the License. import unittest +from copy import deepcopy import numpy as np from parameterized import parameterized from monai.transforms import RandCropByPosNegLabeld +from tests.utils import TEST_NDARRAYS TEST_CASE_0 = [ { @@ -33,7 +35,6 @@ "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), "image_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, }, - list, (3, 3, 2, 2), ] @@ -54,7 +55,6 @@ "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), "label_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, }, - list, (3, 2, 2, 2), ] @@ -75,25 +75,36 @@ "label": np.ones([3, 3, 3, 3]), "extra_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, }, - list, (3, 2, 2, 2), ] class TestRandCropByPosNegLabeld(unittest.TestCase): + @staticmethod + def convert_data_type(im_type, d, keys=("img", "image", "label")): + out = deepcopy(d) + for k, v in out.items(): + if k in keys and isinstance(v, np.ndarray): + out[k] = im_type(v) + return out + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) - def test_type_shape(self, input_param, input_data, expected_type, expected_shape): - result = RandCropByPosNegLabeld(**input_param)(input_data) - self.assertIsInstance(result, expected_type) - self.assertTupleEqual(result[0]["image"].shape, expected_shape) - self.assertTupleEqual(result[0]["extra"].shape, expected_shape) - self.assertTupleEqual(result[0]["label"].shape, expected_shape) - _len = len(tuple(input_data.keys())) - self.assertTupleEqual(tuple(result[0].keys())[:_len], tuple(input_data.keys())) - for i, item in enumerate(result): - self.assertEqual(item["image_meta_dict"]["patch_index"], i) - self.assertEqual(item["label_meta_dict"]["patch_index"], i) - self.assertEqual(item["extra_meta_dict"]["patch_index"], i) + def test_type_shape(self, input_param, input_data, expected_shape): + for p in TEST_NDARRAYS: + input_param_mod = self.convert_data_type(p, input_param) + input_data_mod = self.convert_data_type(p, input_data) + cropper = RandCropByPosNegLabeld(**input_param_mod) + cropper.set_random_state(0) + result = cropper(input_data_mod) + + self.assertIsInstance(result, list) + + _len = len(tuple(input_data.keys())) + self.assertTupleEqual(tuple(result[0].keys())[:_len], tuple(input_data.keys())) + for k in ("image", "extra", "label"): + self.assertTupleEqual(result[0][k].shape, expected_shape) + for i, item in enumerate(result): + self.assertEqual(item[k + "_meta_dict"]["patch_index"], i) if __name__ == "__main__": diff --git a/tests/test_rand_elastic_2d.py b/tests/test_rand_elastic_2d.py index fbfb7d5761..830d0f39e9 100644 --- a/tests/test_rand_elastic_2d.py +++ b/tests/test_rand_elastic_2d.py @@ -16,90 +16,105 @@ from parameterized import parameterized from monai.transforms import Rand2DElastic +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "as_tensor_output": False, "device": None}, - {"img": torch.ones((3, 3, 3)), "spatial_size": (2, 2)}, - np.ones((3, 2, 2)), - ], - [ - {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "as_tensor_output": False, "device": None}, - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - { - "spacing": (0.3, 0.3), - "magnitude_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "device": None, - "padding_mode": "zeros", - }, - {"img": torch.ones((3, 3, 3)), "spatial_size": (2, 2), "mode": "bilinear"}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], + {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "device": device}, + {"img": p(torch.ones((3, 3, 3))), "spatial_size": (2, 2)}, + p(np.ones((3, 2, 2))), ] - ), - ], - [ - { - "spacing": (1.0, 1.0), - "magnitude_range": (1.0, 1.0), - "scale_range": [1.2, 2.2], - "prob": 0.9, - "padding_mode": "border", - "as_tensor_output": True, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3))}, - torch.tensor( + ) + TESTS.append( [ - [[3.0793, 2.6141], [4.0568, 5.9978]], - [[12.0793, 11.6141], [13.0568, 14.9978]], - [[21.0793, 20.6141], [22.0568, 23.9978]], + {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "device": device}, + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), ] - ), - ], - [ - { - "spacing": (0.3, 0.3), - "magnitude_range": (0.1, 0.2), - "translate_range": [-0.01, 0.01], - "scale_range": [0.01, 0.02], - "prob": 0.9, - "as_tensor_output": False, - "device": "cuda" if torch.cuda.is_available() else "cpu", - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.array( + ) + TESTS.append( [ - [[1.3584113, 1.9251312], [5.626623, 6.642721]], - [[10.358411, 10.925131], [14.626623, 15.642721]], - [[19.358412, 19.92513], [23.626623, 24.642721]], + { + "spacing": (0.3, 0.3), + "magnitude_range": (1.0, 2.0), + "prob": 0.9, + "device": device, + "padding_mode": "zeros", + }, + {"img": p(torch.ones((3, 3, 3))), "spatial_size": (2, 2), "mode": "bilinear"}, + p( + np.array( + [ + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + ] + ) + ), ] - ), - ], -] + ) + TESTS.append( + [ + { + "spacing": (1.0, 1.0), + "magnitude_range": (1.0, 1.0), + "scale_range": [1.2, 2.2], + "prob": 0.9, + "padding_mode": "border", + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + torch.tensor( + [ + [[3.0793, 2.6141], [4.0568, 5.9978]], + [[12.0793, 11.6141], [13.0568, 14.9978]], + [[21.0793, 20.6141], [22.0568, 23.9978]], + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "spacing": (0.3, 0.3), + "magnitude_range": (0.1, 0.2), + "translate_range": [-0.01, 0.01], + "scale_range": [0.01, 0.02], + "prob": 0.9, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + np.array( + [ + [[1.3584113, 1.9251312], [5.626623, 6.642721]], + [[10.358411, 10.925131], [14.626623, 15.642721]], + [[19.358412, 19.92513], [23.626623, 24.642721]], + ] + ) + ), + ] + ) class TestRand2DElastic(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_2d_elastic(self, input_param, input_data, expected_val): g = Rand2DElastic(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) + self.assertEqual(type(result), type(expected_val)) if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + self.assertEqual(result.device, expected_val.device) + result, expected_val = result.cpu(), expected_val.cpu() + np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_elastic_3d.py b/tests/test_rand_elastic_3d.py index c63282d571..6add569723 100644 --- a/tests/test_rand_elastic_3d.py +++ b/tests/test_rand_elastic_3d.py @@ -16,69 +16,93 @@ from parameterized import parameterized from monai.transforms import Rand3DElastic +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - { - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": -1, - }, - {"img": torch.arange(72).reshape((2, 3, 3, 4))}, - np.arange(72).reshape((2, 3, 3, 4)), - ], - [ - { - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - }, - {"img": torch.ones((2, 3, 3, 3)), "spatial_size": (2, 2, 2)}, - np.ones((2, 2, 2, 2)), - ], - [ - { - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "device": None, - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "spatial_size": (2, 2, 2)}, - np.array([[[[6.4939356, 7.50289], [9.518351, 10.522849]], [[15.512375, 16.523542], [18.531467, 19.53646]]]]), - ], - [ - { - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "rotate_range": [1, 1, 1], - "as_tensor_output": False, - "device": "cuda" if torch.cuda.is_available() else "cpu", - "spatial_size": (2, 2, 2), - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "mode": "bilinear"}, - np.array([[[[5.0069294, 9.463932], [9.287769, 13.739735]], [[12.319424, 16.777205], [16.594296, 21.045748]]]]), - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + { + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": -1, + }, + {"img": p(torch.arange(72).reshape((2, 3, 3, 4)))}, + p(np.arange(72).reshape((2, 3, 3, 4))), + ] + ) + TESTS.append( + [ + { + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + }, + {"img": p(torch.ones((2, 3, 3, 3))), "spatial_size": (2, 2, 2)}, + p(np.ones((2, 2, 2, 2))), + ] + ) + TESTS.append( + [ + { + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "device": device, + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "spatial_size": (2, 2, 2)}, + p( + np.array( + [ + [ + [[6.4939356, 7.50289], [9.518351, 10.522849]], + [[15.512375, 16.523542], [18.531467, 19.53646]], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "rotate_range": [1, 1, 1], + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "mode": "bilinear"}, + p( + np.array( + [ + [ + [[5.0069294, 9.463932], [9.287769, 13.739735]], + [[12.319424, 16.777205], [16.594296, 21.045748]], + ] + ] + ) + ), + ] + ) class TestRand3DElastic(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_3d_elastic(self, input_param, input_data, expected_val): g = Rand3DElastic(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) + self.assertEqual(type(result), type(expected_val)) if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + self.assertEqual(result.device, expected_val.device) + result, expected_val = result.cpu(), expected_val.cpu() + np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_elasticd_2d.py b/tests/test_rand_elasticd_2d.py index f8eb026088..5dd5dabda2 100644 --- a/tests/test_rand_elasticd_2d.py +++ b/tests/test_rand_elasticd_2d.py @@ -16,127 +16,147 @@ from parameterized import parameterized from monai.transforms import Rand2DElasticd +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (1.0, 2.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.ones((3, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (0.3, 0.3), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": -1, - }, - {"img": torch.arange(4).reshape((1, 2, 2)), "seg": torch.arange(4).reshape((1, 2, 2))}, - np.arange(4).reshape((1, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "padding_mode": "zeros", - "device": None, - "spatial_size": (2, 2), - "mode": "bilinear", - }, - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (1.0, 2.0), + "prob": 0.0, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p(np.ones((3, 2, 2))), ] - ), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (1.0, 1.0), - "magnitude_range": (1.0, 1.0), - "scale_range": [1.2, 2.2], - "prob": 0.9, - "padding_mode": "border", - "as_tensor_output": True, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - torch.tensor( + ) + TESTS.append( [ - [[3.0793, 2.6141], [4.0568, 5.9978]], - [[12.0793, 11.6141], [13.0568, 14.9978]], - [[21.0793, 20.6141], [22.0568, 23.9978]], + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (0.3, 0.3), + "prob": 0.0, + "device": device, + "spatial_size": -1, + }, + {"img": p(torch.arange(4).reshape((1, 2, 2))), "seg": p(torch.arange(4).reshape((1, 2, 2)))}, + p(np.arange(4).reshape((1, 2, 2))), ] - ), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (0.1, 0.2), - "translate_range": [-0.01, 0.01], - "scale_range": [0.01, 0.02], - "prob": 0.9, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - np.array( + ) + TESTS.append( [ - [[1.3584113, 1.9251312], [5.626623, 6.642721]], - [[10.358411, 10.925131], [14.626623, 15.642721]], - [[19.358412, 19.92513], [23.626623, 24.642721]], + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (1.0, 2.0), + "prob": 0.9, + "padding_mode": "zeros", + "device": device, + "spatial_size": (2, 2), + "mode": "bilinear", + }, + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p( + np.array( + [ + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + ] + ) + ), ] - ), - ], - [ - { - "keys": ("img", "seg"), - "mode": ("bilinear", "nearest"), - "spacing": (0.3, 0.3), - "magnitude_range": (0.1, 0.2), - "translate_range": [-0.01, 0.01], - "scale_range": [0.01, 0.02], - "prob": 0.9, - "as_tensor_output": True, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - { - "img": torch.tensor( - [ - [[1.3584, 1.9251], [5.6266, 6.6427]], - [[10.3584, 10.9251], [14.6266, 15.6427]], - [[19.3584, 19.9251], [23.6266, 24.6427]], - ] - ), - "seg": torch.tensor([[[0.0, 2.0], [6.0, 8.0]], [[9.0, 11.0], [15.0, 17.0]], [[18.0, 20.0], [24.0, 26.0]]]), - }, - ], -] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "spacing": (1.0, 1.0), + "magnitude_range": (1.0, 1.0), + "scale_range": [1.2, 2.2], + "prob": 0.9, + "padding_mode": "border", + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + torch.tensor( + [ + [[3.0793, 2.6141], [4.0568, 5.9978]], + [[12.0793, 11.6141], [13.0568, 14.9978]], + [[21.0793, 20.6141], [22.0568, 23.9978]], + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (0.1, 0.2), + "translate_range": [-0.01, 0.01], + "scale_range": [0.01, 0.02], + "prob": 0.9, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + np.array( + [ + [[1.3584113, 1.9251312], [5.626623, 6.642721]], + [[10.358411, 10.925131], [14.626623, 15.642721]], + [[19.358412, 19.92513], [23.626623, 24.642721]], + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "mode": ("bilinear", "nearest"), + "spacing": (0.3, 0.3), + "magnitude_range": (0.1, 0.2), + "translate_range": [-0.01, 0.01], + "scale_range": [0.01, 0.02], + "prob": 0.9, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + { + "img": p( + torch.tensor( + [ + [[1.3584, 1.9251], [5.6266, 6.6427]], + [[10.3584, 10.9251], [14.6266, 15.6427]], + [[19.3584, 19.9251], [23.6266, 24.6427]], + ] + ) + ), + "seg": p( + torch.tensor( + [[[0.0, 2.0], [6.0, 8.0]], [[9.0, 11.0], [15.0, 17.0]], [[18.0, 20.0], [24.0, 26.0]]] + ) + ), + }, + ] + ) class TestRand2DElasticd(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_2d_elasticd(self, input_param, input_data, expected_val): g = Rand2DElasticd(**input_param) g.set_random_state(123) @@ -144,11 +164,11 @@ def test_rand_2d_elasticd(self, input_param, input_data, expected_val): for key in res: result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) + self.assertEqual(type(result), type(expected)) if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + self.assertEqual(result.device, expected.device) + result, expected = result.cpu(), expected.cpu() + np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_elasticd_3d.py b/tests/test_rand_elasticd_3d.py index 47ab814882..dfa960d374 100644 --- a/tests/test_rand_elasticd_3d.py +++ b/tests/test_rand_elasticd_3d.py @@ -16,98 +16,128 @@ from parameterized import parameterized from monai.transforms import Rand3DElasticd +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2, 2), - }, - {"img": torch.ones((2, 3, 3, 3)), "seg": torch.ones((2, 3, 3, 3))}, - np.ones((2, 2, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, -1, -1), - }, - {"img": torch.ones((2, 3, 3, 3)), "seg": torch.ones((2, 3, 3, 3))}, - np.ones((2, 2, 3, 3)), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": -1, - }, - {"img": torch.arange(8).reshape((1, 2, 2, 2)), "seg": torch.arange(8).reshape((1, 2, 2, 2))}, - np.arange(8).reshape((1, 2, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2, 2), - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))}, - np.array([[[[6.4939356, 7.50289], [9.518351, 10.522849]], [[15.512375, 16.523542], [18.531467, 19.53646]]]]), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "rotate_range": [1, 1, 1], - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2, 2), - "mode": "bilinear", - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))}, - np.array([[[[5.0069294, 9.463932], [9.287769, 13.739735]], [[12.319424, 16.777205], [16.594296, 21.045748]]]]), - ], - [ - { - "keys": ("img", "seg"), - "mode": ("bilinear", "nearest"), - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "rotate_range": [1, 1, 1], - "as_tensor_output": True, - "device": torch.device("cpu:0"), - "spatial_size": (2, 2, 2), - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))}, - { - "img": torch.tensor([[[[5.0069, 9.4639], [9.2878, 13.7397]], [[12.3194, 16.7772], [16.5943, 21.0457]]]]), - "seg": torch.tensor([[[[4.0, 14.0], [7.0, 14.0]], [[9.0, 19.0], [12.0, 22.0]]]]), - }, - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.ones((2, 3, 3, 3))), "seg": p(torch.ones((2, 3, 3, 3)))}, + p(np.ones((2, 2, 2, 2))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": (2, -1, -1), + }, + {"img": p(torch.ones((2, 3, 3, 3))), "seg": p(torch.ones((2, 3, 3, 3)))}, + p(np.ones((2, 2, 3, 3))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": -1, + }, + {"img": p(torch.arange(8).reshape((1, 2, 2, 2))), "seg": p(torch.arange(8).reshape((1, 2, 2, 2)))}, + p(np.arange(8).reshape((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "seg": p(torch.arange(27).reshape((1, 3, 3, 3)))}, + p( + np.array( + [ + [ + [[6.4939356, 7.50289], [9.518351, 10.522849]], + [[15.512375, 16.523542], [18.531467, 19.53646]], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "rotate_range": [1, 1, 1], + "device": device, + "spatial_size": (2, 2, 2), + "mode": "bilinear", + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "seg": p(torch.arange(27).reshape((1, 3, 3, 3)))}, + p( + np.array( + [ + [ + [[5.0069294, 9.463932], [9.287769, 13.739735]], + [[12.319424, 16.777205], [16.594296, 21.045748]], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "mode": ("bilinear", "nearest"), + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "rotate_range": [1, 1, 1], + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "seg": p(torch.arange(27).reshape((1, 3, 3, 3)))}, + { + "img": p( + torch.tensor( + [[[[5.0069, 9.4639], [9.2878, 13.7397]], [[12.3194, 16.7772], [16.5943, 21.0457]]]] + ) + ), + "seg": p(torch.tensor([[[[4.0, 14.0], [7.0, 14.0]], [[9.0, 19.0], [12.0, 22.0]]]])), + }, + ] + ) class TestRand3DElasticd(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_3d_elasticd(self, input_param, input_data, expected_val): g = Rand3DElasticd(**input_param) g.set_random_state(123) @@ -115,11 +145,11 @@ def test_rand_3d_elasticd(self, input_param, input_data, expected_val): for key in res: result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) + self.assertEqual(type(result), type(expected)) if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + self.assertEqual(result.device, expected.device) + result, expected = result.cpu(), expected.cpu() + np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_gaussian_noise.py b/tests/test_rand_gaussian_noise.py index 96f1fa8e6d..d376add460 100644 --- a/tests/test_rand_gaussian_noise.py +++ b/tests/test_rand_gaussian_noise.py @@ -12,35 +12,32 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandGaussianNoise -from tests.utils import NumpyImageTestCase2D, TorchImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append(("test_zero_mean", p, 0, 0.1)) + TESTS.append(("test_non_zero_mean", p, 1, 0.5)) -class TestRandGaussianNoise(NumpyImageTestCase2D): - @parameterized.expand([("test_zero_mean", 0, 0.1), ("test_non_zero_mean", 1, 0.5)]) - def test_correct_results(self, _, mean, std): - seed = 0 - gaussian_fn = RandGaussianNoise(prob=1.0, mean=mean, std=std) - gaussian_fn.set_random_state(seed) - noised = gaussian_fn(self.imt) - np.random.seed(seed) - np.random.random() - expected = self.imt + np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape) - np.testing.assert_allclose(expected, noised, atol=1e-5) - -class TestRandGaussianNoiseTorch(TorchImageTestCase2D): - @parameterized.expand([("test_zero_mean", 0, 0.1), ("test_non_zero_mean", 1, 0.5)]) - def test_correct_results(self, _, mean, std): +class TestRandGaussianNoise(NumpyImageTestCase2D): + @parameterized.expand(TESTS) + def test_correct_results(self, _, im_type, mean, std): seed = 0 gaussian_fn = RandGaussianNoise(prob=1.0, mean=mean, std=std) gaussian_fn.set_random_state(seed) - noised = gaussian_fn(self.imt) + im = im_type(self.imt) + noised = gaussian_fn(im) np.random.seed(seed) np.random.random() expected = self.imt + np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape) + self.assertEqual(type(im), type(noised)) + if isinstance(noised, torch.Tensor): + noised = noised.cpu() np.testing.assert_allclose(expected, noised, atol=1e-5) diff --git a/tests/test_rand_gaussian_noised.py b/tests/test_rand_gaussian_noised.py index 442a85ca77..4b0d2a311a 100644 --- a/tests/test_rand_gaussian_noised.py +++ b/tests/test_rand_gaussian_noised.py @@ -12,41 +12,35 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandGaussianNoised -from tests.utils import NumpyImageTestCase2D, TorchImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D -TEST_CASE_0 = ["test_zero_mean", ["img1", "img2"], 0, 0.1] -TEST_CASE_1 = ["test_non_zero_mean", ["img1", "img2"], 1, 0.5] -TEST_CASES = [TEST_CASE_0, TEST_CASE_1] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append(["test_zero_mean", p, ["img1", "img2"], 0, 0.1]) + TESTS.append(["test_non_zero_mean", p, ["img1", "img2"], 1, 0.5]) seed = 0 -def test_numpy_or_torch(keys, mean, std, imt): - gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std) - gaussian_fn.set_random_state(seed) - noised = gaussian_fn({k: imt for k in keys}) - np.random.seed(seed) - np.random.random() - for k in keys: - expected = imt + np.random.normal(mean, np.random.uniform(0, std), size=imt.shape) - np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5) - - -# Test with numpy -class TestRandGaussianNoisedNumpy(NumpyImageTestCase2D): - @parameterized.expand(TEST_CASES) - def test_correct_results(self, _, keys, mean, std): - test_numpy_or_torch(keys, mean, std, self.imt) - - -# Test with torch -class TestRandGaussianNoisedTorch(TorchImageTestCase2D): - @parameterized.expand(TEST_CASES) - def test_correct_results(self, _, keys, mean, std): - test_numpy_or_torch(keys, mean, std, self.imt) +class TestRandGaussianNoised(NumpyImageTestCase2D): + @parameterized.expand(TESTS) + def test_correct_results(self, _, im_type, keys, mean, std): + gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std) + gaussian_fn.set_random_state(seed) + im = im_type(self.imt) + noised = gaussian_fn({k: im for k in keys}) + np.random.seed(seed) + np.random.random() + for k in keys: + expected = self.imt + np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape) + self.assertEqual(type(im), type(noised[k])) + if isinstance(noised[k], torch.Tensor): + noised[k] = noised[k].cpu() + np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5) if __name__ == "__main__": diff --git a/tests/test_rand_gaussian_sharpen.py b/tests/test_rand_gaussian_sharpen.py index 909f96f56b..f511555f21 100644 --- a/tests/test_rand_gaussian_sharpen.py +++ b/tests/test_rand_gaussian_sharpen.py @@ -11,88 +11,129 @@ import unittest -import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandGaussianSharpen +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"prob": 1.0}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( +TESTS = [] + +for p in TEST_NDARRAYS: + TESTS.append( [ - [[5.2919216, 5.5854445, 5.29192], [11.3982, 12.62332, 11.398202], [14.870525, 17.323769, 14.870527]], - [[20.413757, 22.767355, 20.413757], [28.495504, 31.558315, 28.495499], [29.99236, 34.505676, 29.992361]], + {"prob": 1.0}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [5.2919216, 5.5854445, 5.29192], + [11.3982, 12.62332, 11.398202], + [14.870525, 17.323769, 14.870527], + ], + [ + [20.413757, 22.767355, 20.413757], + [28.495504, 31.558315, 28.495499], + [29.99236, 34.505676, 29.992361], + ], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - { - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": 0.4, - "sigma2_y": 0.4, - "sigma2_z": 0.4, - "prob": 1.0, - }, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[4.1071496, 3.597953, 4.1071477], [10.062014, 9.825114, 10.0620165], [14.698058, 15.818766, 14.698058]], - [[18.211048, 18.16049, 18.211048], [25.155039, 24.56279, 25.155039], [28.801964, 30.381308, 28.801964]], + { + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": 0.4, + "sigma2_y": 0.4, + "sigma2_z": 0.4, + "prob": 1.0, + }, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [4.1071496, 3.597953, 4.1071477], + [10.062014, 9.825114, 10.0620165], + [14.698058, 15.818766, 14.698058], + ], + [ + [18.211048, 18.16049, 18.211048], + [25.155039, 24.56279, 25.155039], + [28.801964, 30.381308, 28.801964], + ], + ] + ), ] - ), -] + ) -TEST_CASE_3 = [ - { - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": (0.5, 0.75), - "sigma2_y": (0.5, 0.75), - "sigma2_z": (0.5, 0.75), - "prob": 1.0, - }, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[4.81077, 4.4237204, 4.81077], [12.061236, 12.298177, 12.061236], [17.362553, 19.201174, 17.362553]], - [[21.440754, 22.142393, 21.440754], [30.15308, 30.745445, 30.153086], [33.99255, 36.919838, 33.99255]], + { + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": (0.5, 0.75), + "sigma2_y": (0.5, 0.75), + "sigma2_z": (0.5, 0.75), + "prob": 1.0, + }, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [4.81077, 4.4237204, 4.81077], + [12.061236, 12.298177, 12.061236], + [17.362553, 19.201174, 17.362553], + ], + [ + [21.440754, 22.142393, 21.440754], + [30.15308, 30.745445, 30.153086], + [33.99255, 36.919838, 33.99255], + ], + ] + ), ] - ), -] + ) -TEST_CASE_4 = [ - { - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": (0.5, 0.75), - "sigma2_y": (0.5, 0.75), - "sigma2_z": (0.5, 0.75), - "approx": "scalespace", - "prob": 1.0, - }, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[4.430213, 3.2278745, 4.4302144], [10.325399, 8.507457, 10.325399], [17.494898, 16.5609, 17.494894]], - [[20.87405, 18.06946, 20.87405], [25.813503, 21.268656, 25.8135], [33.93874, 31.402481, 33.938725]], + { + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": (0.5, 0.75), + "sigma2_y": (0.5, 0.75), + "sigma2_z": (0.5, 0.75), + "approx": "scalespace", + "prob": 1.0, + }, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + p( + [ + [ + [4.430213, 3.2278745, 4.4302144], + [10.325399, 8.507457, 10.325399], + [17.494898, 16.5609, 17.494894], + ], + [[20.87405, 18.06946, 20.87405], [25.813503, 21.268656, 25.8135], [33.93874, 31.402481, 33.938725]], + ] + ), ] - ), -] + ) class TestRandGaussianSharpen(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): converter = RandGaussianSharpen(**argments) converter.set_random_state(seed=0) result = converter(image) - np.testing.assert_allclose(result, expected_data, rtol=1e-4) + torch.testing.assert_allclose(result, expected_data, atol=0, rtol=1e-4) + self.assertIsInstance(result, type(image)) if __name__ == "__main__": diff --git a/tests/test_rand_gibbs_noise.py b/tests/test_rand_gibbs_noise.py index 94948c5a0d..15cadea0e2 100644 --- a/tests/test_rand_gibbs_noise.py +++ b/tests/test_rand_gibbs_noise.py @@ -19,12 +19,15 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandGibbsNoise from monai.utils.misc import set_determinism +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS + +_, has_torch_fft = optional_import("torch.fft", name="fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, input_type)) class TestRandGibbsNoise(unittest.TestCase): @@ -36,50 +39,50 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d im = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] - return torch.Tensor(im) if as_tensor_input else im + return input_type(im) @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_0_prob(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [0.5, 1.0] - t = RandGibbsNoise(0.0, alpha, as_tensor_output) + t = RandGibbsNoise(0.0, alpha) out = t(im) - np.testing.assert_allclose(im, out) + torch.testing.assert_allclose(im, out, rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_same_result(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [0.5, 0.8] - t = RandGibbsNoise(1.0, alpha, as_tensor_output) + t = RandGibbsNoise(1.0, alpha) t.set_random_state(42) out1 = t(deepcopy(im)) t.set_random_state(42) out2 = t(deepcopy(im)) - np.testing.assert_allclose(out1, out2) - self.assertIsInstance(out1, torch.Tensor if as_tensor_output else np.ndarray) + torch.testing.assert_allclose(out1, out2, rtol=1e-7, atol=0) + self.assertIsInstance(out1, type(im)) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_identity(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [0.0, 0.0] t = RandGibbsNoise(1.0, alpha) out = t(deepcopy(im)) - np.testing.assert_allclose(im, out, atol=1e-2) + torch.testing.assert_allclose(im, out, atol=1e-2, rtol=1e-7) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_alpha_1(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [1.0, 1.0] t = RandGibbsNoise(1.0, alpha) out = t(deepcopy(im)) - np.testing.assert_allclose(0 * im, out) + torch.testing.assert_allclose(0 * im, out, rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_alpha(self, im_shape, _, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_alpha(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [0.5, 0.51] t = RandGibbsNoise(1.0, alpha) _ = t(deepcopy(im)) diff --git a/tests/test_rand_gibbs_noised.py b/tests/test_rand_gibbs_noised.py index 986f4c02ae..b8bac67b81 100644 --- a/tests/test_rand_gibbs_noised.py +++ b/tests/test_rand_gibbs_noised.py @@ -19,12 +19,15 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandGibbsNoised from monai.utils.misc import set_determinism +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS + +_, has_torch_fft = optional_import("torch.fft", name="fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, input_type)) KEYS = ["im", "label"] @@ -38,70 +41,76 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims - return {k: v for k, v in zip(KEYS, ims)} + return {k: input_type(v) for k, v in zip(KEYS, ims)} @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_0_prob(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = [0.5, 1.0] - t = RandGibbsNoised(KEYS, 0.0, alpha, as_tensor_output) + t = RandGibbsNoised(KEYS, 0.0, alpha) out = t(data) for k in KEYS: - np.testing.assert_allclose(data[k], out[k]) + torch.testing.assert_allclose(data[k], out[k], rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_same_result(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = [0.5, 0.8] - t = RandGibbsNoised(KEYS, 1.0, alpha, as_tensor_output) + t = RandGibbsNoised(KEYS, 1.0, alpha) t.set_random_state(42) out1 = t(deepcopy(data)) t.set_random_state(42) out2 = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(out1[k], out2[k]) - self.assertIsInstance(out1[k], torch.Tensor if as_tensor_output else np.ndarray) + torch.testing.assert_allclose(out1[k], out2[k], rtol=1e-7, atol=0) + self.assertIsInstance(out1[k], type(data[k])) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_identity(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = [0.0, 0.0] t = RandGibbsNoised(KEYS, 1.0, alpha) out = t(deepcopy(data)) for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k], data[k] = out[k].cpu(), data[k].cpu() np.testing.assert_allclose(data[k], out[k], atol=1e-2) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_alpha_1(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = [1.0, 1.0] t = RandGibbsNoised(KEYS, 1.0, alpha) out = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(0 * data[k], out[k]) + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k], data[k] = out[k].cpu(), data[k].cpu() + np.testing.assert_allclose(0.0 * data[k], out[k], atol=1e-2) @parameterized.expand(TEST_CASES) - def test_dict_matches(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_dict_matches(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) # use same image for both dictionary entries to check same trans is applied to them data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} alpha = [0.5, 1.0] t = RandGibbsNoised(KEYS, 1.0, alpha) out = t(deepcopy(data)) - np.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]]) + torch.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]], rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_alpha(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_alpha(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = [0.5, 0.51] t = RandGibbsNoised(KEYS, 1.0, alpha) _ = t(deepcopy(data)) - self.assertGreaterEqual(t.sampled_alpha, 0.5) - self.assertLessEqual(t.sampled_alpha, 0.51) + self.assertTrue(0.5 <= t.sampled_alpha <= 0.51) if __name__ == "__main__": diff --git a/tests/test_rand_histogram_shift.py b/tests/test_rand_histogram_shift.py index b258cc5a7e..59c8e51128 100644 --- a/tests/test_rand_histogram_shift.py +++ b/tests/test_rand_histogram_shift.py @@ -12,35 +12,48 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandHistogramShift - -TEST_CASES = [ - [ - {"num_control_points": 5, "prob": 0.0}, - {"img": np.arange(8).reshape((1, 2, 2, 2))}, - np.arange(8).reshape((1, 2, 2, 2)), - ], - [ - {"num_control_points": 5, "prob": 0.9}, - {"img": np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)}, - np.array([[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]]), - ], - [ - {"num_control_points": (5, 20), "prob": 0.9}, - {"img": np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)}, - np.array([[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]]), - ], -] +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"num_control_points": 5, "prob": 0.0}, + {"img": p(np.arange(8).reshape((1, 2, 2, 2)))}, + np.arange(8).reshape((1, 2, 2, 2)), + ] + ) + TESTS.append( + [ + {"num_control_points": 5, "prob": 0.9}, + {"img": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32))}, + np.array([[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]]), + ] + ) + TESTS.append( + [ + {"num_control_points": (5, 20), "prob": 0.9}, + {"img": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32))}, + np.array([[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]]), + ] + ) class TestRandHistogramShift(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_histogram_shift(self, input_param, input_data, expected_val): g = RandHistogramShift(**input_param) g.set_random_state(123) result = g(**input_data) + im_in = input_data["img"] + self.assertEqual(type(result), type(im_in)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, im_in.device) + result = result.cpu() np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) diff --git a/tests/test_rand_histogram_shiftd.py b/tests/test_rand_histogram_shiftd.py index 806e4f5cf2..a1b11a0905 100644 --- a/tests/test_rand_histogram_shiftd.py +++ b/tests/test_rand_histogram_shiftd.py @@ -12,47 +12,67 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandHistogramShiftD +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - {"keys": ("img",), "num_control_points": 5, "prob": 0.0}, - {"img": np.arange(8).reshape((1, 2, 2, 2)), "seg": np.ones(8).reshape((1, 2, 2, 2))}, - {"img": np.arange(8).reshape((1, 2, 2, 2)), "seg": np.ones(8).reshape((1, 2, 2, 2))}, - ], - [ - {"keys": ("img",), "num_control_points": 5, "prob": 0.9}, - {"img": np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32), "seg": np.ones(8).reshape((1, 2, 2, 2))}, - { - "img": np.array( - [[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]] - ), - "seg": np.ones(8).reshape((1, 2, 2, 2)), - }, - ], - [ - {"keys": ("img",), "num_control_points": (5, 20), "prob": 0.9}, - {"img": np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32), "seg": np.ones(8).reshape((1, 2, 2, 2))}, - { - "img": np.array( - [[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]] - ), - "seg": np.ones(8).reshape((1, 2, 2, 2)), - }, - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ("img",), "num_control_points": 5, "prob": 0.0}, + {"img": p(np.arange(8).reshape((1, 2, 2, 2))), "seg": p(np.ones(8).reshape((1, 2, 2, 2)))}, + {"img": np.arange(8).reshape((1, 2, 2, 2)), "seg": np.ones(8).reshape((1, 2, 2, 2))}, + ] + ) + TESTS.append( + [ + {"keys": ("img",), "num_control_points": 5, "prob": 0.9}, + { + "img": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)), + "seg": p(np.ones(8).reshape((1, 2, 2, 2))), + }, + { + "img": np.array( + [[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]] + ), + "seg": np.ones(8).reshape((1, 2, 2, 2)), + }, + ] + ) + TESTS.append( + [ + {"keys": ("img",), "num_control_points": (5, 20), "prob": 0.9}, + { + "img": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)), + "seg": p(np.ones(8).reshape((1, 2, 2, 2))), + }, + { + "img": np.array( + [[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]] + ), + "seg": np.ones(8).reshape((1, 2, 2, 2)), + }, + ] + ) class TestRandHistogramShiftD(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_histogram_shiftd(self, input_param, input_data, expected_val): g = RandHistogramShiftD(**input_param) g.set_random_state(123) res = g(input_data) for key in res: result = res[key] + im_in = input_data[key] + self.assertEqual(type(result), type(im_in)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, im_in.device) + result = result.cpu() + expected = expected_val[key] if isinstance(expected_val, dict) else expected_val np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) diff --git a/tests/test_rand_k_space_spike_noise.py b/tests/test_rand_k_space_spike_noise.py index ba9156c5b2..1c9ca9c1d5 100644 --- a/tests/test_rand_k_space_spike_noise.py +++ b/tests/test_rand_k_space_spike_noise.py @@ -19,13 +19,13 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoise, RandKSpaceSpikeNoise from monai.utils.misc import set_determinism +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] +TESTS = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - for channel_wise in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input, channel_wise)) + for p in TEST_NDARRAYS: + for channel_wise in (True, False): + TESTS.append((shape, p, channel_wise)) class TestRandKSpaceSpikeNoise(unittest.TestCase): @@ -37,44 +37,55 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, im_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d im = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] - return torch.Tensor(im) if as_tensor_input else im + return im_type(im) - @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input, channel_wise): - im = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_0_prob(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) intensity_range = [14, 15] - t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise, as_tensor_output) + t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise) out = t(im) + self.assertEqual(type(im), type(out)) + if isinstance(out, torch.Tensor): + self.assertEqual(out.device, im.device) + im, out = im.cpu(), out.cpu() np.testing.assert_allclose(im, out) - @parameterized.expand(TEST_CASES) - def test_1_prob(self, im_shape, as_tensor_output, as_tensor_input, channel_wise): - im = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_1_prob(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) intensity_range = [14, 14] - t = RandKSpaceSpikeNoise(1.0, intensity_range, channel_wise, as_tensor_output) + t = RandKSpaceSpikeNoise(1.0, intensity_range, channel_wise) out = t(im) - base_t = KSpaceSpikeNoise(t.sampled_locs, [14], as_tensor_output) + base_t = KSpaceSpikeNoise(t.sampled_locs, [14]) out = out - base_t(im) + self.assertEqual(type(im), type(out)) + if isinstance(out, torch.Tensor): + self.assertEqual(out.device, im.device) + im, out = im.cpu(), out.cpu() np.testing.assert_allclose(out, im * 0) - @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input, channel_wise): - im = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_same_result(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) intensity_range = [14, 15] - t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise, as_tensor_output) + t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise) t.set_random_state(42) out1 = t(deepcopy(im)) t.set_random_state(42) out2 = t(deepcopy(im)) + self.assertEqual(type(im), type(out1)) + if isinstance(out1, torch.Tensor): + self.assertEqual(out1.device, im.device) + out1, out2 = out1.cpu(), out2.cpu() np.testing.assert_allclose(out1, out2) - self.assertIsInstance(out1, torch.Tensor if as_tensor_output else np.ndarray) - @parameterized.expand(TEST_CASES) - def test_intensity(self, im_shape, _, as_tensor_input, channel_wise): - im = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_intensity(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) intensity_range = [14, 14.1] t = RandKSpaceSpikeNoise(1.0, intensity_range, channel_wise) _ = t(deepcopy(im)) diff --git a/tests/test_rand_k_space_spike_noised.py b/tests/test_rand_k_space_spike_noised.py index 3cb49f1c08..93b6ebfad8 100644 --- a/tests/test_rand_k_space_spike_noised.py +++ b/tests/test_rand_k_space_spike_noised.py @@ -19,12 +19,12 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandKSpaceSpikeNoised from monai.utils.misc import set_determinism +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] +TESTS = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for p in TEST_NDARRAYS: + TESTS.append((shape, p)) KEYS = ["image", "label"] @@ -38,17 +38,16 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, im_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) - ims = [im[None] for im in ims] - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims + ims = [im_type(im[None]) for im in ims] return {k: v for k, v in zip(KEYS, ims)} - @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_same_result(self, im_shape, im_type): - data = self.get_data(im_shape, as_tensor_input) + data = self.get_data(im_shape, im_type) intensity_range = (13, 15) t = RandKSpaceSpikeNoised( @@ -58,7 +57,6 @@ def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): img_intensity_range=intensity_range, label_intensity_range=intensity_range, channel_wise=True, - as_tensor_output=as_tensor_output, ) t.set_rand_state(42) out1 = t(deepcopy(data)) @@ -67,12 +65,16 @@ def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): out2 = t(deepcopy(data)) for k in KEYS: + self.assertEqual(type(out1[k]), type(data[k])) + if isinstance(out1[k], torch.Tensor): + self.assertEqual(out1[k].device, data[k].device) + out1[k] = out1[k].cpu() + out2[k] = out2[k].cpu() np.testing.assert_allclose(out1[k], out2[k], atol=1e-10) - self.assertIsInstance(out1[k], torch.Tensor if as_tensor_output else np.ndarray) - @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_0_prob(self, im_shape, im_type): + data = self.get_data(im_shape, im_type) intensity_range = (13, 15) t1 = RandKSpaceSpikeNoised( KEYS, @@ -81,7 +83,6 @@ def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): img_intensity_range=intensity_range, label_intensity_range=intensity_range, channel_wise=True, - as_tensor_output=as_tensor_output, ) t2 = RandKSpaceSpikeNoised( @@ -91,19 +92,25 @@ def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): img_intensity_range=intensity_range, label_intensity_range=intensity_range, channel_wise=True, - as_tensor_output=as_tensor_output, ) out1 = t1(data) out2 = t2(data) for k in KEYS: + self.assertEqual(type(out1[k]), type(data[k])) + if isinstance(out1[k], torch.Tensor): + self.assertEqual(out1[k].device, data[k].device) + out1[k] = out1[k].cpu() + out2[k] = out2[k].cpu() + data[k] = data[k].cpu() + np.testing.assert_allclose(data[k], out1[k]) np.testing.assert_allclose(data[k], out2[k]) - @parameterized.expand(TEST_CASES) - def test_intensity(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_intensity(self, im_shape, im_type): - data = self.get_data(im_shape, as_tensor_input) + data = self.get_data(im_shape, im_type) image_range = (15, 15.1) label_range = (14, 14.1) t = RandKSpaceSpikeNoised( @@ -113,7 +120,6 @@ def test_intensity(self, im_shape, as_tensor_output, as_tensor_input): img_intensity_range=image_range, label_intensity_range=label_range, channel_wise=True, - as_tensor_output=True, ) _ = t(data) @@ -122,9 +128,9 @@ def test_intensity(self, im_shape, as_tensor_output, as_tensor_input): self.assertGreaterEqual(t.t_label.sampled_k_intensity[0], 14) self.assertLessEqual(t.t_label.sampled_k_intensity[0], 14.1) - @parameterized.expand(TEST_CASES) - def test_same_transformation(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_same_transformation(self, im_shape, im_type): + data = self.get_data(im_shape, im_type) # use same image for both dictionary entries to check same trans is applied to them data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} @@ -138,11 +144,16 @@ def test_same_transformation(self, im_shape, _, as_tensor_input): label_intensity_range=label_range, channel_wise=True, common_sampling=True, - as_tensor_output=True, ) out = t(deepcopy(data)) + for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k] = out[k].cpu() + np.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]]) diff --git a/tests/test_rand_rician_noise.py b/tests/test_rand_rician_noise.py index 6504fd9069..7ec5fc4dc4 100644 --- a/tests/test_rand_rician_noise.py +++ b/tests/test_rand_rician_noise.py @@ -12,36 +12,25 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandRicianNoise -from tests.utils import NumpyImageTestCase2D, TorchImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D - -class TestRandRicianNoise(NumpyImageTestCase2D): - @parameterized.expand([("test_zero_mean", 0, 0.1), ("test_non_zero_mean", 1, 0.5)]) - def test_correct_results(self, _, mean, std): - seed = 0 - rician_fn = RandRicianNoise(prob=1.0, mean=mean, std=std) - rician_fn.set_random_state(seed) - noised = rician_fn(self.imt) - np.random.seed(seed) - np.random.random() - _std = np.random.uniform(0, std) - expected = np.sqrt( - (self.imt + np.random.normal(mean, _std, size=self.imt.shape)) ** 2 - + np.random.normal(mean, _std, size=self.imt.shape) ** 2 - ) - np.testing.assert_allclose(expected, noised, atol=1e-5) +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append(("test_zero_mean", p, 0, 0.1)) + TESTS.append(("test_non_zero_mean", p, 1, 0.5)) -class TestRandRicianNoiseTorch(TorchImageTestCase2D): - @parameterized.expand([("test_zero_mean", 0, 0.1), ("test_non_zero_mean", 1, 0.5)]) - def test_correct_results(self, _, mean, std): +class TestRandRicianNoise(NumpyImageTestCase2D): + @parameterized.expand(TESTS) + def test_correct_results(self, _, in_type, mean, std): seed = 0 rician_fn = RandRicianNoise(prob=1.0, mean=mean, std=std) rician_fn.set_random_state(seed) - noised = rician_fn(self.imt) + noised = rician_fn(in_type(self.imt)) np.random.seed(seed) np.random.random() _std = np.random.uniform(0, std) @@ -49,6 +38,8 @@ def test_correct_results(self, _, mean, std): (self.imt + np.random.normal(mean, _std, size=self.imt.shape)) ** 2 + np.random.normal(mean, _std, size=self.imt.shape) ** 2 ) + if isinstance(noised, torch.Tensor): + noised = noised.cpu() np.testing.assert_allclose(expected, noised, atol=1e-5) diff --git a/tests/test_rand_rician_noised.py b/tests/test_rand_rician_noised.py index 3dbfce154d..666edc3d3b 100644 --- a/tests/test_rand_rician_noised.py +++ b/tests/test_rand_rician_noised.py @@ -12,48 +12,38 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandRicianNoised -from tests.utils import NumpyImageTestCase2D, TorchImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D -TEST_CASE_0 = ["test_zero_mean", ["img1", "img2"], 0, 0.1] -TEST_CASE_1 = ["test_non_zero_mean", ["img1", "img2"], 1, 0.5] -TEST_CASES = [TEST_CASE_0, TEST_CASE_1] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append(["test_zero_mean", p, ["img1", "img2"], 0, 0.1]) + TESTS.append(["test_non_zero_mean", p, ["img1", "img2"], 1, 0.5]) seed = 0 -def test_numpy_or_torch(keys, mean, std, imt): - rician_fn = RandRicianNoised(keys=keys, global_prob=1.0, prob=1.0, mean=mean, std=std) - rician_fn.set_random_state(seed) - rician_fn.rand_rician_noise.set_random_state(seed) - noised = rician_fn({k: imt for k in keys}) - np.random.seed(seed) - np.random.random() - np.random.seed(seed) - for k in keys: - np.random.random() - _std = np.random.uniform(0, std) - expected = np.sqrt( - (imt + np.random.normal(mean, _std, size=imt.shape)) ** 2 - + np.random.normal(mean, _std, size=imt.shape) ** 2 - ) - np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5) - - # Test with numpy class TestRandRicianNoisedNumpy(NumpyImageTestCase2D): - @parameterized.expand(TEST_CASES) - def test_correct_results(self, _, keys, mean, std): - test_numpy_or_torch(keys, mean, std, self.imt) - - -# Test with torch -class TestRandRicianNoisedTorch(TorchImageTestCase2D): - @parameterized.expand(TEST_CASES) - def test_correct_results(self, _, keys, mean, std): - test_numpy_or_torch(keys, mean, std, self.imt) + @parameterized.expand(TESTS) + def test_correct_results(self, _, in_type, keys, mean, std): + rician_fn = RandRicianNoised(keys=keys, global_prob=1.0, prob=1.0, mean=mean, std=std) + rician_fn.set_random_state(seed) + noised = rician_fn({k: in_type(self.imt) for k in keys}) + np.random.seed(seed) + for k in keys: + np.random.random() + _std = np.random.uniform(0, std) + expected = np.sqrt( + (self.imt + np.random.normal(mean, _std, size=self.imt.shape)) ** 2 + + np.random.normal(mean, _std, size=self.imt.shape) ** 2 + ) + if isinstance(noised[k], torch.Tensor): + noised[k] = noised[k].cpu() + np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5) if __name__ == "__main__": diff --git a/tests/test_rand_rotate.py b/tests/test_rand_rotate.py index 0ff8508a0f..4817e81735 100644 --- a/tests/test_rand_rotate.py +++ b/tests/test_rand_rotate.py @@ -10,25 +10,60 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import RandRotate -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, np.pi / 2, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4, True, "nearest", "border", False)) + TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", True)) + TEST_CASES_2D.append((p, (-np.pi / 4, 0), False, "nearest", "zeros", True)) -class TestRandRotate2D(NumpyImageTestCase2D): - @parameterized.expand( - [ - (np.pi / 2, True, "bilinear", "border", False), - (np.pi / 4, True, "nearest", "border", False), - (np.pi, False, "nearest", "zeros", True), - ((-np.pi / 4, 0), False, "nearest", "zeros", True), - ] +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append( + (p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)) + ) + TEST_CASES_3D.append( + ( + p, + np.pi / 4, + (-np.pi / 9, np.pi / 4.5), + (np.pi / 9, np.pi / 6), + False, + "nearest", + "border", + True, + (1, 89, 105, 104), + ) + ) + TEST_CASES_3D.append( + ( + p, + 0.0, + (2 * np.pi, 2.06 * np.pi), + (-np.pi / 180, np.pi / 180), + True, + "nearest", + "zeros", + True, + (1, 48, 64, 80), + ) ) - def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_corners): + TEST_CASES_3D.append((p, (-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90))) + + +class TestRandRotate2D(NumpyImageTestCase2D): + @parameterized.expand(TEST_CASES_2D) + def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners): rotate_fn = RandRotate( range_x=degrees, prob=1.0, @@ -38,7 +73,7 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor align_corners=align_corners, ) rotate_fn.set_random_state(243) - rotated = rotate_fn(self.imt[0]) + rotated = rotate_fn(im_type(self.imt[0])) _order = 0 if mode == "nearest" else 1 if mode == "border": @@ -52,38 +87,14 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) expected = np.stack(expected).astype(np.float32) + rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated good = np.sum(np.isclose(expected, rotated[0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") class TestRandRotate3D(NumpyImageTestCase3D): - @parameterized.expand( - [ - (np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)), - ( - np.pi / 4, - (-np.pi / 9, np.pi / 4.5), - (np.pi / 9, np.pi / 6), - False, - "nearest", - "border", - True, - (1, 89, 105, 104), - ), - ( - 0.0, - (2 * np.pi, 2.06 * np.pi), - (-np.pi / 180, np.pi / 180), - True, - "nearest", - "zeros", - True, - (1, 48, 64, 80), - ), - ((-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90)), - ] - ) - def test_correct_results(self, x, y, z, keep_size, mode, padding_mode, align_corners, expected): + @parameterized.expand(TEST_CASES_3D) + def test_correct_results(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected): rotate_fn = RandRotate( range_x=x, range_y=y, @@ -95,8 +106,8 @@ def test_correct_results(self, x, y, z, keep_size, mode, padding_mode, align_cor align_corners=align_corners, ) rotate_fn.set_random_state(243) - rotated = rotate_fn(self.imt[0]) - np.testing.assert_allclose(rotated.shape, expected) + rotated = rotate_fn(im_type(self.imt[0])) + torch.testing.assert_allclose(rotated.shape, expected, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_rand_rotate90.py b/tests/test_rand_rotate90.py index 50a1b28e53..dededbba8d 100644 --- a/tests/test_rand_rotate90.py +++ b/tests/test_rand_rotate90.py @@ -12,51 +12,56 @@ import unittest import numpy as np +import torch from monai.transforms import RandRotate90 -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D class TestRandRotate90(NumpyImageTestCase2D): def test_default(self): rotate = RandRotate90() - rotate.set_random_state(123) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(123) + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = p(np.stack(expected)) + torch.testing.assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) def test_k(self): rotate = RandRotate90(max_k=2) - rotate.set_random_state(234) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(234) + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = p(np.stack(expected)) + torch.testing.assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) def test_spatial_axes(self): rotate = RandRotate90(spatial_axes=(0, 1)) - rotate.set_random_state(234) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(234) + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = p(np.stack(expected)) + torch.testing.assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) def test_prob_k_spatial_axes(self): rotate = RandRotate90(prob=1.0, max_k=2, spatial_axes=(0, 1)) - rotate.set_random_state(234) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(234) + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = p(np.stack(expected)) + torch.testing.assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) if __name__ == "__main__": diff --git a/tests/test_rand_rotated.py b/tests/test_rand_rotated.py index 47b4b7107e..4c9a27f668 100644 --- a/tests/test_rand_rotated.py +++ b/tests/test_rand_rotated.py @@ -10,26 +10,104 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import RandRotated from monai.utils import GridSampleMode, GridSamplePadMode -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, np.pi / 2, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4, True, "nearest", "border", False)) + TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", True)) + TEST_CASES_2D.append((p, (-np.pi / 4, 0), False, "nearest", "zeros", True)) -class TestRandRotated2D(NumpyImageTestCase2D): - @parameterized.expand( - [ - (np.pi / 2, True, "bilinear", "border", False), - (np.pi / 4, True, "nearest", "border", False), - (np.pi, False, "nearest", "zeros", True), - ((-np.pi / 4, 0), False, "nearest", "zeros", True), - ] + +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append( + (p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)) ) - def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_corners): + TEST_CASES_3D.append( + ( + p, + np.pi / 2, + -np.pi / 6, + (0.0, np.pi), + False, + GridSampleMode.NEAREST, + GridSamplePadMode.BORDER, + False, + (1, 87, 104, 109), + ) + ) + TEST_CASES_3D.append( + ( + p, + np.pi / 4, + (-np.pi / 9, np.pi / 4.5), + (np.pi / 9, np.pi / 6), + False, + "nearest", + "border", + True, + (1, 89, 105, 104), + ) + ) + TEST_CASES_3D.append( + ( + p, + np.pi / 4, + (-np.pi / 9, np.pi / 4.5), + (np.pi / 9, np.pi / 6), + False, + GridSampleMode.NEAREST, + GridSamplePadMode.BORDER, + True, + (1, 89, 105, 104), + ) + ) + TEST_CASES_3D.append( + ( + p, + 0.0, + (2 * np.pi, 2.06 * np.pi), + (-np.pi / 180, np.pi / 180), + True, + "nearest", + "zeros", + True, + (1, 48, 64, 80), + ) + ) + TEST_CASES_3D.append( + ( + p, + 0.0, + (2 * np.pi, 2.06 * np.pi), + (-np.pi / 180, np.pi / 180), + True, + GridSampleMode.NEAREST, + GridSamplePadMode.ZEROS, + True, + (1, 48, 64, 80), + ) + ) + TEST_CASES_3D.append((p, (-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90))) + TEST_CASES_3D.append( + (p, (-np.pi / 4, 0), 0, 0, False, GridSampleMode.NEAREST, GridSamplePadMode.ZEROS, False, (1, 48, 77, 90)) + ) + + +class TestRandRotated2D(NumpyImageTestCase2D): + @parameterized.expand(TEST_CASES_2D) + def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners): rotate_fn = RandRotated( "img", range_x=degrees, @@ -40,7 +118,7 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor align_corners=align_corners, ) rotate_fn.set_random_state(243) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) _order = 0 if mode == "nearest" else 1 if padding_mode == "border": @@ -53,70 +131,16 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor expected = scipy.ndimage.rotate( self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v expected = np.stack(expected).astype(np.float32) good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") class TestRandRotated3D(NumpyImageTestCase3D): - @parameterized.expand( - [ - (np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)), - ( - np.pi / 2, - -np.pi / 6, - (0.0, np.pi), - False, - GridSampleMode.NEAREST, - GridSamplePadMode.BORDER, - False, - (1, 87, 104, 109), - ), - ( - np.pi / 4, - (-np.pi / 9, np.pi / 4.5), - (np.pi / 9, np.pi / 6), - False, - "nearest", - "border", - True, - (1, 89, 105, 104), - ), - ( - np.pi / 4, - (-np.pi / 9, np.pi / 4.5), - (np.pi / 9, np.pi / 6), - False, - GridSampleMode.NEAREST, - GridSamplePadMode.BORDER, - True, - (1, 89, 105, 104), - ), - ( - 0.0, - (2 * np.pi, 2.06 * np.pi), - (-np.pi / 180, np.pi / 180), - True, - "nearest", - "zeros", - True, - (1, 48, 64, 80), - ), - ( - 0.0, - (2 * np.pi, 2.06 * np.pi), - (-np.pi / 180, np.pi / 180), - True, - GridSampleMode.NEAREST, - GridSamplePadMode.ZEROS, - True, - (1, 48, 64, 80), - ), - ((-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90)), - ((-np.pi / 4, 0), 0, 0, False, GridSampleMode.NEAREST, GridSamplePadMode.ZEROS, False, (1, 48, 77, 90)), - ] - ) - def test_correct_shapes(self, x, y, z, keep_size, mode, padding_mode, align_corners, expected): + @parameterized.expand(TEST_CASES_3D) + def test_correct_shapes(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected): rotate_fn = RandRotated( "img", range_x=x, @@ -129,7 +153,7 @@ def test_correct_shapes(self, x, y, z, keep_size, mode, padding_mode, align_corn align_corners=align_corners, ) rotate_fn.set_random_state(243) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) np.testing.assert_allclose(rotated["img"].shape, expected) diff --git a/tests/test_rand_scale_crop.py b/tests/test_rand_scale_crop.py index db5487ebff..a0870a12b5 100644 --- a/tests/test_rand_scale_crop.py +++ b/tests/test_rand_scale_crop.py @@ -12,9 +12,11 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandScaleCrop +from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [ {"roi_scale": [1.0, 1.0, -1.0], "random_center": True}, @@ -55,22 +57,45 @@ class TestRandScaleCrop(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, input_param, input_data, expected_shape): - result = RandScaleCrop(**input_param)(input_data) - self.assertTupleEqual(result.shape, expected_shape) + results = [] + for p in TEST_NDARRAYS: + im = p(input_data) + cropper = RandScaleCrop(**input_param) + cropper.set_random_state(0) + result = cropper(im) + if isinstance(result, torch.Tensor): + result = result.cpu().numpy() + self.assertTupleEqual(result.shape, expected_shape) + results.append(result) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) @parameterized.expand([TEST_CASE_3]) def test_value(self, input_param, input_data): - cropper = RandScaleCrop(**input_param) - result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - np.testing.assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + for p in TEST_NDARRAYS: + im = p(input_data) + cropper = RandScaleCrop(**input_param) + cropper.set_random_state(0) + result = cropper(im) + if isinstance(result, torch.Tensor): + result = result.cpu().numpy() + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + np.testing.assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_random_shape(self, input_param, input_data, expected_shape): - cropper = RandScaleCrop(**input_param) - cropper.set_random_state(seed=123) - result = cropper(input_data) - self.assertTupleEqual(result.shape, expected_shape) + results = [] + for p in TEST_NDARRAYS: + im = p(input_data) + cropper = RandScaleCrop(**input_param) + cropper.set_random_state(123) + result = cropper(im) + if isinstance(result, torch.Tensor): + result = result.cpu().numpy() + self.assertTupleEqual(result.shape, expected_shape) + results.append(result) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) if __name__ == "__main__": diff --git a/tests/test_rand_scale_cropd.py b/tests/test_rand_scale_cropd.py index 265c6c467d..0f04bab484 100644 --- a/tests/test_rand_scale_cropd.py +++ b/tests/test_rand_scale_cropd.py @@ -12,9 +12,11 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandScaleCropd +from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [ {"keys": "img", "roi_scale": [1.0, 1.0, -1.0], "random_center": True}, @@ -61,22 +63,55 @@ class TestRandScaleCropd(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, input_param, input_data, expected_shape): - result = RandScaleCropd(**input_param)(input_data) - self.assertTupleEqual(result["img"].shape, expected_shape) + results = [] + for p in TEST_NDARRAYS: + data = {"img": p(input_data["img"])} + cropper = RandScaleCropd(**input_param) + cropper.set_random_state(0) + result = cropper(data)["img"] + self.assertTupleEqual(result.shape, expected_shape) + self.assertEqual(type(data["img"]), type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, data["img"].device) + result = result.cpu().numpy() + results.append(result) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) @parameterized.expand([TEST_CASE_3]) def test_value(self, input_param, input_data): - cropper = RandScaleCropd(**input_param) - result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - np.testing.assert_allclose(result["img"], input_data["img"][:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + results = [] + for p in TEST_NDARRAYS: + data = {"img": p(input_data["img"])} + cropper = RandScaleCropd(**input_param) + cropper.set_random_state(0) + result = cropper(data)["img"] + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + self.assertEqual(type(data["img"]), type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, data["img"].device) + result = result.cpu().numpy() + np.testing.assert_allclose(result, input_data["img"][:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + results.append(result) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_random_shape(self, input_param, input_data, expected_shape): - cropper = RandScaleCropd(**input_param) - cropper.set_random_state(seed=123) - result = cropper(input_data) - self.assertTupleEqual(result["img"].shape, expected_shape) + results = [] + for p in TEST_NDARRAYS: + data = {"img": p(input_data["img"])} + cropper = RandScaleCropd(**input_param) + cropper.set_random_state(seed=123) + result = cropper(data)["img"] + self.assertTupleEqual(result.shape, expected_shape) + self.assertEqual(type(data["img"]), type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, data["img"].device) + result = result.cpu().numpy() + results.append(result) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) if __name__ == "__main__": diff --git a/tests/test_rand_scale_intensity.py b/tests/test_rand_scale_intensity.py index 2126301758..74a0be164e 100644 --- a/tests/test_rand_scale_intensity.py +++ b/tests/test_rand_scale_intensity.py @@ -12,19 +12,21 @@ import unittest import numpy as np +import torch from monai.transforms import RandScaleIntensity -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D class TestRandScaleIntensity(NumpyImageTestCase2D): def test_value(self): - scaler = RandScaleIntensity(factors=0.5, prob=1.0) - scaler.set_random_state(seed=0) - result = scaler(self.imt) - np.random.seed(0) - expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32) - np.testing.assert_allclose(result, expected) + for p in TEST_NDARRAYS: + scaler = RandScaleIntensity(factors=0.5, prob=1.0) + scaler.set_random_state(seed=0) + result = scaler(p(self.imt)) + np.random.seed(0) + expected = p((self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32)) + torch.testing.assert_allclose(result, expected, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_rand_spatial_crop.py b/tests/test_rand_spatial_crop.py index 01e057e589..2897ce59f5 100644 --- a/tests/test_rand_spatial_crop.py +++ b/tests/test_rand_spatial_crop.py @@ -12,9 +12,11 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandSpatialCrop +from tests.utils import TEST_NDARRAYS TEST_CASE_0 = [ {"roi_size": [3, 3, -1], "random_center": True}, @@ -51,22 +53,51 @@ class TestRandSpatialCrop(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) def test_shape(self, input_param, input_data, expected_shape): - result = RandSpatialCrop(**input_param)(input_data) - self.assertTupleEqual(result.shape, expected_shape) + results = [] + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS + (None,): + im = p(input_data) + input_param_mod = {k: q(v) if q is not None else v for k, v in input_param.items()} + cropper = RandSpatialCrop(**input_param_mod) + cropper.set_random_state(0) + result = cropper(im) + if isinstance(result, torch.Tensor): + result = result.cpu().numpy() + self.assertTupleEqual(result.shape, expected_shape) + results.append(result) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) @parameterized.expand([TEST_CASE_3]) def test_value(self, input_param, input_data): - cropper = RandSpatialCrop(**input_param) - result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - np.testing.assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS + (None,): + im = p(input_data) + input_param_mod = {k: q(v) if q is not None else v for k, v in input_param.items()} + cropper = RandSpatialCrop(**input_param_mod) + cropper.set_random_state(0) + result = cropper(im) + if isinstance(result, torch.Tensor): + result = result.cpu().numpy() + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + np.testing.assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) def test_random_shape(self, input_param, input_data, expected_shape): - cropper = RandSpatialCrop(**input_param) - cropper.set_random_state(seed=123) - result = cropper(input_data) - self.assertTupleEqual(result.shape, expected_shape) + results = [] + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS + (None,): + im = p(input_data) + input_param_mod = {k: q(v) if q is not None else v for k, v in input_param.items()} + cropper = RandSpatialCrop(**input_param_mod) + cropper.set_random_state(123) + result = cropper(im) + if isinstance(result, torch.Tensor): + result = result.cpu().numpy() + self.assertTupleEqual(result.shape, expected_shape) + results.append(result) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) if __name__ == "__main__": diff --git a/tests/test_rand_spatial_crop_samples.py b/tests/test_rand_spatial_crop_samples.py index 0ade9bbbba..d84247b4da 100644 --- a/tests/test_rand_spatial_crop_samples.py +++ b/tests/test_rand_spatial_crop_samples.py @@ -12,63 +12,80 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandSpatialCropSamples +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"roi_size": [3, 3, 3], "num_samples": 4, "random_center": True, "random_size": False}, - np.arange(192).reshape(3, 4, 4, 4), - [(3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3)], - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( [ - [ - [[21, 22, 23], [25, 26, 27], [29, 30, 31]], - [[37, 38, 39], [41, 42, 43], [45, 46, 47]], - [[53, 54, 55], [57, 58, 59], [61, 62, 63]], - ], - [ - [[85, 86, 87], [89, 90, 91], [93, 94, 95]], - [[101, 102, 103], [105, 106, 107], [109, 110, 111]], - [[117, 118, 119], [121, 122, 123], [125, 126, 127]], - ], - [ - [[149, 150, 151], [153, 154, 155], [157, 158, 159]], - [[165, 166, 167], [169, 170, 171], [173, 174, 175]], - [[181, 182, 183], [185, 186, 187], [189, 190, 191]], - ], + {"roi_size": [3, 3, 3], "num_samples": 4, "random_center": True, "random_size": False}, + p(np.arange(192).reshape(3, 4, 4, 4)), + [(3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3)], + np.array( + [ + [ + [[21, 22, 23], [25, 26, 27], [29, 30, 31]], + [[37, 38, 39], [41, 42, 43], [45, 46, 47]], + [[53, 54, 55], [57, 58, 59], [61, 62, 63]], + ], + [ + [[85, 86, 87], [89, 90, 91], [93, 94, 95]], + [[101, 102, 103], [105, 106, 107], [109, 110, 111]], + [[117, 118, 119], [121, 122, 123], [125, 126, 127]], + ], + [ + [[149, 150, 151], [153, 154, 155], [157, 158, 159]], + [[165, 166, 167], [169, 170, 171], [173, 174, 175]], + [[181, 182, 183], [185, 186, 187], [189, 190, 191]], + ], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - {"roi_size": [3, 3, 3], "num_samples": 8, "random_center": False, "random_size": True}, - np.arange(192).reshape(3, 4, 4, 4), - [(3, 4, 4, 3), (3, 4, 3, 3), (3, 3, 4, 4), (3, 4, 4, 4), (3, 3, 3, 4), (3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3)], - np.array( + TESTS.append( [ + {"roi_size": [3, 3, 3], "num_samples": 8, "random_center": False, "random_size": True}, + p(np.arange(192).reshape(3, 4, 4, 4)), [ - [[21, 22, 23], [25, 26, 27], [29, 30, 31]], - [[37, 38, 39], [41, 42, 43], [45, 46, 47]], - [[53, 54, 55], [57, 58, 59], [61, 62, 63]], - ], - [ - [[85, 86, 87], [89, 90, 91], [93, 94, 95]], - [[101, 102, 103], [105, 106, 107], [109, 110, 111]], - [[117, 118, 119], [121, 122, 123], [125, 126, 127]], - ], - [ - [[149, 150, 151], [153, 154, 155], [157, 158, 159]], - [[165, 166, 167], [169, 170, 171], [173, 174, 175]], - [[181, 182, 183], [185, 186, 187], [189, 190, 191]], + (3, 4, 4, 3), + (3, 4, 3, 3), + (3, 3, 4, 4), + (3, 4, 4, 4), + (3, 3, 3, 4), + (3, 3, 3, 3), + (3, 3, 3, 3), + (3, 3, 3, 3), ], + np.array( + [ + [ + [[21, 22, 23], [25, 26, 27], [29, 30, 31]], + [[37, 38, 39], [41, 42, 43], [45, 46, 47]], + [[53, 54, 55], [57, 58, 59], [61, 62, 63]], + ], + [ + [[85, 86, 87], [89, 90, 91], [93, 94, 95]], + [[101, 102, 103], [105, 106, 107], [109, 110, 111]], + [[117, 118, 119], [121, 122, 123], [125, 126, 127]], + ], + [ + [[149, 150, 151], [153, 154, 155], [157, 158, 159]], + [[165, 166, 167], [169, 170, 171], [173, 174, 175]], + [[181, 182, 183], [185, 186, 187], [189, 190, 191]], + ], + ] + ), ] - ), -] + ) class TestRandSpatialCropSamples(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape, expected_last_item): xform = RandSpatialCropSamples(**input_param) xform.set_random_state(1234) @@ -77,7 +94,10 @@ def test_shape(self, input_param, input_data, expected_shape, expected_last_item np.testing.assert_equal(len(result), input_param["num_samples"]) for item, expected in zip(result, expected_shape): self.assertTupleEqual(item.shape, expected) - np.testing.assert_allclose(result[-1], expected_last_item) + r = result[-1] + if isinstance(r, torch.Tensor): + r = r.cpu() + np.testing.assert_allclose(r, expected_last_item) if __name__ == "__main__": diff --git a/tests/test_rand_spatial_crop_samplesd.py b/tests/test_rand_spatial_crop_samplesd.py index 3f5eee7b27..47979838a4 100644 --- a/tests/test_rand_spatial_crop_samplesd.py +++ b/tests/test_rand_spatial_crop_samplesd.py @@ -12,58 +12,74 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import Compose, RandSpatialCropSamplesd, ToTensord +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"keys": ["img", "seg"], "num_samples": 4, "roi_size": [2, 2, 2], "random_center": True}, - {"img": np.arange(81).reshape(3, 3, 3, 3), "seg": np.arange(81, 0, -1).reshape(3, 3, 3, 3)}, - [(3, 3, 3, 2), (3, 2, 2, 2), (3, 3, 3, 2), (3, 3, 2, 2)], - { - "img": np.array( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["img", "seg"], "num_samples": 4, "roi_size": [2, 2, 2], "random_center": True}, + {"img": p(np.arange(81).reshape(3, 3, 3, 3)), "seg": p(np.arange(81, 0, -1).reshape(3, 3, 3, 3))}, + [(3, 3, 3, 2), (3, 2, 2, 2), (3, 3, 3, 2), (3, 3, 2, 2)], + { + "img": np.array( + [ + [[[0, 1], [3, 4]], [[9, 10], [12, 13]], [[18, 19], [21, 22]]], + [[[27, 28], [30, 31]], [[36, 37], [39, 40]], [[45, 46], [48, 49]]], + [[[54, 55], [57, 58]], [[63, 64], [66, 67]], [[72, 73], [75, 76]]], + ] + ), + "seg": np.array( + [ + [[[81, 80], [78, 77]], [[72, 71], [69, 68]], [[63, 62], [60, 59]]], + [[[54, 53], [51, 50]], [[45, 44], [42, 41]], [[36, 35], [33, 32]]], + [[[27, 26], [24, 23]], [[18, 17], [15, 14]], [[9, 8], [6, 5]]], + ] + ), + }, + ] + ) + TESTS.append( + [ + {"keys": ["img", "seg"], "num_samples": 8, "roi_size": [2, 2, 3], "random_center": False}, + {"img": p(np.arange(81).reshape(3, 3, 3, 3)), "seg": p(np.arange(81, 0, -1).reshape(3, 3, 3, 3))}, [ - [[[0, 1], [3, 4]], [[9, 10], [12, 13]], [[18, 19], [21, 22]]], - [[[27, 28], [30, 31]], [[36, 37], [39, 40]], [[45, 46], [48, 49]]], - [[[54, 55], [57, 58]], [[63, 64], [66, 67]], [[72, 73], [75, 76]]], - ] - ), - "seg": np.array( - [ - [[[81, 80], [78, 77]], [[72, 71], [69, 68]], [[63, 62], [60, 59]]], - [[[54, 53], [51, 50]], [[45, 44], [42, 41]], [[36, 35], [33, 32]]], - [[[27, 26], [24, 23]], [[18, 17], [15, 14]], [[9, 8], [6, 5]]], - ] - ), - }, -] - -TEST_CASE_2 = [ - {"keys": ["img", "seg"], "num_samples": 8, "roi_size": [2, 2, 3], "random_center": False}, - {"img": np.arange(81).reshape(3, 3, 3, 3), "seg": np.arange(81, 0, -1).reshape(3, 3, 3, 3)}, - [(3, 3, 3, 3), (3, 2, 3, 3), (3, 2, 2, 3), (3, 2, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3), (3, 2, 2, 3), (3, 3, 2, 3)], - { - "img": np.array( - [ - [[[0, 1, 2], [3, 4, 5]], [[9, 10, 11], [12, 13, 14]], [[18, 19, 20], [21, 22, 23]]], - [[[27, 28, 29], [30, 31, 32]], [[36, 37, 38], [39, 40, 41]], [[45, 46, 47], [48, 49, 50]]], - [[[54, 55, 56], [57, 58, 59]], [[63, 64, 65], [66, 67, 68]], [[72, 73, 74], [75, 76, 77]]], - ] - ), - "seg": np.array( - [ - [[[81, 80, 79], [78, 77, 76]], [[72, 71, 70], [69, 68, 67]], [[63, 62, 61], [60, 59, 58]]], - [[[54, 53, 52], [51, 50, 49]], [[45, 44, 43], [42, 41, 40]], [[36, 35, 34], [33, 32, 31]]], - [[[27, 26, 25], [24, 23, 22]], [[18, 17, 16], [15, 14, 13]], [[9, 8, 7], [6, 5, 4]]], - ] - ), - }, -] + (3, 3, 3, 3), + (3, 2, 3, 3), + (3, 2, 2, 3), + (3, 2, 3, 3), + (3, 3, 3, 3), + (3, 3, 3, 3), + (3, 2, 2, 3), + (3, 3, 2, 3), + ], + { + "img": np.array( + [ + [[[0, 1, 2], [3, 4, 5]], [[9, 10, 11], [12, 13, 14]], [[18, 19, 20], [21, 22, 23]]], + [[[27, 28, 29], [30, 31, 32]], [[36, 37, 38], [39, 40, 41]], [[45, 46, 47], [48, 49, 50]]], + [[[54, 55, 56], [57, 58, 59]], [[63, 64, 65], [66, 67, 68]], [[72, 73, 74], [75, 76, 77]]], + ] + ), + "seg": np.array( + [ + [[[81, 80, 79], [78, 77, 76]], [[72, 71, 70], [69, 68, 67]], [[63, 62, 61], [60, 59, 58]]], + [[[54, 53, 52], [51, 50, 49]], [[45, 44, 43], [42, 41, 40]], [[36, 35, 34], [33, 32, 31]]], + [[[27, 26, 25], [24, 23, 22]], [[18, 17, 16], [15, 14, 13]], [[9, 8, 7], [6, 5, 4]]], + ] + ), + }, + ] + ) class TestRandSpatialCropSamplesd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_shape(self, input_param, input_data, expected_shape, expected_last): + @parameterized.expand(TESTS) + def test_value(self, input_param, input_data, expected_shape, expected_last): xform = RandSpatialCropSamplesd(**input_param) xform.set_random_state(1234) result = xform(input_data) @@ -73,24 +89,30 @@ def test_shape(self, input_param, input_data, expected_shape, expected_last): for i, item in enumerate(result): self.assertEqual(item["img_meta_dict"]["patch_index"], i) self.assertEqual(item["seg_meta_dict"]["patch_index"], i) + self.assertEqual(type(item["img"]), type(input_data["img"])) + if isinstance(item["img"], torch.Tensor): + self.assertEqual(item["img"].device, input_data["img"].device) + item["img"] = item["img"].cpu() + item["seg"] = item["seg"].cpu() np.testing.assert_allclose(item["img"], expected_last["img"]) np.testing.assert_allclose(item["seg"], expected_last["seg"]) def test_deep_copy(self): - data = {"img": np.ones((1, 10, 11, 12))} - num_samples = 3 - sampler = RandSpatialCropSamplesd( - keys=["img"], - roi_size=(3, 3, 3), - num_samples=num_samples, - random_center=True, - random_size=False, - ) - transform = Compose([ToTensord(keys="img"), sampler]) - samples = transform(data) - self.assertEqual(len(samples), num_samples) - for sample in samples: - self.assertEqual(len(sample["img_transforms"]), len(transform)) + for p in TEST_NDARRAYS: + data = {"img": p(np.ones((1, 10, 11, 12)))} + num_samples = 3 + sampler = RandSpatialCropSamplesd( + keys=["img"], + roi_size=(3, 3, 3), + num_samples=num_samples, + random_center=True, + random_size=False, + ) + transform = Compose([ToTensord(keys="img"), sampler]) + samples = transform(data) + self.assertEqual(len(samples), num_samples) + for sample in samples: + self.assertEqual(len(sample["img_transforms"]), len(transform)) if __name__ == "__main__": diff --git a/tests/test_rand_std_shift_intensity.py b/tests/test_rand_std_shift_intensity.py index 9aff50ab66..a036734b2b 100644 --- a/tests/test_rand_std_shift_intensity.py +++ b/tests/test_rand_std_shift_intensity.py @@ -12,20 +12,23 @@ import unittest import numpy as np +import torch from monai.transforms import RandStdShiftIntensity -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D class TestRandStdShiftIntensity(NumpyImageTestCase2D): def test_value(self): - shifter = RandStdShiftIntensity(factors=1.0, prob=1.0) - shifter.set_random_state(seed=0) - result = shifter(self.imt) - np.random.seed(0) - factor = np.random.uniform(low=-1.0, high=1.0) - expected = self.imt + factor * np.std(self.imt) - np.testing.assert_allclose(result, expected, rtol=1e-5) + for p in TEST_NDARRAYS: + shifter = RandStdShiftIntensity(factors=1.0, prob=1.0) + shifter.set_random_state(seed=0) + result = shifter(p(self.imt)) + np.random.seed(0) + factor = np.random.uniform(low=-1.0, high=1.0) + offset = factor * np.std(self.imt) + expected = p(self.imt + offset) + torch.testing.assert_allclose(result, expected, atol=0, rtol=1e-5) if __name__ == "__main__": diff --git a/tests/test_rand_weighted_crop.py b/tests/test_rand_weighted_crop.py index 39a9439122..5ee815b711 100644 --- a/tests/test_rand_weighted_crop.py +++ b/tests/test_rand_weighted_crop.py @@ -12,127 +12,161 @@ import unittest import numpy as np +import torch +from parameterized.parameterized import parameterized from monai.transforms.croppad.array import RandWeightedCrop -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D -class TestRandWeightedCrop2D(NumpyImageTestCase2D): - def test_rand_weighted_crop_small_roi(self): - img = self.seg1[0] - n_samples = 3 - crop = RandWeightedCrop((10, 12), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 40, 31] = 1 - weight[0, 80, 21] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 10, 12)) - np.testing.assert_allclose(np.asarray(crop.centers), [[80, 21], [30, 17], [40, 31]]) - - def test_rand_weighted_crop_default_roi(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCrop((10, -1), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 40, 31] = 1 - weight[0, 80, 21] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 10, 64)) - np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32], [105, 32], [20, 32]]) - - def test_rand_weighted_crop_large_roi(self): - img = self.segn[0] - n_samples = 3 - crop = RandWeightedCrop((10000, 400), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 10, 1] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 128, 64)) - np.testing.assert_allclose(np.asarray(crop.centers), [[64, 32], [64, 32], [64, 32]]) - for res in result: - np.testing.assert_allclose(res, self.segn[0]) +def get_data(ndim): + im_gen = NumpyImageTestCase2D() if ndim == 2 else NumpyImageTestCase3D() + im_gen.setUp() + return im_gen.imt[0], im_gen.seg1[0], im_gen.segn[0] - def test_rand_weighted_crop_bad_w(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCrop((20, 40), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = np.inf - weight[0, 10, 1] = -np.inf - weight[0, 10, 20] = -np.nan - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 20, 40)) - np.testing.assert_allclose(np.asarray(crop.centers), [[63, 37], [31, 43], [66, 20]]) +IMT_2D, SEG1_2D, SEGN_2D = get_data(ndim=2) +IMT_3D, SEG1_3D, SEGN_3D = get_data(ndim=3) -class TestRandWeightedCrop(NumpyImageTestCase3D): - def test_rand_weighted_crop_small_roi(self): - img = self.seg1[0] - n_samples = 3 - crop = RandWeightedCrop((8, 10, 12), n_samples) - weight = np.zeros_like(img) - weight[0, 5, 30, 17] = 1.1 - weight[0, 8, 40, 31] = 1 - weight[0, 11, 23, 21] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 8, 10, 12)) - np.testing.assert_allclose(np.asarray(crop.centers), [[11, 23, 21], [5, 30, 17], [8, 40, 31]]) - def test_rand_weighted_crop_default_roi(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCrop((10, -1, -1), n_samples) - weight = np.zeros_like(img) - weight[0, 7, 17] = 1.1 - weight[0, 13, 31] = 1.1 - weight[0, 24, 21] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32, 40], [41, 32, 40], [20, 32, 40]]) +TESTS = [] +for p in TEST_NDARRAYS: + im = SEG1_2D + weight = np.zeros_like(im) + weight[0, 30, 17] = 1.1 + weight[0, 40, 31] = 1 + weight[0, 80, 21] = 1 + TESTS.append( + [ + "small roi 2d", + dict(spatial_size=(10, 12), num_samples=3), + p(im), + p(weight), + (1, 10, 12), + [[80, 21], [30, 17], [40, 31]], + ] + ) + im = IMT_2D + TESTS.append( + [ + "default roi 2d", + dict(spatial_size=(10, -1), num_samples=3), + p(im), + p(weight), + (1, 10, 64), + [[14, 32], [105, 32], [20, 32]], + ] + ) + im = SEGN_2D + weight = np.zeros_like(im) + weight[0, 30, 17] = 1.1 + weight[0, 10, 1] = 1 + TESTS.append( + [ + "large roi 2d", + dict(spatial_size=(10000, 400), num_samples=3), + p(im), + p(weight), + (1, 128, 64), + [[64, 32], [64, 32], [64, 32]], + ] + ) + im = IMT_2D + weight = np.zeros_like(im) + weight[0, 30, 17] = np.inf + weight[0, 10, 1] = -np.inf + weight[0, 10, 20] = -np.nan + TESTS.append( + [ + "bad w 2d", + dict(spatial_size=(20, 40), num_samples=3), + p(im), + p(weight), + (1, 20, 40), + [[63, 37], [31, 43], [66, 20]], + ] + ) + im = SEG1_3D + weight = np.zeros_like(im) + weight[0, 5, 30, 17] = 1.1 + weight[0, 8, 40, 31] = 1 + weight[0, 11, 23, 21] = 1 + TESTS.append( + [ + "small roi 3d", + dict(spatial_size=(8, 10, 12), num_samples=3), + p(im), + p(weight), + (1, 8, 10, 12), + [[11, 23, 21], [5, 30, 17], [8, 40, 31]], + ] + ) + im = IMT_3D + weight = np.zeros_like(im) + weight[0, 7, 17] = 1.1 + weight[0, 13, 31] = 1.1 + weight[0, 24, 21] = 1 + TESTS.append( + [ + "default roi 3d", + dict(spatial_size=(10, -1, -1), num_samples=3), + p(im), + p(weight), + (1, 10, 64, 80), + [[14, 32, 40], [41, 32, 40], [20, 32, 40]], + ] + ) + im = SEGN_3D + weight = np.zeros_like(im) + weight[0, 30, 17, 20] = 1.1 + weight[0, 10, 1, 17] = 1 + TESTS.append( + [ + "large roi 3d", + dict(spatial_size=(10000, 400, 80), num_samples=3), + p(im), + p(weight), + (1, 48, 64, 80), + [[24, 32, 40], [24, 32, 40], [24, 32, 40]], + ] + ) + im = IMT_3D + weight = np.zeros_like(im) + weight[0, 30, 17] = np.inf + weight[0, 10, 1] = -np.inf + weight[0, 10, 20] = -np.nan + TESTS.append( + [ + "bad w 3d", + dict(spatial_size=(48, 64, 80), num_samples=3), + p(im), + p(weight), + (1, 48, 64, 80), + [[24, 32, 40], [24, 32, 40], [24, 32, 40]], + ] + ) - def test_rand_weighted_crop_large_roi(self): - img = self.segn[0] - n_samples = 3 - crop = RandWeightedCrop((10000, 400, 80), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17, 20] = 1.1 - weight[0, 10, 1, 17] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) - for res in result: - np.testing.assert_allclose(res, self.segn[0]) - def test_rand_weighted_crop_bad_w(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCrop((48, 64, 80), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = np.inf - weight[0, 10, 1] = -np.inf - weight[0, 10, 20] = -np.nan +class TestRandWeightedCrop(unittest.TestCase): + @parameterized.expand(TESTS) + def test_rand_weighted_crop(self, _, input_params, img, weight, expected_shape, expected_vals): + crop = RandWeightedCrop(**input_params) crop.set_random_state(10) result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) + self.assertTrue(len(result) == input_params["num_samples"]) + np.testing.assert_allclose(result[0].shape, expected_shape) + for c, e in zip(crop.centers, expected_vals): + if isinstance(c, torch.Tensor): + c = c.cpu() + np.testing.assert_allclose(c, e) + # if desired ROI is larger than image, check image is unchanged + if all(s >= i for i, s in zip(img.shape[1:], input_params["spatial_size"])): + for res in result: + self.assertEqual(type(img), type(res)) + if isinstance(img, torch.Tensor): + self.assertEqual(res.device, img.device) + res = res.cpu() + np.testing.assert_allclose(res, img if isinstance(img, np.ndarray) else img.cpu()) if __name__ == "__main__": diff --git a/tests/test_rand_weighted_cropd.py b/tests/test_rand_weighted_cropd.py index 367ce3beb9..e77ee19b63 100644 --- a/tests/test_rand_weighted_cropd.py +++ b/tests/test_rand_weighted_cropd.py @@ -14,148 +14,168 @@ import numpy as np from monai.transforms.croppad.dictionary import RandWeightedCropd -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D class TestRandWeightedCrop(NumpyImageTestCase2D): def test_rand_weighted_crop_small_roi(self): - img = self.seg1[0] - n_samples = 3 - crop = RandWeightedCropd("img", "w", (10, 12), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 40, 31] = 1 - weight[0, 80, 21] = 1 - crop.set_random_state(10) - d = {"img": img, "w": weight} - result = crop(d) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 12)) - np.testing.assert_allclose(np.asarray(crop.centers), [[80, 21], [30, 17], [40, 31]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.seg1[0] + n_samples = 3 + crop = RandWeightedCropd("img", "w", (10, 12), n_samples) + weight = np.zeros_like(img) + weight[0, 30, 17] = 1.1 + weight[0, 40, 31] = 1 + weight[0, 80, 21] = 1 + crop.set_random_state(10) + d = {"img": p(img), "w": q(weight)} + result = crop(d) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 12)) + np.testing.assert_allclose(np.asarray(crop.centers), [[80, 21], [30, 17], [40, 31]]) def test_rand_weighted_crop_default_roi(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd("im", "weight", (10, -1), n_samples, "coords") - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 40, 31] = 1 - weight[0, 80, 21] = 1 - crop.set_random_state(10) - data = {"im": img, "weight": weight, "others": np.nan} - result = crop(data) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["im"].shape, (1, 10, 64)) - np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32], [105, 32], [20, 32]]) - np.testing.assert_allclose(result[1]["coords"], [105, 32]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd("im", "weight", (10, -1), n_samples, "coords") + weight = np.zeros_like(img) + weight[0, 30, 17] = 1.1 + weight[0, 40, 31] = 1 + weight[0, 80, 21] = 1 + crop.set_random_state(10) + data = {"im": p(img), "weight": q(weight), "others": np.nan} + result = crop(data) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["im"].shape, (1, 10, 64)) + np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32], [105, 32], [20, 32]]) + np.testing.assert_allclose(result[1]["coords"], [105, 32]) def test_rand_weighted_crop_large_roi(self): - img = self.segn[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "weight", (10000, 400), n_samples, "location") - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 10, 1] = 1 - crop.set_random_state(10) - data = {"img": img, "seg": self.imt[0], "weight": weight} - result = crop(data) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 128, 64)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 128, 64)) - np.testing.assert_allclose(np.asarray(crop.centers), [[64, 32], [64, 32], [64, 32]]) - np.testing.assert_allclose(result[1]["location"], [64, 32]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.segn[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "weight", (10000, 400), n_samples, "location") + weight = np.zeros_like(img) + weight[0, 30, 17] = 1.1 + weight[0, 10, 1] = 1 + crop.set_random_state(10) + data = {"img": p(img), "seg": p(self.imt[0]), "weight": q(weight)} + result = crop(data) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 128, 64)) + np.testing.assert_allclose(result[0]["seg"].shape, (1, 128, 64)) + np.testing.assert_allclose(np.asarray(crop.centers), [[64, 32], [64, 32], [64, 32]]) + np.testing.assert_allclose(result[1]["location"], [64, 32]) def test_rand_weighted_crop_bad_w(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (20, 40), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = np.inf - weight[0, 10, 1] = -np.inf - weight[0, 10, 20] = -np.nan - crop.set_random_state(10) - result = crop({"img": img, "seg": self.segn[0], "w": weight}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 20, 40)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 20, 40)) - np.testing.assert_allclose(np.asarray(crop.centers), [[63, 37], [31, 43], [66, 20]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "w", (20, 40), n_samples) + weight = np.zeros_like(img) + weight[0, 30, 17] = np.inf + weight[0, 10, 1] = -np.inf + weight[0, 10, 20] = -np.nan + crop.set_random_state(10) + result = crop({"img": p(img), "seg": p(self.segn[0]), "w": q(weight)}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 20, 40)) + np.testing.assert_allclose(result[0]["seg"].shape, (1, 20, 40)) + np.testing.assert_allclose(np.asarray(crop.centers), [[63, 37], [31, 43], [66, 20]]) class TestRandWeightedCrop3D(NumpyImageTestCase3D): def test_rand_weighted_crop_small_roi(self): - img = self.seg1[0] - n_samples = 3 - crop = RandWeightedCropd("img", "w", (8, 10, 12), n_samples) - weight = np.zeros_like(img) - weight[0, 5, 30, 17] = 1.1 - weight[0, 8, 40, 31] = 1 - weight[0, 11, 23, 21] = 1 - crop.set_random_state(10) - result = crop({"img": img, "w": weight}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 8, 10, 12)) - np.testing.assert_allclose(np.asarray(crop.centers), [[11, 23, 21], [5, 30, 17], [8, 40, 31]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.seg1[0] + n_samples = 3 + crop = RandWeightedCropd("img", "w", (8, 10, 12), n_samples) + weight = np.zeros_like(img) + weight[0, 5, 30, 17] = 1.1 + weight[0, 8, 40, 31] = 1 + weight[0, 11, 23, 21] = 1 + crop.set_random_state(10) + result = crop({"img": p(img), "w": q(weight)}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 8, 10, 12)) + np.testing.assert_allclose(np.asarray(crop.centers), [[11, 23, 21], [5, 30, 17], [8, 40, 31]]) def test_rand_weighted_crop_default_roi(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) - weight = np.zeros_like(img) - weight[0, 7, 17] = 1.1 - weight[0, 13, 31] = 1.1 - weight[0, 24, 21] = 1 - crop.set_random_state(10) - result = crop({"img": img, "seg": self.segn[0], "w": weight}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32, 40], [41, 32, 40], [20, 32, 40]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) + weight = np.zeros_like(img) + weight[0, 7, 17] = 1.1 + weight[0, 13, 31] = 1.1 + weight[0, 24, 21] = 1 + crop.set_random_state(10) + result = crop({"img": p(img), "seg": p(self.segn[0]), "w": q(weight)}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 64, 80)) + np.testing.assert_allclose(result[0]["seg"].shape, (1, 10, 64, 80)) + np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32, 40], [41, 32, 40], [20, 32, 40]]) def test_rand_weighted_crop_large_roi(self): - img = self.segn[0] - n_samples = 3 - crop = RandWeightedCropd("img", "w", (10000, 400, 80), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17, 20] = 1.1 - weight[0, 10, 1, 17] = 1 - crop.set_random_state(10) - result = crop({"img": img, "w": weight}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.segn[0] + n_samples = 3 + crop = RandWeightedCropd("img", "w", (10000, 400, 80), n_samples) + weight = np.zeros_like(img) + weight[0, 30, 17, 20] = 1.1 + weight[0, 10, 1, 17] = 1 + crop.set_random_state(10) + result = crop({"img": p(img), "w": q(weight)}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) + np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) def test_rand_weighted_crop_bad_w(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (48, 64, 80), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = np.inf - weight[0, 10, 1] = -np.inf - weight[0, 10, 20] = -np.nan - crop.set_random_state(10) - result = crop({"img": img, "seg": self.segn[0], "w": weight}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "w", (48, 64, 80), n_samples) + weight = np.zeros_like(img) + weight[0, 30, 17] = np.inf + weight[0, 10, 1] = -np.inf + weight[0, 10, 20] = -np.nan + crop.set_random_state(10) + result = crop({"img": p(img), "seg": p(self.segn[0]), "w": q(weight)}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) + np.testing.assert_allclose(result[0]["seg"].shape, (1, 48, 64, 80)) + np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) def test_rand_weighted_crop_patch_index(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) - weight = np.zeros_like(img) - weight[0, 7, 17] = 1.1 - weight[0, 13, 31] = 1.1 - weight[0, 24, 21] = 1 - crop.set_random_state(10) - result = crop({"img": img, "seg": self.segn[0], "w": weight, "img_meta_dict": {"affine": None}}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32, 40], [41, 32, 40], [20, 32, 40]]) - for i in range(n_samples): - np.testing.assert_allclose(result[i]["img"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(result[i]["seg"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(result[i]["img_meta_dict"]["patch_index"], i) - np.testing.assert_allclose(result[i]["seg_meta_dict"]["patch_index"], i) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) + weight = np.zeros_like(img) + weight[0, 7, 17] = 1.1 + weight[0, 13, 31] = 1.1 + weight[0, 24, 21] = 1 + crop.set_random_state(10) + result = crop( + {"img": p(img), "seg": p(self.segn[0]), "w": q(weight), "img_meta_dict": {"affine": None}} + ) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32, 40], [41, 32, 40], [20, 32, 40]]) + for i in range(n_samples): + np.testing.assert_allclose(result[i]["img"].shape, (1, 10, 64, 80)) + np.testing.assert_allclose(result[i]["seg"].shape, (1, 10, 64, 80)) + np.testing.assert_allclose(result[i]["img_meta_dict"]["patch_index"], i) + np.testing.assert_allclose(result[i]["seg_meta_dict"]["patch_index"], i) if __name__ == "__main__": diff --git a/tests/test_rand_zoom.py b/tests/test_rand_zoom.py index 35cf30bcb1..b0e280080a 100644 --- a/tests/test_rand_zoom.py +++ b/tests/test_rand_zoom.py @@ -12,12 +12,13 @@ import unittest import numpy as np +import torch from parameterized import parameterized from scipy.ndimage import zoom as zoom_scipy from monai.transforms import RandZoom from monai.utils import GridSampleMode, InterpolateMode -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D VALID_CASES = [(0.8, 1.2, "nearest", False), (0.8, 1.2, InterpolateMode.NEAREST, False)] @@ -25,29 +26,34 @@ class TestRandZoom(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) def test_correct_results(self, min_zoom, max_zoom, mode, keep_size): - random_zoom = RandZoom( - prob=1.0, - min_zoom=min_zoom, - max_zoom=max_zoom, - mode=mode, - keep_size=keep_size, - ) - random_zoom.set_random_state(1234) - zoomed = random_zoom(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False)) - expected = np.stack(expected).astype(np.float32) - np.testing.assert_allclose(zoomed, expected, atol=1.0) + for p in TEST_NDARRAYS: + random_zoom = RandZoom( + prob=1.0, + min_zoom=min_zoom, + max_zoom=max_zoom, + mode=mode, + keep_size=keep_size, + ) + random_zoom.set_random_state(1234) + zoomed = random_zoom(p(self.imt[0])) + if isinstance(zoomed, torch.Tensor): + zoomed = zoomed.cpu().numpy() + expected = [] + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False)) + expected = np.stack(expected).astype(np.float32) + np.testing.assert_allclose(zoomed, expected, atol=1.0) def test_keep_size(self): - random_zoom = RandZoom(prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True) - zoomed = random_zoom(self.imt[0]) - self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) - zoomed = random_zoom(self.imt[0]) - self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) - zoomed = random_zoom(self.imt[0]) - self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) + for p in TEST_NDARRAYS: + im = p(self.imt[0]) + random_zoom = RandZoom(prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True) + zoomed = random_zoom(im) + self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) + zoomed = random_zoom(im) + self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) + zoomed = random_zoom(im) + self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) @parameterized.expand( [ @@ -57,23 +63,25 @@ def test_keep_size(self): ] ) def test_invalid_inputs(self, _, min_zoom, max_zoom, mode, raises): - with self.assertRaises(raises): - random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode) - random_zoom(self.imt[0]) + for p in TEST_NDARRAYS: + with self.assertRaises(raises): + random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode) + random_zoom(p(self.imt[0])) def test_auto_expand_3d(self): - random_zoom = RandZoom( - prob=1.0, - min_zoom=[0.8, 0.7], - max_zoom=[1.2, 1.3], - mode="nearest", - keep_size=False, - ) - random_zoom.set_random_state(1234) - test_data = np.random.randint(0, 2, size=[2, 2, 3, 4]) - zoomed = random_zoom(test_data) - np.testing.assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2) - np.testing.assert_allclose(zoomed.shape, (2, 2, 3, 3)) + for p in TEST_NDARRAYS: + random_zoom = RandZoom( + prob=1.0, + min_zoom=[0.8, 0.7], + max_zoom=[1.2, 1.3], + mode="nearest", + keep_size=False, + ) + random_zoom.set_random_state(1234) + test_data = p(np.random.randint(0, 2, size=[2, 2, 3, 4])) + zoomed = random_zoom(test_data) + np.testing.assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2) + np.testing.assert_allclose(zoomed.shape, (2, 2, 3, 3)) if __name__ == "__main__": diff --git a/tests/test_random_bias_field.py b/tests/test_random_bias_field.py index 5aeeb79874..8df21a3a40 100644 --- a/tests/test_random_bias_field.py +++ b/tests/test_random_bias_field.py @@ -12,9 +12,11 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandBiasField +from tests.utils import TEST_NDARRAYS TEST_CASES_2D = [{}, (3, 32, 32)] TEST_CASES_3D = [{}, (3, 32, 32, 32)] @@ -25,36 +27,60 @@ class TestRandBiasField(unittest.TestCase): @parameterized.expand([TEST_CASES_2D, TEST_CASES_3D]) def test_output_shape(self, class_args, img_shape): - for degree in [1, 2, 3]: - bias_field = RandBiasField(degree=degree, **class_args) - img = np.random.rand(*img_shape) - output = bias_field(img) - np.testing.assert_equal(output.shape, img_shape) - np.testing.assert_equal(output.dtype, bias_field.dtype) + for p in TEST_NDARRAYS: + for degree in [1, 2, 3]: + bias_field = RandBiasField(degree=degree, **class_args) + img = p(np.random.rand(*img_shape)) + output = bias_field(img) + + self.assertEqual(type(img), type(output)) + if isinstance(output, torch.Tensor): + self.assertEqual(output.device, img.device) + output = output.cpu().numpy() - img_zero = np.zeros([*img_shape]) - output_zero = bias_field(img_zero) - np.testing.assert_equal(output_zero, img_zero) + np.testing.assert_equal(output.shape, img_shape) + np.testing.assert_equal(output.dtype, bias_field.dtype) + + img_zero = np.zeros([*img_shape]) + output_zero = bias_field(img_zero) + np.testing.assert_equal(output_zero, img_zero) @parameterized.expand([TEST_CASES_2D_ZERO_RANGE]) def test_zero_range(self, class_args, img_shape): - bias_field = RandBiasField(**class_args) - img = np.ones(img_shape) - output = bias_field(img) - np.testing.assert_allclose(output, np.ones(img_shape), rtol=1e-3) + for p in TEST_NDARRAYS: + bias_field = RandBiasField(**class_args) + img = p(np.ones(img_shape)) + output = bias_field(img) + + self.assertEqual(type(img), type(output)) + if isinstance(output, torch.Tensor): + self.assertEqual(output.device, img.device) + output = output.cpu().numpy() + np.testing.assert_allclose(output, np.ones(img_shape), rtol=1e-3) @parameterized.expand([TEST_CASES_2D_ONES]) def test_one_range_input(self, class_args, expected): - bias_field = RandBiasField(**class_args) - img = np.ones([1, 2, 2]) - output = bias_field(img) - np.testing.assert_allclose(output, expected.astype(bias_field.dtype), rtol=1e-3) + for p in TEST_NDARRAYS: + bias_field = RandBiasField(**class_args) + img = p(np.ones([1, 2, 2])) + output = bias_field(img) + self.assertEqual(type(img), type(output)) + if isinstance(output, torch.Tensor): + self.assertEqual(output.device, img.device) + output = output.cpu().numpy() + np.testing.assert_allclose(output, expected.astype(bias_field.dtype), rtol=1e-3) def test_zero_prob(self): - bias_field = RandBiasField(prob=0.0) - img = np.random.rand(3, 32, 32) - output = bias_field(img) - np.testing.assert_equal(output, img) + for p in TEST_NDARRAYS: + bias_field = RandBiasField(prob=0.0) + img = p(np.random.rand(3, 32, 32)) + output = bias_field(img) + self.assertEqual(type(img), type(output)) + if isinstance(output, torch.Tensor): + self.assertEqual(output.device, img.device) + output = output.cpu().numpy() + img = img.cpu().numpy() + np.testing.assert_equal(output, img) if __name__ == "__main__": diff --git a/tests/test_random_bias_fieldd.py b/tests/test_random_bias_fieldd.py index aa2e206de9..f0df49231f 100644 --- a/tests/test_random_bias_fieldd.py +++ b/tests/test_random_bias_fieldd.py @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandBiasFieldd @@ -32,6 +33,12 @@ def test_output_shape(self, class_args, img_shape): bias_field = RandBiasFieldd(keys=[key], **class_args) img = np.random.rand(*img_shape) output = bias_field({key: img}) + + self.assertEqual(type(img), type(output[key])) + if isinstance(output[key], torch.Tensor): + self.assertEqual(output[key].device, img.device) + output = output[key].cpu().numpy() + np.testing.assert_equal(output[key].shape, img_shape) np.testing.assert_equal(output[key].dtype, bias_field.rand_bias_field.dtype) @@ -41,7 +48,13 @@ def test_zero_range(self, class_args, img_shape): bias_field = RandBiasFieldd(keys=[key], **class_args) img = np.ones(img_shape) output = bias_field({key: img}) - np.testing.assert_allclose(output[key], np.ones(img_shape)) + + self.assertEqual(type(img), type(output[key])) + if isinstance(output[key], torch.Tensor): + self.assertEqual(output[key].device, img.device) + output = output[key].cpu().numpy() + + np.testing.assert_allclose(output[key], np.ones(img_shape), rtol=1e-3) @parameterized.expand([TEST_CASES_2D_ONES]) def test_one_range_input(self, class_args, expected): @@ -49,13 +62,25 @@ def test_one_range_input(self, class_args, expected): bias_field = RandBiasFieldd(keys=[key], **class_args) img = np.ones([1, 2, 2]) output = bias_field({key: img}) - np.testing.assert_allclose(output[key], expected.astype(bias_field.rand_bias_field.dtype), rtol=1e-3) + + self.assertEqual(type(img), type(output[key])) + if isinstance(output[key], torch.Tensor): + self.assertEqual(output[key].device, img.device) + output = output[key].cpu().numpy() + + np.testing.assert_equal(output[key], expected.astype(bias_field.rand_bias_field.dtype)) def test_zero_prob(self): key = "img" bias_field = RandBiasFieldd(keys=[key], prob=0.0) img = np.random.rand(3, 32, 32) output = bias_field({key: img}) + + self.assertEqual(type(img), type(output[key])) + if isinstance(output[key], torch.Tensor): + self.assertEqual(output[key].device, img.device) + output = output[key].cpu().numpy() + np.testing.assert_equal(output[key], img) diff --git a/tests/test_remove_repeated_channel.py b/tests/test_remove_repeated_channel.py index 070e0e2b8d..ebbe6c730c 100644 --- a/tests/test_remove_repeated_channel.py +++ b/tests/test_remove_repeated_channel.py @@ -12,15 +12,18 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RemoveRepeatedChannel -TEST_CASE_1 = [{"repeats": 2}, np.array([[1, 2], [1, 2], [3, 4], [3, 4]]), (2, 2)] +TEST_CASES = [] +for q in (torch.Tensor, np.array): + TEST_CASES.append([{"repeats": 2}, q([[1, 2], [1, 2], [3, 4], [3, 4]]), (2, 2)]) # type: ignore class TestRemoveRepeatedChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_shape): result = RemoveRepeatedChannel(**input_param)(input_data) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_repeat_channel.py b/tests/test_repeat_channel.py index 643ebc64de..d19abcc8b4 100644 --- a/tests/test_repeat_channel.py +++ b/tests/test_repeat_channel.py @@ -12,15 +12,18 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RepeatChannel -TEST_CASE_1 = [{"repeats": 3}, np.array([[[0, 1], [1, 2]]]), (3, 2, 2)] +TEST_CASES = [] +for q in (torch.Tensor, np.array): + TEST_CASES.append([{"repeats": 3}, q([[[0, 1], [1, 2]]]), (3, 2, 2)]) # type: ignore class TestRepeatChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_shape): result = RepeatChannel(**input_param)(input_data) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_resampler.py b/tests/test_resampler.py index 2be94acebd..7eba61def6 100644 --- a/tests/test_resampler.py +++ b/tests/test_resampler.py @@ -17,69 +17,128 @@ from monai.transforms import Resample from monai.transforms.utils import create_grid +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"grid": create_grid((2, 2)), "img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 1.0], [2.0, 3.0]]]), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4)), "img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(padding_mode="border", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4)), "img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3, 3.0], [2.0, 2.0, 3.0, 3.0]]]), - ], - [ - dict(padding_mode="reflection", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4)), "img": np.arange(4).reshape((1, 2, 2)), "mode": "nearest"}, - np.array([[[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]]]), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4, 4)), "img": np.arange(8).reshape((1, 2, 2, 2)), "mode": "bilinear"}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 4.0, 5.0, 0.0], [0.0, 6.0, 7.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(padding_mode="zeros", device=None), + {"grid": p(create_grid((2, 2))), "img": q(np.arange(4).reshape((1, 2, 2)))}, + q(np.array([[[0.0, 1.0], [2.0, 3.0]]])), ] - ), - ], - [ - dict(padding_mode="border", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4, 4)), "img": np.arange(8).reshape((1, 2, 2, 2)), "mode": "bilinear"}, - np.array( + ) + TESTS.append( [ - [ - [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3.0, 3.0], [2.0, 2.0, 3.0, 3.0]], - [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3.0, 3.0], [2.0, 2.0, 3.0, 3.0]], - [[4.0, 4.0, 5.0, 5.0], [4.0, 4.0, 5.0, 5.0], [6.0, 6.0, 7.0, 7.0], [6.0, 6.0, 7.0, 7.0]], - [[4.0, 4.0, 5.0, 5.0], [4.0, 4.0, 5.0, 5.0], [6.0, 6.0, 7.0, 7.0], [6.0, 6.0, 7.0, 7.0]], - ] + dict(padding_mode="zeros", device=None), + {"grid": p(create_grid((4, 4))), "img": q(np.arange(4).reshape((1, 2, 2)))}, + q(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), ] - ), - ], -] + ) + TESTS.append( + [ + dict(padding_mode="border", device=None), + {"grid": p(create_grid((4, 4))), "img": q(np.arange(4).reshape((1, 2, 2)))}, + q(np.array([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3, 3.0], [2.0, 2.0, 3.0, 3.0]]])), + ] + ) + TESTS.append( + [ + dict(padding_mode="reflection", device=None), + {"grid": p(create_grid((4, 4))), "img": q(np.arange(4).reshape((1, 2, 2))), "mode": "nearest"}, + q(np.array([[[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=None), + {"grid": p(create_grid((4, 4, 4))), "img": q(np.arange(8).reshape((1, 2, 2, 2))), "mode": "bilinear"}, + q( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 2.0, 3.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 4.0, 5.0, 0.0], + [0.0, 6.0, 7.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + dict(padding_mode="border", device=None), + {"grid": p(create_grid((4, 4, 4))), "img": q(np.arange(8).reshape((1, 2, 2, 2))), "mode": "bilinear"}, + q( + np.array( + [ + [ + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [2.0, 2.0, 3.0, 3.0], + [2.0, 2.0, 3.0, 3.0], + ], + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [2.0, 2.0, 3.0, 3.0], + [2.0, 2.0, 3.0, 3.0], + ], + [ + [4.0, 4.0, 5.0, 5.0], + [4.0, 4.0, 5.0, 5.0], + [6.0, 6.0, 7.0, 7.0], + [6.0, 6.0, 7.0, 7.0], + ], + [ + [4.0, 4.0, 5.0, 5.0], + [4.0, 4.0, 5.0, 5.0], + [6.0, 6.0, 7.0, 7.0], + [6.0, 6.0, 7.0, 7.0], + ], + ] + ] + ) + ), + ] + ) class TestResample(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_resample(self, input_param, input_data, expected_val): g = Resample(**input_param) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) + self.assertEqual(type(result), type(expected_val)) if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + self.assertEqual(result.device, expected_val.device) + result = result.cpu() + expected_val = expected_val.cpu() + np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_resize.py b/tests/test_resize.py index 2f54dcc04f..7b1488b0e8 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -10,19 +10,27 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import skimage.transform +import torch from parameterized import parameterized from monai.transforms import Resize -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D -TEST_CASE_0 = [{"spatial_size": 15}, (6, 11, 15)] +TESTS: List[Tuple] = [] +TEST_LONGEST: List[Tuple] = [] +for p in TEST_NDARRAYS: + TESTS.append((p, (32, -1), "area")) + TESTS.append((p, (32, 32), "area")) + TESTS.append((p, (32, 32, 32), "trilinear")) + TESTS.append((p, (256, 256), "bilinear")) -TEST_CASE_1 = [{"spatial_size": 15, "mode": "area"}, (6, 11, 15)] - -TEST_CASE_2 = [{"spatial_size": 6, "mode": "trilinear", "align_corners": True}, (3, 5, 6)] + TEST_LONGEST.append((p, {"spatial_size": 15}, (6, 11, 15))) + TEST_LONGEST.append((p, {"spatial_size": 15, "mode": "area"}, (6, 11, 15))) + TEST_LONGEST.append((p, {"spatial_size": 6, "mode": "trilinear", "align_corners": True}, (3, 5, 6))) class TestResize(NumpyImageTestCase2D): @@ -35,10 +43,8 @@ def test_invalid_inputs(self): resize = Resize(spatial_size=(128,), mode="order") resize(self.imt[0]) - @parameterized.expand( - [((32, -1), "area"), ((32, 32), "area"), ((32, 32, 32), "trilinear"), ((256, 256), "bilinear")] - ) - def test_correct_results(self, spatial_size, mode): + @parameterized.expand(TESTS) + def test_correct_results(self, in_type, spatial_size, mode): resize = Resize(spatial_size, mode=mode) _order = 0 if mode.endswith("linear"): @@ -53,12 +59,17 @@ def test_correct_results(self, spatial_size, mode): ) ) expected = np.stack(expected).astype(np.float32) - out = resize(self.imt[0]) + im = in_type(self.imt[0]) + out = resize(im) + self.assertEqual(type(im), type(out)) + if isinstance(out, torch.Tensor): + self.assertEqual(im.device, out.device) + out = out.cpu() np.testing.assert_allclose(out, expected, atol=0.9) - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) - def test_longest_shape(self, input_param, expected_shape): - input_data = np.random.randint(0, 2, size=[3, 4, 7, 10]) + @parameterized.expand(TEST_LONGEST) + def test_longest_shape(self, im_type, input_param, expected_shape): + input_data = im_type(np.random.randint(0, 2, size=[3, 4, 7, 10])) input_param["size_mode"] = "longest" result = Resize(**input_param)(input_data) np.testing.assert_allclose(result.shape[1:], expected_shape) diff --git a/tests/test_resize_with_pad_or_crop.py b/tests/test_resize_with_pad_or_crop.py index 46f1fc86cc..94a62292fe 100644 --- a/tests/test_resize_with_pad_or_crop.py +++ b/tests/test_resize_with_pad_or_crop.py @@ -12,9 +12,11 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import ResizeWithPadOrCrop +from tests.utils import TEST_NDARRAYS TEST_CASES = [ [ @@ -48,11 +50,22 @@ class TestResizeWithPadOrCrop(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_pad_shape(self, input_param, input_shape, expected_shape): - paddcroper = ResizeWithPadOrCrop(**input_param) - result = paddcroper(np.zeros(input_shape)) - np.testing.assert_allclose(result.shape, expected_shape) - result = paddcroper(np.zeros(input_shape), mode="constant") - np.testing.assert_allclose(result.shape, expected_shape) + results1, results2 = [], [] + for p in TEST_NDARRAYS: + im = p(np.zeros(input_shape)) + paddcroper = ResizeWithPadOrCrop(**input_param) + result1 = paddcroper(im) + result2 = paddcroper(im, mode="constant") + for result, results in zip((result1, result2), (results1, results2)): + self.assertEqual(type(im), type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, im.device) + result = result.cpu().numpy() + np.testing.assert_allclose(result.shape, expected_shape) + results.append(result) + # check output from numpy torch and torch.cuda match + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) if __name__ == "__main__": diff --git a/tests/test_resize_with_pad_or_cropd.py b/tests/test_resize_with_pad_or_cropd.py index 32a62a9e16..851ec18df7 100644 --- a/tests/test_resize_with_pad_or_cropd.py +++ b/tests/test_resize_with_pad_or_cropd.py @@ -12,9 +12,11 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import ResizeWithPadOrCropd +from tests.utils import TEST_NDARRAYS TEST_CASES = [ [ @@ -47,10 +49,22 @@ class TestResizeWithPadOrCropd(unittest.TestCase): @parameterized.expand(TEST_CASES) - def test_pad_shape(self, input_param, input_data, expected_val): + def test_pad_shape(self, input_param, input_data, expected_shape): paddcroper = ResizeWithPadOrCropd(**input_param) - result = paddcroper(input_data) - np.testing.assert_allclose(result["img"].shape, expected_val) + results = [] + for p in TEST_NDARRAYS: + input_data_mod = {"img": p(input_data["img"])} + result = paddcroper(input_data_mod) + r, i = result["img"], input_data_mod["img"] + self.assertEqual(type(i), type(r)) + if isinstance(r, torch.Tensor): + self.assertEqual(r.device, i.device) + r = r.cpu().numpy() + np.testing.assert_allclose(r.shape, expected_shape) + results.append(r) + # check output from numpy torch and torch.cuda match + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) if __name__ == "__main__": diff --git a/tests/test_resized.py b/tests/test_resized.py index 6c4f31c9c8..c0bd9a49a7 100644 --- a/tests/test_resized.py +++ b/tests/test_resized.py @@ -10,24 +10,39 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import skimage.transform +import torch from parameterized import parameterized from monai.transforms import Resized -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D -TEST_CASE_0 = [{"keys": "img", "spatial_size": 15}, (6, 11, 15)] +TESTS: List[Tuple] = [] +TEST_LONGEST: List[Tuple] = [] +for p in TEST_NDARRAYS: + TESTS.append((p, (32, -1), "area")) + TESTS.append((p, (64, 64), "area")) + TESTS.append((p, (32, 32, 32), "area")) + TESTS.append((p, (256, 256), "bilinear")) -TEST_CASE_1 = [{"keys": "img", "spatial_size": 15, "mode": "area"}, (6, 11, 15)] - -TEST_CASE_2 = [{"keys": "img", "spatial_size": 6, "mode": "trilinear", "align_corners": True}, (3, 5, 6)] - -TEST_CASE_3 = [ - {"keys": ["img", "label"], "spatial_size": 6, "mode": ["trilinear", "nearest"], "align_corners": [True, None]}, - (3, 5, 6), -] + TEST_LONGEST.append((p, {"keys": "img", "spatial_size": 15}, (6, 11, 15))) + TEST_LONGEST.append((p, {"keys": "img", "spatial_size": 15, "mode": "area"}, (6, 11, 15))) + TEST_LONGEST.append((p, {"keys": "img", "spatial_size": 6, "mode": "trilinear", "align_corners": True}, (3, 5, 6))) + TEST_LONGEST.append( + ( + p, + { + "keys": ["img", "label"], + "spatial_size": 6, + "mode": ["trilinear", "nearest"], + "align_corners": [True, None], + }, + (3, 5, 6), + ) + ) class TestResized(NumpyImageTestCase2D): @@ -40,8 +55,8 @@ def test_invalid_inputs(self): resize = Resized(keys="img", spatial_size=(128,), mode="order") resize({"img": self.imt[0]}) - @parameterized.expand([((32, -1), "area"), ((64, 64), "area"), ((32, 32, 32), "area"), ((256, 256), "bilinear")]) - def test_correct_results(self, spatial_size, mode): + @parameterized.expand(TESTS) + def test_correct_results(self, in_type, spatial_size, mode): resize = Resized("img", spatial_size, mode=mode) _order = 0 if mode.endswith("linear"): @@ -56,14 +71,16 @@ def test_correct_results(self, spatial_size, mode): ) ) expected = np.stack(expected).astype(np.float32) - out = resize({"img": self.imt[0]})["img"] + out = resize({"img": in_type(self.imt[0])})["img"] + if isinstance(out, torch.Tensor): + out = out.cpu() np.testing.assert_allclose(out, expected, atol=0.9) - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_longest_shape(self, input_param, expected_shape): + @parameterized.expand(TEST_LONGEST) + def test_longest_shape(self, in_type, input_param, expected_shape): input_data = { - "img": np.random.randint(0, 2, size=[3, 4, 7, 10]), - "label": np.random.randint(0, 2, size=[3, 4, 7, 10]), + "img": in_type(np.random.randint(0, 2, size=[3, 4, 7, 10])), + "label": in_type(np.random.randint(0, 2, size=[3, 4, 7, 10])), } input_param["size_mode"] = "longest" rescaler = Resized(**input_param) diff --git a/tests/test_rotate.py b/tests/test_rotate.py index 436c952d4b..16a9c6d124 100644 --- a/tests/test_rotate.py +++ b/tests/test_rotate.py @@ -10,42 +10,44 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import Rotate -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D - -TEST_CASES_2D = [ - (np.pi / 6, False, "bilinear", "border", False), - (np.pi / 4, True, "bilinear", "border", False), - (-np.pi / 4.5, True, "nearest", "reflection", False), - (np.pi, False, "nearest", "zeros", False), - (-np.pi / 2, False, "bilinear", "zeros", True), -] - -TEST_CASES_3D = [ - (-np.pi / 2, True, "nearest", "border", False), - (np.pi / 4, True, "bilinear", "border", False), - (-np.pi / 4.5, True, "nearest", "reflection", False), - (np.pi, False, "nearest", "zeros", False), - (-np.pi / 2, False, "bilinear", "zeros", False), -] - -TEST_CASES_SHAPE_3D = [ - ([-np.pi / 2, 1.0, 2.0], "nearest", "border", False), - ([np.pi / 4, 0, 0], "bilinear", "border", False), - ([-np.pi / 4.5, -20, 20], "nearest", "reflection", False), -] +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D + +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, np.pi / 6, False, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, -np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", False)) + TEST_CASES_2D.append((p, -np.pi / 2, False, "bilinear", "zeros", True)) + +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append((p, -np.pi / 2, True, "nearest", "border", False)) + TEST_CASES_3D.append((p, np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_3D.append((p, -np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_3D.append((p, np.pi, False, "nearest", "zeros", False)) + TEST_CASES_3D.append((p, -np.pi / 2, False, "bilinear", "zeros", False)) + +TEST_CASES_SHAPE_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_SHAPE_3D.append((p, [-np.pi / 2, 1.0, 2.0], "nearest", "border", False)) + TEST_CASES_SHAPE_3D.append((p, [np.pi / 4, 0, 0], "bilinear", "border", False)) + TEST_CASES_SHAPE_3D.append((p, [-np.pi / 4.5, -20, 20], "nearest", "reflection", False)) class TestRotate2D(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES_2D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotate(angle, keep_size, mode, padding_mode, align_corners) - rotated = rotate_fn(self.imt[0]) + rotated = rotate_fn(im_type(self.imt[0])) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated.shape) _order = 0 if mode == "nearest" else 1 @@ -70,15 +72,16 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne ) ) expected = np.stack(expected).astype(np.float32) + rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated good = np.sum(np.isclose(expected, rotated, atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") class TestRotate3D(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotate([angle, 0, 0], keep_size, mode, padding_mode, align_corners) - rotated = rotate_fn(self.imt[0]) + rotated = rotate_fn(im_type(self.imt[0])) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated.shape) _order = 0 if mode == "nearest" else 1 @@ -103,23 +106,25 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne ) ) expected = np.stack(expected).astype(np.float32) + rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated n_good = np.sum(np.isclose(expected, rotated, atol=1e-3)) self.assertLessEqual(expected.size - n_good, 5, "diff at most 5 pixels") @parameterized.expand(TEST_CASES_SHAPE_3D) - def test_correct_shape(self, angle, mode, padding_mode, align_corners): + def test_correct_shape(self, im_type, angle, mode, padding_mode, align_corners): rotate_fn = Rotate(angle, True, align_corners=align_corners) - rotated = rotate_fn(self.imt[0], mode=mode, padding_mode=padding_mode) + rotated = rotate_fn(im_type(self.imt[0]), mode=mode, padding_mode=padding_mode) np.testing.assert_allclose(self.imt[0].shape, rotated.shape) def test_ill_case(self): - rotate_fn = Rotate(10, True) - with self.assertRaises(ValueError): # wrong shape - rotate_fn(self.imt) - - rotate_fn = Rotate(10, keep_size=False) - with self.assertRaises(ValueError): # wrong mode - rotate_fn(self.imt[0], mode="trilinear") + for p in TEST_NDARRAYS: + rotate_fn = Rotate(10, True) + with self.assertRaises(ValueError): # wrong shape + rotate_fn(p(self.imt)) + + rotate_fn = Rotate(10, keep_size=False) + with self.assertRaises(ValueError): # wrong mode + rotate_fn(p(self.imt[0]), mode="trilinear") if __name__ == "__main__": diff --git a/tests/test_rotate90.py b/tests/test_rotate90.py index 4ab39d5cf6..7780976d3c 100644 --- a/tests/test_rotate90.py +++ b/tests/test_rotate90.py @@ -12,47 +12,52 @@ import unittest import numpy as np +import torch from monai.transforms import Rotate90 -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D class TestRotate90(NumpyImageTestCase2D): def test_rotate90_default(self): rotate = Rotate90() - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = p(np.stack(expected)) + torch.testing.assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) def test_k(self): rotate = Rotate90(k=2) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 2, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 2, (0, 1))) + expected = p(np.stack(expected)) + torch.testing.assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) def test_spatial_axes(self): rotate = Rotate90(spatial_axes=(0, -1)) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, -1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, -1))) + expected = p(np.stack(expected)) + torch.testing.assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) def test_prob_k_spatial_axes(self): rotate = Rotate90(k=2, spatial_axes=(0, 1)) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 2, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 2, (0, 1))) + expected = p(np.stack(expected)) + torch.testing.assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) if __name__ == "__main__": diff --git a/tests/test_rotated.py b/tests/test_rotated.py index 2ea421101b..cd27dd5406 100644 --- a/tests/test_rotated.py +++ b/tests/test_rotated.py @@ -10,36 +10,38 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import Rotated -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D -TEST_CASES_2D = [ - (-np.pi / 6, False, "bilinear", "border", False), - (-np.pi / 4, True, "bilinear", "border", False), - (np.pi / 4.5, True, "nearest", "reflection", False), - (-np.pi, False, "nearest", "zeros", False), - (np.pi / 2, False, "bilinear", "zeros", True), -] +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, -np.pi / 6, False, "bilinear", "border", False)) + TEST_CASES_2D.append((p, -np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_2D.append((p, -np.pi, False, "nearest", "zeros", False)) + TEST_CASES_2D.append((p, np.pi / 2, False, "bilinear", "zeros", True)) -TEST_CASES_3D = [ - (-np.pi / 6, False, "bilinear", "border", False), - (-np.pi / 4, True, "bilinear", "border", False), - (np.pi / 4.5, True, "nearest", "reflection", False), - (-np.pi, False, "nearest", "zeros", False), - (np.pi / 2, False, "bilinear", "zeros", True), -] +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append((p, -np.pi / 6, False, "bilinear", "border", False)) + TEST_CASES_3D.append((p, -np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_3D.append((p, np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_3D.append((p, -np.pi, False, "nearest", "zeros", False)) + TEST_CASES_3D.append((p, np.pi / 2, False, "bilinear", "zeros", True)) class TestRotated2D(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES_2D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotated(("img", "seg"), angle, keep_size, (mode, "nearest"), padding_mode, align_corners) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 @@ -52,6 +54,8 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne expected = scipy.ndimage.rotate( self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") @@ -64,9 +68,9 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne class TestRotated3D(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotated(("img", "seg"), [0, angle, 0], keep_size, (mode, "nearest"), padding_mode, align_corners) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 @@ -79,6 +83,8 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne expected = scipy.ndimage.rotate( self.imt[0, 0], np.rad2deg(angle), (0, 2), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v good = np.sum(np.isclose(expected.astype(np.float32), rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 voxels.") @@ -91,9 +97,9 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne class TestRotated3DXY(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotated(("img", "seg"), [0, 0, angle], keep_size, (mode, "nearest"), padding_mode, align_corners) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 @@ -106,6 +112,8 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne expected = scipy.ndimage.rotate( self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 voxels") diff --git a/tests/test_save_image.py b/tests/test_save_image.py index f7c8e07f06..c9ee2585e2 100644 --- a/tests/test_save_image.py +++ b/tests/test_save_image.py @@ -17,24 +17,37 @@ from parameterized import parameterized from monai.transforms import SaveImage +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - torch.randint(0, 255, (1, 2, 3, 4)), - {"filename_or_obj": "testfile0.nii.gz"}, - ".nii.gz", - False, -] +_, has_pil = optional_import("PIL") +_, has_nib = optional_import("nibabel") -TEST_CASE_2 = [ - torch.randint(0, 255, (1, 2, 3, 4)), - None, - ".nii.gz", - False, -] +exts = [ext for has_lib, ext in zip((has_nib, has_pil), (".nii.gz", ".png")) if has_lib] + +TESTS = [] +for p in TEST_NDARRAYS: + for ext in exts: + TESTS.append( + [ + p(torch.randint(0, 255, (1, 2, 3, 4))), + {"filename_or_obj": "testfile0" + ext}, + ext, + False, + ] + ) + TESTS.append( + [ + p(torch.randint(0, 255, (1, 2, 3, 4))), + None, + ext, + False, + ] + ) class TestSaveImage(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS, skip_on_empty=True) def test_saved_content(self, test_data, meta_data, output_ext, resample): with tempfile.TemporaryDirectory() as tempdir: trans = SaveImage( diff --git a/tests/test_save_imaged.py b/tests/test_save_imaged.py index 35bbea9628..d512c81f9d 100644 --- a/tests/test_save_imaged.py +++ b/tests/test_save_imaged.py @@ -17,29 +17,42 @@ from parameterized import parameterized from monai.transforms import SaveImaged +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - { - "img": torch.randint(0, 255, (1, 2, 3, 4)), - "img_meta_dict": {"filename_or_obj": "testfile0.nii.gz"}, - }, - ".nii.gz", - False, -] - -TEST_CASE_2 = [ - { - "img": torch.randint(0, 255, (1, 2, 3, 4)), - "img_meta_dict": {"filename_or_obj": "testfile0.nii.gz"}, - "patch_index": 6, - }, - ".nii.gz", - False, -] +_, has_pil = optional_import("PIL") +_, has_nib = optional_import("nibabel") + +exts = [ext for has_lib, ext in zip((has_nib, has_pil), (".nii.gz", ".png")) if has_lib] + +TESTS = [] +for p in TEST_NDARRAYS: + for ext in exts: + TESTS.append( + [ + { + "img": p(torch.randint(0, 255, (1, 2, 3, 4))), + "img_meta_dict": {"filename_or_obj": "testfile0" + ext}, + }, + ext, + False, + ] + ) + TESTS.append( + [ + { + "img": p(torch.randint(0, 255, (1, 2, 3, 4))), + "img_meta_dict": {"filename_or_obj": "testfile0" + ext}, + "patch_index": 6, + }, + ext, + False, + ] + ) class TestSaveImaged(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS, skip_on_empty=True) def test_saved_content(self, test_data, output_ext, resample): with tempfile.TemporaryDirectory() as tempdir: trans = SaveImaged( diff --git a/tests/test_savitzky_golay_smooth.py b/tests/test_savitzky_golay_smooth.py index 63dcce1b05..45d0ea3e4d 100644 --- a/tests/test_savitzky_golay_smooth.py +++ b/tests/test_savitzky_golay_smooth.py @@ -12,9 +12,11 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import SavitzkyGolaySmooth +from tests.utils import TEST_NDARRAYS # Zero-padding trivial tests @@ -59,12 +61,18 @@ class TestSavitzkyGolaySmooth(unittest.TestCase): @parameterized.expand([TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH]) def test_value(self, arguments, image, expected_data, atol): - result = SavitzkyGolaySmooth(**arguments)(image) - np.testing.assert_allclose(result, expected_data, atol=atol) + for p in TEST_NDARRAYS: + result = SavitzkyGolaySmooth(**arguments)(p(image)) + torch.testing.assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-7, atol=atol) class TestSavitzkyGolaySmoothREP(unittest.TestCase): @parameterized.expand([TEST_CASE_SINGLE_VALUE_REP]) def test_value(self, arguments, image, expected_data, atol): - result = SavitzkyGolaySmooth(**arguments)(image) - np.testing.assert_allclose(result, expected_data, atol=atol) + for p in TEST_NDARRAYS: + result = SavitzkyGolaySmooth(**arguments)(p(image)) + torch.testing.assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-7, atol=atol) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_scale_intensity.py b/tests/test_scale_intensity.py index 61e89191fd..2c28cec73c 100644 --- a/tests/test_scale_intensity.py +++ b/tests/test_scale_intensity.py @@ -12,26 +12,29 @@ import unittest import numpy as np +import torch from monai.transforms import ScaleIntensity -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D class TestScaleIntensity(NumpyImageTestCase2D): def test_range_scale(self): - scaler = ScaleIntensity(minv=1.0, maxv=2.0) - result = scaler(self.imt) - mina = np.min(self.imt) - maxa = np.max(self.imt) - norm = (self.imt - mina) / (maxa - mina) - expected = (norm * (2.0 - 1.0)) + 1.0 - np.testing.assert_allclose(result, expected) + for p in TEST_NDARRAYS: + scaler = ScaleIntensity(minv=1.0, maxv=2.0) + result = scaler(p(self.imt)) + mina = self.imt.min() + maxa = self.imt.max() + norm = (self.imt - mina) / (maxa - mina) + expected = p((norm * (2.0 - 1.0)) + 1.0) + torch.testing.assert_allclose(result, expected, rtol=1e-7, atol=0) def test_factor_scale(self): - scaler = ScaleIntensity(minv=None, maxv=None, factor=0.1) - result = scaler(self.imt) - expected = (self.imt * (1 + 0.1)).astype(np.float32) - np.testing.assert_allclose(result, expected) + for p in TEST_NDARRAYS: + scaler = ScaleIntensity(minv=None, maxv=None, factor=0.1) + result = scaler(p(self.imt)) + expected = p((self.imt * (1 + 0.1)).astype(np.float32)) + torch.testing.assert_allclose(result, expected, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_seg_loss_integration.py b/tests/test_seg_loss_integration.py index d2f991f160..807597cdc1 100644 --- a/tests/test_seg_loss_integration.py +++ b/tests/test_seg_loss_integration.py @@ -18,7 +18,7 @@ from parameterized import parameterized from monai.losses import DiceLoss, FocalLoss, GeneralizedDiceLoss, TverskyLoss -from monai.networks import one_hot +from monai.networks.utils import one_hot_torch TEST_CASES = [ [DiceLoss, {"to_onehot_y": True, "squared_pred": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, {}], @@ -86,7 +86,7 @@ def test_convergence(self, loss_type, loss_args, forward_args): num_classes = 2 num_voxels = 3 * 4 * 4 - target_onehot = one_hot(target_seg, num_classes=num_classes) + target_onehot = one_hot_torch(target_seg, num_classes=num_classes) # define a one layer model class OnelayerNet(nn.Module): diff --git a/tests/test_smartcachedataset.py b/tests/test_smartcachedataset.py index e2675f4d8c..74d64719ef 100644 --- a/tests/test_smartcachedataset.py +++ b/tests/test_smartcachedataset.py @@ -17,6 +17,7 @@ import nibabel as nib import numpy as np +import torch from parameterized import parameterized from monai.data import DataLoader, SmartCacheDataset @@ -66,7 +67,7 @@ def test_shape(self, replace_rate, num_replace_workers, transform): for _ in range(3): dataset.update_cache() self.assertIsNotNone(dataset[15]) - if isinstance(dataset[15]["image"], np.ndarray): + if isinstance(dataset[15]["image"], (np.ndarray, torch.Tensor)): np.testing.assert_allclose(dataset[15]["image"], dataset[15]["label"]) else: self.assertIsInstance(dataset[15]["image"], str) diff --git a/tests/test_spacing.py b/tests/test_spacing.py index 6be6730c5a..5b7d054a7c 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -12,155 +12,204 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import Spacing from monai.utils import ensure_tuple, fall_back_tuple +from tests.test_dtype_convert import TEST_NDARRAYS -TEST_CASES = [ - [ - {"pixdim": (1.0, 1.5), "padding_mode": "zeros", "dtype": float}, - np.arange(4).reshape((1, 2, 2)) + 1.0, # data - {"affine": np.eye(4)}, - np.array([[[1.0, 1.0], [3.0, 2.0]]]), - ], - [ - {"pixdim": 1.0, "padding_mode": "zeros", "dtype": float}, - np.ones((1, 2, 1, 2)), # data - {"affine": np.eye(4)}, - np.array([[[[1.0, 1.0]], [[1.0, 1.0]]]]), - ], - [ - {"pixdim": (1.0, 1.0, 1.0), "padding_mode": "zeros", "dtype": float}, - np.ones((1, 2, 1, 2)), # data - {"affine": np.eye(4)}, - np.array([[[[1.0, 1.0]], [[1.0, 1.0]]]]), - ], - [ - {"pixdim": (1.0, 0.2, 1.5), "diagonal": False, "padding_mode": "zeros", "align_corners": True}, - np.ones((1, 2, 1, 2)), # data - {"affine": np.array([[2, 1, 0, 4], [-1, -3, 0, 5], [0, 0, 2.0, 5], [0, 0, 0, 1]])}, - np.array([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]]), - ], - [ - {"pixdim": (3.0, 1.0), "padding_mode": "zeros"}, - np.arange(24).reshape((2, 3, 4)), # data - {"affine": np.diag([-3.0, 0.2, 1.5, 1])}, - np.array([[[0, 0], [4, 0], [8, 0]], [[12, 0], [16, 0], [20, 0]]]), - ], - [ - {"pixdim": (3.0, 1.0), "padding_mode": "zeros"}, - np.arange(24).reshape((2, 3, 4)), # data - {}, - np.array([[[0, 1, 2, 3], [0, 0, 0, 0]], [[12, 13, 14, 15], [0, 0, 0, 0]]]), - ], - [ - {"pixdim": (1.0, 1.0)}, - np.arange(24).reshape((2, 3, 4)), # data - {}, - np.array( - [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] - ), - ], - [ - {"pixdim": (4.0, 5.0, 6.0)}, - np.arange(24).reshape((1, 2, 3, 4)), # data - {"affine": np.array([[-4, 0, 0, 4], [0, 5, 0, -5], [0, 0, 6, -6], [0, 0, 0, 1]])}, - np.arange(24).reshape((1, 2, 3, 4)), # data - ], - [ - {"pixdim": (4.0, 5.0, 6.0), "diagonal": True}, - np.arange(24).reshape((1, 2, 3, 4)), # data - {"affine": np.array([[-4, 0, 0, 4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]])}, - np.array( - [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] - ), - ], - [ - {"pixdim": (4.0, 5.0, 6.0), "padding_mode": "border", "diagonal": True}, - np.arange(24).reshape((1, 2, 3, 4)), # data - {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]])}, - np.array( - [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] - ), - ], - [ - {"pixdim": (4.0, 5.0, 6.0), "padding_mode": "border", "diagonal": True}, - np.arange(24).reshape((1, 2, 3, 4)), # data - {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "nearest"}, - np.array( - [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] - ), - ], - [ - {"pixdim": (1.9, 4.0), "padding_mode": "zeros", "diagonal": True}, - np.arange(24).reshape((1, 4, 6)), # data - {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "nearest"}, - np.array( - [ +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + p, + {"pixdim": (1.0, 1.5), "padding_mode": "zeros", "dtype": float}, + np.arange(4).reshape((1, 2, 2)) + 1.0, # data + {"affine": np.eye(4)}, + np.array([[[1.0, 1.0], [3.0, 2.0]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": 1.0, "padding_mode": "zeros", "dtype": float}, + np.ones((1, 2, 1, 2)), # data + {"affine": np.eye(4)}, + np.array([[[[1.0, 1.0]], [[1.0, 1.0]]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (1.0, 1.0, 1.0), "padding_mode": "zeros", "dtype": float}, + np.ones((1, 2, 1, 2)), # data + {"affine": np.eye(4)}, + np.array([[[[1.0, 1.0]], [[1.0, 1.0]]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (1.0, 0.2, 1.5), "diagonal": False, "padding_mode": "zeros", "align_corners": True}, + np.ones((1, 2, 1, 2)), # data + {"affine": np.array([[2, 1, 0, 4], [-1, -3, 0, 5], [0, 0, 2.0, 5], [0, 0, 0, 1]])}, + np.array([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (3.0, 1.0), "padding_mode": "zeros"}, + np.arange(24).reshape((2, 3, 4)), # data + {"affine": np.diag([-3.0, 0.2, 1.5, 1])}, + np.array([[[0, 0], [4, 0], [8, 0]], [[12, 0], [16, 0], [20, 0]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (3.0, 1.0), "padding_mode": "zeros"}, + np.arange(24).reshape((2, 3, 4)), # data + {}, + np.array([[[0, 1, 2, 3], [0, 0, 0, 0]], [[12, 13, 14, 15], [0, 0, 0, 0]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (1.0, 1.0)}, + np.arange(24).reshape((2, 3, 4)), # data + {}, + np.array( + [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (4.0, 5.0, 6.0)}, + np.arange(24).reshape((1, 2, 3, 4)), # data + {"affine": np.array([[-4, 0, 0, 4], [0, 5, 0, -5], [0, 0, 6, -6], [0, 0, 0, 1]])}, + np.arange(24).reshape((1, 2, 3, 4)), # data + ] + ) + TESTS.append( + [ + p, + {"pixdim": (4.0, 5.0, 6.0), "diagonal": True}, + np.arange(24).reshape((1, 2, 3, 4)), # data + {"affine": np.array([[-4, 0, 0, 4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]])}, + np.array( + [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (4.0, 5.0, 6.0), "padding_mode": "border", "diagonal": True}, + np.arange(24).reshape((1, 2, 3, 4)), # data + {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]])}, + np.array( + [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (4.0, 5.0, 6.0), "padding_mode": "border", "diagonal": True}, + np.arange(24).reshape((1, 2, 3, 4)), # data + {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "nearest"}, + np.array( + [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (1.9, 4.0), "padding_mode": "zeros", "diagonal": True}, + np.arange(24).reshape((1, 4, 6)), # data + {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "nearest"}, + np.array( [ - [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0], - [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0], - [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0], - [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0], - [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0], - [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0], - [0.0, 1.0, 2.0, 2.0, 3.0, 4.0, 5.0], + [ + [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0], + [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0], + [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0], + [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0], + [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0], + [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0], + [0.0, 1.0, 2.0, 2.0, 3.0, 4.0, 5.0], + ] ] - ] - ), - ], - [ - {"pixdim": (5.0, 3.0), "padding_mode": "border", "diagonal": True, "dtype": np.float32}, - np.arange(24).reshape((1, 4, 6)), # data - {"affine": np.array([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "bilinear"}, - np.array( - [ + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (5.0, 3.0), "padding_mode": "border", "diagonal": True, "dtype": np.float32}, + np.arange(24).reshape((1, 4, 6)), # data + {"affine": np.array([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "bilinear"}, + np.array( [ - [18.0, 18.6, 19.2, 19.8, 20.400002, 21.0, 21.6, 22.2, 22.8], - [10.5, 11.1, 11.700001, 12.299999, 12.900001, 13.5, 14.1, 14.700001, 15.3], - [3.0, 3.6000001, 4.2000003, 4.8, 5.4000006, 6.0, 6.6000004, 7.200001, 7.8], + [ + [18.0, 18.6, 19.2, 19.8, 20.400002, 21.0, 21.6, 22.2, 22.8], + [10.5, 11.1, 11.700001, 12.299999, 12.900001, 13.5, 14.1, 14.700001, 15.3], + [3.0, 3.6000001, 4.2000003, 4.8, 5.4000006, 6.0, 6.6000004, 7.200001, 7.8], + ] ] - ] - ), - ], - [ - {"pixdim": (5.0, 3.0), "padding_mode": "zeros", "diagonal": True, "dtype": np.float32}, - np.arange(24).reshape((1, 4, 6)), # data - {"affine": np.array([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "bilinear"}, - np.array( - [ + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (5.0, 3.0), "padding_mode": "zeros", "diagonal": True, "dtype": np.float32}, + np.arange(24).reshape((1, 4, 6)), # data + {"affine": np.array([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "bilinear"}, + np.array( [ - [18.0000, 18.6000, 19.2000, 19.8000, 20.4000, 21.0000, 21.6000, 22.2000, 22.8000], - [10.5000, 11.1000, 11.7000, 12.3000, 12.9000, 13.5000, 14.1000, 14.7000, 15.3000], - [3.0000, 3.6000, 4.2000, 4.8000, 5.4000, 6.0000, 6.6000, 7.2000, 7.8000], + [ + [18.0000, 18.6000, 19.2000, 19.8000, 20.4000, 21.0000, 21.6000, 22.2000, 22.8000], + [10.5000, 11.1000, 11.7000, 12.3000, 12.9000, 13.5000, 14.1000, 14.7000, 15.3000], + [3.0000, 3.6000, 4.2000, 4.8000, 5.4000, 6.0000, 6.6000, 7.2000, 7.8000], + ] ] - ] - ), - ], - [ - {"pixdim": [-1, -1, 0.5], "padding_mode": "zeros", "dtype": float}, - np.ones((1, 2, 1, 2)), # data - {"affine": np.eye(4)}, - np.array([[[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]]]), - ], -] + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": [-1, -1, 0.5], "padding_mode": "zeros", "dtype": float}, + np.ones((1, 2, 1, 2)), # data + {"affine": np.eye(4)}, + np.array([[[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]]]), + ] + ) class TestSpacingCase(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_spacing(self, init_param, img, data_param, expected_output): - res = Spacing(**init_param)(img, **data_param) - if not isinstance(res, tuple): - np.testing.assert_allclose(res, expected_output, atol=1e-6) - return - np.testing.assert_allclose(res[0], expected_output, atol=1e-6) - sr = len(res[0].shape) - 1 + @parameterized.expand(TESTS) + def test_spacing(self, in_type, init_param, img, data_param, expected_output): + _img = in_type(img) + output_data, _, new_affine = Spacing(**init_param)(_img, **data_param) + if isinstance(_img, torch.Tensor): + self.assertEqual(_img.device, output_data.device) + output_data = output_data.cpu() + + np.testing.assert_allclose(output_data, expected_output, atol=1e-3, rtol=1e-3) + sr = len(output_data.shape) - 1 if isinstance(init_param["pixdim"], float): init_param["pixdim"] = [init_param["pixdim"]] * sr init_pixdim = ensure_tuple(init_param["pixdim"]) init_pixdim = init_param["pixdim"][:sr] - norm = np.sqrt(np.sum(np.square(res[2]), axis=0))[:sr] + norm = np.sqrt(np.sum(np.square(new_affine), axis=0))[:sr] np.testing.assert_allclose(fall_back_tuple(init_pixdim, norm), norm) diff --git a/tests/test_spacingd.py b/tests/test_spacingd.py index 61a4a4c38b..fd1ee7fd54 100644 --- a/tests/test_spacingd.py +++ b/tests/test_spacingd.py @@ -10,82 +10,95 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np +import torch +from parameterized import parameterized from monai.transforms import Spacingd +from tests.utils import TEST_NDARRAYS - -class TestSpacingDCase(unittest.TestCase): - def test_spacingd_3d(self): - data = {"image": np.ones((2, 10, 15, 20)), "image_meta_dict": {"affine": np.eye(4)}} - spacing = Spacingd(keys="image", pixdim=(1, 2, 1.4)) - res = spacing(data) - self.assertEqual(("image", "image_meta_dict", "image_transforms"), tuple(sorted(res))) - np.testing.assert_allclose(res["image"].shape, (2, 10, 8, 15)) - np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag([1, 2, 1.4, 1.0])) - - def test_spacingd_2d(self): - data = {"image": np.ones((2, 10, 20)), "image_meta_dict": {"affine": np.eye(3)}} - spacing = Spacingd(keys="image", pixdim=(1, 2)) - res = spacing(data) - self.assertEqual(("image", "image_meta_dict", "image_transforms"), tuple(sorted(res))) - np.testing.assert_allclose(res["image"].shape, (2, 10, 10)) - np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 2, 1))) - - def test_spacingd_2d_no_metadata(self): - data = {"image": np.ones((2, 10, 20))} - spacing = Spacingd(keys="image", pixdim=(1, 2)) - res = spacing(data) - self.assertEqual(("image", "image_meta_dict", "image_transforms"), tuple(sorted(res))) - np.testing.assert_allclose(res["image"].shape, (2, 10, 10)) - np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 2, 1))) - - def test_interp_all(self): - data = { - "image": np.arange(20).reshape((2, 1, 10)), - "seg": np.ones((2, 1, 10)), - "image_meta_dict": {"affine": np.eye(4)}, - "seg_meta_dict": {"affine": np.eye(4)}, - } - spacing = Spacingd( - keys=("image", "seg"), - mode="nearest", - pixdim=( - 1, - 0.2, - ), +TESTS: List[Tuple] = [] +for p in TEST_NDARRAYS: + TESTS.append( + ( + "spacing 3d", + {"image": p(np.ones((2, 10, 15, 20))), "image_meta_dict": {"affine": p(np.eye(4))}}, + dict(keys="image", pixdim=(1, 2, 1.4)), + ("image", "image_meta_dict", "image_transforms"), + (2, 10, 8, 15), + np.diag([1, 2, 1.4, 1.0]), ) - res = spacing(data) - self.assertEqual( - ("image", "image_meta_dict", "image_transforms", "seg", "seg_meta_dict", "seg_transforms"), - tuple(sorted(res)), + ) + TESTS.append( + ( + "spacing 2d", + {"image": np.ones((2, 10, 20)), "image_meta_dict": {"affine": np.eye(3)}}, + dict(keys="image", pixdim=(1, 2)), + ("image", "image_meta_dict", "image_transforms"), + (2, 10, 10), + np.diag((1, 2, 1)), ) - np.testing.assert_allclose(res["image"].shape, (2, 1, 46)) - np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 0.2, 1, 1))) - - def test_interp_sep(self): - data = { - "image": np.ones((2, 1, 10)), - "seg": np.ones((2, 1, 10)), - "image_meta_dict": {"affine": np.eye(4)}, - "seg_meta_dict": {"affine": np.eye(4)}, - } - spacing = Spacingd( - keys=("image", "seg"), - mode=("bilinear", "nearest"), - pixdim=( - 1, - 0.2, + ) + TESTS.append( + ( + "spacing 2d no metadata", + {"image": np.ones((2, 10, 20))}, + dict(keys="image", pixdim=(1, 2)), + ("image", "image_meta_dict", "image_transforms"), + (2, 10, 10), + np.diag((1, 2, 1)), + ) + ) + TESTS.append( + ( + "interp all", + { + "image": np.arange(20).reshape((2, 1, 10)), + "seg": np.ones((2, 1, 10)), + "image_meta_dict": {"affine": np.eye(4)}, + "seg_meta_dict": {"affine": np.eye(4)}, + }, + dict( + keys=("image", "seg"), + mode="nearest", + pixdim=( + 1, + 0.2, + ), ), + ("image", "image_meta_dict", "image_transforms", "seg", "seg_meta_dict", "seg_transforms"), + (2, 1, 46), + np.diag((1, 0.2, 1, 1)), ) - res = spacing(data) - self.assertEqual( + ) + TESTS.append( + ( + "interp sep", + { + "image": np.ones((2, 1, 10)), + "seg": np.ones((2, 1, 10)), + "image_meta_dict": {"affine": np.eye(4)}, + "seg_meta_dict": {"affine": np.eye(4)}, + }, + dict(keys=("image", "seg"), mode=("bilinear", "nearest"), pixdim=(1, 0.2)), ("image", "image_meta_dict", "image_transforms", "seg", "seg_meta_dict", "seg_transforms"), - tuple(sorted(res)), + (2, 1, 46), + np.diag((1, 0.2, 1, 1)), ) - np.testing.assert_allclose(res["image"].shape, (2, 1, 46)) - np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 0.2, 1, 1))) + ) + + +class TestSpacingDCase(unittest.TestCase): + @parameterized.expand(TESTS) + def test_spacingd(self, _, data, kw_args, expected_keys, expected_shape, expected_affine): + res = Spacingd(**kw_args)(data) + if isinstance(data["image"], torch.Tensor): + self.assertEqual(data["image"].device, res["image"].device) + self.assertEqual(expected_keys, tuple(sorted(res))) + np.testing.assert_allclose(res["image"].shape, expected_shape) + np.testing.assert_allclose(res["image_meta_dict"]["affine"], expected_affine) if __name__ == "__main__": diff --git a/tests/test_spatial_crop.py b/tests/test_spatial_crop.py index c76915f0a3..50c5b74139 100644 --- a/tests/test_spatial_crop.py +++ b/tests/test_spatial_crop.py @@ -16,8 +16,9 @@ from parameterized import parameterized from monai.transforms import SpatialCrop +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ +TESTS = [ [ {"roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, (3, 3, 3, 3), @@ -53,17 +54,25 @@ class TestSpatialCrop(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_shape, expected_shape): + results = [] input_data = np.random.randint(0, 2, size=input_shape) - result = SpatialCrop(**input_param)(input_data) - self.assertTupleEqual(result.shape, expected_shape) - - @parameterized.expand(TEST_CASES) - def test_tensor_shape(self, input_param, input_shape, expected_shape): - input_data = torch.randint(0, 2, size=input_shape, device="cuda" if torch.cuda.is_available() else "cpu") - result = SpatialCrop(**input_param)(input_data) - self.assertTupleEqual(result.shape, expected_shape) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS + (None,): + input_param_mod = { + k: q(v) if k != "roi_slices" and q is not None else v for k, v in input_param.items() + } + im = p(input_data) + result = SpatialCrop(**input_param_mod)(im) + self.assertEqual(type(im), type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, im.device) + result = result.cpu().numpy() + self.assertTupleEqual(result.shape, expected_shape) + results.append(result) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) @parameterized.expand(TEST_ERRORS) def test_error(self, input_param): diff --git a/tests/test_spatial_pad.py b/tests/test_spatial_pad.py index 93241610de..142c0c3e4e 100644 --- a/tests/test_spatial_pad.py +++ b/tests/test_spatial_pad.py @@ -10,47 +10,86 @@ # limitations under the License. import unittest +from typing import List import numpy as np +import torch from parameterized import parameterized from monai.transforms import SpatialPad +from monai.utils.enums import NumpyPadMode +from monai.utils.misc import set_determinism +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"spatial_size": [15, 8, 8], "method": "symmetric", "mode": "constant"}, - np.zeros((3, 8, 8, 4)), - np.zeros((3, 15, 8, 8)), -] +TESTS = [] -TEST_CASE_2 = [ - {"spatial_size": [15, 8, 8], "method": "end", "mode": "constant"}, - np.zeros((3, 8, 8, 4)), - np.zeros((3, 15, 8, 8)), +# Numpy modes +MODES: List = [ + "constant", + "edge", + "linear_ramp", + "maximum", + "mean", + "median", + "minimum", + "reflect", + "symmetric", + "wrap", + "empty", ] +MODES += [NumpyPadMode(i) for i in MODES] -TEST_CASE_3 = [ - {"spatial_size": [15, 4, -1], "method": "symmetric", "mode": "constant"}, - np.zeros((3, 8, 8, 4)), - np.zeros((3, 15, 8, 4)), -] +for mode in MODES: + TESTS.append( + [ + {"spatial_size": [50, 50], "method": "end", "mode": mode}, + (1, 2, 2), + (1, 50, 50), + ] + ) + + TESTS.append( + [ + {"spatial_size": [15, 4, -1], "method": "symmetric", "mode": mode}, + (3, 8, 8, 4), + (3, 15, 8, 4), + ] + ) class TestSpatialPad(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_pad_shape(self, input_param, input_data, expected_val): - padder = SpatialPad(**input_param) - result = padder(input_data) - np.testing.assert_allclose(result.shape, expected_val.shape) - result = padder(input_data, mode=input_param["mode"]) - np.testing.assert_allclose(result.shape, expected_val.shape) + def setUp(self) -> None: + set_determinism(seed=0) + + def tearDown(self) -> None: + set_determinism(None) + + @staticmethod + def get_arr(shape): + return np.random.randint(100, size=shape).astype(float) + + @parameterized.expand(TESTS) + def test_pad_shape(self, input_param, input_shape, expected_shape): + results_1 = [] + results_2 = [] + input_data = self.get_arr(input_shape) + for p in TEST_NDARRAYS: + padder = SpatialPad(**input_param) + results_1.append(padder(p(input_data))) + results_2.append(padder(p(input_data), mode=input_param["mode"])) + for results in (results_1, results_2): + np.testing.assert_allclose(results[-1].shape, expected_shape) + if input_param["mode"] not in ("empty", NumpyPadMode.EMPTY): + torch.testing.assert_allclose(p(results[0]), results[-1], atol=0, rtol=1e-5) def test_pad_kwargs(self): padder = SpatialPad( spatial_size=[15, 8], method="end", mode="constant", constant_values=((0, 0), (1, 1), (2, 2)) ) - result = padder(np.zeros((3, 8, 4))) - np.testing.assert_allclose(result[:, 8:, :4], np.ones((3, 7, 4))) - np.testing.assert_allclose(result[:, :, 4:], np.ones((3, 15, 4)) + 1) + for p in TEST_NDARRAYS: + result = padder(p(np.zeros((3, 8, 4)))) + torch.testing.assert_allclose(result[:, 8:, :4], p(np.ones((3, 7, 4))), rtol=1e-7, atol=0) + torch.testing.assert_allclose(result[:, :, 4:], p(np.ones((3, 15, 4)) + 1), rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_squeezedim.py b/tests/test_squeezedim.py index 01ea489320..0d627a8ca8 100644 --- a/tests/test_squeezedim.py +++ b/tests/test_squeezedim.py @@ -17,29 +17,27 @@ from monai.transforms import SqueezeDim -TEST_CASE_1 = [{"dim": None}, np.random.rand(1, 2, 1, 3), (2, 3)] +TESTS, TESTS_FAIL = [], [] +for r in (np.random.rand, torch.rand): + TESTS.append([{"dim": None}, r(1, 2, 1, 3), (2, 3)]) # type: ignore + TESTS.append([{"dim": 2}, r(1, 2, 1, 8, 16), (1, 2, 8, 16)]) # type: ignore + TESTS.append([{"dim": -1}, r(1, 1, 16, 8, 1), (1, 1, 16, 8)]) # type: ignore + TESTS.append([{}, r(1, 2, 1, 3), (2, 1, 3)]) # type: ignore -TEST_CASE_2 = [{"dim": 2}, np.random.rand(1, 2, 1, 8, 16), (1, 2, 8, 16)] - -TEST_CASE_3 = [{"dim": -1}, np.random.rand(1, 1, 16, 8, 1), (1, 1, 16, 8)] - -TEST_CASE_4 = [{}, np.random.rand(1, 2, 1, 3), (2, 1, 3)] - -TEST_CASE_4_PT = [{}, torch.rand(1, 2, 1, 3), (2, 1, 3)] - -TEST_CASE_5 = [ValueError, {"dim": -2}, np.random.rand(1, 1, 16, 8, 1)] - -TEST_CASE_6 = [TypeError, {"dim": 0.5}, np.random.rand(1, 1, 16, 8, 1)] + TESTS_FAIL.append([ValueError, {"dim": -2}, r(1, 1, 16, 8, 1)]) # type: ignore + TESTS_FAIL.append([TypeError, {"dim": 0.5}, r(1, 1, 16, 8, 1)]) # type: ignore class TestSqueezeDim(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_4_PT]) + @parameterized.expand(TESTS) def test_shape(self, input_param, test_data, expected_shape): + result = SqueezeDim(**input_param)(test_data) self.assertTupleEqual(result.shape, expected_shape) - @parameterized.expand([TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS_FAIL) def test_invalid_inputs(self, exception, input_param, test_data): + with self.assertRaises(exception): SqueezeDim(**input_param)(test_data) diff --git a/tests/test_std_shift_intensity.py b/tests/test_std_shift_intensity.py index f317330435..5c16e14c45 100644 --- a/tests/test_std_shift_intensity.py +++ b/tests/test_std_shift_intensity.py @@ -12,45 +12,53 @@ import unittest import numpy as np +import torch from monai.transforms import ShiftIntensity, StdShiftIntensity -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D class TestStdShiftIntensity(NumpyImageTestCase2D): def test_value(self): - factor = np.random.rand() - offset = np.std(self.imt) * factor - shifter = ShiftIntensity(offset=offset) - expected = shifter(self.imt) - std_shifter = StdShiftIntensity(factor=factor) - result = std_shifter(self.imt) - np.testing.assert_allclose(result, expected, rtol=1e-5) + for p in TEST_NDARRAYS: + imt = p(self.imt) + factor = np.random.rand() + offset = np.std(self.imt) * factor + shifter = ShiftIntensity(offset=offset) + expected = shifter(imt) + std_shifter = StdShiftIntensity(factor=factor) + result = std_shifter(imt) + torch.testing.assert_allclose(result, expected, atol=0, rtol=1e-5) def test_zerostd(self): - image = np.ones([2, 3, 3]) - for nonzero in [True, False]: - for channel_wise in [True, False]: - factor = np.random.rand() - std_shifter = StdShiftIntensity(factor=factor, nonzero=nonzero, channel_wise=channel_wise) - result = std_shifter(image) - np.testing.assert_allclose(result, image, rtol=1e-5) + for p in TEST_NDARRAYS: + image = p(np.ones([2, 3, 3], dtype=np.float32)) + for nonzero in [True, False]: + for channel_wise in [True, False]: + factor = np.random.rand() + std_shifter = StdShiftIntensity(factor=factor, nonzero=nonzero, channel_wise=channel_wise) + result = std_shifter(image) + torch.testing.assert_allclose(result, image, atol=0, rtol=1e-5) def test_nonzero(self): - image = np.asarray([[4.0, 0.0, 2.0], [0, 2, 4]]) # std = 1 - factor = np.random.rand() - std_shifter = StdShiftIntensity(factor=factor, nonzero=True) - result = std_shifter(image) - expected = np.asarray([[4 + factor, 0, 2 + factor], [0, 2 + factor, 4 + factor]]) - np.testing.assert_allclose(result, expected, rtol=1e-5) + for p in TEST_NDARRAYS: + image = p(np.asarray([[4.0, 0.0, 2.0], [0, 2, 4]])) # std = 1 + factor = np.random.rand() + std_shifter = StdShiftIntensity(factor=factor, nonzero=True) + result = std_shifter(image) + expected = p(np.asarray([[4 + factor, 0, 2 + factor], [0, 2 + factor, 4 + factor]], dtype=np.float32)) + torch.testing.assert_allclose(result, expected, atol=0, rtol=1e-5) def test_channel_wise(self): - image = np.stack((np.asarray([1.0, 2.0]), np.asarray([1.0, 1.0]))) # std: 0.5, 0 - factor = np.random.rand() - std_shifter = StdShiftIntensity(factor=factor, channel_wise=True) - result = std_shifter(image) - expected = np.stack((np.asarray([1 + 0.5 * factor, 2 + 0.5 * factor]), np.asarray([1, 1]))) - np.testing.assert_allclose(result, expected, rtol=1e-5) + for p in TEST_NDARRAYS: + image = p(np.stack((np.asarray([1.0, 2.0]), np.asarray([1.0, 1.0])))) # std: 0.5, 0 + factor = np.random.rand() + std_shifter = StdShiftIntensity(factor=factor, channel_wise=True) + result = std_shifter(image) + expected = p( + np.stack((np.asarray([1 + 0.5 * factor, 2 + 0.5 * factor]), np.asarray([1, 1]))).astype(np.float32) + ) + torch.testing.assert_allclose(result, expected, atol=0, rtol=1e-5) def test_dtype(self): trans_dtype = np.float32 diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index 66d7627971..e5038b4241 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -82,7 +82,6 @@ def test_test_time_augmentation(self): scale_range=((0.8, 1), (0.8, 1)), padding_mode="zeros", mode=("bilinear", "nearest"), - as_tensor_output=False, ), CropForegroundd(keys, source_key="image"), DivisiblePadd(keys, 4), diff --git a/tests/test_threshold_intensity.py b/tests/test_threshold_intensity.py index a6d3895709..094efc1257 100644 --- a/tests/test_threshold_intensity.py +++ b/tests/test_threshold_intensity.py @@ -12,22 +12,26 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import ThresholdIntensity +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"threshold": 5, "above": True, "cval": 0}, (0, 0, 0, 0, 0, 0, 6, 7, 8, 9)] - -TEST_CASE_2 = [{"threshold": 5, "above": False, "cval": 0}, (0, 1, 2, 3, 4, 0, 0, 0, 0, 0)] - -TEST_CASE_3 = [{"threshold": 5, "above": True, "cval": 5}, (5, 5, 5, 5, 5, 5, 6, 7, 8, 9)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"threshold": 5, "above": True, "cval": 0}, (0, 0, 0, 0, 0, 0, 6, 7, 8, 9)]) + TESTS.append([p, {"threshold": 5, "above": False, "cval": 0}, (0, 1, 2, 3, 4, 0, 0, 0, 0, 0)]) + TESTS.append([p, {"threshold": 5, "above": True, "cval": 5}, (5, 5, 5, 5, 5, 5, 6, 7, 8, 9)]) class TestThresholdIntensity(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_value(self, input_param, expected_value): - test_data = np.arange(10) + @parameterized.expand(TESTS) + def test_value(self, in_type, input_param, expected_value): + test_data = in_type(np.arange(10)) result = ThresholdIntensity(**input_param)(test_data) + if isinstance(result, torch.Tensor): + result = result.cpu() np.testing.assert_allclose(result, expected_value) diff --git a/tests/test_threshold_intensityd.py b/tests/test_threshold_intensityd.py index efcfcfe604..5d8b69a1de 100644 --- a/tests/test_threshold_intensityd.py +++ b/tests/test_threshold_intensityd.py @@ -12,31 +12,45 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import ThresholdIntensityd - -TEST_CASE_1 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 0}, - (0, 0, 0, 0, 0, 0, 6, 7, 8, 9), -] - -TEST_CASE_2 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": False, "cval": 0}, - (0, 1, 2, 3, 4, 0, 0, 0, 0, 0), -] - -TEST_CASE_3 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 5}, - (5, 5, 5, 5, 5, 5, 6, 7, 8, 9), -] +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 0}, + (0, 0, 0, 0, 0, 0, 6, 7, 8, 9), + ] + ) + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": False, "cval": 0}, + (0, 1, 2, 3, 4, 0, 0, 0, 0, 0), + ] + ) + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 5}, + (5, 5, 5, 5, 5, 5, 6, 7, 8, 9), + ] + ) class TestThresholdIntensityd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_value(self, input_param, expected_value): - test_data = {"image": np.arange(10), "label": np.arange(10), "extra": np.arange(10)} + @parameterized.expand(TESTS) + def test_value(self, in_type, input_param, expected_value): + test_data = {"image": in_type(np.arange(10)), "label": in_type(np.arange(10)), "extra": in_type(np.arange(10))} result = ThresholdIntensityd(**input_param)(test_data) + for k, v in result.items(): + if isinstance(result[k], torch.Tensor): + result[k] = v.cpu() np.testing.assert_allclose(result["image"], expected_value) np.testing.assert_allclose(result["label"], expected_value) np.testing.assert_allclose(result["extra"], expected_value) diff --git a/tests/test_to_affine_nd.py b/tests/test_to_affine_nd.py new file mode 100644 index 0000000000..087f006ddd --- /dev/null +++ b/tests/test_to_affine_nd.py @@ -0,0 +1,37 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data.utils import to_affine_nd +from tests.test_dtype_convert import TEST_NDARRAYS + +TESTS = [] +TESTS.append([2, np.eye(4)]) + + +class TestToAffinend(unittest.TestCase): + @parameterized.expand(TESTS) + def test_to_affine_nd(self, r, affine): + outs = [] + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + res = to_affine_nd(p(r), q(affine)) + outs.append(res.cpu() if isinstance(res, torch.Tensor) else res) + np.testing.assert_allclose(outs[-1], outs[0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_to_numpy.py b/tests/test_to_numpy.py index 291601ffeb..53706f5942 100644 --- a/tests/test_to_numpy.py +++ b/tests/test_to_numpy.py @@ -23,7 +23,7 @@ class TestToNumpy(unittest.TestCase): @skipUnless(has_cp, "CuPy is required.") - def test_cumpy_input(self): + def test_cupy_input(self): test_data = cp.array([[1, 2], [3, 4]]) test_data = cp.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) diff --git a/tests/test_to_onehot.py b/tests/test_to_onehot.py index c3e373955d..c6c3264f74 100644 --- a/tests/test_to_onehot.py +++ b/tests/test_to_onehot.py @@ -16,6 +16,8 @@ from parameterized import parameterized from monai.networks import one_hot +from monai.utils.misc import dtype_convert +from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [ # single channel 2D, batch 3, shape (2, 1, 2, 2) {"labels": torch.tensor([[[[0, 1], [1, 2]]], [[[2, 1], [1, 0]]]]), "num_classes": 3}, @@ -44,16 +46,30 @@ class TestToOneHot(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_shape(self, input_data, expected_shape, expected_result=None): - result = one_hot(**input_data) - self.assertEqual(result.shape, expected_shape) - if expected_result is not None: - self.assertTrue(np.allclose(expected_result, result.numpy())) - - if "dtype" in input_data: - self.assertEqual(result.dtype, input_data["dtype"]) - else: - # by default, expecting float type - self.assertEqual(result.dtype, torch.float) + results = [] + for p in TEST_NDARRAYS: + input_data_mod = {k: p(v) if isinstance(v, torch.Tensor) else v for k, v in input_data.items()} + orig_dtype = input_data_mod["labels"].dtype + result = one_hot(**input_data_mod) + self.assertEqual(result.shape, expected_shape) + self.assertEqual(type(result), type(input_data_mod["labels"])) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, input_data_mod["labels"].device) + result = result.cpu().numpy() + + if expected_result is not None: + self.assertTrue(np.allclose(expected_result, result)) + + self.assertEqual(input_data_mod["labels"].dtype, orig_dtype) + if "dtype" in input_data: + self.assertEqual(result.dtype, dtype_convert(input_data_mod["dtype"], type(result))) + else: + # by default, expecting float type + self.assertEqual(result.dtype, dtype_convert(torch.float, type(result))) + + results.append(result) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) if __name__ == "__main__": diff --git a/tests/test_to_tensor.py b/tests/test_to_tensor.py index 4a36254743..e35a493bb8 100644 --- a/tests/test_to_tensor.py +++ b/tests/test_to_tensor.py @@ -10,19 +10,53 @@ # limitations under the License. import unittest +from unittest import skipUnless import numpy as np import torch from monai.transforms import ToTensor +from monai.utils import optional_import + +cp, has_cp = optional_import("cupy") class TestToTensor(unittest.TestCase): - def test_array_input(self): - for test_data in ([[1, 2], [3, 4]], np.array([[1, 2], [3, 4]]), torch.as_tensor([[1, 2], [3, 4]])): - result = ToTensor()(test_data) - torch.testing.assert_allclose(result, test_data) - self.assertTupleEqual(result.shape, (2, 2)) + @skipUnless(has_cp, "CuPy is required.") + def test_cupy_input(self): + test_data = cp.array([[1, 2], [3, 4]]) + test_data = cp.rot90(test_data) + self.assertFalse(test_data.flags["C_CONTIGUOUS"]) + result = ToTensor()(test_data) + self.assertTrue(isinstance(result, torch.Tensor)) + self.assertTrue(result.is_contiguous()) + torch.testing.assert_allclose(result, test_data.get()) + + def test_numpy_input(self): + test_data = np.array([[1, 2], [3, 4]]) + test_data = np.rot90(test_data) + self.assertFalse(test_data.flags["C_CONTIGUOUS"]) + result = ToTensor()(test_data) + self.assertTrue(isinstance(result, torch.Tensor)) + self.assertTrue(result.is_contiguous()) + torch.testing.assert_allclose(result, np.ascontiguousarray(test_data)) + + def test_tensor_input(self): + test_data = torch.tensor([[1, 2], [3, 4]]) + test_data = test_data.rot90() + self.assertFalse(test_data.is_contiguous()) + result = ToTensor()(test_data) + self.assertTrue(isinstance(result, torch.Tensor)) + self.assertTrue(result.is_contiguous()) + torch.testing.assert_allclose(result, test_data) + + def test_list_tuple(self): + test_data = [[1, 2], [3, 4]] + result = ToTensor()(test_data) + torch.testing.assert_allclose(result, test_data) + test_data = ((1, 2), (3, 4)) + result = ToTensor()(test_data) + torch.testing.assert_allclose(result, test_data) def test_single_input(self): for test_data in (5, np.asarray(5), torch.tensor(5)): diff --git a/tests/test_transpose.py b/tests/test_transpose.py index 3b758b5aa2..8602027d87 100644 --- a/tests/test_transpose.py +++ b/tests/test_transpose.py @@ -12,27 +12,30 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import Transpose +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - np.arange(5 * 4).reshape(5, 4), - None, -] -TEST_CASE_1 = [ - np.arange(5 * 4 * 3).reshape(5, 4, 3), - [2, 0, 1], -] -TEST_CASES = [TEST_CASE_0, TEST_CASE_1] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p(np.arange(5 * 4).reshape(5, 4)), None]) + TESTS.append([p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), [2, 0, 1]]) class TestTranspose(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_transpose(self, im, indices): tr = Transpose(indices) + out1 = tr(im) - out2 = np.transpose(im, indices) + im_cpu = im if isinstance(im, np.ndarray) else im.cpu() + out2 = np.transpose(im_cpu, indices) + self.assertEqual(type(im), type(out1)) + if isinstance(out1, torch.Tensor): + self.assertEqual(im.device, out1.device) + out1 = out1.cpu() np.testing.assert_array_equal(out1, out2) diff --git a/tests/test_transposed.py b/tests/test_transposed.py index 56375f3981..9998a8c76c 100644 --- a/tests/test_transposed.py +++ b/tests/test_transposed.py @@ -13,44 +13,72 @@ from copy import deepcopy import numpy as np +import torch from parameterized import parameterized from monai.transforms import Transposed +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - np.arange(5 * 4).reshape(5, 4), - [1, 0], -] -TEST_CASE_1 = [ - np.arange(5 * 4).reshape(5, 4), - None, -] -TEST_CASE_2 = [ - np.arange(5 * 4 * 3).reshape(5, 4, 3), - [2, 0, 1], -] -TEST_CASE_3 = [ - np.arange(5 * 4 * 3).reshape(5, 4, 3), - None, -] -TEST_CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3] +KEYS = ("i", "j") + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + p(np.arange(5 * 4).reshape(5, 4)), + [1, 0], + ] + ) + TESTS.append( + [ + p(np.arange(5 * 4).reshape(5, 4)), + None, + ] + ) + TESTS.append( + [ + p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), + [2, 0, 1], + ] + ) + TESTS.append( + [ + p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), + None, + ] + ) class TestTranspose(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_transpose(self, im, indices): - data = {"i": deepcopy(im), "j": deepcopy(im)} - tr = Transposed(["i", "j"], indices) + im_cpu = deepcopy(im if isinstance(im, np.ndarray) else im.cpu()) + out_gt = np.transpose(im_cpu, indices) + + data = {k: deepcopy(im) for k in KEYS} + tr = Transposed(KEYS, indices) out_data = tr(data) - out_im1, out_im2 = out_data["i"], out_data["j"] - out_gt = np.transpose(im, indices) - np.testing.assert_array_equal(out_im1, out_gt) - np.testing.assert_array_equal(out_im2, out_gt) + + for k, v in out_data.items(): + if k not in KEYS: + continue + self.assertEqual(type(im), type(v)) + if isinstance(v, torch.Tensor): + self.assertEqual(im.device, v.device) + v = v.cpu() + np.testing.assert_array_equal(v, out_gt) # test inverse fwd_inv_data = tr.inverse(out_data) - for i, j in zip(data.values(), fwd_inv_data.values()): - np.testing.assert_array_equal(i, j) + + for k, v in fwd_inv_data.items(): + if k not in KEYS: + continue + self.assertEqual(type(im), type(v)) + if isinstance(v, torch.Tensor): + self.assertEqual(im.device, v.device) + v = v.cpu() + np.testing.assert_array_equal(v, im_cpu) if __name__ == "__main__": diff --git a/tests/test_vote_ensemble.py b/tests/test_vote_ensemble.py index 74c19d5f48..ba48ed1825 100644 --- a/tests/test_vote_ensemble.py +++ b/tests/test_vote_ensemble.py @@ -11,59 +11,91 @@ import unittest +import numpy as np import torch from parameterized import parameterized from monai.transforms import VoteEnsemble +from tests.utils import TEST_NDARRAYS -# shape: [2, 1, 1] -TEST_CASE_1 = [ - {"num_classes": None}, - [torch.tensor([[[1]], [[0]]]), torch.tensor([[[1]], [[0]]]), torch.tensor([[[0]], [[1]]])], - torch.tensor([[[1.0]], [[0.0]]]), -] - -# shape: [1, 2, 1, 1] -TEST_CASE_2 = [ - {"num_classes": None}, - torch.stack([torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])]), - torch.tensor([[[[1.0]], [[0.0]]]]), -] - -# shape: [1, 2, 1] -TEST_CASE_3 = [ - {"num_classes": 3}, - [torch.tensor([[[0], [2]]]), torch.tensor([[[0], [2]]]), torch.tensor([[[1], [1]]])], - torch.tensor([[[0], [2]]]), -] - -# shape: [1, 2, 1] -TEST_CASE_4 = [ - {"num_classes": 5}, - [torch.tensor([[[0], [2]]]), torch.tensor([[[0], [2]]]), torch.tensor([[[1], [1]]])], - torch.tensor([[[0], [2]]]), -] - -# shape: [1] -TEST_CASE_5 = [ - {"num_classes": 3}, - [torch.tensor([2]), torch.tensor([2]), torch.tensor([1])], - torch.tensor([2]), -] - -# shape: 1 -TEST_CASE_6 = [ - {"num_classes": 3}, - [torch.tensor(2), torch.tensor(2), torch.tensor(1)], - torch.tensor(2), -] +TESTS = [] +for p in TEST_NDARRAYS: + # shape: [2, 1, 1] + TESTS.append( + [ + p, + {"num_classes": None}, + [torch.tensor([[[1]], [[0]]]), torch.tensor([[[1]], [[0]]]), torch.tensor([[[0]], [[1]]])], + torch.tensor([[[1.0]], [[0.0]]]), + ] + ) + TESTS.append( + # shape: [1, 2, 1, 1] + [ + p, + {"num_classes": None}, + torch.stack( + [torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])] + ), + torch.tensor([[[[1.0]], [[0.0]]]]), + ] + ) + TESTS.append( + # shape: [1, 2, 1] + [ + p, + {"num_classes": 3}, + [torch.tensor([[[0], [2]]]), torch.tensor([[[0], [2]]]), torch.tensor([[[1], [1]]])], + torch.tensor([[[0], [2]]]), + ] + ) + TESTS.append( + # shape: [1, 2, 1] + [ + p, + {"num_classes": 5}, + [torch.tensor([[[0], [2]]]), torch.tensor([[[0], [2]]]), torch.tensor([[[1], [1]]])], + torch.tensor([[[0], [2]]]), + ] + ) + TESTS.append( + # shape: [1] + [ + p, + {"num_classes": 3}, + [torch.tensor([2]), torch.tensor([2]), torch.tensor([1])], + torch.tensor([2]), + ] + ) + TESTS.append( + # shape: 1 + [ + p, + {"num_classes": 3}, + [torch.tensor(2), torch.tensor(2), torch.tensor(1)], + torch.tensor(2), + ] + ) class TestVoteEnsemble(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) - def test_value(self, input_param, img, expected_value): - result = VoteEnsemble(**input_param)(img) - torch.testing.assert_allclose(result, expected_value) + @parameterized.expand(TESTS) + def test_value(self, in_type, input_param, img, expected_value): + if isinstance(img, list): + im = [in_type(i) for i in img] + im_type = type(im[0]) + im_device = im[0].device if isinstance(im[0], torch.Tensor) else None + else: + im = in_type(img) + im_type = type(im) + im_device = im.device if isinstance(im, torch.Tensor) else None + + result = VoteEnsemble(**input_param)(im) + self.assertEqual(im_type, type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, im_device) + result = result.cpu() + np.testing.assert_allclose(result, expected_value) def test_cuda_value(self): img = torch.stack( diff --git a/tests/test_vote_ensembled.py b/tests/test_vote_ensembled.py index e94213733f..972469d1db 100644 --- a/tests/test_vote_ensembled.py +++ b/tests/test_vote_ensembled.py @@ -11,68 +11,94 @@ import unittest +import numpy as np import torch from parameterized import parameterized from monai.transforms import VoteEnsembled +from tests.utils import TEST_NDARRAYS -# shape: [1, 2, 1, 1] -TEST_CASE_1 = [ - {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": None}, - { - "pred0": torch.tensor([[[[1]], [[0]]]]), - "pred1": torch.tensor([[[[1]], [[0]]]]), - "pred2": torch.tensor([[[[0]], [[1]]]]), - }, - torch.tensor([[[[1.0]], [[0.0]]]]), -] +TESTS = [] +for p in TEST_NDARRAYS: + # shape: [1, 2, 1, 1] + TESTS.append( + [ + p, + {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": None}, + { + "pred0": torch.tensor([[[[1]], [[0]]]]), + "pred1": torch.tensor([[[[1]], [[0]]]]), + "pred2": torch.tensor([[[[0]], [[1]]]]), + }, + torch.tensor([[[[1.0]], [[0.0]]]]), + ] + ) -# shape: [1, 2, 1, 1] -TEST_CASE_2 = [ - {"keys": "output", "output_key": "output", "num_classes": None}, - { - "output": torch.stack( - [torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])] - ) - }, - torch.tensor([[[[1.0]], [[0.0]]]]), -] + # shape: [1, 2, 1, 1] + TESTS.append( + [ + p, + {"keys": "output", "output_key": "output", "num_classes": None}, + { + "output": torch.stack( + [torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])] + ) + }, + torch.tensor([[[[1.0]], [[0.0]]]]), + ] + ) -# shape: [1, 2, 1] -TEST_CASE_3 = [ - {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 3}, - { - "pred0": torch.tensor([[[0], [2]]]), - "pred1": torch.tensor([[[0], [2]]]), - "pred2": torch.tensor([[[1], [1]]]), - }, - torch.tensor([[[0], [2]]]), -] + # shape: [1, 2, 1] + TESTS.append( + [ + p, + {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 3}, + { + "pred0": torch.tensor([[[0], [2]]]), + "pred1": torch.tensor([[[0], [2]]]), + "pred2": torch.tensor([[[1], [1]]]), + }, + torch.tensor([[[0], [2]]]), + ] + ) -# shape: [1, 2, 1] -TEST_CASE_4 = [ - {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 5}, - { - "pred0": torch.tensor([[[0], [2]]]), - "pred1": torch.tensor([[[0], [2]]]), - "pred2": torch.tensor([[[1], [1]]]), - }, - torch.tensor([[[0], [2]]]), -] + # shape: [1, 2, 1] + TESTS.append( + [ + p, + {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 5}, + { + "pred0": torch.tensor([[[0], [2]]]), + "pred1": torch.tensor([[[0], [2]]]), + "pred2": torch.tensor([[[1], [1]]]), + }, + torch.tensor([[[0], [2]]]), + ] + ) -# shape: [1] -TEST_CASE_5 = [ - {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 3}, - {"pred0": torch.tensor([2]), "pred1": torch.tensor([2]), "pred2": torch.tensor([1])}, - torch.tensor([2]), -] + # shape: [1] + TESTS.append( + [ + p, + {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 3}, + {"pred0": torch.tensor([2]), "pred1": torch.tensor([2]), "pred2": torch.tensor([1])}, + torch.tensor([2]), + ] + ) class TestVoteEnsembled(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) - def test_value(self, input_param, img, expected_value): - result = VoteEnsembled(**input_param)(img) - torch.testing.assert_allclose(result["output"], expected_value) + @parameterized.expand(TESTS) + def test_value(self, in_type, input_param, img, expected_value): + for k, v in img.items(): + img[k] = in_type(v) + result = VoteEnsembled(**input_param)(img)["output"] + in_im = img["pred0"] if "pred0" in img else img["output"] + self.assertEqual(type(result), type(in_im)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, in_im.device) + result = result.cpu() + np.testing.assert_allclose(result, expected_value) def test_cuda_value(self): img = torch.stack( diff --git a/tests/test_zoom.py b/tests/test_zoom.py index dcc401f16c..3875046ded 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -12,11 +12,12 @@ import unittest import numpy as np +import torch from parameterized import parameterized from scipy.ndimage import zoom as zoom_scipy from monai.transforms import Zoom -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D VALID_CASES = [(1.5, "nearest"), (1.5, "nearest"), (0.8, "bilinear"), (0.8, "area")] @@ -26,38 +27,46 @@ class TestZoom(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) def test_correct_results(self, zoom, mode): - zoom_fn = Zoom(zoom=zoom, mode=mode, keep_size=False) - zoomed = zoom_fn(self.imt[0]) - _order = 0 - if mode.endswith("linear"): - _order = 1 - expected = [] - for channel in self.imt[0]: - expected.append(zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False)) - expected = np.stack(expected).astype(np.float32) - np.testing.assert_allclose(zoomed, expected, atol=1.0) + for p in TEST_NDARRAYS: + zoom_fn = Zoom(zoom=zoom, mode=mode, keep_size=False) + zoomed = zoom_fn(p(self.imt[0])) + _order = 0 + if mode.endswith("linear"): + _order = 1 + expected = [] + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False)) + expected = np.stack(expected).astype(np.float32) + if isinstance(zoomed, torch.Tensor): + zoomed = zoomed.detach().cpu().numpy() + np.testing.assert_allclose(zoomed, expected, atol=1.0) def test_keep_size(self): - zoom_fn = Zoom(zoom=[0.6, 0.6], keep_size=True, align_corners=True) - zoomed = zoom_fn(self.imt[0], mode="bilinear") - np.testing.assert_allclose(zoomed.shape, self.imt.shape[1:]) + for p in TEST_NDARRAYS: + zoom_fn = Zoom(zoom=[0.6, 0.6], keep_size=True, align_corners=True) + zoomed = zoom_fn(p(self.imt[0]), mode="bilinear") + np.testing.assert_allclose(zoomed.shape, self.imt.shape[1:]) - zoom_fn = Zoom(zoom=[1.3, 1.3], keep_size=True) - zoomed = zoom_fn(self.imt[0]) - np.testing.assert_allclose(zoomed.shape, self.imt.shape[1:]) + zoom_fn = Zoom(zoom=[1.3, 1.3], keep_size=True) + zoomed = zoom_fn(p(self.imt[0])) + if isinstance(zoomed, torch.Tensor): + zoomed = zoomed.detach().cpu().numpy() + np.testing.assert_allclose(zoomed.shape, self.imt.shape[1:]) @parameterized.expand(INVALID_CASES) def test_invalid_inputs(self, zoom, mode, raises): - with self.assertRaises(raises): - zoom_fn = Zoom(zoom=zoom, mode=mode) - zoom_fn(self.imt[0]) + for p in TEST_NDARRAYS: + with self.assertRaises(raises): + zoom_fn = Zoom(zoom=zoom, mode=mode) + zoom_fn(p(self.imt[0])) def test_padding_mode(self): - zoom_fn = Zoom(zoom=0.5, mode="nearest", padding_mode="constant", keep_size=True) - test_data = np.array([[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]]) - zoomed = zoom_fn(test_data) - expected = np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]) - np.testing.assert_allclose(zoomed, expected) + for p in TEST_NDARRAYS: + zoom_fn = Zoom(zoom=0.5, mode="nearest", padding_mode="constant", keep_size=True) + test_data = p([[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]]) + zoomed = zoom_fn(test_data) + expected = p([[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]) + torch.testing.assert_allclose(zoomed, expected) if __name__ == "__main__": diff --git a/tests/test_zoom_affine.py b/tests/test_zoom_affine.py index 49c3c0dcac..7d21341a01 100644 --- a/tests/test_zoom_affine.py +++ b/tests/test_zoom_affine.py @@ -10,69 +10,99 @@ # limitations under the License. import unittest +from typing import List, Tuple import nibabel as nib import numpy as np from parameterized import parameterized from monai.data.utils import zoom_affine +from tests.utils import TEST_NDARRAYS -VALID_CASES = [ - ( - np.array([[2, 1, 4], [-1, -3, 5], [0, 0, 1]]), - (10, 20, 30), - np.array([[8.94427191, -8.94427191, 0], [-4.47213595, -17.88854382, 0], [0.0, 0.0, 1.0]]), - ), - ( - np.array([[1, 0, 0, 4], [0, 2, 0, 5], [0, 0, 3, 6], [0, 0, 0, 1]]), - (10, 20, 30), - np.array([[10, 0, 0, 0], [0, 20, 0, 0], [0, 0, 30, 0], [0, 0, 0, 1]]), - ), - ( - np.array([[1, 0, 0, 4], [0, 2, 0, 5], [0, 0, 3, 6], [0, 0, 0, 1]]), - (10, 20), - np.array([[10, 0, 0, 0], [0, 20, 0, 0], [0, 0, 3, 0], [0, 0, 0, 1]]), - ), - ( - np.array([[1, 0, 0, 4], [0, 2, 0, 5], [0, 0, 3, 6], [0, 0, 0, 1]]), - (10,), - np.array([[10, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 1]]), - ), - ( - [[1, 0, 10], [0, 1, 20], [0, 0, 1]] - @ ([[0, -1, 0], [1, 0, 0], [0, 0, 1]] @ np.array([[2, 0.3, 0], [0, 3, 0], [0, 0, 1]])), - (4, 5, 6), - ([[0, -1, 0], [1, 0, 0], [0, 0, 1]] @ np.array([[4, 0, 0], [0, 5, 0], [0, 0, 1]])), - ), -] +VALID_CASES: List[Tuple] = [] +DIAGONAL_CASES: List[Tuple] = [] +for p in TEST_NDARRAYS: + VALID_CASES.append( + ( + p, + np.array([[2, 1, 4], [-1, -3, 5], [0, 0, 1]]), + (10, 20, 30), + np.array([[8.94427191, -8.94427191, 0], [-4.47213595, -17.88854382, 0], [0.0, 0.0, 1.0]]), + ) + ) + VALID_CASES.append( + ( + p, + np.array([[1, 0, 0, 4], [0, 2, 0, 5], [0, 0, 3, 6], [0, 0, 0, 1]]), + (10, 20, 30), + np.array([[10, 0, 0, 0], [0, 20, 0, 0], [0, 0, 30, 0], [0, 0, 0, 1]]), + ) + ) + VALID_CASES.append( + ( + p, + np.array([[1, 0, 0, 4], [0, 2, 0, 5], [0, 0, 3, 6], [0, 0, 0, 1]]), + (10, 20), + np.array([[10, 0, 0, 0], [0, 20, 0, 0], [0, 0, 3, 0], [0, 0, 0, 1]]), + ) + ) + VALID_CASES.append( + ( + p, + np.array([[1, 0, 0, 4], [0, 2, 0, 5], [0, 0, 3, 6], [0, 0, 0, 1]]), + (10,), + np.array([[10, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 1]]), + ) + ) + VALID_CASES.append( + ( + p, + [[1, 0, 10], [0, 1, 20], [0, 0, 1]] + @ ([[0, -1, 0], [1, 0, 0], [0, 0, 1]] @ np.array([[2, 0.3, 0], [0, 3, 0], [0, 0, 1]])), + (4, 5, 6), + ([[0, -1, 0], [1, 0, 0], [0, 0, 1]] @ np.array([[4, 0, 0], [0, 5, 0], [0, 0, 1]])), + ) + ) -DIAGONAL_CASES = [ - ( - np.array([[-1, 0, 0, 4], [0, 2, 0, 5], [0, 0, 3, 6], [0, 0, 0, 1]]), - (10, 20, 30), - np.array([[10, 0, 0, 0], [0, 20, 0, 0], [0, 0, 30, 0], [0, 0, 0, 1]]), - ), - (np.array([[2, 1, 4], [-1, -3, 5], [0, 0, 1]]), (10, 20, 30), np.array([[10, 0, 0], [0, 20, 0], [0.0, 0.0, 1.0]])), - ( # test default scale from affine - np.array([[2, 1, 4], [-1, -3, 5], [0, 0, 1]]), - (10,), - np.array([[10, 0, 0], [0, 3.162278, 0], [0.0, 0.0, 1.0]]), - ), -] + DIAGONAL_CASES.append( + ( + p, + np.array([[-1, 0, 0, 4], [0, 2, 0, 5], [0, 0, 3, 6], [0, 0, 0, 1]]), + (10, 20, 30), + np.array([[10, 0, 0, 0], [0, 20, 0, 0], [0, 0, 30, 0], [0, 0, 0, 1]]), + ) + ) + DIAGONAL_CASES.append( + ( + p, + np.array([[2, 1, 4], [-1, -3, 5], [0, 0, 1]]), + (10, 20, 30), + np.array([[10, 0, 0], [0, 20, 0], [0.0, 0.0, 1.0]]), + ) + ) + DIAGONAL_CASES.append( + ( # test default scale from affine + p, + np.array([[2, 1, 4], [-1, -3, 5], [0, 0, 1]]), + (10,), + np.array([[10, 0, 0], [0, 3.162278, 0], [0.0, 0.0, 1.0]]), + ) + ) class TestZoomAffine(unittest.TestCase): @parameterized.expand(VALID_CASES) - def test_correct(self, affine, scale, expected): - output = zoom_affine(affine, scale, diagonal=False) + def test_correct(self, in_type, affine, scale, expected): + _affine = in_type(affine) + output = zoom_affine(_affine, scale, diagonal=False) ornt_affine = nib.orientations.ornt2axcodes(nib.orientations.io_orientation(output)) ornt_output = nib.orientations.ornt2axcodes(nib.orientations.io_orientation(affine)) np.testing.assert_array_equal(ornt_affine, ornt_output) np.testing.assert_allclose(output, expected, rtol=1e-6, atol=1e-6) @parameterized.expand(DIAGONAL_CASES) - def test_diagonal(self, affine, scale, expected): - output = zoom_affine(affine, scale, diagonal=True) + def test_diagonal(self, in_type, affine, scale, expected): + output = zoom_affine(in_type(affine), scale, diagonal=True) np.testing.assert_allclose(output, expected, rtol=1e-6, atol=1e-6) diff --git a/tests/utils.py b/tests/utils.py index ce280a13f0..988d40488a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,9 +20,10 @@ import traceback import unittest import warnings +from functools import partial from io import BytesIO from subprocess import PIPE, Popen -from typing import Optional +from typing import Callable, Optional, Tuple from urllib.error import ContentTooShortError, HTTPError, URLError import numpy as np @@ -32,6 +33,7 @@ from monai.config.deviceconfig import USE_COMPILED from monai.data import create_test_image_2d, create_test_image_3d from monai.utils import ensure_tuple, optional_import, set_determinism +from monai.utils.misc import is_module_ver_at_least from monai.utils.module import version_leq nib, _ = optional_import("nibabel") @@ -112,8 +114,7 @@ class SkipIfBeforePyTorchVersion: def __init__(self, pytorch_version_tuple): self.min_version = pytorch_version_tuple - test_ver = ".".join(map(str, self.min_version)) - self.version_too_old = torch.__version__ != test_ver and version_leq(torch.__version__, test_ver) + self.version_too_old = not is_module_ver_at_least(torch, pytorch_version_tuple) def __call__(self, obj): return unittest.skipIf( @@ -562,5 +563,11 @@ def query_memory(n=2): return ",".join(f"{int(x)}" for x in ids) +TEST_NDARRAYS: Tuple[Callable] = (np.array, torch.as_tensor) # type: ignore +if torch.cuda.is_available(): + gpu_tensor: Callable = partial(torch.as_tensor, device="cuda") + TEST_NDARRAYS = TEST_NDARRAYS + (gpu_tensor,) # type: ignore + + if __name__ == "__main__": print(query_memory())