Skip to content

Commit 3e4e353

Browse files
authored
Cutmix -> CutMix (#7784)
1 parent edde825 commit 3e4e353

File tree

9 files changed

+29
-29
lines changed

9 files changed

+29
-29
lines changed

docs/source/transforms.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,8 @@ are combining pairs of images together. These can be used after the dataloader
274274
:toctree: generated/
275275
:template: class.rst
276276

277-
v2.Cutmix
278-
v2.Mixup
277+
v2.CutMix
278+
v2.MixUp
279279

280280
.. _functional_transforms:
281281

references/classification/transforms.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,24 @@ def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_categories, use_v2):
1313
mixup_cutmix = []
1414
if mixup_alpha > 0:
1515
mixup_cutmix.append(
16-
transforms_module.Mixup(alpha=mixup_alpha, num_categories=num_categories)
16+
transforms_module.MixUp(alpha=mixup_alpha, num_categories=num_categories)
1717
if use_v2
18-
else RandomMixup(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
18+
else RandomMixUp(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
1919
)
2020
if cutmix_alpha > 0:
2121
mixup_cutmix.append(
22-
transforms_module.Cutmix(alpha=mixup_alpha, num_categories=num_categories)
22+
transforms_module.CutMix(alpha=mixup_alpha, num_categories=num_categories)
2323
if use_v2
24-
else RandomCutmix(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
24+
else RandomCutMix(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
2525
)
2626
if not mixup_cutmix:
2727
return None
2828

2929
return transforms_module.RandomChoice(mixup_cutmix)
3030

3131

32-
class RandomMixup(torch.nn.Module):
33-
"""Randomly apply Mixup to the provided batch and targets.
32+
class RandomMixUp(torch.nn.Module):
33+
"""Randomly apply MixUp to the provided batch and targets.
3434
The class implements the data augmentations as described in the paper
3535
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
3636
@@ -112,8 +112,8 @@ def __repr__(self) -> str:
112112
return s
113113

114114

115-
class RandomCutmix(torch.nn.Module):
116-
"""Randomly apply Cutmix to the provided batch and targets.
115+
class RandomCutMix(torch.nn.Module):
116+
"""Randomly apply CutMix to the provided batch and targets.
117117
The class implements the data augmentations as described in the paper
118118
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
119119
<https://arxiv.org/abs/1905.04899>`_.

test/test_prototype_transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def parametrize(transforms_with_inputs):
6060
],
6161
)
6262
for transform in [
63-
transforms.RandomMixup(alpha=1.0),
64-
transforms.RandomCutmix(alpha=1.0),
63+
transforms.RandomMixUp(alpha=1.0),
64+
transforms.RandomCutMix(alpha=1.0),
6565
]
6666
]
6767
)

test/test_transforms_v2_refactored.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1914,7 +1914,7 @@ def __getitem__(self, idx):
19141914
def __len__(self):
19151915
return self.size
19161916

1917-
@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup])
1917+
@pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
19181918
def test_supported_input_structure(self, T):
19191919

19201920
batch_size = 32
@@ -1964,7 +1964,7 @@ def collate_fn_2(batch):
19641964
check_output(img, target)
19651965

19661966
@needs_cuda
1967-
@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup])
1967+
@pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
19681968
def test_cpu_vs_gpu(self, T):
19691969
num_classes = 10
19701970
batch_size = 3
@@ -1976,7 +1976,7 @@ def test_cpu_vs_gpu(self, T):
19761976

19771977
_check_kernel_cuda_vs_cpu(cutmix_mixup, imgs, labels, rtol=None, atol=None)
19781978

1979-
@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup])
1979+
@pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
19801980
def test_error(self, T):
19811981

19821982
num_classes = 10
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from ._presets import StereoMatching # usort: skip
22

3-
from ._augment import RandomCutmix, RandomMixup, SimpleCopyPaste
3+
from ._augment import RandomCutMix, RandomMixUp, SimpleCopyPaste
44
from ._geometry import FixedSizeCrop
55
from ._misc import PermuteDimensions, TransposeDimensions
66
from ._type_conversion import LabelToOneHot

torchvision/prototype/transforms/_augment.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_size
1515

