-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prototype references #6433
base: main
Are you sure you want to change the base?
Prototype references #6433
Changes from 17 commits
c19f838
7b7602e
4ea2aaf
4990b89
1cfa965
ca4c5a7
a2e24e1
693795e
69f5299
fe96a54
0f06516
148885d
6fd5e50
6edb7f4
6fcffb2
4b68e2f
4d73fe7
7cb08d5
e51791d
3e5e064
ec120ff
f993926
6167038
4699e55
e2e459d
aa7a655
cb02041
a98c05d
49e653f
fcd37d9
05be06d
9459b0a
6c90b3a
2eccb84
8df9043
99e6c36
47772ac
94ac15d
51307b7
9dad6e0
f5f1716
2aefd09
a2893a1
8df0cf4
e912976
2e7e168
74ecb49
585c64a
93d7a32
5a311b3
aac24c1
766af6c
cb6c90e
e9c480e
3894efb
6ef4d82
5a1de52
c6950ae
b59beae
906428a
758de46
a0895c1
2bd4291
669b1ba
591a773
0db3ce2
9e95b78
4f3b593
a364b15
00d1b9b
ef3dc55
711128c
25c4664
091948e
6b23587
eb37f8f
bb468ba
5f8d233
5928876
b63e607
ab141f9
d5f1532
707190c
598542c
6a0a32c
a59f995
f72f5b2
9d0a0a3
7c41f0c
87031f1
7c5da3a
d8b5202
959af2d
d435378
bda072d
8f07159
8344ce9
8b53036
c7f2ac8
f205f1e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,14 +4,17 @@ | |
import warnings | ||
|
||
import presets | ||
from sampler import RASampler | ||
from transforms import WrapIntoFeatures | ||
import utils # usort: skip | ||
|
||
import torch | ||
import torch.utils.data | ||
import torchvision | ||
import transforms | ||
import utils | ||
from sampler import RASampler | ||
|
||
from torch import nn | ||
from torch.utils.data.dataloader import default_collate | ||
from torchvision.prototype import features, transforms | ||
from torchvision.transforms.functional import InterpolationMode | ||
|
||
|
||
|
@@ -128,12 +131,13 @@ def load_data(traindir, valdir, args): | |
random_erase_prob = getattr(args, "random_erase", 0.0) | ||
dataset = torchvision.datasets.ImageFolder( | ||
traindir, | ||
presets.ClassificationPresetTrain( | ||
transform=presets.ClassificationPresetTrain( | ||
crop_size=train_crop_size, | ||
interpolation=interpolation, | ||
auto_augment_policy=auto_augment_policy, | ||
random_erase_prob=random_erase_prob, | ||
), | ||
target_transform=lambda target: features.Label(target), | ||
) | ||
if args.cache_dataset: | ||
print(f"Saving dataset_train to {cache_path}") | ||
|
@@ -158,7 +162,8 @@ def load_data(traindir, valdir, args): | |
|
||
dataset_test = torchvision.datasets.ImageFolder( | ||
valdir, | ||
preprocessing, | ||
transform=preprocessing, | ||
target_transform=lambda target: features.Label(target), | ||
) | ||
if args.cache_dataset: | ||
print(f"Saving dataset_test to {cache_path}") | ||
|
@@ -200,14 +205,21 @@ def main(args): | |
|
||
collate_fn = None | ||
num_classes = len(dataset.classes) | ||
mixup_transforms = [] | ||
mixup_or_cutmix = [] | ||
if args.mixup_alpha > 0.0: | ||
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha)) | ||
mixup_or_cutmix.append(transforms.RandomMixup(alpha=args.mixup_alpha, p=1.0)) | ||
if args.cutmix_alpha > 0.0: | ||
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha)) | ||
if mixup_transforms: | ||
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) | ||
collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731 | ||
mixup_or_cutmix.append(transforms.RandomCutmix(alpha=args.cutmix_alpha, p=1.0)) | ||
if mixup_or_cutmix: | ||
batch_transform = transforms.Compose( | ||
[ | ||
WrapIntoFeatures(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right now we are testing the PIL backend. In order to test the |
||
transforms.LabelToOneHot(num_categories=num_classes), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed from this line the
I expected that the target_transform=lambda target: features.Label(target),
cc @pmeier There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that is a known issue. The Three options here:
My vote is out for 3. To avoid rewriting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a hard problem. I agree option 2 is not possible. What I don't like about option 3 is that if TorchVision has it's own collate, it will miss out on potential improvements from Core/Data. I also wouldn't put it inside the library but rather in the examples. Option 1 sounds promising only if we feel we can do all the necessary checks to avoid gotchas (+ think of whether this is a performance bottleneck). If that's the case, we could offer implementations that whitelist the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Depends on what architecture There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Talked to Vitaly about it a while ago and we thought adding argument to Would it be possible if we let class CustomObject:
@staticmethod
def collate_fn(batch):
return ...
def default_collate(batch:
elem = batch[0]
...
if hasattr(elem, "collate_fn"):
return elem.collate_fn(batch) This would provide the same mechanism as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Shouldn't it be an I would prefer 1. since otherwise an object would need to know how it wants to be collated, but doesn't have all the context. For example, for classification and segmentation tasks, we want to Plus, we don't have to worry about JIT and that can be quite a pain. If you send a PR, please ping me there so I can see if it actually is extensible enough for us. @datumbox can you confirm that you would also be ok with handling this in the collation function if we only need to provide custom behavior for our stuff, but otherwise can depend on the core functionality? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If the Will open a PR today. It should be a simple change. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Well, not really. Since Python 3.6, dct = {"a": 1, "b": 2} you can be sure that assert {"a": 1, "b": 2} == {"b": 2, "a": 1} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you try iterate over those dicts, you will find it's ordered. d1 = {"a": 1, "b": 2}
for k in d1:
print(k) # a, b
d2 = {"b": 2, "a": 1}
for k in d2:
print(k) # b, a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Just to confirm, is your proposal to handle custom behaviour on our reference scripts? If yes, then that sounds good. I wouldn't want to add this on main TorchVision but rather leave it to the users to handle those special bases based on their use-case/dataset. Does that make sense? |
||
transforms.ToDtype(torch.float, features.OneHotLabel), | ||
transforms.RandomChoice(*mixup_or_cutmix), | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
] | ||
) | ||
collate_fn = lambda batch: batch_transform(default_collate(batch)) # noqa: E731 | ||
data_loader = torch.utils.data.DataLoader( | ||
dataset, | ||
batch_size=args.batch_size, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,183 +1,8 @@ | ||
import math | ||
from typing import Tuple | ||
from torch import nn | ||
from torchvision.prototype import features | ||
|
||
import torch | ||
from torch import Tensor | ||
from torchvision.transforms import functional as F | ||
|
||
|
||
class RandomMixup(torch.nn.Module): | ||
"""Randomly apply Mixup to the provided batch and targets. | ||
The class implements the data augmentations as described in the paper | ||
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_. | ||
|
||
Args: | ||
num_classes (int): number of classes used for one-hot encoding. | ||
p (float): probability of the batch being transformed. Default value is 0.5. | ||
alpha (float): hyperparameter of the Beta distribution used for mixup. | ||
Default value is 1.0. | ||
inplace (bool): boolean to make this transform inplace. Default set to False. | ||
""" | ||
|
||
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: | ||
super().__init__() | ||
|
||
if num_classes < 1: | ||
raise ValueError( | ||
f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}" | ||
) | ||
|
||
if alpha <= 0: | ||
raise ValueError("Alpha param can't be zero.") | ||
|
||
self.num_classes = num_classes | ||
self.p = p | ||
self.alpha = alpha | ||
self.inplace = inplace | ||
|
||
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: | ||
""" | ||
Args: | ||
batch (Tensor): Float tensor of size (B, C, H, W) | ||
target (Tensor): Integer tensor of size (B, ) | ||
|
||
Returns: | ||
Tensor: Randomly transformed batch. | ||
""" | ||
if batch.ndim != 4: | ||
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") | ||
if target.ndim != 1: | ||
raise ValueError(f"Target ndim should be 1. Got {target.ndim}") | ||
if not batch.is_floating_point(): | ||
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") | ||
if target.dtype != torch.int64: | ||
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") | ||
|
||
if not self.inplace: | ||
batch = batch.clone() | ||
target = target.clone() | ||
|
||
if target.ndim == 1: | ||
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) | ||
|
||
if torch.rand(1).item() >= self.p: | ||
return batch, target | ||
|
||
# It's faster to roll the batch by one instead of shuffling it to create image pairs | ||
batch_rolled = batch.roll(1, 0) | ||
target_rolled = target.roll(1, 0) | ||
|
||
# Implemented as on mixup paper, page 3. | ||
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) | ||
batch_rolled.mul_(1.0 - lambda_param) | ||
batch.mul_(lambda_param).add_(batch_rolled) | ||
|
||
target_rolled.mul_(1.0 - lambda_param) | ||
target.mul_(lambda_param).add_(target_rolled) | ||
|
||
return batch, target | ||
|
||
def __repr__(self) -> str: | ||
s = ( | ||
f"{self.__class__.__name__}(" | ||
f"num_classes={self.num_classes}" | ||
f", p={self.p}" | ||
f", alpha={self.alpha}" | ||
f", inplace={self.inplace}" | ||
f")" | ||
) | ||
return s | ||
|
||
|
||
class RandomCutmix(torch.nn.Module): | ||
"""Randomly apply Cutmix to the provided batch and targets. | ||
The class implements the data augmentations as described in the paper | ||
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" | ||
<https://arxiv.org/abs/1905.04899>`_. | ||
|
||
Args: | ||
num_classes (int): number of classes used for one-hot encoding. | ||
p (float): probability of the batch being transformed. Default value is 0.5. | ||
alpha (float): hyperparameter of the Beta distribution used for cutmix. | ||
Default value is 1.0. | ||
inplace (bool): boolean to make this transform inplace. Default set to False. | ||
""" | ||
|
||
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: | ||
super().__init__() | ||
if num_classes < 1: | ||
raise ValueError("Please provide a valid positive value for the num_classes.") | ||
if alpha <= 0: | ||
raise ValueError("Alpha param can't be zero.") | ||
|
||
self.num_classes = num_classes | ||
self.p = p | ||
self.alpha = alpha | ||
self.inplace = inplace | ||
|
||
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: | ||
""" | ||
Args: | ||
batch (Tensor): Float tensor of size (B, C, H, W) | ||
target (Tensor): Integer tensor of size (B, ) | ||
|
||
Returns: | ||
Tensor: Randomly transformed batch. | ||
""" | ||
if batch.ndim != 4: | ||
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") | ||
if target.ndim != 1: | ||
raise ValueError(f"Target ndim should be 1. Got {target.ndim}") | ||
if not batch.is_floating_point(): | ||
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") | ||
if target.dtype != torch.int64: | ||
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") | ||
|
||
if not self.inplace: | ||
batch = batch.clone() | ||
target = target.clone() | ||
|
||
if target.ndim == 1: | ||
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) | ||
|
||
if torch.rand(1).item() >= self.p: | ||
return batch, target | ||
|
||
# It's faster to roll the batch by one instead of shuffling it to create image pairs | ||
batch_rolled = batch.roll(1, 0) | ||
target_rolled = target.roll(1, 0) | ||
|
||
# Implemented as on cutmix paper, page 12 (with minor corrections on typos). | ||
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) | ||
_, H, W = F.get_dimensions(batch) | ||
|
||
r_x = torch.randint(W, (1,)) | ||
r_y = torch.randint(H, (1,)) | ||
|
||
r = 0.5 * math.sqrt(1.0 - lambda_param) | ||
r_w_half = int(r * W) | ||
r_h_half = int(r * H) | ||
|
||
x1 = int(torch.clamp(r_x - r_w_half, min=0)) | ||
y1 = int(torch.clamp(r_y - r_h_half, min=0)) | ||
x2 = int(torch.clamp(r_x + r_w_half, max=W)) | ||
y2 = int(torch.clamp(r_y + r_h_half, max=H)) | ||
|
||
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] | ||
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) | ||
|
||
target_rolled.mul_(1.0 - lambda_param) | ||
target.mul_(lambda_param).add_(target_rolled) | ||
|
||
return batch, target | ||
|
||
def __repr__(self) -> str: | ||
s = ( | ||
f"{self.__class__.__name__}(" | ||
f"num_classes={self.num_classes}" | ||
f", p={self.p}" | ||
f", alpha={self.alpha}" | ||
f", inplace={self.inplace}" | ||
f")" | ||
) | ||
return s | ||
class WrapIntoFeatures(nn.Module): | ||
def forward(self, sample): | ||
input, target = sample | ||
return features.Image(input), features.Label(target) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we have to WrapIntoFeatures unconditionally from whether we use mixup/cutmix?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is that
default_collate
does not respect tensor subclasses. Since we use this transform afterwards, we need to wrap here. Of course we can also wrap before, but it is not necessary since the input is a plain tensor that defaults to images and an integer which is completely ignored.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My question is, what if I don't use mixup or cutmix? Shouldn't we wrap the data into
features.Image
anyway? I might be missing something here. My main point is that since we are testing the new API, we should probably wrap all inputs using their appropriate types and see how the new kernels behave (Rather than relying on their default/legacy pure tensor implementations).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can only wrap after we have converted from PIL. This happens fairly late in the transform:
vision/references/classification/presets.py
Line 33 in b83d5f7
I remember @vfdev-5 noting that on the CPU PIL kernels are faster (I don't remember if there was a special case or other constraints; please fill the blanks). Thus, if we want to optimize for speed, we should probably leave it as is. No strong opinion though.