Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 14 additions & 18 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,25 @@
import torch.utils.data
import torchvision
import torchvision.transforms
import transforms
import utils
from sampler import RASampler
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
from transforms import get_batch_transform
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sooooo to avoid bikeshedding on how we should call those (batch transforms vs pairwise transforms vs something else), maybe we should just rename that to get_cutmix_mixup?



def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
def train_one_epoch(
model, criterion, optimizer, data_loader, batch_transform, device, epoch, args, model_ema=None, scaler=None
):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))

header = f"Epoch: [{epoch}]"
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
if batch_transform:
image, target = batch_transform(image, target)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I failed to notice this when we discussed it offline, but we should keep those transforms as collate_fn: calling them after the dataloder like done here means we can't leverage multi-processing.

start_time = time.time()
image, target = image.to(device), target.to(device)
with torch.cuda.amp.autocast(enabled=scaler is not None):
Expand Down Expand Up @@ -218,31 +221,22 @@ def main(args):
val_dir = os.path.join(args.data_path, "val")
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)

collate_fn = None
num_classes = len(dataset.classes)
mixup_transforms = []
if args.mixup_alpha > 0.0:
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
if args.cutmix_alpha > 0.0:
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
if mixup_transforms:
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)

def collate_fn(batch):
return mixupcutmix(*default_collate(batch))

data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
sampler=train_sampler,
num_workers=args.workers,
pin_memory=True,
collate_fn=collate_fn,
)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
)

num_classes = len(dataset.classes)
batch_transform = get_batch_transform(
mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_categories=num_classes, use_v2=args.use_v2
)

print("Creating model")
model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
model.to(device)
Expand Down Expand Up @@ -364,7 +358,9 @@ def collate_fn(batch):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
train_one_epoch(
model, criterion, optimizer, data_loader, batch_transform, device, epoch, args, model_ema, scaler
)
lr_scheduler.step()
evaluate(model, criterion, data_loader_test, device=device)
if model_ema:
Expand Down
23 changes: 23 additions & 0 deletions references/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,33 @@
from typing import Tuple

import torch
from presets import get_module
from torch import Tensor
from torchvision.transforms import functional as F


def get_batch_transform(*, mixup_alpha, cutmix_alpha, num_categories, use_v2):
transforms_module = get_module(use_v2)

