From 0cfd88d9b39e7cd7d65e22578d1aaa1bedb0bfa5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 1 Aug 2023 12:46:23 +0100 Subject: [PATCH] Fix docs --- gallery/plot_cutmix_mixup.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/gallery/plot_cutmix_mixup.py b/gallery/plot_cutmix_mixup.py index d1c92a27812..932ce325b56 100644 --- a/gallery/plot_cutmix_mixup.py +++ b/gallery/plot_cutmix_mixup.py @@ -4,8 +4,8 @@ How to use CutMix and MixUp =========================== -:class:`~torchvision.transforms.v2.Cutmix` and -:class:`~torchvision.transforms.v2.Mixup` are popular augmentation strategies +:class:`~torchvision.transforms.v2.CutMix` and +:class:`~torchvision.transforms.v2.MixUp` are popular augmentation strategies that can improve classification accuracy. These transforms are slightly different from the rest of the Torchvision @@ -79,8 +79,8 @@ dataloader = DataLoader(dataset, batch_size=4, shuffle=True) -cutmix = v2.Cutmix(num_classes=NUM_CLASSES) -mixup = v2.Mixup(num_classes=NUM_CLASSES) +cutmix = v2.CutMix(num_classes=NUM_CLASSES) +mixup = v2.MixUp(num_classes=NUM_CLASSES) cutmix_or_mixup = v2.RandomChoice([cutmix, mixup]) for images, labels in dataloader: @@ -148,5 +148,5 @@ def labels_getter(batch): return batch["target"]["classes"] -out = v2.Cutmix(num_classes=NUM_CLASSES, labels_getter=labels_getter)(batch) +out = v2.CutMix(num_classes=NUM_CLASSES, labels_getter=labels_getter)(batch) print(f"{out['imgs'].shape = }, {out['target']['classes'].shape = }")