From 9f961ab86ca6e317bb7dd93f4251790a43851719 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 14 Jan 2022 20:46:48 +0800 Subject: [PATCH] [DLMED] add label transform Signed-off-by: Nic Ma --- monai/data/image_dataset.py | 25 ++++++++++++++---------- tests/test_image_dataset.py | 38 ++++++++++++++++++++++++------------- 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/monai/data/image_dataset.py b/monai/data/image_dataset.py index 0ab71cd444..51f4e04959 100644 --- a/monai/data/image_dataset.py +++ b/monai/data/image_dataset.py @@ -37,6 +37,7 @@ def __init__( labels: Optional[Sequence[float]] = None, transform: Optional[Callable] = None, seg_transform: Optional[Callable] = None, + label_transform: Optional[Callable] = None, image_only: bool = True, transform_with_metadata: bool = False, dtype: DtypeLike = np.float32, @@ -49,19 +50,20 @@ def __init__( to the images and `seg_transform` to the segmentations. Args: - image_files: list of image filenames - seg_files: if in segmentation task, list of segmentation filenames - labels: if in classification task, list of classification labels - transform: transform to apply to image arrays - seg_transform: transform to apply to segmentation arrays - image_only: if True return only the image volume, otherwise, return image volume and the metadata + image_files: list of image filenames. + seg_files: if in segmentation task, list of segmentation filenames. + labels: if in classification task, list of classification labels. + transform: transform to apply to image arrays. + seg_transform: transform to apply to segmentation arrays. + label_transform: transform to apply to the label data. + image_only: if True return only the image volume, otherwise, return image volume and the metadata. transform_with_metadata: if True, the metadata will be passed to the transforms whenever possible. - dtype: if not None convert the loaded image to this data type + dtype: if not None convert the loaded image to this data type. reader: register reader to load image file and meta data, if None, will use the default readers. If a string of reader name provided, will construct a reader object with the `*args` and `**kwargs` parameters, supported reader name: "NibabelReader", "PILReader", "ITKReader", "NumpyReader" - args: additional parameters for reader if providing a reader name - kwargs: additional parameters for reader if providing a reader name + args: additional parameters for reader if providing a reader name. + kwargs: additional parameters for reader if providing a reader name. Raises: ValueError: When ``seg_files`` length differs from ``image_files`` @@ -79,6 +81,7 @@ def __init__( self.labels = labels self.transform = transform self.seg_transform = seg_transform + self.label_transform = label_transform if image_only and transform_with_metadata: raise ValueError("transform_with_metadata=True requires image_only=False.") self.image_only = image_only @@ -117,7 +120,7 @@ def __getitem__(self, index: int): else: img = apply_transform(self.transform, img, map_items=False) - if self.seg_transform is not None: + if self.seg_files is not None and self.seg_transform is not None: if isinstance(self.seg_transform, Randomizable): self.seg_transform.set_random_state(seed=self._seed) @@ -130,6 +133,8 @@ def __getitem__(self, index: int): if self.labels is not None: label = self.labels[index] + if self.label_transform is not None: + label = apply_transform(self.label_transform, label, map_items=False) # type: ignore # construct outputs data = [img] diff --git a/tests/test_image_dataset.py b/tests/test_image_dataset.py index c478f28d13..41eda803dc 100644 --- a/tests/test_image_dataset.py +++ b/tests/test_image_dataset.py @@ -17,7 +17,15 @@ import numpy as np from monai.data import ImageDataset -from monai.transforms import Compose, EnsureChannelFirst, RandAdjustContrast, RandomizableTransform, Spacing +from monai.transforms import ( + Compose, + EnsureChannelFirst, + MapLabelValue, + RandAdjustContrast, + RandomizableTransform, + Spacing, +) +from monai.transforms.utility.array import ToNumpy FILENAMES = ["test1.nii.gz", "test2.nii", "test3.nii.gz"] @@ -106,16 +114,6 @@ def test_dataset(self): for d, ref in zip(dataset, ref_data): np.testing.assert_allclose(d, ref + 1, atol=1e-3) - # set seg transform, but no seg_files - with self.assertRaises(RuntimeError): - dataset = ImageDataset(full_names, seg_transform=lambda x: x + 1, image_only=True) - _ = dataset[0] - - # set seg transform, but no seg_files - with self.assertRaises(RuntimeError): - dataset = ImageDataset(full_names, seg_transform=lambda x: x + 1, image_only=True) - _ = dataset[0] - # loading image/label, with meta dataset = ImageDataset( full_names, @@ -133,13 +131,27 @@ def test_dataset(self): # loading image/label, with meta dataset = ImageDataset( - full_names, transform=lambda x: x + 1, seg_files=full_names, labels=[1, 2, 3], image_only=False + image_files=full_names, + seg_files=full_names, + labels=[1, 2, 3], + transform=lambda x: x + 1, + label_transform=Compose( + [ + ToNumpy(), + MapLabelValue(orig_labels=[1, 2, 3], target_labels=[30.0, 20.0, 10.0], dtype=np.float32), + ] + ), + image_only=False, ) for idx, (d_tuple, ref) in enumerate(zip(dataset, ref_data)): img, seg, label, meta, seg_meta = d_tuple np.testing.assert_allclose(img, ref + 1, atol=1e-3) np.testing.assert_allclose(seg, ref, atol=1e-3) - np.testing.assert_allclose(idx + 1, label) + # test label_transform + + np.testing.assert_allclose((3 - idx) * 10.0, label) + self.assertTrue(isinstance(label, np.ndarray)) + self.assertEqual(label.dtype, np.float32) np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3) np.testing.assert_allclose(seg_meta["original_affine"], np.eye(4), atol=1e-3)