Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NEW Feature: Mixup transform for Object Detection #6721

Draft
wants to merge 44 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
676a3ba
ADD: Empty file mixup.py for dummy PR
Oct 7, 2022
60cdf3b
ADD: Empty transform class
Oct 7, 2022
fd922ca
Merge branch 'main' into 6720_add_mixup_transform
ambujpawar Oct 22, 2022
728c7ca
WIP: Random Mixup for detection
Oct 28, 2022
3f204ac
Merge branch 'main' into 6720_add_mixup_transform
Nov 4, 2022
f1b70b9
First draft: Mixup detections
Nov 5, 2022
cdda41b
Fix: precommit issues
Nov 5, 2022
2d0765c
Fix: failing CI issues
Nov 6, 2022
7e82ff2
Fix: Tests and ADD: get_params and check_inputs functions
Nov 6, 2022
b83aedf
Fix: Remove usage of soon to be deprecated to_tensor function
Nov 6, 2022
50bea74
Merge branch 'main' into 6720_add_mixup_transform
ambujpawar Nov 6, 2022
90799b8
Remove: get params for mixup
Nov 6, 2022
248737d
Update _mixup_detection.py
ambujpawar Nov 7, 2022
26316a4
Remove unused type: ignore due to failing CI test
ambujpawar Nov 7, 2022
d7e08d2
Merge branch 'main' into 6720_add_mixup_transform
pmeier Nov 8, 2022
04c80d7
add batch detection helpers
pmeier Nov 8, 2022
5667c91
use helpers in detection mixup
pmeier Nov 8, 2022
e0724a3
refactor helpers
pmeier Nov 8, 2022
10c9033
Merge branch 'main' into 6720_add_mixup_transform
pmeier Dec 1, 2022
6177057
revert accidental COCO change
pmeier Dec 1, 2022
2b67017
Move: mixup detection to _augment.py
Dec 4, 2022
fa0f54e
Merge branch 'main' into 6720_add_mixup_transform
ambujpawar Dec 4, 2022
cae66d9
Merge branch 'main' into 6720_add_mixup_transform
pmeier Dec 6, 2022
ae9908b
refactor extraction and insertion
pmeier Dec 6, 2022
c2e2757
Fix: Failing SimpleCopyPaste and MixupDetection Failing tests
Dec 17, 2022
6b58135
Merge branch 'main' into 6720_add_mixup_transform
pmeier Dec 19, 2022
5398c73
sample ratio in get_params
pmeier Dec 19, 2022
044ba0d
fix padding
pmeier Dec 19, 2022
884ace1
perform image conversion upfront
pmeier Dec 19, 2022
99de232
create base class
pmeier Dec 19, 2022
4ceef89
Merge branch 'main' into 6720_add_mixup_transform
pmeier Dec 19, 2022
a6b9ae0
add shortcut for ratio==0
pmeier Dec 19, 2022
fce49b8
fix dtype
pmeier Dec 19, 2022
05c0491
Merge branch 'main' into 6720_add_mixup_transform
ambujpawar Jan 17, 2023
d995471
Apply suggestions from code review
ambujpawar Jan 21, 2023
914a9ee
Merge branch 'main' into 6720_add_mixup_transform
ambujpawar Jan 21, 2023
cbf09c2
Undo removing test_extract_image_target of TestSimpleCopyPaste
ambujpawar Jan 21, 2023
685d042
ADD: Test cases when mixup ratio is 0, 0.5, 1
ambujpawar Jan 22, 2023
3319215
Fix: was doing wrong asserts. Corrected it
ambujpawar Jan 22, 2023
02214b6
fix mixing
pmeier Jan 23, 2023
4486e78
pass flat_inputs to get_params
pmeier Jan 23, 2023
1b6dbe1
Update torchvision/prototype/transforms/_augment.py
ambujpawar Jan 23, 2023
8a912ba
refactor SimpleCopyPaste
pmeier Jan 23, 2023
ebd6bfd
Merge branch '6720_add_mixup_transform' of https://github.com/ambujpa…
pmeier Jan 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torchvision.ops.boxes import box_iou
from torchvision.prototype import features, transforms
from torchvision.prototype.transforms._utils import _isinstance
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image, to_tensor

BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]

Expand Down Expand Up @@ -1918,3 +1918,80 @@ def test__transform(self, inpt):
assert type(output) is type(inpt)
assert output.shape[-4] == num_samples
assert output.dtype == inpt.dtype


class TestMixupDetection:
def create_fake_image(self, mocker, image_type):
ambujpawar marked this conversation as resolved.
Show resolved Hide resolved
if image_type == PIL.Image.Image:
return PIL.Image.new("RGB", (32, 32), 123)
return mocker.MagicMock(spec=image_type)

def test__extract_image_targets_assertion(self, mocker):
transform = transforms.MixupDetection()

flat_sample = [
# images, batch size = 2
self.create_fake_image(mocker, features.Image),
# labels, bboxes, masks
mocker.MagicMock(spec=features.Label),
mocker.MagicMock(spec=features.BoundingBox),
# labels, bboxes, masks
mocker.MagicMock(spec=features.BoundingBox),
]

with pytest.raises(TypeError, match="requires input sample to contain equal-sized list of Images"):
transform._extract_image_targets(flat_sample)

@pytest.mark.parametrize("image_type", [features.Image, PIL.Image.Image, torch.Tensor])
def test__extract_image_targets(self, image_type, mocker):
transform = transforms.MixupDetection()

flat_sample = [
# images, batch size = 2
self.create_fake_image(mocker, image_type),
self.create_fake_image(mocker, image_type),
# labels, bboxes
mocker.MagicMock(spec=features.Label),
mocker.MagicMock(spec=features.BoundingBox),
# labels, bboxes
mocker.MagicMock(spec=features.Label),
mocker.MagicMock(spec=features.BoundingBox),
]

images, targets = transform._extract_image_targets(flat_sample)

assert len(images) == len(targets) == 2
if image_type == PIL.Image.Image:
torch.testing.assert_close(images[0], to_tensor(flat_sample[0]))
torch.testing.assert_close(images[1], to_tensor(flat_sample[1]))
else:
assert images[0] == flat_sample[0]
assert images[1] == flat_sample[1]

def test__mixup(self):
image1 = 2 * torch.ones(3, 32, 64)
target_1 = {
"boxes": features.BoundingBox(
torch.tensor([[0.0, 0.0, 10.0, 10.0], [20.0, 20.0, 30.0, 30.0]]),
format="XYXY",
spatial_size=(32, 64),
),
"labels": features.Label(torch.tensor([1, 2])),
}

image2 = 10 * torch.ones(3, 64, 32)
target_2 = {
"boxes": features.BoundingBox(
torch.tensor([[10.0, 0.0, 20.0, 20.0], [10.0, 20.0, 30.0, 30.0]]),
format="XYXY",
spatial_size=(64, 32),
),
"labels": features.Label(torch.tensor([2, 3])),
}

transform = transforms.MixupDetection()
output_image, output_target = transform._mixup(image1, target_1, image2, target_2)
assert output_image.shape == (3, 64, 64)
assert output_target["boxes"].spatial_size == (64, 64)
assert len(output_target["boxes"]) == 4
assert len(output_target["labels"]) == 4
4 changes: 2 additions & 2 deletions torchvision/prototype/datasets/_builtin/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
path_accessor,
read_categories_file,
)
from torchvision.prototype.features import _Feature, BoundingBox, EncodedImage, Label
from torchvision.prototype.features import _Feature, BoundingBox, EncodedImage, Label, Mask

from .._api import register_dataset, register_info

