Skip to content

Commit

Permalink
[Feature] Add min_size arg in Resize to keep the shape after resi…
Browse files Browse the repository at this point in the history
…ze bigger than slide window (#1318)

* [Feature] add setr_resize

* fix a bug

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
linfangjian01 authored Mar 1, 2022
1 parent 9947a39 commit 2d66179
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
26 changes: 25 additions & 1 deletion mmseg/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,19 @@ class Resize(object):
Default: None
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
image. Default: True
min_size (int, optional): The minimum size for input and the shape
of the image and seg map will not be less than ``min_size``.
As the shape of model input is fixed like 'SETR' and 'BEiT'.
Following the setting in these models, resized images must be
bigger than the crop size in ``slide_inference``. Default: None
"""

def __init__(self,
img_scale=None,
multiscale_mode='range',
ratio_range=None,
keep_ratio=True):
keep_ratio=True,
min_size=None):
if img_scale is None:
self.img_scale = None
else:
Expand All @@ -126,6 +132,7 @@ def __init__(self,
self.multiscale_mode = multiscale_mode
self.ratio_range = ratio_range
self.keep_ratio = keep_ratio
self.min_size = min_size

@staticmethod
def random_select(img_scales):
Expand Down Expand Up @@ -240,6 +247,23 @@ def _random_scale(self, results):
def _resize_img(self, results):
"""Resize images with ``results['scale']``."""
if self.keep_ratio:
if self.min_size is not None:
# TODO: Now 'min_size' is an 'int' which means the minimum
# shape of images is (min_size, min_size, 3). 'min_size'
# with tuple type will be supported, i.e. the width and
# height are not equal.
if min(results['scale']) < self.min_size:
new_short = self.min_size
else:
new_short = min(results['scale'])

h, w = results['img'].shape[:2]
if h > w:
new_h, new_w = new_short * h / w, new_short
else:
new_h, new_w = new_short, new_short * w / h
results['scale'] = (new_h, new_w)

img, scale_factor = mmcv.imrescale(
results['img'], results['scale'], return_scale=True)
# the w_scale and h_scale has minor difference
Expand Down
25 changes: 25 additions & 0 deletions tests/test_data/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,31 @@ def test_resize():
assert int(288 * 0.5) <= resized_results['img_shape'][0] <= 288 * 2.0
assert int(512 * 0.5) <= resized_results['img_shape'][1] <= 512 * 2.0

# test min_size=640
transform = dict(type='Resize', img_scale=(2560, 640), min_size=640)
resize_module = build_from_cfg(transform, PIPELINES)
resized_results = resize_module(results.copy())
assert resized_results['img_shape'] == (640, 1138, 3)

# test min_size=640 and img_scale=(512, 640)
transform = dict(type='Resize', img_scale=(512, 640), min_size=640)
resize_module = build_from_cfg(transform, PIPELINES)
resized_results = resize_module(results.copy())
assert resized_results['img_shape'] == (640, 1138, 3)

# test h > w
img = np.random.randn(512, 288, 3)
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
transform = dict(type='Resize', img_scale=(2560, 640), min_size=640)
resize_module = build_from_cfg(transform, PIPELINES)
resized_results = resize_module(results.copy())
assert resized_results['img_shape'] == (1138, 640, 3)


def test_flip():
# test assertion for invalid prob
Expand Down

0 comments on commit 2d66179

Please sign in to comment.