Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype references #6433

Draft
wants to merge 100 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
c19f838
use prototype transforms in classification reference
pmeier Aug 17, 2022
7b7602e
cleanup
pmeier Aug 17, 2022
4ea2aaf
Merge branch 'main' into prototype-references/classification
pmeier Aug 18, 2022
4990b89
move WrapIntoFeatures into transforms module
pmeier Aug 18, 2022
1cfa965
Merge branch 'main' into prototype-references/classification
pmeier Aug 18, 2022
ca4c5a7
[skip ci] add p=1.0 to CutMix and MixUp
pmeier Aug 18, 2022
a2e24e1
Merge branch 'main' into prototype-references/classification
pmeier Aug 23, 2022
693795e
[skip ci]
pmeier Aug 23, 2022
69f5299
Merge branch 'main' into prototype-references/classification
pmeier Aug 24, 2022
fe96a54
use prototype transforms in detection references
pmeier Aug 24, 2022
0f06516
Merge branch 'main' into prototype-references/classification
pmeier Aug 24, 2022
148885d
Merge branch 'main' into prototype-references/classification
pmeier Aug 26, 2022
6fd5e50
[skip ci]
pmeier Aug 26, 2022
6edb7f4
Merge branch 'main' into prototype-references/classification
pmeier Aug 30, 2022
6fcffb2
[skip ci]
pmeier Aug 30, 2022
4b68e2f
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Aug 30, 2022
4d73fe7
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Aug 31, 2022
7cb08d5
[skip ci] fix scripts
pmeier Sep 1, 2022
e51791d
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Sep 1, 2022
3e5e064
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Sep 1, 2022
ec120ff
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Sep 1, 2022
f993926
Merge branch 'main' into prototype-references/classification
datumbox Sep 5, 2022
6167038
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Sep 6, 2022
4699e55
[skip ci] Merge branch 'prototype-references/classification' of https…
pmeier Sep 6, 2022
e2e459d
Merge branch 'main' into prototype-references/classification
datumbox Sep 7, 2022
aa7a655
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Sep 7, 2022
cb02041
Merge branch 'prototype-references/classification' of https://github.…
pmeier Sep 7, 2022
a98c05d
[SKIP CI] CircleCI
pmeier Sep 7, 2022
49e653f
[skip ci]
pmeier Sep 7, 2022
fcd37d9
[skip ci] Merge branch 'main' into prototype-references/classification
pmeier Sep 8, 2022
05be06d
Merge branch 'main' into prototype-references/classification
pmeier Sep 12, 2022
9459b0a
Merge branch 'main' into prototype-references/classification
pmeier Sep 13, 2022
6c90b3a
Merge branch 'main' into prototype-references/classification
pmeier Sep 13, 2022
2eccb84
Merge branch 'main' into prototype-references/classification
pmeier Sep 14, 2022
8df9043
update segmentation references
pmeier Sep 13, 2022
99e6c36
[skip ci]
pmeier Sep 14, 2022
47772ac
Merge branch 'main' into prototype-references/classification
pmeier Sep 14, 2022
94ac15d
[skip ci]
pmeier Sep 14, 2022
51307b7
[skip ci] fix workaround
pmeier Sep 14, 2022
9dad6e0
only wrap segmentation mask
pmeier Sep 14, 2022
f5f1716
fix pretrained weights test only
pmeier Sep 14, 2022
2aefd09
[skip ci]
pmeier Sep 14, 2022
a2893a1
Restore get_dimensions
datumbox Sep 14, 2022
8df0cf4
Merge branch 'main' into prototype-references/classification
pmeier Sep 21, 2022
e912976
fix segmentation transforms
pmeier Sep 21, 2022
2e7e168
[skip ci]
pmeier Sep 21, 2022
74ecb49
Merge branch 'prototype-references/classification' of https://github.…
pmeier Sep 21, 2022
585c64a
fix mask rewrapping
pmeier Sep 21, 2022
93d7a32
[skip ci]
pmeier Sep 21, 2022
5a311b3
Merge branch 'main' into prototype-references/classification
datumbox Sep 21, 2022
aac24c1
Merge branch 'main' into prototype-references/classification
datumbox Sep 23, 2022
766af6c
Fix merge issue
datumbox Sep 23, 2022
cb6c90e
Tensor Backend + antialiasing=True
datumbox Sep 23, 2022
e9c480e
Switch to view to reshape to avoid incompatibilities with size/stride
datumbox Sep 23, 2022
3894efb
Merge branch 'main' into prototype-references/classification
datumbox Sep 25, 2022
6ef4d82
Cherrypick PR #6642
datumbox Sep 25, 2022
5a1de52
Merge branch 'main' into prototype-references/classification
datumbox Sep 26, 2022
c6950ae
Merge branch 'main'
pmeier Oct 10, 2022
b59beae
Merge branch 'main' into prototype-references/classification
pmeier Oct 10, 2022
906428a
Merge branch 'main' into prototype-references/classification
pmeier Oct 10, 2022
758de46
[skip ci] add support for video_classification
pmeier Oct 10, 2022
a0895c1
Merge branch 'prototype-references/classification' of https://github.…
pmeier Oct 10, 2022
2bd4291
Merge branch 'main' into prototype-references/classification
datumbox Oct 11, 2022
669b1ba
Restoring original reference transforms so that test can run
datumbox Oct 11, 2022
591a773
Adding AA, Random Erase, MixUp/CutMix and a different resize/crop str…
datumbox Oct 11, 2022
0db3ce2
Merge branch 'main' into prototype-references/classification
datumbox Oct 11, 2022
9e95b78
image_size to spatial_size
datumbox Oct 13, 2022
4f3b593
Merge branch 'main' into prototype-references/classification
datumbox Oct 14, 2022
a364b15
Merge branch 'main' into prototype-references/classification
datumbox Oct 14, 2022
00d1b9b
Update the RandomShortestSize behaviour on Video presets.
datumbox Oct 14, 2022
ef3dc55
Fix ToDtype transform to accept dictionaries.
datumbox Oct 14, 2022
711128c
Merge branch 'main' into prototype-references/classification
datumbox Oct 14, 2022
25c4664
Fix issue with collate and audio using Philip's proposal.
datumbox Oct 14, 2022
091948e
Fix linter
datumbox Oct 14, 2022
6b23587
Fix ToDtype parameters.
datumbox Oct 14, 2022
eb37f8f
Wrapping id into a no-op.
datumbox Oct 14, 2022
bb468ba
Define `_Feature` in the dict.
datumbox Oct 14, 2022
5f8d233
Merge branch 'main' into prototype-references/classification
datumbox Oct 14, 2022
5928876
Handling hot-encoded tensors in `accuracy`
datumbox Oct 14, 2022
b63e607
Handle ConvertBCHWtoCBHW interactions with mixup/cutmix.
datumbox Oct 14, 2022
ab141f9
Merge branch 'main' into prototype-references/classification
datumbox Oct 14, 2022
d5f1532
Add Permute Transform.
datumbox Oct 14, 2022
707190c
Merge branch 'main' into prototype-references/classification
datumbox Oct 19, 2022
598542c
Merge branch 'main' into prototype-references/classification
datumbox Oct 21, 2022
6a0a32c
Switch to `TransposeDimensions`
datumbox Oct 21, 2022
a59f995
Merge branch 'main' into prototype-references/classification
datumbox Oct 26, 2022
f72f5b2
Merge branch 'main' into prototype-references/classification
datumbox Oct 27, 2022
9d0a0a3
Fix linter.
datumbox Oct 27, 2022
7c41f0c
Merge branch 'main' into prototype-references/classification
datumbox Oct 31, 2022
87031f1
Merge branch 'main' into prototype-references/classification
datumbox Nov 4, 2022
7c5da3a
Merge branch 'main' into prototype-references/classification
datumbox Nov 4, 2022
d8b5202
Fix method location.
datumbox Nov 4, 2022
959af2d
Fixing minor bug
datumbox Nov 7, 2022
d435378
Merge branch 'main' into prototype-references/classification
datumbox Nov 15, 2022
bda072d
Merge branch 'main' into prototype-references/classification
datumbox Nov 16, 2022
8f07159
Convert to floats at the beginning.
datumbox Nov 17, 2022
8344ce9
Revert "Convert to floats at the beginning."
datumbox Nov 17, 2022
8b53036
Switch to PIL backend
datumbox Nov 17, 2022
c7f2ac8
Revert "Switch to PIL backend"
datumbox Nov 17, 2022
f205f1e
Merge branch 'main' of github.com:pytorch/vision into prototype-refer…
NicolasHug Feb 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from torchvision.transforms import autoaugment, transforms
from torchvision.prototype import transforms
from torchvision.transforms.functional import InterpolationMode


