Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NEW Feature: Mixup transform for Object Detection #6721

Draft
wants to merge 44 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
676a3ba
ADD: Empty file mixup.py for dummy PR
Oct 7, 2022
60cdf3b
ADD: Empty transform class
Oct 7, 2022
fd922ca
Merge branch 'main' into 6720_add_mixup_transform
ambujpawar Oct 22, 2022
728c7ca
WIP: Random Mixup for detection
Oct 28, 2022
3f204ac
Merge branch 'main' into 6720_add_mixup_transform
Nov 4, 2022
f1b70b9
First draft: Mixup detections
Nov 5, 2022
cdda41b
Fix: precommit issues
Nov 5, 2022
2d0765c
Fix: failing CI issues
Nov 6, 2022
7e82ff2
Fix: Tests and ADD: get_params and check_inputs functions
Nov 6, 2022
b83aedf
Fix: Remove usage of soon to be deprecated to_tensor function
Nov 6, 2022
50bea74
Merge branch 'main' into 6720_add_mixup_transform
ambujpawar Nov 6, 2022
90799b8
Remove: get params for mixup
Nov 6, 2022
248737d
Update _mixup_detection.py
ambujpawar Nov 7, 2022
26316a4
Remove unused type: ignore due to failing CI test
ambujpawar Nov 7, 2022
d7e08d2
Merge branch 'main' into 6720_add_mixup_transform
pmeier Nov 8, 2022
04c80d7
add batch detection helpers
pmeier Nov 8, 2022
5667c91
use helpers in detection mixup
pmeier Nov 8, 2022
e0724a3
refactor helpers
pmeier Nov 8, 2022
10c9033
Merge branch 'main' into 6720_add_mixup_transform
pmeier Dec 1, 2022
6177057
revert accidental COCO change
pmeier Dec 1, 2022
2b67017
Move: mixup detection to _augment.py
Dec 4, 2022
fa0f54e
Merge branch 'main' into 6720_add_mixup_transform
ambujpawar Dec 4, 2022
cae66d9
Merge branch 'main' into 6720_add_mixup_transform
pmeier Dec 6, 2022
ae9908b
refactor extraction and insertion
pmeier Dec 6, 2022
c2e2757
Fix: Failing SimpleCopyPaste and MixupDetection Failing tests
Dec 17, 2022
6b58135
Merge branch 'main' into 6720_add_mixup_transform
pmeier Dec 19, 2022
5398c73
sample ratio in get_params
pmeier Dec 19, 2022
044ba0d
fix padding
pmeier Dec 19, 2022
884ace1
perform image conversion upfront
pmeier Dec 19, 2022
99de232
create base class
pmeier Dec 19, 2022
4ceef89
Merge branch 'main' into 6720_add_mixup_transform
pmeier Dec 19, 2022
a6b9ae0
add shortcut for ratio==0
pmeier Dec 19, 2022
fce49b8
fix dtype
pmeier Dec 19, 2022
05c0491
Merge branch 'main' into 6720_add_mixup_transform
ambujpawar Jan 17, 2023
d995471
Apply suggestions from code review
ambujpawar Jan 21, 2023
914a9ee
Merge branch 'main' into 6720_add_mixup_transform
ambujpawar Jan 21, 2023
cbf09c2
Undo removing test_extract_image_target of TestSimpleCopyPaste
ambujpawar Jan 21, 2023
685d042
ADD: Test cases when mixup ratio is 0, 0.5, 1
ambujpawar Jan 22, 2023
3319215
Fix: was doing wrong asserts. Corrected it
ambujpawar Jan 22, 2023
02214b6
fix mixing
pmeier Jan 23, 2023
4486e78
pass flat_inputs to get_params
pmeier Jan 23, 2023
1b6dbe1
Update torchvision/prototype/transforms/_augment.py
ambujpawar Jan 23, 2023
8a912ba
refactor SimpleCopyPaste
pmeier Jan 23, 2023
ebd6bfd
Merge branch '6720_add_mixup_transform' of https://github.com/ambujpa…
pmeier Jan 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from torchvision.ops.boxes import box_iou
from torchvision.prototype import datapoints, transforms
from torchvision.prototype.transforms.utils import check_type
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image

from torchvision.transforms.functional import get_image_size, InterpolationMode, pil_to_tensor, to_pil_image

BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]

Expand Down Expand Up @@ -1924,3 +1925,59 @@ def test__transform(self, inpt):
assert type(output) is type(inpt)
assert output.shape[-4] == num_samples
assert output.dtype == inpt.dtype


class TestMixupDetection:
def create_fake_image(self, mocker, image_type, *, size=(32, 32), color=123):
if image_type == PIL.Image.Image:
return PIL.Image.new("RGB", size, color)
return mocker.MagicMock(spec=image_type)

@pytest.mark.parametrize("ratio", [0.0, 0.5, 1.0])
def test__mixup(self, mocker, ratio):
image_1 = self.create_fake_image(mocker, PIL.Image.Image, size=(128, 128), color=(124, 124, 124))
image_1 = pil_to_tensor(image_1)
target_1 = {
"boxes": datapoints.BoundingBox(
torch.tensor([[0.0, 0.0, 10.0, 10.0], [20.0, 20.0, 30.0, 30.0]]),
format="XYXY",
spatial_size=get_image_size(image_1),
),
"labels": datapoints.Label(torch.tensor([1, 2])),
}
sample_1 = {
"image": image_1,
"boxes": target_1["boxes"],
"labels": target_1["labels"],
}

image_2 = self.create_fake_image(mocker, PIL.Image.Image, size=(128, 128), color=(0, 0, 0))
image_2 = pil_to_tensor(image_2)
target_2 = {
"boxes": datapoints.BoundingBox(
torch.tensor([[10.0, 0.0, 20.0, 20.0], [10.0, 20.0, 30.0, 30.0]]),
format="XYXY",
spatial_size=get_image_size(image_2),
),
"labels": datapoints.Label(torch.tensor([2, 3])),
}
sample_2 = {
"image": image_2,
"boxes": target_2["boxes"],
"labels": target_2["labels"],
}

transform = transforms.MixupDetection()
output = transform._mixup(sample_1, sample_2, ratio)

if ratio == 0:
assert output == sample_2

elif ratio == 1:
assert output == sample_1

elif ratio == 0.5:
# TODO: Fix this test
assert output["image"] == (np.asarray(image_1) + np.asarray(image_2)) / 2
assert output["boxes"] == torch.cat([target_1["boxes"], target_2["boxes"]])
assert output["labels"] == torch.cat([target_1["labels"], target_2["labels"]])
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ._transform import Transform # usort: skip
from ._presets import StereoMatching # usort: skip

from ._augment import RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste
from ._augment import MixupDetection, RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
Expand Down
Loading