Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype references #6433

Draft
wants to merge 100 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
c19f838
use prototype transforms in classification reference
pmeier Aug 17, 2022
7b7602e
cleanup
pmeier Aug 17, 2022
4ea2aaf
Merge branch 'main' into prototype-references/classification
pmeier Aug 18, 2022
4990b89
move WrapIntoFeatures into transforms module
pmeier Aug 18, 2022
1cfa965
Merge branch 'main' into prototype-references/classification
pmeier Aug 18, 2022
ca4c5a7
[skip ci] add p=1.0 to CutMix and MixUp
pmeier Aug 18, 2022
a2e24e1
Merge branch 'main' into prototype-references/classification
pmeier Aug 23, 2022
693795e
[skip ci]
pmeier Aug 23, 2022
69f5299
Merge branch 'main' into prototype-references/classification
pmeier Aug 24, 2022
fe96a54
use prototype transforms in detection references
pmeier Aug 24, 2022
0f06516
Merge branch 'main' into prototype-references/classification
pmeier Aug 24, 2022
148885d
Merge branch 'main' into prototype-references/classification
pmeier Aug 26, 2022
6fd5e50
[skip ci]
pmeier Aug 26, 2022
6edb7f4
Merge branch 'main' into prototype-references/classification
pmeier Aug 30, 2022
6fcffb2
[skip ci]
pmeier Aug 30, 2022
4b68e2f
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Aug 30, 2022
4d73fe7
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Aug 31, 2022
7cb08d5
[skip ci] fix scripts
pmeier Sep 1, 2022
e51791d
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Sep 1, 2022
3e5e064
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Sep 1, 2022
ec120ff
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Sep 1, 2022
f993926
Merge branch 'main' into prototype-references/classification
datumbox Sep 5, 2022
6167038
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Sep 6, 2022
4699e55
[skip ci] Merge branch 'prototype-references/classification' of https…
pmeier Sep 6, 2022
e2e459d
Merge branch 'main' into prototype-references/classification
datumbox Sep 7, 2022
aa7a655
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Sep 7, 2022
cb02041
Merge branch 'prototype-references/classification' of https://github.…
pmeier Sep 7, 2022
a98c05d
[SKIP CI] CircleCI
pmeier Sep 7, 2022
49e653f
[skip ci]
pmeier Sep 7, 2022
fcd37d9
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Sep 8, 2022
05be06d
Merge branch 'main' into prototype-references/classification
pmeier Sep 12, 2022
9459b0a
Merge branch 'main' into prototype-references/classification
pmeier Sep 13, 2022
6c90b3a
Merge branch 'main' into prototype-references/classification
pmeier Sep 13, 2022
2eccb84
Merge branch 'main' into prototype-references/classification
pmeier Sep 14, 2022
8df9043
update segmentation references
pmeier Sep 13, 2022
99e6c36
[skip ci]
pmeier Sep 14, 2022
47772ac
Merge branch 'main' into prototype-references/classification
pmeier Sep 14, 2022
94ac15d
[skip ci]
pmeier Sep 14, 2022
51307b7
[skip ci] fix workaround
pmeier Sep 14, 2022
9dad6e0
only wrap segmentation mask
pmeier Sep 14, 2022
f5f1716
fix pretrained weights test only
pmeier Sep 14, 2022
2aefd09
[skip ci]
pmeier Sep 14, 2022
a2893a1
Restore get_dimensions
datumbox Sep 14, 2022
8df0cf4
Merge branch 'main' into prototype-references/classification
pmeier Sep 21, 2022
e912976
fix segmentation transforms
pmeier Sep 21, 2022
2e7e168
[skip ci]
pmeier Sep 21, 2022
74ecb49
Merge branch 'prototype-references/classification' of https://github.…
pmeier Sep 21, 2022
585c64a
fix mask rewrapping
pmeier Sep 21, 2022
93d7a32
[skip ci]
pmeier Sep 21, 2022
5a311b3
Merge branch 'main' into prototype-references/classification
datumbox Sep 21, 2022
aac24c1
Merge branch 'main' into prototype-references/classification
datumbox Sep 23, 2022
766af6c
Fix merge issue
datumbox Sep 23, 2022
cb6c90e
Tensor Backend + antialiasing=True
datumbox Sep 23, 2022
e9c480e
Switch to view to reshape to avoid incompatibilities with size/stride
datumbox Sep 23, 2022
3894efb
Merge branch 'main' into prototype-references/classification
datumbox Sep 25, 2022
6ef4d82
Cherrypick PR #6642
datumbox Sep 25, 2022
5a1de52
Merge branch 'main' into prototype-references/classification
datumbox Sep 26, 2022
c6950ae
Merge branch 'main'
pmeier Oct 10, 2022
b59beae
Merge branch 'main' into prototype-references/classification
pmeier Oct 10, 2022
906428a
Merge branch 'main' into prototype-references/classification
pmeier Oct 10, 2022
758de46
[skip ci] add support for video_classification
pmeier Oct 10, 2022
a0895c1
Merge branch 'prototype-references/classification' of https://github.…
pmeier Oct 10, 2022
2bd4291
Merge branch 'main' into prototype-references/classification
datumbox Oct 11, 2022
669b1ba
Restoring original reference transforms so that test can run
datumbox Oct 11, 2022
591a773
Adding AA, Random Erase, MixUp/CutMix and a different resize/crop str…
datumbox Oct 11, 2022
0db3ce2
Merge branch 'main' into prototype-references/classification
datumbox Oct 11, 2022
9e95b78
image_size to spatial_size
datumbox Oct 13, 2022
4f3b593
Merge branch 'main' into prototype-references/classification
datumbox Oct 14, 2022
a364b15
Merge branch 'main' into prototype-references/classification
datumbox Oct 14, 2022
00d1b9b
Update the RandomShortestSize behaviour on Video presets.
datumbox Oct 14, 2022
ef3dc55
Fix ToDtype transform to accept dictionaries.
datumbox Oct 14, 2022
711128c
Merge branch 'main' into prototype-references/classification
datumbox Oct 14, 2022
25c4664
Fix issue with collate and audio using Philip's proposal.
datumbox Oct 14, 2022
091948e
Fix linter
datumbox Oct 14, 2022
6b23587
Fix ToDtype parameters.
datumbox Oct 14, 2022
eb37f8f
Wrapping id into a no-op.
datumbox Oct 14, 2022
bb468ba
Define `_Feature` in the dict.
datumbox Oct 14, 2022
5f8d233
Merge branch 'main' into prototype-references/classification
datumbox Oct 14, 2022
5928876
Handling hot-encoded tensors in `accuracy`
datumbox Oct 14, 2022
b63e607
Handle ConvertBCHWtoCBHW interactions with mixup/cutmix.
datumbox Oct 14, 2022
ab141f9
Merge branch 'main' into prototype-references/classification
datumbox Oct 14, 2022
d5f1532
Add Permute Transform.
datumbox Oct 14, 2022
707190c
Merge branch 'main' into prototype-references/classification
datumbox Oct 19, 2022
598542c
Merge branch 'main' into prototype-references/classification
datumbox Oct 21, 2022
6a0a32c
Switch to `TransposeDimensions`
datumbox Oct 21, 2022
a59f995
Merge branch 'main' into prototype-references/classification
datumbox Oct 26, 2022
f72f5b2
Merge branch 'main' into prototype-references/classification
datumbox Oct 27, 2022
9d0a0a3
Fix linter.
datumbox Oct 27, 2022
7c41f0c
Merge branch 'main' into prototype-references/classification
datumbox Oct 31, 2022
87031f1
Merge branch 'main' into prototype-references/classification
datumbox Nov 4, 2022
7c5da3a
Merge branch 'main' into prototype-references/classification
datumbox Nov 4, 2022
d8b5202
Fix method location.
datumbox Nov 4, 2022
959af2d
Fixing minor bug
datumbox Nov 7, 2022
d435378
Merge branch 'main' into prototype-references/classification
datumbox Nov 15, 2022
bda072d
Merge branch 'main' into prototype-references/classification
datumbox Nov 16, 2022
8f07159
Convert to floats at the beginning.
datumbox Nov 17, 2022
8344ce9
Revert "Convert to floats at the beginning."
datumbox Nov 17, 2022
8b53036
Switch to PIL backend
datumbox Nov 17, 2022
c7f2ac8
Revert "Switch to PIL backend"
datumbox Nov 17, 2022
f205f1e
Merge branch 'main' of github.com:pytorch/vision into prototype-refer…
NicolasHug Feb 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from torchvision.transforms import autoaugment, transforms
from torchvision.prototype import transforms
from torchvision.transforms.functional import InterpolationMode


