Skip to content

Commit

Permalink
Support resize data augmentation according to original image size (#291)
Browse files Browse the repository at this point in the history
* Support resize data augmentation according to original image size (img_scale=None and retio_range is tuple)

* fix docstring

* fix bug

* add unittest

* img_scale=None in TTA

* fix bug

* add unittest

* fix typos

* fix bug
  • Loading branch information
Junjun2016 authored Dec 15, 2020
1 parent 7970e0f commit 061a295
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 17 deletions.
25 changes: 19 additions & 6 deletions mmseg/datasets/pipelines/test_time_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class MultiScaleFlipAug(object):
Args:
transforms (list[dict]): Transforms to apply in each augmentation.
img_scale (tuple | list[tuple]): Images scales for resizing.
img_scale (None | tuple | list[tuple]): Images scales for resizing.
img_ratios (float | list[float]): Image ratios for resizing
flip (bool): Whether apply flip augmentation. Default: False.
flip_direction (str | list[str]): Flip augmentation directions,
Expand All @@ -58,20 +58,27 @@ def __init__(self,
flip_direction='horizontal'):
self.transforms = Compose(transforms)
if img_ratios is not None:
# mode 1: given a scale and a range of image ratio
img_ratios = img_ratios if isinstance(img_ratios,
list) else [img_ratios]
assert mmcv.is_list_of(img_ratios, float)
assert isinstance(img_scale, tuple) and len(img_scale) == 2
if img_scale is None:
# mode 1: given img_scale=None and a range of image ratio
self.img_scale = None
assert mmcv.is_list_of(img_ratios, float)
elif isinstance(img_scale, tuple) and mmcv.is_list_of(
img_ratios, float):
assert len(img_scale) == 2
# mode 2: given a scale and a range of image ratio
self.img_scale = [(int(img_scale[0] * ratio),
int(img_scale[1] * ratio))
for ratio in img_ratios]
else:
# mode 2: given multiple scales
# mode 3: given multiple scales
self.img_scale = img_scale if isinstance(img_scale,
list) else [img_scale]
assert mmcv.is_list_of(self.img_scale, tuple)
assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None
self.flip = flip
self.img_ratios = img_ratios
self.flip_direction = flip_direction if isinstance(
flip_direction, list) else [flip_direction]
assert mmcv.is_list_of(self.flip_direction, str)
Expand All @@ -95,8 +102,14 @@ def __call__(self, results):
"""

aug_data = []
if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float):
h, w = results['img'].shape[:2]
img_scale = [(int(h * ratio), int(w * ratio))
for ratio in self.img_ratios]
else:
img_scale = self.img_scale
flip_aug = [False, True] if self.flip else [False]
for scale in self.img_scale:
for scale in img_scale:
for flip in flip_aug:
for direction in self.flip_direction:
_results = results.copy()
Expand Down
31 changes: 20 additions & 11 deletions mmseg/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,21 @@ class Resize(object):
contains the key "scale", then the scale in the input dict is used,
otherwise the specified scale in the init method is used.
``img_scale`` can either be a tuple (single-scale) or a list of tuple
(multi-scale). There are 3 multiscale modes:
``img_scale`` can be Nong, a tuple (single-scale) or a list of tuple
(multi-scale). There are 4 multiscale modes:
- ``ratio_range is not None``: randomly sample a ratio from the ratio range
and multiply it with the image scale.
- ``ratio_range is not None``:
1. When img_scale is None, img_scale is the shape of image in results
(img_scale = results['img'].shape[:2]) and the image is resized based
on the original size. (mode 1)
2. When img_scale is a tuple (single-scale), randomly sample a ratio from
the ratio range and multiply it with the image scale. (mode 2)
- ``ratio_range is None and multiscale_mode == "range"``: randomly sample a
scale from the a range.
scale from the a range. (mode 3)
- ``ratio_range is None and multiscale_mode == "value"``: randomly sample a
scale from multiple scales.
scale from multiple scales. (mode 4)
Args:
img_scale (tuple or list[tuple]): Images scales for resizing.
Expand All @@ -49,10 +53,11 @@ def __init__(self,
assert mmcv.is_list_of(self.img_scale, tuple)

if ratio_range is not None:
# mode 1: given a scale and a range of image ratio
assert len(self.img_scale) == 1
# mode 1: given img_scale=None and a range of image ratio
# mode 2: given a scale and a range of image ratio
assert self.img_scale is None or len(self.img_scale) == 1
else:
# mode 2: given multiple scales or a range of scales
# mode 3 and 4: given multiple scales or a range of scales
assert multiscale_mode in ['value', 'range']

self.multiscale_mode = multiscale_mode
Expand Down Expand Up @@ -150,8 +155,12 @@ def _random_scale(self, results):
"""

if self.ratio_range is not None:
scale, scale_idx = self.random_sample_ratio(
self.img_scale[0], self.ratio_range)
if self.img_scale is None:
scale, scale_idx = self.random_sample_ratio(
results['img'].shape[:2], self.ratio_range)
else:
scale, scale_idx = self.random_sample_ratio(
self.img_scale[0], self.ratio_range)
elif len(self.img_scale) == 1:
scale, scale_idx = self.img_scale[0], 0
elif self.multiscale_mode == 'range':
Expand Down
10 changes: 10 additions & 0 deletions tests/test_data/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_resize():
resize_module = build_from_cfg(transform, PIPELINES)

results = dict()
# (288, 512, 3)
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
results['img'] = img
Expand Down Expand Up @@ -92,6 +93,15 @@ def test_resize():
resized_results = resize_module(results.copy())
assert max(resized_results['img_shape'][:2]) <= 1333 * 1.1

# test img_scale=None and ratio_range is tuple.
# img shape: (288, 512, 3)
transform = dict(
type='Resize', img_scale=None, ratio_range=(0.5, 2.0), keep_ratio=True)
resize_module = build_from_cfg(transform, PIPELINES)
resized_results = resize_module(results.copy())
assert int(288 * 0.5) <= resized_results['img_shape'][0] <= 288 * 2.0
assert int(512 * 0.5) <= resized_results['img_shape'][1] <= 512 * 2.0


def test_flip():
# test assertion for invalid prob
Expand Down
150 changes: 150 additions & 0 deletions tests/test_data/test_tta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import os.path as osp

import mmcv
import pytest
from mmcv.utils import build_from_cfg

from mmseg.datasets.builder import PIPELINES


def test_multi_scale_flip_aug():
# test assertion if img_scale=None, img_ratios=1 (not float).
with pytest.raises(AssertionError):
tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=None,
img_ratios=1,
transforms=[dict(type='Resize', keep_ratio=False)],
)
build_from_cfg(tta_transform, PIPELINES)

