From f39e8c1cdb3489df8e5b1f884e817d1438131259 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 5 Mar 2020 12:03:54 +0000 Subject: [PATCH] 108-resize (#125) * Add Resize transform (spatial scaling). * Adding tests. --- monai/transforms/transforms.py | 42 +++++++++++++++++++++++++++ tests/test_resize.py | 53 ++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+) create mode 100644 tests/test_resize.py diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index 454c3aa7f6..4447cdd282 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -15,6 +15,7 @@ import numpy as np import torch +from skimage.transform import resize import scipy.ndimage import monai @@ -81,6 +82,47 @@ def __call__(self, img): return np.flip(img, self.axis) +@export +class Resize: + """ + Resize the input image to given resolution. Uses skimage.transform.resize underneath. + For additional details, see https://scikit-image.org/docs/dev/api/skimage.transform.html#skimage.transform.resize. + + Args: + order (int): Order of spline interpolation. Default=1. + mode (str): Points outside boundaries are filled according to given mode. + Options are 'constant', 'edge', 'symmetric', 'reflect', 'wrap'. + cval (float): Used with mode 'constant', the value outside image boundaries. + clip (bool): Wheter to clip range of output values after interpolation. Default: True. + preserve_range (bool): Whether to keep original range of values. Default is True. + If False, input is converted according to conventions of img_as_float. See + https://scikit-image.org/docs/dev/user_guide/data_types.html. + anti_aliasing (bool): Whether to apply a gaussian filter to image before down-scaling. + Default is True. + anti_aliasing_sigma (float, tuple of floats): Standard deviation for gaussian filtering. + """ + + def __init__(self, output_shape, order=1, mode='reflect', cval=0, + clip=True, preserve_range=True, + anti_aliasing=True, anti_aliasing_sigma=None): + assert isinstance(order, int), "order must be integer." + self.output_shape = output_shape + self.order = order + self.mode = mode + self.cval = cval + self.clip = clip + self.preserve_range = preserve_range + self.anti_aliasing = anti_aliasing + self.anti_aliasing_sigma = anti_aliasing_sigma + + def __call__(self, img): + return resize(img, self.output_shape, order=self.order, + mode=self.mode, cval=self.cval, + clip=self.clip, preserve_range=self.preserve_range, + anti_aliasing=self.anti_aliasing, + anti_aliasing_sigma=self.anti_aliasing_sigma) + + @export class Rotate: """ diff --git a/tests/test_resize.py b/tests/test_resize.py new file mode 100644 index 0000000000..7feaf9f634 --- /dev/null +++ b/tests/test_resize.py @@ -0,0 +1,53 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import skimage +from parameterized import parameterized + +from monai.transforms import Resize +from tests.utils import NumpyImageTestCase2D + + +class ResizeTest(NumpyImageTestCase2D): + + @parameterized.expand([ + ("invalid_order", "order", AssertionError) + ]) + def test_invalid_inputs(self, _, order, raises): + with self.assertRaises(raises): + resize = Resize(output_shape=(128, 128, 3), order=order) + resize(self.imt) + + @parameterized.expand([ + ((1, 1, 64, 64), 1, 'reflect', 0, True, True, True, None), + ((1, 1, 32, 32), 2, 'constant', 3, False, False, False, None), + ((1, 1, 256, 256), 3, 'constant', 3, False, False, False, None), + ]) + def test_correct_results(self, output_shape, order, mode, + cval, clip, preserve_range, + anti_aliasing, anti_aliasing_sigma): + resize = Resize(output_shape, order, mode, cval, clip, + preserve_range, anti_aliasing, + anti_aliasing_sigma) + expected = skimage.transform.resize(self.imt, output_shape, + order=order, mode=mode, + cval=cval, clip=clip, + preserve_range=preserve_range, + anti_aliasing=anti_aliasing, + anti_aliasing_sigma=anti_aliasing_sigma) + self.assertTrue(np.allclose(resize(self.imt), expected)) + + +if __name__ == '__main__': + unittest.main()