diff --git a/docs/source/data.rst b/docs/source/data.rst index ef05c4c0e4..50d44ee709 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -68,6 +68,11 @@ NibabelReader .. autoclass:: NibabelReader :members: +NumpyReader +~~~~~~~~~~~ +.. autoclass:: NunpyReader + :members: + Nifti format handling --------------------- diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index d2bdede4bd..ef8fb17585 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -14,6 +14,7 @@ import numpy as np +from monai.config import KeysCollection from monai.data.utils import correct_nifti_header_if_necessary from monai.utils import ensure_tuple, optional_import @@ -348,3 +349,94 @@ def _get_array_data(self, img: Nifti1Image) -> np.ndarray: """ return np.asarray(img.dataobj) + + +class NumpyReader(ImageReader): + """ + Load NPY or NPZ format data based on Numpy library, they can be arrays or pickled objects. + A typical usage is to load the `mask` data for classification task. + It can load part of the npz file with specified `npz_keys`. + + Args: + npz_keys: if loading npz file, only load the specified keys, if None, load all the items. + stack the loaded items together to construct a new first dimension. + + """ + + def __init__(self, npz_keys: Optional[KeysCollection] = None): + super().__init__() + self._img: Optional[Sequence[Nifti1Image]] = None + if npz_keys is not None: + npz_keys = ensure_tuple(npz_keys) + self.npz_keys = npz_keys + + def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + """ + Verify whether the specified file or files format is supported by Numpy reader. + + Args: + filename: file name or a list of file names to read. + if a list of files, verify all the subffixes. + + """ + suffixes: Sequence[str] = ["npz", "npy"] + return is_supported_format(filename, suffixes) + + def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs): + """ + Read image data from specified file or files, or set a Numpy array. + Note that the returned object is Numpy array or list of Numpy arrays. + `self._img` is always a list, even only has 1 image. + + Args: + data: file name or a list of file names to read. + kwargs: additional args for `numpy.load` API except `allow_pickle`. more details about available args: + https://numpy.org/doc/stable/reference/generated/numpy.load.html + + """ + self._img = list() + if isinstance(data, np.ndarray): + self._img.append(data) + return data + + filenames: Sequence[str] = ensure_tuple(data) + for name in filenames: + img = np.load(name, allow_pickle=True, **kwargs) + if name.endswith(".npz"): + # load expected items from NPZ file + npz_keys = [f"arr_{i}" for i in range(len(img))] if self.npz_keys is None else self.npz_keys + for k in npz_keys: + self._img.append(img[k]) + else: + self._img.append(img) + + return self._img if len(filenames) > 1 else self._img[0] + + def get_data(self): + """ + Extract data array and meta data from loaded data and return them. + This function returns 2 objects, first is numpy array of image data, second is dict of meta data. + It constructs `spatial_shape=data.shape` and stores in meta dict if the data is numpy array. + If loading a list of files, stack them together and add a new dimension as first dimension, + and use the meta data of the first image to represent the stacked result. + + """ + img_array: List[np.ndarray] = list() + compatible_meta: Dict = None + if self._img is None: + raise RuntimeError("please call read() first then use get_data().") + + for img in self._img: + header = dict() + if isinstance(img, np.ndarray): + header["spatial_shape"] = img.shape + img_array.append(img) + + if compatible_meta is None: + compatible_meta = header + else: + if not np.allclose(header["spatial_shape"], compatible_meta["spatial_shape"]): + raise RuntimeError("spatial_shape of all images should be same.") + + img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0] + return img_array_, compatible_meta diff --git a/tests/test_numpy_reader.py b/tests/test_numpy_reader.py new file mode 100644 index 0000000000..9c40ded9d1 --- /dev/null +++ b/tests/test_numpy_reader.py @@ -0,0 +1,90 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import numpy as np + +from monai.data import NumpyReader + + +class TestNumpyReader(unittest.TestCase): + def test_npy(self): + test_data = np.random.randint(0, 256, size=[3, 4, 4]) + with tempfile.TemporaryDirectory() as tempdir: + filepath = os.path.join(tempdir, "test_data.npy") + np.save(filepath, test_data) + + reader = NumpyReader() + reader.read(filepath) + result = reader.get_data() + self.assertTupleEqual(result[1]["spatial_shape"], test_data.shape) + self.assertTupleEqual(result[0].shape, test_data.shape) + np.testing.assert_allclose(result[0], test_data) + + def test_npz1(self): + test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) + with tempfile.TemporaryDirectory() as tempdir: + filepath = os.path.join(tempdir, "test_data.npy") + np.save(filepath, test_data1) + + reader = NumpyReader() + reader.read(filepath) + result = reader.get_data() + self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) + self.assertTupleEqual(result[0].shape, test_data1.shape) + np.testing.assert_allclose(result[0], test_data1) + + def test_npz2(self): + test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) + test_data2 = np.random.randint(0, 256, size=[3, 4, 4]) + with tempfile.TemporaryDirectory() as tempdir: + filepath = os.path.join(tempdir, "test_data.npz") + np.savez(filepath, test_data1, test_data2) + + reader = NumpyReader() + reader.read(filepath) + result = reader.get_data() + self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) + self.assertTupleEqual(result[0].shape, (2, 3, 4, 4)) + np.testing.assert_allclose(result[0], np.stack([test_data1, test_data2])) + + def test_npz3(self): + test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) + test_data2 = np.random.randint(0, 256, size=[3, 4, 4]) + with tempfile.TemporaryDirectory() as tempdir: + filepath = os.path.join(tempdir, "test_data.npz") + np.savez(filepath, test1=test_data1, test2=test_data2) + + reader = NumpyReader(npz_keys=["test1", "test2"]) + reader.read(filepath) + result = reader.get_data() + self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) + self.assertTupleEqual(result[0].shape, (2, 3, 4, 4)) + np.testing.assert_allclose(result[0], np.stack([test_data1, test_data2])) + + def test_npy_pickle(self): + test_data = {"test": np.random.randint(0, 256, size=[3, 4, 4])} + with tempfile.TemporaryDirectory() as tempdir: + filepath = os.path.join(tempdir, "test_data.npy") + np.save(filepath, test_data, allow_pickle=True) + + reader = NumpyReader() + reader.read(filepath) + result = reader.get_data()[0].item() + self.assertTupleEqual(result["test"].shape, test_data["test"].shape) + np.testing.assert_allclose(result["test"], test_data["test"]) + + +if __name__ == "__main__": + unittest.main()