# test assertion if img_scale=None, img_ratios=None.
with pytest.raises(AssertionError):
tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=None,
img_ratios=None,
transforms=[dict(type='Resize', keep_ratio=False)],
)
build_from_cfg(tta_transform, PIPELINES)

# test assertion if img_scale=(512, 512), img_ratios=1 (not float).
with pytest.raises(AssertionError):
tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=(512, 512),
img_ratios=1,
transforms=[dict(type='Resize', keep_ratio=False)],
)
build_from_cfg(tta_transform, PIPELINES)

tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=(512, 512),
img_ratios=[0.5, 1.0, 2.0],
flip=False,
transforms=[dict(type='Resize', keep_ratio=False)],
)
tta_module = build_from_cfg(tta_transform, PIPELINES)

results = dict()
# (288, 512, 3)
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0

tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(256, 256), (512, 512), (1024, 1024)]
assert tta_results['flip'] == [False, False, False]

tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=(512, 512),
img_ratios=[0.5, 1.0, 2.0],
flip=True,
transforms=[dict(type='Resize', keep_ratio=False)],
)
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
(512, 512), (1024, 1024), (1024, 1024)]
assert tta_results['flip'] == [False, True, False, True, False, True]

tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=(512, 512),
img_ratios=1.0,
flip=False,
transforms=[dict(type='Resize', keep_ratio=False)],
)
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(512, 512)]
assert tta_results['flip'] == [False]

tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=(512, 512),
img_ratios=1.0,
flip=True,
transforms=[dict(type='Resize', keep_ratio=False)],
)
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(512, 512), (512, 512)]
assert tta_results['flip'] == [False, True]

tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=None,
img_ratios=[0.5, 1.0, 2.0],
flip=False,
transforms=[dict(type='Resize', keep_ratio=False)],
)
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(144, 256), (288, 512), (576, 1024)]
assert tta_results['flip'] == [False, False, False]

tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=None,
img_ratios=[0.5, 1.0, 2.0],
flip=True,
transforms=[dict(type='Resize', keep_ratio=False)],
)
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(144, 256), (144, 256), (288, 512),
(288, 512), (576, 1024), (576, 1024)]
assert tta_results['flip'] == [False, True, False, True, False, True]

tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=[(256, 256), (512, 512), (1024, 1024)],
img_ratios=None,
flip=False,
transforms=[dict(type='Resize', keep_ratio=False)],
)
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(256, 256), (512, 512), (1024, 1024)]
assert tta_results['flip'] == [False, False, False]

tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=[(256, 256), (512, 512), (1024, 1024)],
img_ratios=None,
flip=True,
transforms=[dict(type='Resize', keep_ratio=False)],
)
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
(512, 512), (1024, 1024), (1024, 1024)]
assert tta_results['flip'] == [False, True, False, True, False, True]

0 comments on commit 061a295

Please sign in to comment.