Expand Down Expand Up @@ -114,7 +114,7 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st
labels = [ann["category_id"] for ann in anns]
return dict(
# TODO: create a segmentation feature
segmentations=_Feature(
segmentations=Mask(
torch.stack(
[
self._segmentation_to_mask(
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
ToDtype,
TransposeDimensions,
)
from ._mixup_detection import MixupDetection
ambujpawar marked this conversation as resolved.
Show resolved Hide resolved
from ._temporal import UniformTemporalSubsample
from ._type_conversion import DecodeImage, LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage

Expand Down
166 changes: 96 additions & 70 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import math
import numbers
import warnings
from typing import Any, cast, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union

import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from torchvision.ops import masks_to_boxes
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, InterpolationMode

from ._transform import _RandomApplyTransform
from ._utils import has_any, query_chw, query_spatial_size
from ._utils import _isinstance, has_any, query_chw, query_spatial_size


class RandomErasing(_RandomApplyTransform):
Expand Down Expand Up @@ -190,6 +190,90 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt


def flatten_and_extract_data(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it's too early but worth benchmarking that this doesn't introduce a significant speed regression on SimpleCopyPaste.

inputs: Any, **target_types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]
) -> Tuple[Tuple[List[Any], TreeSpec, List[Dict[str, int]]], List[features.TensorImageType], List[Dict[str, Any]]]:
# Images are special in the sense that they will always be extracted and returned
# separately. Internally however, they behave just as the other features.
types_or_checks: Dict[str, Tuple[Union[Type, Callable[[Any], bool]], ...]] = {
"images": (features.Image, PIL.Image.Image, features.is_simple_tensor),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we avoid defaults and just explicitly pass images in the 2 cases? This might be useful if on the future we extended this for video. No strong opinions.

**target_types_or_checks,
}

batch = inputs if len(inputs) > 1 else inputs[0]
flat_batch = []
sample_specs = []

offset = 0
batch_idcs = []
batch_data = []
for sample_idx, sample in enumerate(batch):
flat_sample, sample_spec = tree_flatten(sample)
flat_batch.extend(flat_sample)
sample_specs.append(sample_spec)

sample_types_or_checks = types_or_checks.copy()
sample_idcs = {}
sample_data = {}
for flat_idx, item in enumerate(flat_sample, offset):
if not sample_types_or_checks:
break

for key, types_or_checks_ in sample_types_or_checks.items():
if _isinstance(item, types_or_checks_):
break
else:
continue

del sample_types_or_checks[key]
sample_idcs[key] = flat_idx
sample_data[key] = item

if sample_types_or_checks:
# TODO: improve message
raise TypeError(f"Sample at index {sample_idx} in the batch is missing {sample_types_or_checks.keys()}`")

batch_idcs.append(sample_idcs)
batch_data.append(sample_data)
offset += len(flat_sample)

batch_spec = TreeSpec(list, context=None, children_specs=sample_specs)

targets = batch_data
batch_data = []
for target in targets:
image = target.pop("images")
if isinstance(image, features.Image):
image = image.as_subclass(torch.Tensor)
elif isinstance(image, PIL.Image.Image):
image = F.pil_to_tensor(image)
batch_data.append(image)

return (flat_batch, batch_spec, batch_idcs), batch_data, targets


def unflatten_and_insert_data(
flat_batch_with_spec: Tuple[List[Any], TreeSpec, List[Dict[str, int]]],
images: List[features.TensorImageType],
targets: List[Dict[str, Any]],
) -> Any:
flat_batch, batch_spec, batch_idcs = flat_batch_with_spec

for sample_idx, sample_idcs in enumerate(batch_idcs):
for key, flat_idx in sample_idcs.items():
item = images[sample_idx] if key == "images" else targets[sample_idx][key]

inpt = flat_batch[flat_idx]
if isinstance(inpt, features._Feature):
item = type(inpt).wrap_like(inpt, item)
elif isinstance(inpt, PIL.Image.Image):
item = F.to_image_pil(item)

flat_batch[flat_idx] = item

return tree_unflatten(flat_batch, batch_spec)


class SimpleCopyPaste(_RandomApplyTransform):
def __init__(
self,
Expand All @@ -214,7 +298,6 @@ def _copy_paste(
resize_interpolation: F.InterpolationMode,
antialias: Optional[bool],
) -> Tuple[features.TensorImageType, Dict[str, Any]]:

paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection])
paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection])
paste_labels = paste_target["labels"].wrap_like(
Expand All @@ -241,7 +324,7 @@ def _copy_paste(

inverse_paste_alpha_mask = paste_alpha_mask.logical_not()
# Copy-paste images:
image = image.mul(inverse_paste_alpha_mask).add_(paste_image.mul(paste_alpha_mask))
out_image = image.mul(inverse_paste_alpha_mask).add_(paste_image.mul(paste_alpha_mask))

# Copy-paste masks:
masks = masks * inverse_paste_alpha_mask
Expand Down Expand Up @@ -281,69 +364,15 @@ def _copy_paste(
out_target["masks"] = out_target["masks"][valid_targets]
out_target["labels"] = out_target["labels"][valid_targets]

return image, out_target

def _extract_image_targets(
self, flat_sample: List[Any]
) -> Tuple[List[features.TensorImageType], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBox], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample:
if isinstance(obj, features.Image) or features.is_simple_tensor(obj):
images.append(obj)
elif isinstance(obj, PIL.Image.Image):
images.append(F.to_image_tensor(obj))
elif isinstance(obj, features.BoundingBox):
bboxes.append(obj)
elif isinstance(obj, features.Mask):
masks.append(obj)
elif isinstance(obj, (features.Label, features.OneHotLabel)):
labels.append(obj)

if not (len(images) == len(bboxes) == len(masks) == len(labels)):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain equal sized list of Images, "
"BoundingBoxes, Masks and Labels or OneHotLabels."
)

targets = []
for bbox, mask, label in zip(bboxes, masks, labels):
targets.append({"boxes": bbox, "masks": mask, "labels": label})

return images, targets

def _insert_outputs(
self,
flat_sample: List[Any],
output_images: List[features.TensorImageType],
output_targets: List[Dict[str, Any]],
) -> None:
c0, c1, c2, c3 = 0, 0, 0, 0
for i, obj in enumerate(flat_sample):
if isinstance(obj, features.Image):
flat_sample[i] = features.Image.wrap_like(obj, output_images[c0])
c0 += 1
elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_image_pil(output_images[c0])
c0 += 1
elif features.is_simple_tensor(obj):
flat_sample[i] = output_images[c0]
c0 += 1
elif isinstance(obj, features.BoundingBox):
flat_sample[i] = features.BoundingBox.wrap_like(obj, output_targets[c1]["boxes"])
c1 += 1
elif isinstance(obj, features.Mask):
flat_sample[i] = features.Mask.wrap_like(obj, output_targets[c2]["masks"])
c2 += 1
elif isinstance(obj, (features.Label, features.OneHotLabel)):
flat_sample[i] = obj.wrap_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type]
c3 += 1
return out_image, out_target

def forward(self, *inputs: Any) -> Any:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])

images, targets = self._extract_image_targets(flat_inputs)
flat_batch_with_spec, images, targets = flatten_and_extract_data(
inputs,
boxes=(features.BoundingBox,),
masks=(features.Mask,),
labels=(features.Label, features.OneHotLabel),
)

# images = [t1, t2, ..., tN]
# Let's define paste_images as shifted list of input images
Expand Down Expand Up @@ -380,7 +409,4 @@ def forward(self, *inputs: Any) -> Any:
output_images.append(output_image)
output_targets.append(output_target)

# Insert updated images and targets into input flat_sample
self._insert_outputs(flat_inputs, output_images, output_targets)

return tree_unflatten(flat_inputs, spec)
return unflatten_and_insert_data(flat_batch_with_spec, output_images, output_targets)
Loading