Expand All @@ -17,20 +17,20 @@ def __init__(
):
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
trans.append(transforms.RandomHorizontalFlip(p=hflip_prob))
if auto_augment_policy is not None:
if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment(interpolation=interpolation))
trans.append(transforms.RandAugment(interpolation=interpolation))
elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
trans.append(transforms.TrivialAugmentWide(interpolation=interpolation))
elif auto_augment_policy == "augmix":
trans.append(autoaugment.AugMix(interpolation=interpolation))
trans.append(transforms.AugMix(interpolation=interpolation))
else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
aa_policy = transforms.AutoAugmentPolicy(auto_augment_policy)
trans.append(transforms.AutoAugment(policy=aa_policy, interpolation=interpolation))
trans.extend(
[
transforms.PILToTensor(),
transforms.ToImageTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
]
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(
[
transforms.Resize(resize_size, interpolation=interpolation),
transforms.CenterCrop(crop_size),
transforms.PILToTensor(),
transforms.ToImageTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
]
Expand Down
34 changes: 23 additions & 11 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
import warnings

import presets
from sampler import RASampler
from transforms import WrapIntoFeatures
import utils # usort: skip

import torch
import torch.utils.data
import torchvision
import transforms
import utils
from sampler import RASampler

from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.prototype import features, transforms
from torchvision.transforms.functional import InterpolationMode


Expand Down Expand Up @@ -128,12 +131,13 @@ def load_data(traindir, valdir, args):
random_erase_prob = getattr(args, "random_erase", 0.0)
dataset = torchvision.datasets.ImageFolder(
traindir,
presets.ClassificationPresetTrain(
transform=presets.ClassificationPresetTrain(
crop_size=train_crop_size,
interpolation=interpolation,
auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob,
),
target_transform=lambda target: features.Label(target),
)
if args.cache_dataset:
print(f"Saving dataset_train to {cache_path}")
Expand All @@ -158,7 +162,8 @@ def load_data(traindir, valdir, args):

dataset_test = torchvision.datasets.ImageFolder(
valdir,
preprocessing,
transform=preprocessing,
target_transform=lambda target: features.Label(target),
)
if args.cache_dataset:
print(f"Saving dataset_test to {cache_path}")
Expand Down Expand Up @@ -200,14 +205,21 @@ def main(args):

collate_fn = None
num_classes = len(dataset.classes)
mixup_transforms = []
mixup_or_cutmix = []
if args.mixup_alpha > 0.0:
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
mixup_or_cutmix.append(transforms.RandomMixup(alpha=args.mixup_alpha, p=1.0))
if args.cutmix_alpha > 0.0:
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
if mixup_transforms:
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731
mixup_or_cutmix.append(transforms.RandomCutmix(alpha=args.cutmix_alpha, p=1.0))
if mixup_or_cutmix:
batch_transform = transforms.Compose(
[
WrapIntoFeatures(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we have to WrapIntoFeatures unconditionally from whether we use mixup/cutmix?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that default_collate does not respect tensor subclasses. Since we use this transform afterwards, we need to wrap here. Of course we can also wrap before, but it is not necessary since the input is a plain tensor that defaults to images and an integer which is completely ignored.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My question is, what if I don't use mixup or cutmix? Shouldn't we wrap the data into features.Image anyway? I might be missing something here. My main point is that since we are testing the new API, we should probably wrap all inputs using their appropriate types and see how the new kernels behave (Rather than relying on their default/legacy pure tensor implementations).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can only wrap after we have converted from PIL. This happens fairly late in the transform:

transforms.PILToTensor(),

I remember @vfdev-5 noting that on the CPU PIL kernels are faster (I don't remember if there was a special case or other constraints; please fill the blanks). Thus, if we want to optimize for speed, we should probably leave it as is. No strong opinion though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now we are testing the PIL backend. In order to test the features.Image one, we will need to move this at the beginning (aka on the transforms defined for ImageFolder), right?

transforms.LabelToOneHot(num_categories=num_classes),
Copy link
Contributor

@datumbox datumbox Sep 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed from this line the WrapIntoFeatures() call and now get:

TypeError: RandomCutmix() is only defined for tensor images and one-hot labels.

I expected that the Label is already annotated because the dataset is defined with:

target_transform=lambda target: features.Label(target),

I'll check it out further on Monday. It's because of the collating. UI-wise, it's not great that batching causes the _Feature information to be lost. Is there a way to improve upon this?

cc @pmeier

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that is a known issue. The default_collate internally uses torch.stack. Since we only automatically wrap on .to() and .clone() this will drop the feature type.

Three options here:

  1. Allow torch.stack to retain the feature type. I would not go that route, since we cannot know in general if the result of a stack is still the feature we started out with. We at least need more elaborate checks that make sure this is happening.
  2. Fix default_collate to handle our feature types. It is a possible solution, but won't happen, because the feature types are torchvision only and default_collate is a general solution.
  3. Provide a custom collate like vision_collate. This would mean that depending on the taks, users can no longer go with the default value in the DataLoader, but judging from our references, this seems to be somewhat common. Plus, this vision_collate could also include more functionality that we need:
    1. Don't torch.stack images if we find BoundingBox'es.
    2. Support for None that some of our datasets already return (add test if dataset samples can be collated #5233) and will be heavily utilized by the new datasets for missing data in a configuration.

My vote is out for 3. To avoid rewriting default_collate from scratch in torchvision and thus risking it going out of sync with the core version, I made a push to make default_collate more extensible in this gist. @ejguan wanted to handle this from the torchdata side. Any progress there?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a hard problem. I agree option 2 is not possible. What I don't like about option 3 is that if TorchVision has it's own collate, it will miss out on potential improvements from Core/Data. I also wouldn't put it inside the library but rather in the examples.

Option 1 sounds promising only if we feel we can do all the necessary checks to avoid gotchas (+ think of whether this is a performance bottleneck). If that's the case, we could offer implementations that whitelist the stack similary to to() and clone(). But if checking all the TV meta-data + Tensor meta-data is too slow, then we might leave it to the user to decide how to handle it efficiently on a case by case basis (as you have it here). I'm also open exploring other ideas to rewrap the types, though nothing elegant pops to my mind right now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I don't like about option 3 is that if TorchVision has it's own collate, it will miss out on potential improvements from Core/Data. I also wouldn't put it inside the library but rather in the examples.

Depends on what architecture default_collate actually gets after the refactor. My proposal would allow us to extend the functionality. For example, the collation for features.Image's could just call the collation for tensors and rewrap at the end. This way if there are any improvements for tensor collation, we automatically also benefit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked to Vitaly about it a while ago and we thought adding argument to default_collate might not be the ideal option because users can always do it in a reversed way by providing custom collate function and call default_collate if needed.

Would it be possible if we let default_collate to do collate by checking if the object is attached with a specific method (tentative: collate_fn)?

class CustomObject:
   @staticmethod
    def collate_fn(batch):
        return ...

def default_collate(batch:
    elem = batch[0]
    ...
    if hasattr(elem, "collate_fn"):
        return elem.collate_fn(batch)

This would provide the same mechanism as pin_memory

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Second go through each key in dict and check isinstance for derived class

Shouldn't it be an OrderedDict then? Otherwise if an input is a subtype of multiple types in the map, two runs may yield different results making the collation non-deterministic.


I would prefer 1. since otherwise an object would need to know how it wants to be collated, but doesn't have all the context. For example, for classification and segmentation tasks, we want to torch.stack images, but for detection tasks they should be kept in a list. We can figure this out from the full sample passed to the collation function, but he object itself can't.

Plus, we don't have to worry about JIT and that can be quite a pain.

If you send a PR, please ping me there so I can see if it actually is extensible enough for us. @datumbox can you confirm that you would also be ok with handling this in the collation function if we only need to provide custom behavior for our stuff, but otherwise can depend on the core functionality?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be an OrderedDict then?

If the type exists in the dict, it should invoke the corresponding collate function. And, dict is actually ordered in Python. So, if an input is a subtype of multiple types, and the type of the input doesn't equal to any of key, it will follow the order.

Will open a PR today. It should be a simple change.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And, dict is actually ordered in Python.

Well, not really. Since Python 3.6, dict's are insertion ordered. Meaning, if you create a dictionary like

dct = {"a": 1, "b": 2}

you can be sure that "a" is always the first key. However, dict's are not ordered

assert {"a": 1, "b": 2} == {"b": 2, "a": 1}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you try iterate over those dicts, you will find it's ordered.

d1 = {"a": 1, "b": 2}
for k in d1:
    print(k)  # a, b

d2 = {"b": 2, "a": 1}
for k in d2:
    print(k)  # b, a

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox can you confirm that you would also be ok with handling this in the collation function if we only need to provide custom behavior for our stuff, but otherwise can depend on the core functionality?

Just to confirm, is your proposal to handle custom behaviour on our reference scripts? If yes, then that sounds good. I wouldn't want to add this on main TorchVision but rather leave it to the users to handle those special bases based on their use-case/dataset. Does that make sense?

transforms.ToDtype(torch.float, features.OneHotLabel),
transforms.RandomChoice(*mixup_or_cutmix),
pmeier marked this conversation as resolved.
Show resolved Hide resolved
]
)
collate_fn = lambda batch: batch_transform(default_collate(batch)) # noqa: E731
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
Expand Down
187 changes: 6 additions & 181 deletions references/classification/transforms.py
Original file line number Diff line number Diff line change
@@ -1,183 +1,8 @@
import math
from typing import Tuple
from torch import nn
from torchvision.prototype import features

import torch
from torch import Tensor
from torchvision.transforms import functional as F


class RandomMixup(torch.nn.Module):
"""Randomly apply Mixup to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.

Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for mixup.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""

def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
super().__init__()

if num_classes < 1:
raise ValueError(
f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
)

if alpha <= 0:
raise ValueError("Alpha param can't be zero.")

self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace

def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )

Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
if target.ndim != 1:
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
if not batch.is_floating_point():
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
if target.dtype != torch.int64:
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")

if not self.inplace:
batch = batch.clone()
target = target.clone()

if target.ndim == 1:
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)

if torch.rand(1).item() >= self.p:
return batch, target

# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1, 0)

# Implemented as on mixup paper, page 3.
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
batch_rolled.mul_(1.0 - lambda_param)
batch.mul_(lambda_param).add_(batch_rolled)

target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)

return batch, target

def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"num_classes={self.num_classes}"
f", p={self.p}"
f", alpha={self.alpha}"
f", inplace={self.inplace}"
f")"
)
return s


class RandomCutmix(torch.nn.Module):
"""Randomly apply Cutmix to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
<https://arxiv.org/abs/1905.04899>`_.

Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for cutmix.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""

def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
super().__init__()
if num_classes < 1:
raise ValueError("Please provide a valid positive value for the num_classes.")
if alpha <= 0:
raise ValueError("Alpha param can't be zero.")

self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace

def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )

Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
if target.ndim != 1:
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
if not batch.is_floating_point():
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
if target.dtype != torch.int64:
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")

if not self.inplace:
batch = batch.clone()
target = target.clone()

if target.ndim == 1:
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)

if torch.rand(1).item() >= self.p:
return batch, target

# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1, 0)

# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
_, H, W = F.get_dimensions(batch)

r_x = torch.randint(W, (1,))
r_y = torch.randint(H, (1,))

r = 0.5 * math.sqrt(1.0 - lambda_param)
r_w_half = int(r * W)
r_h_half = int(r * H)

x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))

batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))

target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)

return batch, target

def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"num_classes={self.num_classes}"
f", p={self.p}"
f", alpha={self.alpha}"
f", inplace={self.inplace}"
f")"
)
return s
class WrapIntoFeatures(nn.Module):
def forward(self, sample):
input, target = sample
return features.Image(input), features.Label(target)
Loading