1616

17-
class _BaseMixupCutmix(_RandomApplyTransform):
17+
class _BaseMixUpCutMix(_RandomApplyTransform):
1818
def __init__(self, alpha: float, p: float = 0.5) -> None:
1919
super().__init__(p=p)
2020
self.alpha = alpha
@@ -38,7 +38,7 @@ def _mixup_onehotlabel(self, inpt: proto_datapoints.OneHotLabel, lam: float) ->
3838
return proto_datapoints.OneHotLabel.wrap_like(inpt, output)
3939

4040

41-
class RandomMixup(_BaseMixupCutmix):
41+
class RandomMixUp(_BaseMixUpCutMix):
4242
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
4343
return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type]
4444

@@ -60,7 +60,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
6060
return inpt
6161

6262

63-
class RandomCutmix(_BaseMixupCutmix):
63+
class RandomCutMix(_BaseMixUpCutMix):
6464
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
6565
lam = float(self._dist.sample(())) # type: ignore[arg-type]
6666

torchvision/transforms/v2/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from ._transform import Transform # usort: skip
66

7-
from ._augment import Cutmix, Mixup, RandomErasing
7+
from ._augment import CutMix, MixUp, RandomErasing
88
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
99
from ._color import (
1010
ColorJitter,

torchvision/transforms/v2/_augment.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def _transform(
140140
return inpt
141141

142142

143-
class _BaseMixupCutmix(Transform):
143+
class _BaseMixUpCutMix(Transform):
144144
def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default") -> None:
145145
super().__init__()
146146
self.alpha = float(alpha)
@@ -203,10 +203,10 @@ def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor:
203203
return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam))
204204

205205

206-
class Mixup(_BaseMixupCutmix):
206+
class MixUp(_BaseMixUpCutMix):
207207
"""[BETA] Apply MixUp to the provided batch of images and labels.
208208
209-
.. v2betastatus:: Mixup transform
209+
.. v2betastatus:: MixUp transform
210210
211211
Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_.
212212
@@ -227,7 +227,7 @@ class Mixup(_BaseMixupCutmix):
227227
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
228228
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
229229
By default, this will pick the second parameter a the labels if it's a tensor. This covers the most
230-
common scenario where this transform is called as ``Mixup()(imgs_batch, labels_batch)``.
230+
common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
231231
It can also be a callable that takes the same input as the transform, and returns the labels.
232232
"""
233233

@@ -252,10 +252,10 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
252252
return inpt
253253

254254

255-
class Cutmix(_BaseMixupCutmix):
255+
class CutMix(_BaseMixUpCutMix):
256256
"""[BETA] Apply CutMix to the provided batch of images and labels.
257257
258-
.. v2betastatus:: Cutmix transform
258+
.. v2betastatus:: CutMix transform
259259
260260
Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
261261
<https://arxiv.org/abs/1905.04899>`_.
@@ -277,7 +277,7 @@ class Cutmix(_BaseMixupCutmix):
277277
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
278278
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
279279
By default, this will pick the second parameter a the labels if it's a tensor. This covers the most
280-
common scenario where this transform is called as ``Cutmix()(imgs_batch, labels_batch)``.
280+
common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.
281281
It can also be a callable that takes the same input as the transform, and returns the labels.
282282
"""
283283

torchvision/transforms/v2/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
8989
This heuristic covers three cases:
9090
9191
1. The input is tuple or list whose second item is a labels tensor. This happens for already batched
92-
classification inputs for Mixup and Cutmix (typically after the Dataloder).
92+
classification inputs for MixUp and CutMix (typically after the Dataloder).
9393
2. The input is a tuple or list whose second item is a dictionary that contains the labels tensor
9494
under a label-like (see below) key. This happens for the inputs of detection models.
9595
3. The input is a dictionary that is structured as the one from 2.
@@ -103,7 +103,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
103103
if isinstance(inputs, (tuple, list)):
104104
inputs = inputs[1]
105105

106-
# Mixup, Cutmix
106+
# MixUp, CutMix
107107
if isinstance(inputs, torch.Tensor):
108108
return inputs
109109

0 commit comments

Comments
 (0)