diff --git a/docs/source/data.rst b/docs/source/data.rst index f6ed71c266..11609964c3 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -63,6 +63,11 @@ Generic Interfaces :members: :special-members: __getitem__ +`ImageDataset` +~~~~~~~~~~~~~~ +.. autoclass:: ImageDataset + :members: + :special-members: __getitem__ Patch-based dataset ------------------- @@ -104,11 +109,6 @@ PILReader Nifti format handling --------------------- -Reading -~~~~~~~ -.. autoclass:: monai.data.NiftiDataset - :members: - Writing Nifti ~~~~~~~~~~~~~ .. autoclass:: monai.data.NiftiSaver diff --git a/monai/data/__init__.py b/monai/data/__init__.py index e2bd32861c..e0db1e17ae 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -23,9 +23,9 @@ ) from .decathlon_datalist import load_decathlon_datalist, load_decathlon_properties from .grid_dataset import GridPatchDataset, PatchDataset +from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader from .iterable_dataset import IterableDataset -from .nifti_reader import NiftiDataset from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti from .png_saver import PNGSaver diff --git a/monai/data/nifti_reader.py b/monai/data/image_dataset.py similarity index 71% rename from monai/data/nifti_reader.py rename to monai/data/image_dataset.py index 1378fb25a0..7dd55431af 100644 --- a/monai/data/nifti_reader.py +++ b/monai/data/image_dataset.py @@ -9,19 +9,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence +from typing import Any, Callable, Optional, Sequence, Union import numpy as np from torch.utils.data import Dataset +from monai.data.image_reader import ImageReader from monai.transforms import LoadImage, Randomizable, apply_transform from monai.utils import MAX_SEED, get_seed -class NiftiDataset(Dataset, Randomizable): +class ImageDataset(Dataset, Randomizable): """ - Loads image/segmentation pairs of Nifti files from the given filename lists. Transformations can be specified + Loads image/segmentation pairs of files from the given filename lists. Transformations can be specified for the image and segmentation arrays separately. + The difference between this dataset and `ArrayDataset` is that this dataset can apply transform chain to images + and segs and return both the images and metadata, and no need to specify transform to load images from files. + """ def __init__( @@ -29,11 +33,13 @@ def __init__( image_files: Sequence[str], seg_files: Optional[Sequence[str]] = None, labels: Optional[Sequence[float]] = None, - as_closest_canonical: bool = False, transform: Optional[Callable] = None, seg_transform: Optional[Callable] = None, image_only: bool = True, dtype: Optional[np.dtype] = np.float32, + reader: Optional[Union[ImageReader, str]] = None, + *args, + **kwargs, ) -> None: """ Initializes the dataset with the image and segmentation filename lists. The transform `transform` is applied @@ -43,14 +49,18 @@ def __init__( image_files: list of image filenames seg_files: if in segmentation task, list of segmentation filenames labels: if in classification task, list of classification labels - as_closest_canonical: if True, load the image as closest to canonical orientation transform: transform to apply to image arrays seg_transform: transform to apply to segmentation arrays - image_only: if True return only the image volume, other return image volume and header dict + image_only: if True return only the image volume, otherwise, return image volume and the metadata dtype: if not None convert the loaded image to this data type + reader: register reader to load image file and meta data, if None, will use the default readers. + If a string of reader name provided, will construct a reader object with the `*args` and `**kwargs` + parameters, supported reader name: "NibabelReader", "PILReader", "ITKReader", "NumpyReader" + args: additional parameters for reader if providing a reader name + kwargs: additional parameters for reader if providing a reader name Raises: - ValueError: When ``seg_files`` length differs from ``image_files``. + ValueError: When ``seg_files`` length differs from ``image_files`` """ @@ -63,13 +73,11 @@ def __init__( self.image_files = image_files self.seg_files = seg_files self.labels = labels - self.as_closest_canonical = as_closest_canonical self.transform = transform self.seg_transform = seg_transform self.image_only = image_only - self.dtype = dtype + self.loader = LoadImage(reader, image_only, dtype, *args, **kwargs) self.set_random_state(seed=get_seed()) - self._seed = 0 # transform synchronization seed def __len__(self) -> int: @@ -81,21 +89,18 @@ def randomize(self, data: Optional[Any] = None) -> None: def __getitem__(self, index: int): self.randomize() meta_data = None - img_loader = LoadImage( - reader="NibabelReader", - image_only=self.image_only, - dtype=self.dtype, - as_closest_canonical=self.as_closest_canonical, - ) - if self.image_only: - img = img_loader(self.image_files[index]) - else: - img, meta_data = img_loader(self.image_files[index]) seg = None - if self.seg_files is not None: - seg_loader = LoadImage(image_only=True) - seg = seg_loader(self.seg_files[index]) label = None + + if self.image_only: + img = self.loader(self.image_files[index]) + if self.seg_files is not None: + seg = self.loader(self.seg_files[index]) + else: + img, meta_data = self.loader(self.image_files[index]) + if self.seg_files is not None: + seg, _ = self.loader(self.seg_files[index]) + if self.labels is not None: label = self.labels[index] diff --git a/tests/min_tests.py b/tests/min_tests.py index daf238a154..9a2dc0f05f 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -69,7 +69,7 @@ def run_testsuit(): "test_load_imaged", "test_load_spacing_orientation", "test_mednistdataset", - "test_nifti_dataset", + "test_image_dataset", "test_nifti_header_revise", "test_nifti_rw", "test_nifti_saver", diff --git a/tests/test_nifti_dataset.py b/tests/test_image_dataset.py similarity index 88% rename from tests/test_nifti_dataset.py rename to tests/test_image_dataset.py index f5d6e11290..d79a7d884c 100644 --- a/tests/test_nifti_dataset.py +++ b/tests/test_image_dataset.py @@ -16,7 +16,7 @@ import nibabel as nib import numpy as np -from monai.data import NiftiDataset +from monai.data import ImageDataset from monai.transforms import Randomizable FILENAMES = ["test1.nii.gz", "test2.nii", "test3.nii.gz"] @@ -35,7 +35,7 @@ def __call__(self, data): return data + self._a -class TestNiftiDataset(unittest.TestCase): +class TestImageDataset(unittest.TestCase): def test_dataset(self): with tempfile.TemporaryDirectory() as tempdir: full_names, ref_data = [], [] @@ -47,46 +47,46 @@ def test_dataset(self): nib.save(nib.Nifti1Image(test_image, np.eye(4)), save_path) # default loading no meta - dataset = NiftiDataset(full_names) + dataset = ImageDataset(full_names) for d, ref in zip(dataset, ref_data): np.testing.assert_allclose(d, ref, atol=1e-3) # loading no meta, int - dataset = NiftiDataset(full_names, dtype=np.float16) + dataset = ImageDataset(full_names, dtype=np.float16) for d, _ in zip(dataset, ref_data): self.assertEqual(d.dtype, np.float16) # loading with meta, no transform - dataset = NiftiDataset(full_names, image_only=False) + dataset = ImageDataset(full_names, image_only=False) for d_tuple, ref in zip(dataset, ref_data): d, meta = d_tuple np.testing.assert_allclose(d, ref, atol=1e-3) np.testing.assert_allclose(meta["original_affine"], np.eye(4)) # loading image/label, no meta - dataset = NiftiDataset(full_names, seg_files=full_names, image_only=True) + dataset = ImageDataset(full_names, seg_files=full_names, image_only=True) for d_tuple, ref in zip(dataset, ref_data): img, seg = d_tuple np.testing.assert_allclose(img, ref, atol=1e-3) np.testing.assert_allclose(seg, ref, atol=1e-3) # loading image/label, no meta - dataset = NiftiDataset(full_names, transform=lambda x: x + 1, image_only=True) + dataset = ImageDataset(full_names, transform=lambda x: x + 1, image_only=True) for d, ref in zip(dataset, ref_data): np.testing.assert_allclose(d, ref + 1, atol=1e-3) # set seg transform, but no seg_files with self.assertRaises(RuntimeError): - dataset = NiftiDataset(full_names, seg_transform=lambda x: x + 1, image_only=True) + dataset = ImageDataset(full_names, seg_transform=lambda x: x + 1, image_only=True) _ = dataset[0] # set seg transform, but no seg_files with self.assertRaises(RuntimeError): - dataset = NiftiDataset(full_names, seg_transform=lambda x: x + 1, image_only=True) + dataset = ImageDataset(full_names, seg_transform=lambda x: x + 1, image_only=True) _ = dataset[0] # loading image/label, with meta - dataset = NiftiDataset( + dataset = ImageDataset( full_names, transform=lambda x: x + 1, seg_files=full_names, @@ -100,7 +100,7 @@ def test_dataset(self): np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3) # loading image/label, with meta - dataset = NiftiDataset( + dataset = ImageDataset( full_names, transform=lambda x: x + 1, seg_files=full_names, labels=[1, 2, 3], image_only=False ) for idx, (d_tuple, ref) in enumerate(zip(dataset, ref_data)): @@ -111,7 +111,7 @@ def test_dataset(self): np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3) # loading image/label, with sync. transform - dataset = NiftiDataset( + dataset = ImageDataset( full_names, transform=RandTest(), seg_files=full_names, seg_transform=RandTest(), image_only=False ) for d_tuple, ref in zip(dataset, ref_data): diff --git a/tests/test_integration_sliding_window.py b/tests/test_integration_sliding_window.py index 92cc9397cb..c4d020276e 100644 --- a/tests/test_integration_sliding_window.py +++ b/tests/test_integration_sliding_window.py @@ -19,7 +19,7 @@ from ignite.engine import Engine from torch.utils.data import DataLoader -from monai.data import NiftiDataset, create_test_image_3d +from monai.data import ImageDataset, create_test_image_3d from monai.handlers import SegmentationSaver from monai.inferers import sliding_window_inference from monai.networks import eval_mode, predict_segmentation @@ -30,7 +30,7 @@ def run_test(batch_size, img_name, seg_name, output_dir, device="cuda:0"): - ds = NiftiDataset([img_name], [seg_name], transform=AddChannel(), seg_transform=AddChannel(), image_only=False) + ds = ImageDataset([img_name], [seg_name], transform=AddChannel(), seg_transform=AddChannel(), image_only=False) loader = DataLoader(ds, batch_size=1, pin_memory=torch.cuda.is_available()) net = UNet(