diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index 003a564507..5673b646fa 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -99,13 +99,19 @@ class Resize(object): Default: None keep_ratio (bool): Whether to keep the aspect ratio when resizing the image. Default: True + min_size (int, optional): The minimum size for input and the shape + of the image and seg map will not be less than ``min_size``. + As the shape of model input is fixed like 'SETR' and 'BEiT'. + Following the setting in these models, resized images must be + bigger than the crop size in ``slide_inference``. Default: None """ def __init__(self, img_scale=None, multiscale_mode='range', ratio_range=None, - keep_ratio=True): + keep_ratio=True, + min_size=None): if img_scale is None: self.img_scale = None else: @@ -126,6 +132,7 @@ def __init__(self, self.multiscale_mode = multiscale_mode self.ratio_range = ratio_range self.keep_ratio = keep_ratio + self.min_size = min_size @staticmethod def random_select(img_scales): @@ -240,6 +247,23 @@ def _random_scale(self, results): def _resize_img(self, results): """Resize images with ``results['scale']``.""" if self.keep_ratio: + if self.min_size is not None: + # TODO: Now 'min_size' is an 'int' which means the minimum + # shape of images is (min_size, min_size, 3). 'min_size' + # with tuple type will be supported, i.e. the width and + # height are not equal. + if min(results['scale']) < self.min_size: + new_short = self.min_size + else: + new_short = min(results['scale']) + + h, w = results['img'].shape[:2] + if h > w: + new_h, new_w = new_short * h / w, new_short + else: + new_h, new_w = new_short, new_short * w / h + results['scale'] = (new_h, new_w) + img, scale_factor = mmcv.imrescale( results['img'], results['scale'], return_scale=True) # the w_scale and h_scale has minor difference diff --git a/tests/test_data/test_transform.py b/tests/test_data/test_transform.py index e9aa1d75ae..fcc46e7d02 100644 --- a/tests/test_data/test_transform.py +++ b/tests/test_data/test_transform.py @@ -123,6 +123,31 @@ def test_resize(): assert int(288 * 0.5) <= resized_results['img_shape'][0] <= 288 * 2.0 assert int(512 * 0.5) <= resized_results['img_shape'][1] <= 512 * 2.0 + # test min_size=640 + transform = dict(type='Resize', img_scale=(2560, 640), min_size=640) + resize_module = build_from_cfg(transform, PIPELINES) + resized_results = resize_module(results.copy()) + assert resized_results['img_shape'] == (640, 1138, 3) + + # test min_size=640 and img_scale=(512, 640) + transform = dict(type='Resize', img_scale=(512, 640), min_size=640) + resize_module = build_from_cfg(transform, PIPELINES) + resized_results = resize_module(results.copy()) + assert resized_results['img_shape'] == (640, 1138, 3) + + # test h > w + img = np.random.randn(512, 288, 3) + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + # Set initial values for default meta_keys + results['pad_shape'] = img.shape + results['scale_factor'] = 1.0 + transform = dict(type='Resize', img_scale=(2560, 640), min_size=640) + resize_module = build_from_cfg(transform, PIPELINES) + resized_results = resize_module(results.copy()) + assert resized_results['img_shape'] == (1138, 640, 3) + def test_flip(): # test assertion for invalid prob