batch_transforms = []
if mixup_alpha > 0:
batch_transforms.append(
transforms_module.Mixup(alpha=mixup_alpha, num_categories=num_categories)
if use_v2
else RandomMixup(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
)
if cutmix_alpha > 0:
batch_transforms.append(
transforms_module.Cutmix(alpha=mixup_alpha, num_categories=num_categories)
if use_v2
else RandomCutmix(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
)
if not batch_transforms:
return None

return transforms_module.RandomChoice(batch_transforms)


class RandomMixup(torch.nn.Module):
"""Randomly apply Mixup to the provided batch and targets.
The class implements the data augmentations as described in the paper
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ._transform import Transform # usort: skip

from ._augment import RandomErasing
from ._augment import Cutmix, Mixup, RandomErasing
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
Expand Down
156 changes: 153 additions & 3 deletions torchvision/transforms/v2/_augment.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import math
import numbers
import warnings
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import PIL.Image
import torch
from torch.nn.functional import one_hot
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints, transforms as _transforms
from torchvision.transforms.v2 import functional as F

from ._transform import _RandomApplyTransform
from .utils import is_simple_tensor, query_chw
from ._transform import _RandomApplyTransform, Transform
from ._utils import _parse_labels_getter
from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size


class RandomErasing(_RandomApplyTransform):
Expand Down Expand Up @@ -135,3 +138,150 @@ def _transform(
inpt = F.erase(inpt, **params, inplace=self.inplace)

return inpt


class _BaseMixupCutmix(Transform):
def __init__(self, *, alpha: float, num_categories: Optional[int] = None, labels_getter="default") -> None:
super().__init__()
self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))

self.num_categories = num_categories

self.labels_getter = labels_getter
self._labels_getter = _parse_labels_getter(labels_getter)

def forward(self, *inputs):
inputs = inputs if len(inputs) > 1 else inputs[0]
flat_inputs, spec = tree_flatten(inputs)
needs_transform_list = self._needs_transform_list(flat_inputs)

if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBox, datapoints.Mask):
raise TypeError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.")

labels = self._labels_getter(inputs)
if labels is None:
msg = "Couldn't find a label in the inputs."
if self.labels_getter == "default":
msg = f"{msg} To overwrite the default find behavior, pass a callable for labels_getter."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can write that entire message regardless of whether "default" was passed. It would simplify the logic a bit and avoid storing self.labels_getter.

raise RuntimeError(msg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically this could qualify as a ValueError as well?

elif not isinstance(labels, torch.Tensor):
raise ValueError(f"The labels must be a torch.Tensor, but got {type(labels)} instead.")
elif labels.ndim in {1, 2}:
if labels.ndim == 2 and self.num_categories is not None and labels.shape[-1] != self.num_categories:
raise ValueError(
f"2D labels are assumed to be probability based, "
f"but the number of elements in last dimension does not match the number of categories: "
f"{labels.shape[-1]} != {self.num_categories}."
)
else:
raise ValueError(
f"labels should be a index based with shape (batch_size,) "
f"or a probability based with shape (batch_size, num_categories), "
f"but got a tensor of shape {labels.shape} instead."
)

params = {
"labels": labels,
"batch_size": labels.shape[0],
**self._get_params(
[inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
),
}

# By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor, but coming
# after an image or video. However, since we want to handle them in _transform, we
needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We used a different strategy in SanitizeBoundingBox where we called _transform() on all inputs and just handled that filtering logic within _transform(). I don't have a pref right now (haven't thought about it much). But maybe we should align both transforms to follow the same strat? (we could do it in another PR)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't use the same strategy here. SanitizeBoundingBox does not affect images or videos, so we don't care about needs_transform_list there. However, here we transform images. Meaning, we need to use needs_transform_list to make use of the heuristic about what image to transform. This cannot be done in _transform since in there we have no concept if an image should be transformed or not.

Copy link
Contributor

@vfdev-5 vfdev-5 Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we transform all images (each image is collated as (N, C, H, W)) ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand what you mean: we'd need to re-implement the "tensor pass-through heuristic" in _transform() if we were to do something like in SanitizeBoundingBox(), and we don't want to do that. I feel like we could use the same strategy used here in SanitizeBoudingBox() though. But that's OK.

Shouldn't we transform all images (each image is collated as (N, C, H, W)) ?

We are transforming all images yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we could use the same strategy used here in SanitizeBoudingBox() though. But that's OK.

Yes, we can certainly also use needs_transform_list there. I'm ok with that. Up to you.

flat_outputs = [
self._transform(inpt, params) if needs_transform else inpt
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
]

return tree_unflatten(flat_outputs, spec)

def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int):
if inpt.ndim != (5 if isinstance(inpt, datapoints.Video) else 4):
raise ValueError(
f"The transform expects a batched input, but got an {type(inpt).__name__} with {inpt.ndim} dimensions."
)
if inpt.shape[0] != batch_size:
raise ValueError(
f"The batch size of the image or video does not match the batch size of the labels: "
f"{inpt.shape[0]} != {batch_size}."
)

def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor:
if label.ndim == 1:
if self.num_categories is None:
raise ValueError(
"Cannot transform an index based labels (1D tensor) into an probability based one (2D tensor), "
"when num_categories is not set."
)
label = one_hot(label, num_classes=self.num_categories)
if not label.dtype.is_floating_point:
label = label.float()
return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam))


class Mixup(_BaseMixupCutmix):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type]

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
lam = params["lam"]

if inpt is params["labels"]:
return self._mixup_label(inpt, lam=lam)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"])

output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))

if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]

return output
else:
return inpt


class Cutmix(_BaseMixupCutmix):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
lam = float(self._dist.sample(())) # type: ignore[arg-type]

H, W = query_spatial_size(flat_inputs)

r_x = torch.randint(W, ())
r_y = torch.randint(H, ())

r = 0.5 * math.sqrt(1.0 - lam)
r_w_half = int(r * W)
r_h_half = int(r * H)

x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))
box = (x1, y1, x2, y2)

lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))

return dict(box=box, lam_adjusted=lam_adjusted)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if inpt is params["labels"]:
return self._mixup_label(inpt, lam=params["lam_adjusted"])
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"])

x1, y1, x2, y2 = params["box"]
rolled = inpt.roll(1, 0)
output = inpt.clone()
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]

if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]

return output
else:
return inpt
68 changes: 8 additions & 60 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import collections
import warnings
from contextlib import suppress
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Sequence, Type, Union
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Type, Union

import PIL.Image

Expand All @@ -11,7 +9,7 @@
from torchvision import datapoints, transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform

from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size
from ._utils import _get_defaultdict, _parse_labels_getter, _setup_float_or_seq, _setup_size
from .utils import has_any, is_simple_tensor, query_bounding_box


Expand Down Expand Up @@ -298,66 +296,16 @@ def __init__(
self.min_size = min_size

self.labels_getter = labels_getter
self._labels_getter: Optional[Callable[[Any], Optional[torch.Tensor]]]
if labels_getter == "default":
self._labels_getter = self._find_labels_default_heuristic
elif callable(labels_getter):
self._labels_getter = labels_getter
elif isinstance(labels_getter, str):
self._labels_getter = lambda inputs: SanitizeBoundingBox._get_dict_or_second_tuple_entry(inputs)[
labels_getter # type: ignore[index]
]
elif labels_getter is None:
self._labels_getter = None
else:
raise ValueError(
"labels_getter should either be a str, callable, or 'default'. "
f"Got {labels_getter} of type {type(labels_getter)}."
)

@staticmethod
def _get_dict_or_second_tuple_entry(inputs: Any) -> Mapping[str, Any]:
# datasets outputs may be plain dicts like {"img": ..., "labels": ..., "bbox": ...}
# or tuples like (img, {"labels":..., "bbox": ...})
# This hacky helper accounts for both structures.
if isinstance(inputs, tuple):
inputs = inputs[1]

if not isinstance(inputs, collections.abc.Mapping):
raise ValueError(
f"If labels_getter is a str or 'default', "
f"then the input to forward() must be a dict or a tuple whose second element is a dict."
f" Got {type(inputs)} instead."
)
return inputs

@staticmethod
def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
# Tries to find a "labels" key, otherwise tries for the first key that contains "label" - case insensitive
# Returns None if nothing is found
inputs = SanitizeBoundingBox._get_dict_or_second_tuple_entry(inputs)
candidate_key = None
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if key.lower() == "labels")
if candidate_key is None:
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if "label" in key.lower())
if candidate_key is None:
raise ValueError(
"Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?"
"If there are no samples and it is by design, pass labels_getter=None."
)
return inputs[candidate_key]
self._labels_getter = _parse_labels_getter(labels_getter)

def forward(self, *inputs: Any) -> Any:
inputs = inputs if len(inputs) > 1 else inputs[0]

if self._labels_getter is None:
labels = None
else:
labels = self._labels_getter(inputs)
if labels is not None and not isinstance(labels, torch.Tensor):
raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.")
labels = self._labels_getter(inputs)
if labels is not None and not isinstance(labels, torch.Tensor):
raise ValueError(
f"The labels in the input to forward() must be a tensor or None, got {type(labels)} instead."
)

flat_inputs, spec = tree_flatten(inputs)
# TODO: this enforces one single BoundingBox entry.
Expand Down
Loading