@@ -1955,8 +1955,9 @@ def __repr__(self):
19551955 return self .__class__ .__name__ + '(p={})' .format (self .p )
19561956
19571957
1958+ # TODO: move this to references before merging and delete the tests
19581959class RandomMixupCutmix (torch .nn .Module ):
1959- """Randomly apply Mixum or Cutmix to the provided batch and targets.
1960+ """Randomly apply Mixup or Cutmix to the provided batch and targets.
19601961 The class implements the data augmentations as described in the papers
19611962 `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_ and
19621963 `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
@@ -2014,8 +2015,8 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
20142015 return batch , target
20152016
20162017 # It's faster to roll the batch by one instead of shuffling it to create image pairs
2017- batch_flipped = batch .roll (1 )
2018- target_flipped = target .roll (1 )
2018+ batch_rolled = batch .roll (1 , 0 )
2019+ target_rolled = target .roll (1 )
20192020
20202021 if self .mixup_alpha <= 0.0 :
20212022 use_mixup = False
@@ -2025,8 +2026,8 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
20252026 if use_mixup :
20262027 # Implemented as on mixup paper, page 3.
20272028 lambda_param = float (torch ._sample_dirichlet (torch .tensor ([self .mixup_alpha , self .mixup_alpha ]))[0 ])
2028- batch_flipped .mul_ (1.0 - lambda_param )
2029- batch .mul_ (lambda_param ).add_ (batch_flipped )
2029+ batch_rolled .mul_ (1.0 - lambda_param )
2030+ batch .mul_ (lambda_param ).add_ (batch_rolled )
20302031 else :
20312032 # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
20322033 lambda_param = float (torch ._sample_dirichlet (torch .tensor ([self .cutmix_alpha , self .cutmix_alpha ]))[0 ])
@@ -2044,11 +2045,11 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
20442045 x2 = int (torch .clamp (r_x + r_w_half , max = W ))
20452046 y2 = int (torch .clamp (r_y + r_h_half , max = H ))
20462047
2047- batch [:, :, y1 :y2 , x1 :x2 ] = batch_flipped [:, :, y1 :y2 , x1 :x2 ]
2048+ batch [:, :, y1 :y2 , x1 :x2 ] = batch_rolled [:, :, y1 :y2 , x1 :x2 ]
20482049 lambda_param = float (1.0 - (x2 - x1 ) * (y2 - y1 ) / (W * H ))
20492050
2050- target_flipped .mul_ (1.0 - lambda_param )
2051- target .mul_ (lambda_param ).add_ (target_flipped )
2051+ target_rolled .mul_ (1.0 - lambda_param )
2052+ target .mul_ (lambda_param ).add_ (target_rolled )
20522053
20532054 return batch , target
20542055
0 commit comments