Expand All @@ -17,22 +17,24 @@ def __init__(
augmix_severity=3,
random_erase_prob=0.0,
):
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
trans = [
transforms.ToImageTensor(),
transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True),
]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
trans.append(transforms.RandomHorizontalFlip(p=hflip_prob))
if auto_augment_policy is not None:
if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
trans.append(transforms.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
trans.append(transforms.TrivialAugmentWide(interpolation=interpolation))
elif auto_augment_policy == "augmix":
trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity))
trans.append(transforms.AugMix(interpolation=interpolation, severity=augmix_severity))
else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
aa_policy = transforms.AutoAugmentPolicy(auto_augment_policy)
trans.append(transforms.AutoAugment(policy=aa_policy, interpolation=interpolation))
trans.extend(
[
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
]
Expand All @@ -59,9 +61,9 @@ def __init__(

self.transforms = transforms.Compose(
[
transforms.Resize(resize_size, interpolation=interpolation),
transforms.ToImageTensor(),
transforms.Resize(resize_size, interpolation=interpolation, antialias=True),
transforms.CenterCrop(crop_size),
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
]
Expand Down
32 changes: 22 additions & 10 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
import warnings

import presets
from sampler import RASampler
from transforms import WrapIntoFeatures
import utils # usort: skip

import torch
import torch.utils.data
import torchvision
import transforms
import utils
from sampler import RASampler

from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.prototype import features, transforms
from torchvision.transforms.functional import InterpolationMode


Expand Down Expand Up @@ -144,6 +147,7 @@ def load_data(traindir, valdir, args):
ra_magnitude=ra_magnitude,
augmix_severity=augmix_severity,
),
target_transform=lambda target: features.Label(target),
)
if args.cache_dataset:
print(f"Saving dataset_train to {cache_path}")
Expand All @@ -168,7 +172,8 @@ def load_data(traindir, valdir, args):

dataset_test = torchvision.datasets.ImageFolder(
valdir,
preprocessing,
transform=preprocessing,
target_transform=lambda target: features.Label(target),
)
if args.cache_dataset:
print(f"Saving dataset_test to {cache_path}")
Expand Down Expand Up @@ -210,16 +215,23 @@ def main(args):

collate_fn = None
num_classes = len(dataset.classes)
mixup_transforms = []
mixup_or_cutmix = []
if args.mixup_alpha > 0.0:
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
mixup_or_cutmix.append(transforms.RandomMixup(alpha=args.mixup_alpha, p=1.0))
if args.cutmix_alpha > 0.0:
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
if mixup_transforms:
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
mixup_or_cutmix.append(transforms.RandomCutmix(alpha=args.cutmix_alpha, p=1.0))
if mixup_or_cutmix:
batch_transform = transforms.Compose(
[
WrapIntoFeatures(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we have to WrapIntoFeatures unconditionally from whether we use mixup/cutmix?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that default_collate does not respect tensor subclasses. Since we use this transform afterwards, we need to wrap here. Of course we can also wrap before, but it is not necessary since the input is a plain tensor that defaults to images and an integer which is completely ignored.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My question is, what if I don't use mixup or cutmix? Shouldn't we wrap the data into features.Image anyway? I might be missing something here. My main point is that since we are testing the new API, we should probably wrap all inputs using their appropriate types and see how the new kernels behave (Rather than relying on their default/legacy pure tensor implementations).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can only wrap after we have converted from PIL. This happens fairly late in the transform:

transforms.PILToTensor(),

I remember @vfdev-5 noting that on the CPU PIL kernels are faster (I don't remember if there was a special case or other constraints; please fill the blanks). Thus, if we want to optimize for speed, we should probably leave it as is. No strong opinion though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now we are testing the PIL backend. In order to test the features.Image one, we will need to move this at the beginning (aka on the transforms defined for ImageFolder), right?

transforms.LabelToOneHot(num_categories=num_classes),
Copy link
Contributor

@datumbox datumbox Sep 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed from this line the WrapIntoFeatures() call and now get:

TypeError: RandomCutmix() is only defined for tensor images and one-hot labels.

I expected that the Label is already annotated because the dataset is defined with:

target_transform=lambda target: features.Label(target),

I'll check it out further on Monday. It's because of the collating. UI-wise, it's not great that batching causes the _Feature information to be lost. Is there a way to improve upon this?

cc @pmeier

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that is a known issue. The default_collate internally uses torch.stack. Since we only automatically wrap on .to() and .clone() this will drop the feature type.

Three options here:

  1. Allow torch.stack to retain the feature type. I would not go that route, since we cannot know in general if the result of a stack is still the feature we started out with. We at least need more elaborate checks that make sure this is happening.
  2. Fix default_collate to handle our feature types. It is a possible solution, but won't happen, because the feature types are torchvision only and default_collate is a general solution.
  3. Provide a custom collate like vision_collate. This would mean that depending on the taks, users can no longer go with the default value in the DataLoader, but judging from our references, this seems to be somewhat common. Plus, this vision_collate could also include more functionality that we need:
    1. Don't torch.stack images if we find BoundingBox'es.
    2. Support for None that some of our datasets already return (add test if dataset samples can be collated #5233) and will be heavily utilized by the new datasets for missing data in a configuration.

My vote is out for 3. To avoid rewriting default_collate from scratch in torchvision and thus risking it going out of sync with the core version, I made a push to make default_collate more extensible in this gist. @ejguan wanted to handle this from the torchdata side. Any progress there?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a hard problem. I agree option 2 is not possible. What I don't like about option 3 is that if TorchVision has it's own collate, it will miss out on potential improvements from Core/Data. I also wouldn't put it inside the library but rather in the examples.

Option 1 sounds promising only if we feel we can do all the necessary checks to avoid gotchas (+ think of whether this is a performance bottleneck). If that's the case, we could offer implementations that whitelist the stack similary to to() and clone(). But if checking all the TV meta-data + Tensor meta-data is too slow, then we might leave it to the user to decide how to handle it efficiently on a case by case basis (as you have it here). I'm also open exploring other ideas to rewrap the types, though nothing elegant pops to my mind right now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I don't like about option 3 is that if TorchVision has it's own collate, it will miss out on potential improvements from Core/Data. I also wouldn't put it inside the library but rather in the examples.

Depends on what architecture default_collate actually gets after the refactor. My proposal would allow us to extend the functionality. For example, the collation for features.Image's could just call the collation for tensors and rewrap at the end. This way if there are any improvements for tensor collation, we automatically also benefit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked to Vitaly about it a while ago and we thought adding argument to default_collate might not be the ideal option because users can always do it in a reversed way by providing custom collate function and call default_collate if needed.

Would it be possible if we let default_collate to do collate by checking if the object is attached with a specific method (tentative: collate_fn)?

class CustomObject:
   @staticmethod
    def collate_fn(batch):
        return ...

def default_collate(batch:
    elem = batch[0]
    ...
    if hasattr(elem, "collate_fn"):
        return elem.collate_fn(batch)

This would provide the same mechanism as pin_memory

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Second go through each key in dict and check isinstance for derived class

Shouldn't it be an OrderedDict then? Otherwise if an input is a subtype of multiple types in the map, two runs may yield different results making the collation non-deterministic.


I would prefer 1. since otherwise an object would need to know how it wants to be collated, but doesn't have all the context. For example, for classification and segmentation tasks, we want to torch.stack images, but for detection tasks they should be kept in a list. We can figure this out from the full sample passed to the collation function, but he object itself can't.

Plus, we don't have to worry about JIT and that can be quite a pain.

If you send a PR, please ping me there so I can see if it actually is extensible enough for us. @datumbox can you confirm that you would also be ok with handling this in the collation function if we only need to provide custom behavior for our stuff, but otherwise can depend on the core functionality?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be an OrderedDict then?

If the type exists in the dict, it should invoke the corresponding collate function. And, dict is actually ordered in Python. So, if an input is a subtype of multiple types, and the type of the input doesn't equal to any of key, it will follow the order.

Will open a PR today. It should be a simple change.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And, dict is actually ordered in Python.

Well, not really. Since Python 3.6, dict's are insertion ordered. Meaning, if you create a dictionary like

dct = {"a": 1, "b": 2}

you can be sure that "a" is always the first key. However, dict's are not ordered

assert {"a": 1, "b": 2} == {"b": 2, "a": 1}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you try iterate over those dicts, you will find it's ordered.

d1 = {"a": 1, "b": 2}
for k in d1:
    print(k)  # a, b

d2 = {"b": 2, "a": 1}
for k in d2:
    print(k)  # b, a

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox can you confirm that you would also be ok with handling this in the collation function if we only need to provide custom behavior for our stuff, but otherwise can depend on the core functionality?

Just to confirm, is your proposal to handle custom behaviour on our reference scripts? If yes, then that sounds good. I wouldn't want to add this on main TorchVision but rather leave it to the users to handle those special bases based on their use-case/dataset. Does that make sense?

transforms.ToDtype({features.OneHotLabel: torch.float, features.Image: None}),
transforms.RandomChoice(mixup_or_cutmix),
]
)

def collate_fn(batch):
return mixupcutmix(*default_collate(batch))
return batch_transform(*default_collate(batch))

data_loader = torch.utils.data.DataLoader(
dataset,
Expand Down
11 changes: 11 additions & 0 deletions references/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,20 @@

import torch
from torch import Tensor
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as PF
from torchvision.transforms import functional as F


class WrapIntoFeatures(torch.nn.Module):
def forward(self, sample):
image, target = sample
return PF.to_image_tensor(image), features.Label(target)


# Original Transforms can be removed:


class RandomMixup(torch.nn.Module):
"""Randomly apply Mixup to the provided batch and targets.
The class implements the data augmentations as described in the paper
Expand Down
55 changes: 30 additions & 25 deletions references/detection/coco_utils.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,13 @@
import copy
import os

import torch
import torch.utils.data
import torchvision
import transforms as T

from pycocotools import mask as coco_mask
from pycocotools.coco import COCO


class FilterAndRemapCocoCategories:
datumbox marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, categories, remap=True):
self.categories = categories
self.remap = remap

def __call__(self, image, target):
anno = target["annotations"]
anno = [obj for obj in anno if obj["category_id"] in self.categories]
if not self.remap:
target["annotations"] = anno
return image, target
anno = copy.deepcopy(anno)
for obj in anno:
obj["category_id"] = self.categories.index(obj["category_id"])
target["annotations"] = anno
return image, target
from torchvision.prototype import features, transforms as T
from torchvision.prototype.transforms import functional as F


def convert_coco_poly_to_mask(segmentations, height, width):
Expand All @@ -45,7 +28,8 @@ def convert_coco_poly_to_mask(segmentations, height, width):


class ConvertCocoPolysToMask:
def __call__(self, image, target):
def __call__(self, sample):
image, target = sample
w, h = image.size

image_id = target["image_id"]
Expand Down Expand Up @@ -100,6 +84,27 @@ def __call__(self, image, target):
return image, target


class WrapIntoFeatures:
def __call__(self, sample):
image, target = sample

wrapped_target = dict(
boxes=features.BoundingBox(
target["boxes"],
format=features.BoundingBoxFormat.XYXY,
spatial_size=(image.height, image.width),
),
# TODO: add categories
labels=features.Label(target["labels"], categories=None),
masks=features.Mask(target["masks"]),
image_id=int(target["image_id"]),
pmeier marked this conversation as resolved.
Show resolved Hide resolved
area=target["area"].tolist(),
iscrowd=target["iscrowd"].bool().tolist(),
)

return F.to_image_tensor(image), wrapped_target


def _coco_remove_images_without_annotations(dataset, cat_list=None):
def _has_only_empty_bbox(anno):
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
Expand Down Expand Up @@ -225,10 +230,12 @@ def get_coco(root, image_set, transforms, mode="instances"):
PATHS = {
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
"val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))),
# "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
}

t = [ConvertCocoPolysToMask()]
t = [
ConvertCocoPolysToMask(),
WrapIntoFeatures(),
]

if transforms is not None:
t.append(transforms)
Expand All @@ -243,8 +250,6 @@ def get_coco(root, image_set, transforms, mode="instances"):
if image_set == "train":
dataset = _coco_remove_images_without_annotations(dataset)

# dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])

return dataset


Expand Down
4 changes: 2 additions & 2 deletions references/detection/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc

for images, targets in metric_logger.log_every(data_loader, print_freq, header):
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
with torch.cuda.amp.autocast(enabled=scaler is not None):
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
Expand Down Expand Up @@ -97,7 +97,7 @@ def evaluate(model, data_loader, device):
outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
model_time = time.time() - model_time

res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
res = {target["image_id"]: output for target, output in zip(targets, outputs)}
evaluator_time = time.time()
coco_evaluator.update(res)
evaluator_time = time.time() - evaluator_time
Expand Down
91 changes: 37 additions & 54 deletions references/detection/presets.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,56 @@
from collections import defaultdict

import torch
import transforms as T
from torchvision.prototype import features, transforms as T


class DetectionPresetTrain:
class DetectionPresetTrain(T.Compose):
def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)):
if data_augmentation == "hflip":
self.transforms = T.Compose(
[
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms = [
T.RandomHorizontalFlip(p=hflip_prob),
T.ConvertImageDtype(torch.float),
]
elif data_augmentation == "lsj":
self.transforms = T.Compose(
[
T.ScaleJitter(target_size=(1024, 1024)),
T.FixedSizeCrop(size=(1024, 1024), fill=mean),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms = [
T.ScaleJitter(target_size=(1024, 1024), antialias=True),
T.FixedSizeCrop(size=(1024, 1024), fill=defaultdict(lambda: mean, {features.Mask: 0})),
T.RandomHorizontalFlip(p=hflip_prob),
T.ConvertImageDtype(torch.float),
]
elif data_augmentation == "multiscale":
self.transforms = T.Compose(
[
T.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms = [
T.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333, antialias=True
),
T.RandomHorizontalFlip(p=hflip_prob),
T.ConvertImageDtype(torch.float),
]
elif data_augmentation == "ssd":
self.transforms = T.Compose(
[
T.RandomPhotometricDistort(),
T.RandomZoomOut(fill=list(mean)),
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms = [
T.RandomPhotometricDistort(),
T.RandomZoomOut(fill=defaultdict(lambda: mean, {features.Mask: 0})),
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
T.ConvertImageDtype(torch.float),
]
elif data_augmentation == "ssdlite":
self.transforms = T.Compose(
[
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms = [
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
T.ConvertImageDtype(torch.float),
]
else:
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')

def __call__(self, img, target):
return self.transforms(img, target)
super().__init__(transforms)


class DetectionPresetEval:
class DetectionPresetEval(T.Compose):
def __init__(self):
self.transforms = T.Compose(
super().__init__(
[
T.PILToTensor(),
T.ToImageTensor(),
T.ConvertImageDtype(torch.float),
]
)

def __call__(self, img, target):
return self.transforms(img, target)
4 changes: 2 additions & 2 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
from coco_utils import get_coco, get_coco_kp
from engine import evaluate, train_one_epoch
from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler
from torchvision.prototype import transforms as T
from torchvision.transforms import InterpolationMode
from transforms import SimpleCopyPaste


def copypaste_collate_fn(batch):
copypaste = SimpleCopyPaste(blending=True, resize_interpolation=InterpolationMode.BILINEAR)
copypaste = T.SimpleCopyPaste(blending=True, resize_interpolation=InterpolationMode.BILINEAR, antialias=True)
return copypaste(*utils.collate_fn(batch))


Expand Down
1 change: 1 addition & 0 deletions references/detection/transforms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Original Transforms can be removed:
from typing import Dict, List, Optional, Tuple, Union

import torch
Expand Down
Loading