-
Notifications
You must be signed in to change notification settings - Fork 7.2k
promote Mixup and Cutmix from prototype to transforms v2 #7731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
6e9eb90
de92eb6
c160ae7
1cd7c7a
7934566
d5bb664
b4e6d43
fa97d52
f3708be
50fa4d2
26f55de
e91e879
45bf28c
9f4a9e6
0505f24
4d5890d
4538c10
6542fd0
0c3b932
acc7a98
5e02675
993f693
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,22 +8,25 @@ | |
| import torch.utils.data | ||
| import torchvision | ||
| import torchvision.transforms | ||
| import transforms | ||
| import utils | ||
| from sampler import RASampler | ||
| from torch import nn | ||
| from torch.utils.data.dataloader import default_collate | ||
| from torchvision.transforms.functional import InterpolationMode | ||
| from transforms import get_batch_transform | ||
|
|
||
|
|
||
| def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): | ||
| def train_one_epoch( | ||
| model, criterion, optimizer, data_loader, batch_transform, device, epoch, args, model_ema=None, scaler=None | ||
| ): | ||
| model.train() | ||
| metric_logger = utils.MetricLogger(delimiter=" ") | ||
| metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) | ||
| metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}")) | ||
|
|
||
| header = f"Epoch: [{epoch}]" | ||
| for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): | ||
| if batch_transform: | ||
| image, target = batch_transform(image, target) | ||
|
||
| start_time = time.time() | ||
| image, target = image.to(device), target.to(device) | ||
| with torch.cuda.amp.autocast(enabled=scaler is not None): | ||
|
|
@@ -218,31 +221,22 @@ def main(args): | |
| val_dir = os.path.join(args.data_path, "val") | ||
| dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) | ||
|
|
||
| collate_fn = None | ||
| num_classes = len(dataset.classes) | ||
| mixup_transforms = [] | ||
| if args.mixup_alpha > 0.0: | ||
| mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha)) | ||
| if args.cutmix_alpha > 0.0: | ||
| mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha)) | ||
| if mixup_transforms: | ||
| mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) | ||
|
|
||
| def collate_fn(batch): | ||
| return mixupcutmix(*default_collate(batch)) | ||
|
|
||
| data_loader = torch.utils.data.DataLoader( | ||
| dataset, | ||
| batch_size=args.batch_size, | ||
| sampler=train_sampler, | ||
| num_workers=args.workers, | ||
| pin_memory=True, | ||
| collate_fn=collate_fn, | ||
| ) | ||
| data_loader_test = torch.utils.data.DataLoader( | ||
| dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True | ||
| ) | ||
|
|
||
| num_classes = len(dataset.classes) | ||
| batch_transform = get_batch_transform( | ||
| mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_categories=num_classes, use_v2=args.use_v2 | ||
| ) | ||
|
|
||
| print("Creating model") | ||
| model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes) | ||
| model.to(device) | ||
|
|
@@ -364,7 +358,9 @@ def collate_fn(batch): | |
| for epoch in range(args.start_epoch, args.epochs): | ||
| if args.distributed: | ||
| train_sampler.set_epoch(epoch) | ||
| train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler) | ||
| train_one_epoch( | ||
| model, criterion, optimizer, data_loader, batch_transform, device, epoch, args, model_ema, scaler | ||
| ) | ||
| lr_scheduler.step() | ||
| evaluate(model, criterion, data_loader_test, device=device) | ||
| if model_ema: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,15 +1,18 @@ | ||
| import math | ||
| import numbers | ||
| import warnings | ||
| from typing import Any, Dict, List, Tuple, Union | ||
| from typing import Any, Dict, List, Optional, Tuple, Union | ||
|
|
||
| import PIL.Image | ||
| import torch | ||
| from torch.nn.functional import one_hot | ||
| from torch.utils._pytree import tree_flatten, tree_unflatten | ||
| from torchvision import datapoints, transforms as _transforms | ||
| from torchvision.transforms.v2 import functional as F | ||
|
|
||
| from ._transform import _RandomApplyTransform | ||
| from .utils import is_simple_tensor, query_chw | ||
| from ._transform import _RandomApplyTransform, Transform | ||
| from ._utils import _parse_labels_getter | ||
| from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size | ||
|
|
||
|
|
||
| class RandomErasing(_RandomApplyTransform): | ||
|
|
@@ -135,3 +138,150 @@ def _transform( | |
| inpt = F.erase(inpt, **params, inplace=self.inplace) | ||
|
|
||
| return inpt | ||
|
|
||
|
|
||
| class _BaseMixupCutmix(Transform): | ||
| def __init__(self, *, alpha: float, num_categories: Optional[int] = None, labels_getter="default") -> None: | ||
| super().__init__() | ||
| self.alpha = alpha | ||
| self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) | ||
|
|
||
| self.num_categories = num_categories | ||
|
|
||
| self.labels_getter = labels_getter | ||
| self._labels_getter = _parse_labels_getter(labels_getter) | ||
|
|
||
| def forward(self, *inputs): | ||
| inputs = inputs if len(inputs) > 1 else inputs[0] | ||
| flat_inputs, spec = tree_flatten(inputs) | ||
| needs_transform_list = self._needs_transform_list(flat_inputs) | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBox, datapoints.Mask): | ||
| raise TypeError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.") | ||
|
|
||
| labels = self._labels_getter(inputs) | ||
| if labels is None: | ||
| msg = "Couldn't find a label in the inputs." | ||
| if self.labels_getter == "default": | ||
| msg = f"{msg} To overwrite the default find behavior, pass a callable for labels_getter." | ||
|
||
| raise RuntimeError(msg) | ||
|
||
| elif not isinstance(labels, torch.Tensor): | ||
| raise ValueError(f"The labels must be a torch.Tensor, but got {type(labels)} instead.") | ||
| elif labels.ndim in {1, 2}: | ||
| if labels.ndim == 2 and self.num_categories is not None and labels.shape[-1] != self.num_categories: | ||
| raise ValueError( | ||
| f"2D labels are assumed to be probability based, " | ||
NicolasHug marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| f"but the number of elements in last dimension does not match the number of categories: " | ||
| f"{labels.shape[-1]} != {self.num_categories}." | ||
| ) | ||
| else: | ||
| raise ValueError( | ||
| f"labels should be a index based with shape (batch_size,) " | ||
| f"or a probability based with shape (batch_size, num_categories), " | ||
NicolasHug marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| f"but got a tensor of shape {labels.shape} instead." | ||
| ) | ||
|
|
||
| params = { | ||
| "labels": labels, | ||
| "batch_size": labels.shape[0], | ||
| **self._get_params( | ||
| [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform] | ||
| ), | ||
| } | ||
|
|
||
| # By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor, but coming | ||
| # after an image or video. However, since we want to handle them in _transform, we | ||
NicolasHug marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We used a different strategy in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can't use the same strategy here.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we transform all images (each image is collated as (N, C, H, W)) ?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I understand what you mean: we'd need to re-implement the "tensor pass-through heuristic" in
We are transforming all images yes
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, we can certainly also use |
||
| flat_outputs = [ | ||
| self._transform(inpt, params) if needs_transform else inpt | ||
| for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) | ||
| ] | ||
|
|
||
| return tree_unflatten(flat_outputs, spec) | ||
|
|
||
| def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int): | ||
| if inpt.ndim != (5 if isinstance(inpt, datapoints.Video) else 4): | ||
| raise ValueError( | ||
| f"The transform expects a batched input, but got an {type(inpt).__name__} with {inpt.ndim} dimensions." | ||
| ) | ||
| if inpt.shape[0] != batch_size: | ||
| raise ValueError( | ||
| f"The batch size of the image or video does not match the batch size of the labels: " | ||
| f"{inpt.shape[0]} != {batch_size}." | ||
| ) | ||
|
|
||
| def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor: | ||
| if label.ndim == 1: | ||
| if self.num_categories is None: | ||
| raise ValueError( | ||
| "Cannot transform an index based labels (1D tensor) into an probability based one (2D tensor), " | ||
NicolasHug marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| "when num_categories is not set." | ||
| ) | ||
| label = one_hot(label, num_classes=self.num_categories) | ||
| if not label.dtype.is_floating_point: | ||
| label = label.float() | ||
| return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam)) | ||
|
|
||
|
|
||
| class Mixup(_BaseMixupCutmix): | ||
| def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: | ||
| return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type] | ||
|
|
||
| def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: | ||
| lam = params["lam"] | ||
|
|
||
| if inpt is params["labels"]: | ||
| return self._mixup_label(inpt, lam=lam) | ||
| elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): | ||
| self._check_image_or_video(inpt, batch_size=params["batch_size"]) | ||
|
|
||
| output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) | ||
|
|
||
| if isinstance(inpt, (datapoints.Image, datapoints.Video)): | ||
| output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] | ||
|
|
||
| return output | ||
| else: | ||
| return inpt | ||
|
|
||
|
|
||
| class Cutmix(_BaseMixupCutmix): | ||
| def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: | ||
| lam = float(self._dist.sample(())) # type: ignore[arg-type] | ||
|
|
||
| H, W = query_spatial_size(flat_inputs) | ||
|
|
||
| r_x = torch.randint(W, ()) | ||
| r_y = torch.randint(H, ()) | ||
|
|
||
| r = 0.5 * math.sqrt(1.0 - lam) | ||
| r_w_half = int(r * W) | ||
| r_h_half = int(r * H) | ||
|
|
||
| x1 = int(torch.clamp(r_x - r_w_half, min=0)) | ||
| y1 = int(torch.clamp(r_y - r_h_half, min=0)) | ||
| x2 = int(torch.clamp(r_x + r_w_half, max=W)) | ||
| y2 = int(torch.clamp(r_y + r_h_half, max=H)) | ||
| box = (x1, y1, x2, y2) | ||
|
|
||
| lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) | ||
|
|
||
| return dict(box=box, lam_adjusted=lam_adjusted) | ||
|
|
||
| def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: | ||
| if inpt is params["labels"]: | ||
| return self._mixup_label(inpt, lam=params["lam_adjusted"]) | ||
| elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): | ||
| self._check_image_or_video(inpt, batch_size=params["batch_size"]) | ||
|
|
||
| x1, y1, x2, y2 = params["box"] | ||
| rolled = inpt.roll(1, 0) | ||
| output = inpt.clone() | ||
| output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2] | ||
|
|
||
| if isinstance(inpt, (datapoints.Image, datapoints.Video)): | ||
| output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] | ||
|
|
||
| return output | ||
| else: | ||
| return inpt | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sooooo to avoid bikeshedding on how we should call those (batch transforms vs pairwise transforms vs something else), maybe we should just rename that to
get_cutmix_mixup?