Skip to content

Commit

Permalink
[Feature] Add RandomRotate transform (#215)
Browse files Browse the repository at this point in the history
* add RandomRotate for transforms

* change rotation function to mmcv.imrotate

* refactor

* add unittest

* fixed test

* fixed docstring

* fixed test

* add more test

* fixed repr

* rename to prob

* fixed unittest

Co-authored-by: hkzhang95 <GodBlessZhk@outlook.com>
  • Loading branch information
xvjiarui and hkzhang95 committed Nov 7, 2020
1 parent 0d10921 commit 3d18775
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 18 deletions.
2 changes: 1 addition & 1 deletion configs/_base_/datasets/ade20k.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
Expand Down
2 changes: 1 addition & 1 deletion configs/_base_/datasets/cityscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
Expand Down
2 changes: 1 addition & 1 deletion configs/_base_/datasets/cityscapes_769x769.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
Expand Down
2 changes: 1 addition & 1 deletion configs/_base_/datasets/pascal_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
Expand Down
2 changes: 1 addition & 1 deletion configs/_base_/datasets/pascal_voc12.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
Expand Down
99 changes: 92 additions & 7 deletions mmseg/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import mmcv
import numpy as np
from mmcv.utils import deprecated_api_warning
from numpy import random

from ..builder import PIPELINES
Expand Down Expand Up @@ -232,16 +233,17 @@ class RandomFlip(object):
method.
Args:
flip_ratio (float, optional): The flipping probability. Default: None.
prob (float, optional): The flipping probability. Default: None.
direction(str, optional): The flipping direction. Options are
'horizontal' and 'vertical'. Default: 'horizontal'.
"""

def __init__(self, flip_ratio=None, direction='horizontal'):
self.flip_ratio = flip_ratio
@deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip')
def __init__(self, prob=None, direction='horizontal'):
self.prob = prob
self.direction = direction
if flip_ratio is not None:
assert flip_ratio >= 0 and flip_ratio <= 1
if prob is not None:
assert prob >= 0 and prob <= 1
assert direction in ['horizontal', 'vertical']

def __call__(self, results):
Expand All @@ -257,7 +259,7 @@ def __call__(self, results):
"""

if 'flip' not in results:
flip = True if np.random.rand() < self.flip_ratio else False
flip = True if np.random.rand() < self.prob else False
results['flip'] = flip
if 'flip_direction' not in results:
results['flip_direction'] = self.direction
Expand All @@ -274,7 +276,7 @@ def __call__(self, results):
return results

def __repr__(self):
return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})'
return self.__class__.__name__ + f'(prob={self.prob})'


@PIPELINES.register_module()
Expand Down Expand Up @@ -463,6 +465,89 @@ def __repr__(self):
return self.__class__.__name__ + f'(crop_size={self.crop_size})'


@PIPELINES.register_module()
class RandomRotate(object):
"""Rotate the image & seg.
Args:
prob (float): The rotation probability.
degree (float, tuple[float]): Range of degrees to select from. If
degree is a number instead of tuple like (min, max),
the range of degree will be (``-degree``, ``+degree``)
pad_val (float, optional): Padding value of image. Default: 0.
seg_pad_val (float, optional): Padding value of segmentation map.
Default: 255.
center (tuple[float], optional): Center point (w, h) of the rotation in
the source image. If not specified, the center of the image will be
used. Default: None.
auto_bound (bool): Whether to adjust the image size to cover the whole
rotated image. Default: False
"""

def __init__(self,
prob,
degree,
pad_val=0,
seg_pad_val=255,
center=None,
auto_bound=False):
self.prob = prob
assert prob >= 0 and prob <= 1
if isinstance(degree, (float, int)):
assert degree > 0, f'degree {degree} should be positive'
self.degree = (-degree, degree)
else:
self.degree = degree
assert len(self.degree) == 2, f'degree {self.degree} should be a ' \
f'tuple of (min, max)'
self.pal_val = pad_val
self.seg_pad_val = seg_pad_val
self.center = center
self.auto_bound = auto_bound

def __call__(self, results):
"""Call function to rotate image, semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Rotated results.
"""

rotate = True if np.random.rand() < self.prob else False
degree = np.random.uniform(min(*self.degree), max(*self.degree))
if rotate:
# rotate image
results['img'] = mmcv.imrotate(
results['img'],
angle=degree,
border_value=self.pal_val,
center=self.center,
auto_bound=self.auto_bound)

# rotate segs
for key in results.get('seg_fields', []):
results[key] = mmcv.imrotate(
results[key],
angle=degree,
border_value=self.seg_pad_val,
center=self.center,
auto_bound=self.auto_bound,
interpolation='nearest')
return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, ' \
f'degree={self.degree}, ' \
f'pad_val={self.pal_val}, ' \
f'seg_pad_val={self.seg_pad_val}, ' \
f'center={self.center}, ' \
f'auto_bound={self.auto_bound})'
return repr_str


@PIPELINES.register_module()
class SegRescale(object):
"""Rescale semantic segmentation maps.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_custom_dataset():
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(128, 256), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
Expand Down
50 changes: 45 additions & 5 deletions tests/test_data/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,17 @@ def test_resize():


def test_flip():
# test assertion for invalid flip_ratio
# test assertion for invalid prob
with pytest.raises(AssertionError):
transform = dict(type='RandomFlip', flip_ratio=1.5)
transform = dict(type='RandomFlip', prob=1.5)
build_from_cfg(transform, PIPELINES)

# test assertion for invalid direction
with pytest.raises(AssertionError):
transform = dict(
type='RandomFlip', flip_ratio=1, direction='horizonta')
transform = dict(type='RandomFlip', prob=1, direction='horizonta')
build_from_cfg(transform, PIPELINES)

transform = dict(type='RandomFlip', flip_ratio=1)
transform = dict(type='RandomFlip', prob=1)
flip_module = build_from_cfg(transform, PIPELINES)

results = dict()
Expand Down Expand Up @@ -197,6 +196,47 @@ def test_pad():
assert img_shape[1] % 32 == 0


def test_rotate():
# test assertion degree should be tuple[float] or float
with pytest.raises(AssertionError):
transform = dict(type='RandomRotate', prob=0.5, degree=-10)
build_from_cfg(transform, PIPELINES)
# test assertion degree should be tuple[float] or float
with pytest.raises(AssertionError):
transform = dict(type='RandomRotate', prob=0.5, degree=(10., 20., 30.))
build_from_cfg(transform, PIPELINES)

transform = dict(type='RandomRotate', degree=10., prob=1.)
transform = build_from_cfg(transform, PIPELINES)

assert str(transform) == f'RandomRotate(' \
f'prob={1.}, ' \
f'degree=({-10.}, {10.}), ' \
f'pad_val={0}, ' \
f'seg_pad_val={255}, ' \
f'center={None}, ' \
f'auto_bound={False})'

results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
h, w, _ = img.shape
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0

results = transform(results)
assert results['img'].shape[:2] == (h, w)
assert results['gt_semantic_seg'].shape[:2] == (h, w)


def test_normalize():
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53],
Expand Down

0 comments on commit 3d18775

Please sign in to comment.