Skip to content

Commit 8c90a87

Browse files
authored
[Fix] Fix magnitude_range in RandAug (#249)
* add increasing in solarize and posterize * fix linting * Revert "add increasing in solarize and posterize" This reverts commit 128af36. * revise according to comments
1 parent f415c49 commit 8c90a87

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

mmcls/datasets/pipelines/auto_augment.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@ class RandAugment(object):
6767
augmentation. For those which have magnitude, (given to the fact
6868
they are named differently in different augmentation, )
6969
`magnitude_key` and `magnitude_range` shall be the magnitude
70-
argument (str) and the range of magnitude (tuple in the format or
71-
(minval, maxval)), respectively.
70+
argument (str) and the range of magnitude (tuple in the format of
71+
(val1, val2)), respectively. Note that val1 is not necessarily
72+
less than val2.
7273
num_policies (int): Number of policies to select from policies each
7374
time.
7475
magnitude_level (int | float): Magnitude level for all the augmentation
@@ -85,6 +86,10 @@ class RandAugment(object):
8586
Note:
8687
`magnitude_std` will introduce some randomness to policy, modified by
8788
https://github.com/rwightman/pytorch-image-models
89+
When magnitude_std=0, we calculate the magnitude as follows:
90+
91+
.. math::
92+
magnitude = magnitude_level / total_level * (val2 - val1) + val1
8893
"""
8994

9095
def __init__(self,
@@ -130,18 +135,20 @@ def _process_policies(self, policies):
130135
processed_policy = copy.deepcopy(policy)
131136
magnitude_key = processed_policy.pop('magnitude_key', None)
132137
if magnitude_key is not None:
133-
minval, maxval = processed_policy.pop('magnitude_range')
138+
val1, val2 = processed_policy.pop('magnitude_range')
134139
magnitude_value = (self.magnitude_level / self.total_level
135-
) * float(maxval - minval) + minval
140+
) * float(val2 - val1) + val1
136141

137142
# if magnitude_std is positive number or 'inf', move
138143
# magnitude_value randomly.
144+
maxval = max(val1, val2)
145+
minval = min(val1, val2)
139146
if self.magnitude_std == 'inf':
140147
magnitude_value = random.uniform(minval, magnitude_value)
141148
elif self.magnitude_std > 0:
142149
magnitude_value = random.gauss(magnitude_value,
143150
self.magnitude_std)
144-
magnitude_value = min(maxval, max(0, magnitude_value))
151+
magnitude_value = min(maxval, max(minval, magnitude_value))
145152
processed_policy.update({magnitude_key: magnitude_value})
146153
processed_policies.append(processed_policy)
147154
return processed_policies

tests/test_pipelines/test_auto_augment.py

+28
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,34 @@ def test_rand_augment():
185185
# apply rotation with prob=0.
186186
assert (results['img'] == results['ori_img']).all()
187187

188+
# test case where magnitude_range is reversed
189+
random.seed(1)
190+
np.random.seed(0)
191+
results = construct_toy_data()
192+
reversed_policies = [
193+
dict(
194+
type='Translate',
195+
magnitude_key='magnitude',
196+
magnitude_range=(1, 0),
197+
pad_val=128,
198+
prob=1.,
199+
direction='horizontal'),
200+
dict(type='Invert', prob=1.),
201+
dict(
202+
type='Rotate',
203+
magnitude_key='angle',
204+
magnitude_range=(30, 0),
205+
prob=0.)
206+
]
207+
transform = dict(
208+
type='RandAugment',
209+
policies=reversed_policies,
210+
num_policies=1,
211+
magnitude_level=30)
212+
pipeline = build_from_cfg(transform, PIPELINES)
213+
results = pipeline(results)
214+
assert (results['img'] == results['ori_img']).all()
215+
188216
# test case where num_policies = 2
189217
random.seed(0)
190218
np.random.seed(0)

0 commit comments

Comments
 (0)