forked from open-mmlab/mmsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_time_aug.py
132 lines (115 loc) · 5.09 KB
/
test_time_aug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import warnings
import mmcv
from ..builder import PIPELINES
from .compose import Compose
@PIPELINES.register_module()
class MultiScaleFlipAug(object):
"""Test-time augmentation with multiple scales and flipping.
An example configuration is as followed:
.. code-block::
img_scale=(2048, 1024),
img_ratios=[0.5, 1.0],
flip=True,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
]
After MultiScaleFLipAug with above configuration, the results are wrapped
into lists of the same length as followed:
.. code-block::
dict(
img=[...],
img_shape=[...],
scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)]
flip=[False, True, False, True]
...
)
Args:
transforms (list[dict]): Transforms to apply in each augmentation.
img_scale (None | tuple | list[tuple]): Images scales for resizing.
img_ratios (float | list[float]): Image ratios for resizing
flip (bool): Whether apply flip augmentation. Default: False.
flip_direction (str | list[str]): Flip augmentation directions,
options are "horizontal" and "vertical". If flip_direction is list,
multiple flip augmentations will be applied.
It has no effect when flip == False. Default: "horizontal".
"""
def __init__(self,
transforms,
img_scale,
img_ratios=None,
flip=False,
flip_direction='horizontal'):
self.transforms = Compose(transforms)
if img_ratios is not None:
img_ratios = img_ratios if isinstance(img_ratios,
list) else [img_ratios]
assert mmcv.is_list_of(img_ratios, float)
if img_scale is None:
# mode 1: given img_scale=None and a range of image ratio
self.img_scale = None
assert mmcv.is_list_of(img_ratios, float)
elif isinstance(img_scale,
tuple) and len(img_scale) == 2 and mmcv.is_list_of(
img_ratios, float):
# mode 2: given a scale and a range of image ratio
self.img_scale = [(int(img_scale[0] * ratio),
int(img_scale[1] * ratio))
for ratio in img_ratios]
else:
# mode 3: given multiple scales
self.img_scale = img_scale if isinstance(img_scale,
list) else [img_scale]
assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None
self.flip = flip
self.img_ratios = img_ratios
self.flip_direction = flip_direction if isinstance(
flip_direction, list) else [flip_direction]
assert mmcv.is_list_of(self.flip_direction, str)
if not self.flip and self.flip_direction != ['horizontal']:
warnings.warn(
'flip_direction has no effect when flip is set to False')
if (self.flip
and not any([t['type'] == 'RandomFlip' for t in transforms])):
warnings.warn(
'flip has no effect when RandomFlip is not in transforms')
def __call__(self, results):
"""Call function to apply test time augment transforms on results.
Args:
results (dict): Result dict contains the data to transform.
Returns:
dict[str: list]: The augmented data, where each value is wrapped
into a list.
"""
aug_data = []
if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float):
img_scale = results['img'].shape[:2]
self.img_scale = [(int(img_scale[0] * ratio),
int(img_scale[1] * ratio))
for ratio in self.img_ratios]
flip_aug = [False, True] if self.flip else [False]
for scale in self.img_scale:
for flip in flip_aug:
for direction in self.flip_direction:
_results = results.copy()
_results['scale'] = scale
_results['flip'] = flip
_results['flip_direction'] = direction
data = self.transforms(_results)
aug_data.append(data)
# list of dict to dict of list
aug_data_dict = {key: [] for key in aug_data[0]}
for data in aug_data:
for key, val in data.items():
aug_data_dict[key].append(val)
return aug_data_dict
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(transforms={self.transforms}, '
repr_str += f'img_scale={self.img_scale}, flip={self.flip})'
repr_str += f'flip_direction={self.flip_direction}'
return repr_str