@@ -67,8 +67,9 @@ class RandAugment(object):
67
67
augmentation. For those which have magnitude, (given to the fact
68
68
they are named differently in different augmentation, )
69
69
`magnitude_key` and `magnitude_range` shall be the magnitude
70
- argument (str) and the range of magnitude (tuple in the format or
71
- (minval, maxval)), respectively.
70
+ argument (str) and the range of magnitude (tuple in the format of
71
+ (val1, val2)), respectively. Note that val1 is not necessarily
72
+ less than val2.
72
73
num_policies (int): Number of policies to select from policies each
73
74
time.
74
75
magnitude_level (int | float): Magnitude level for all the augmentation
@@ -85,6 +86,10 @@ class RandAugment(object):
85
86
Note:
86
87
`magnitude_std` will introduce some randomness to policy, modified by
87
88
https://github.com/rwightman/pytorch-image-models
89
+ When magnitude_std=0, we calculate the magnitude as follows:
90
+
91
+ .. math::
92
+ magnitude = magnitude_level / total_level * (val2 - val1) + val1
88
93
"""
89
94
90
95
def __init__ (self ,
@@ -130,18 +135,20 @@ def _process_policies(self, policies):
130
135
processed_policy = copy .deepcopy (policy )
131
136
magnitude_key = processed_policy .pop ('magnitude_key' , None )
132
137
if magnitude_key is not None :
133
- minval , maxval = processed_policy .pop ('magnitude_range' )
138
+ val1 , val2 = processed_policy .pop ('magnitude_range' )
134
139
magnitude_value = (self .magnitude_level / self .total_level
135
- ) * float (maxval - minval ) + minval
140
+ ) * float (val2 - val1 ) + val1
136
141
137
142
# if magnitude_std is positive number or 'inf', move
138
143
# magnitude_value randomly.
144
+ maxval = max (val1 , val2 )
145
+ minval = min (val1 , val2 )
139
146
if self .magnitude_std == 'inf' :
140
147
magnitude_value = random .uniform (minval , magnitude_value )
141
148
elif self .magnitude_std > 0 :
142
149
magnitude_value = random .gauss (magnitude_value ,
143
150
self .magnitude_std )
144
- magnitude_value = min (maxval , max (0 , magnitude_value ))
151
+ magnitude_value = min (maxval , max (minval , magnitude_value ))
145
152
processed_policy .update ({magnitude_key : magnitude_value })
146
153
processed_policies .append (processed_policy )
147
154
return processed_policies
0 commit comments