@@ -332,14 +332,18 @@ def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
332332 # in the usually fixed policy and sample magnitude from a normal distribution
333333 # with mean `magnitude` and std-dev of `magnitude_std`.
334334 # NOTE This is my own hack, being tested, not in papers or reference impls.
335+ # If magnitude_std is inf, we sample magnitude from a uniform distribution
335336 self .magnitude_std = self .hparams .get ('magnitude_std' , 0 )
336337
337338 def __call__ (self , img ):
338339 if self .prob < 1.0 and random .random () > self .prob :
339340 return img
340341 magnitude = self .magnitude
341- if self .magnitude_std and self .magnitude_std > 0 :
342- magnitude = random .gauss (magnitude , self .magnitude_std )
342+ if self .magnitude_std :
343+ if self .magnitude_std == float ('inf' ):
344+ magnitude = random .uniform (0 , magnitude )
345+ elif self .magnitude_std > 0 :
346+ magnitude = random .gauss (magnitude , self .magnitude_std )
343347 magnitude = min (_MAX_LEVEL , max (0 , magnitude )) # clip to valid range
344348 level_args = self .level_fn (magnitude , self .hparams ) if self .level_fn is not None else tuple ()
345349 return self .aug_fn (img , * level_args , ** self .kwargs )
@@ -790,6 +794,7 @@ def augment_and_mix_transform(config_str, hparams):
790794 depth = - 1
791795 alpha = 1.
792796 blended = False
797+ hparams ['magnitude_std' ] = float ('inf' )
793798 config = config_str .split ('-' )
794799 assert config [0 ] == 'augmix'
795800 config = config [1 :]
0 commit comments