-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Add gallery example for MixUp and CutMix #7772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1df3b8a
8b8b752
fa9790c
e7e5977
4f54ac0
afb15d8
a367808
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,152 @@ | ||
|
|
||
| """ | ||
| =========================== | ||
| How to use Cutmix and Mixup | ||
| How to use CutMix and MixUp | ||
| =========================== | ||
|
|
||
| TODO | ||
| :class:`~torchvision.transforms.v2.Cutmix` and | ||
| :class:`~torchvision.transforms.v2.Mixup` are popular augmentation strategies | ||
| that can improve classification accuracy. | ||
|
|
||
| These transforms are slightly different from the rest of the Torchvision | ||
| transforms, because they expect | ||
| **batches** of samples as input, not individual images. In this example we'll | ||
| explain how to use them: after the ``DataLoader``, or as part of a collation | ||
| function. | ||
| """ | ||
|
|
||
| # %% | ||
| import torch | ||
| import torchvision | ||
| from torchvision.datasets import FakeData | ||
|
|
||
| # We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that | ||
| # some APIs may slightly change in the future | ||
| torchvision.disable_beta_transforms_warning() | ||
|
|
||
| from torchvision.transforms import v2 | ||
|
|
||
|
|
||
| NUM_CLASSES = 100 | ||
|
|
||
| # %% | ||
| # Pre-processing pipeline | ||
| # ----------------------- | ||
| # | ||
| # We'll use a simple but typical image classification pipeline: | ||
|
|
||
| preproc = v2.Compose([ | ||
| v2.PILToTensor(), | ||
| v2.RandomResizedCrop(size=(224, 224), antialias=True), | ||
| v2.RandomHorizontalFlip(p=0.5), | ||
| v2.ToDtype(torch.float32, scale=True), # to float32 in [0, 1] | ||
| v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # typically from ImageNet | ||
| ]) | ||
|
|
||
| dataset = FakeData(size=1000, num_classes=NUM_CLASSES, transform=preproc) | ||
|
|
||
| img, label = dataset[0] | ||
| print(f"{type(img) = }, {img.dtype = }, {img.shape = }, {label = }") | ||
|
|
||
| # %% | ||
| # | ||
| # One important thing to note is that neither CutMix nor MixUp are part of this | ||
| # pre-processing pipeline. We'll add them a bit later once we define the | ||
| # DataLoader. Just as a refresher, this is what the DataLoader and training loop | ||
| # would look like if we weren't using CutMix or MixUp: | ||
|
|
||
| from torch.utils.data import DataLoader | ||
|
|
||
| dataloader = DataLoader(dataset, batch_size=4, shuffle=True) | ||
|
|
||
| for images, labels in dataloader: | ||
| print(f"{images.shape = }, {labels.shape = }") | ||
| print(labels.dtype) | ||
| # <rest of the training loop here> | ||
| break | ||
| # %% | ||
|
|
||
| # %% | ||
| # Where to use MixUp and CutMix | ||
| # ----------------------------- | ||
| # | ||
| # After the DataLoader | ||
| # ^^^^^^^^^^^^^^^^^^^^ | ||
| # | ||
| # Now let's add CutMix and MixUp. The simplest way to do this right after the | ||
| # DataLoader: the Dataloader has already batched the images and labels for us, | ||
| # and this is exactly what these transforms expect as input: | ||
|
|
||
| dataloader = DataLoader(dataset, batch_size=4, shuffle=True) | ||
|
|
||
| cutmix = v2.Cutmix(num_classes=NUM_CLASSES) | ||
| mixup = v2.Mixup(num_classes=NUM_CLASSES) | ||
| cutmix_or_mixup = v2.RandomChoice([cutmix, mixup]) | ||
|
|
||
| for images, labels in dataloader: | ||
| print(f"Before CutMix/MixUp: {images.shape = }, {labels.shape = }") | ||
| images, labels = cutmix_or_mixup(images, labels) | ||
| print(f"After CutMix/MixUp: {images.shape = }, {labels.shape = }") | ||
|
|
||
| # <rest of the training loop here> | ||
| break | ||
| # %% | ||
| # | ||
| # Note how the labels were also transformed: we went from a batched label of | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # shape (batch_size,) to a tensor of shape (batch_size, num_classes). The | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # transformed labels can still be passed as-is to a loss function like | ||
| # :func:`torch.nn.functional.cross_entropy`. | ||
| # | ||
| # As part of the collation function | ||
| # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
| # | ||
| # Passing the transforms after the DataLoader is the simplest way to use CutMix | ||
| # and MixUp, but one disadvantage is that it does not take advantage of the | ||
| # DataLoader multi-processing. For that, we can pass those transforms as part of | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # the collation function (refer to the `PyTorch docs | ||
| # <https://pytorch.org/docs/stable/data.html#dataloader-collate-fn>`_ to learn | ||
| # more about collation). | ||
|
|
||
| from torch.utils.data import default_collate | ||
|
|
||
|
|
||
| def collate_fn(batch): | ||
| return cutmix_or_mixup(*default_collate(batch)) | ||
|
|
||
|
|
||
| dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, collate_fn=collate_fn) | ||
|
|
||
| for images, labels in dataloader: | ||
| print(f"{images.shape = }, {labels.shape = }") | ||
| # No need to call cutmix_or_mixup, it's already been called as part of the DataLoader! | ||
| # <rest of the training loop here> | ||
| break | ||
|
|
||
| # %% | ||
| # Non-standard input format | ||
| # ------------------------- | ||
| # | ||
| # So far we've used a typical sample structure where we pass ``(images, | ||
| # labels)`` as inputs. MixUp and CutMix will magically work by default with most | ||
| # common sample structures: tuples where the second parameter is a tensor label, | ||
| # or dict with a "label[s]" key. Look at the documentation of the | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # ``labels_getter`` parameter for more details. | ||
| # | ||
| # If your samples have a different structure, you can still use CutMix and MixUp | ||
| # by passing a callable to the ``labels_getter`` parameter. For example: | ||
|
|
||
| batch = { | ||
| "imgs": torch.rand(4, 3, 224, 224), | ||
| "target": { | ||
| "classes": torch.randint(0, NUM_CLASSES, size=(4,)), | ||
| "some_other_key": "this is going to be passed-through" | ||
| } | ||
| } | ||
|
|
||
|
|
||
| def labels_getter(batch): | ||
| return batch["target"]["classes"] | ||
|
|
||
|
|
||
| out = v2.Cutmix(num_classes=NUM_CLASSES, labels_getter=labels_getter)(batch) | ||
| print(f"{out['imgs'].shape = }, {out['target']['classes'].shape = }") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -141,9 +141,9 @@ def _transform( | |
|
|
||
|
|
||
| class _BaseMixupCutmix(Transform): | ||
| def __init__(self, *, alpha: float = 1, num_classes: int, labels_getter="default") -> None: | ||
| def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default") -> None: | ||
| super().__init__() | ||
| self.alpha = alpha | ||
| self.alpha = float(alpha) | ||
| self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) | ||
|
|
||
| self.num_classes = num_classes | ||
|
|
@@ -204,13 +204,20 @@ def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor: | |
|
|
||
|
|
||
| class Mixup(_BaseMixupCutmix): | ||
| """[BETA] Apply Mixup to the provided batch of images and labels. | ||
| """[BETA] Apply MixUp to the provided batch of images and labels. | ||
|
Comment on lines
206
to
+207
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we than also rename the classes?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thought about it, went for consistency with timm (i.e.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't mind either, but I want both our documentation and class name to be the same. Your choice.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🥲 #7731 (comment) I'm going to let you guys decide what to do as I want to avoid re-undo stuff. (I don't think we absolutely have to align docs and code: CutMix is the technic, Cutmix is the class object - it's OK to have a distinction.)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vfdev-5 Could you clarify if you meant to only switch to |
||
|
|
||
| .. v2betastatus:: Mixup transform | ||
|
|
||
| Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_. | ||
|
|
||
| See :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples. | ||
| .. note:: | ||
| This transform is meant to be used on **batches** of samples, not | ||
| individual images. See | ||
| :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage | ||
| examples. | ||
| The sample pairing is deterministic and done by matching consecutive | ||
| samples in the batch, so the batch needs to be shuffled (this is an | ||
| implementation detail, not a guaranteed convention.) | ||
|
|
||
| In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed | ||
| into a tensor of shape ``(batch_size, num_classes)``. | ||
|
|
@@ -246,14 +253,21 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: | |
|
|
||
|
|
||
| class Cutmix(_BaseMixupCutmix): | ||
| """[BETA] Apply Cutmix to the provided batch of images and labels. | ||
| """[BETA] Apply CutMix to the provided batch of images and labels. | ||
|
|
||
| .. v2betastatus:: Cutmix transform | ||
|
|
||
| Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features | ||
| <https://arxiv.org/abs/1905.04899>`_. | ||
|
|
||
| See :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples. | ||
| .. note:: | ||
| This transform is meant to be used on **batches** of samples, not | ||
| individual images. See | ||
| :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage | ||
| examples. | ||
| The sample pairing is deterministic and done by matching consecutive | ||
| samples in the batch, so the batch needs to be shuffled (this is an | ||
| implementation detail, not a guaranteed convention.) | ||
|
|
||
| In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed | ||
| into a tensor of shape ``(batch_size, num_classes)``. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using #
%%instead of the long#####...as in the other examples allows for those scripts to be properly interpreted as notebook within vscode / pycharm. It makes writing / debugging those examples a lot easier. I'm tempted to align the other scripts to use this syntax as well?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, sure, but please in a follow-up PR.