Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NEW Feature: Mixup transform for Object Detection #6721

Draft
wants to merge 44 commits into
base: main
Choose a base branch
from

Conversation

ambujpawar
Copy link
Contributor

@ambujpawar ambujpawar commented Oct 7, 2022

Official implementation of the paper: Here

Minimalist code to reproduce:

import PIL
from torchvision import io, utils
from torchvision.prototype import features, transforms as T
from torchvision.prototype.transforms import functional as F


# Defining and wrapping input to appropriate Tensor Subclasses
path = "/Users/ambujpawar/Desktop/Cat03.jpeg"
path2 = "/Users/ambujpawar/Desktop/dog_2.jpeg"

# img = features.Image(io.read_image(path), color_space=features.ColorSpace.RGB)
img = PIL.Image.open(path)
img2 = PIL.Image.open(path2)

bbox_1 = features.BoundingBox(
    [[2, 0, 100, 100], [396, 92, 479, 241]],
    format=features.BoundingBoxFormat.XYXY,
    spatial_size=F.get_spatial_size(img),
)
bbox_2 = features.BoundingBox(
    [ [200, 100, 300, 300], [424, 38, 479, 250]],
    format=features.BoundingBoxFormat.XYXY,
    spatial_size=F.get_spatial_size(img2),
)
label = features.Label([59, 58])


# Defining and applying Transforms V2
trans = T.Compose(
    [
        T.MixupDetection(),
    ]
)

imgs = [img, img2]
bboxes = [bbox_1, bbox_2]
labels= [label, label]

imgs, bboxes, labels = trans(imgs, bboxes, labels)

# Visualizing results
viz = utils.draw_bounding_boxes(F.to_image_tensor(imgs[1]), boxes=bboxes[0])
F.to_pil_image(viz).show()

Examples output:
Please dont pay attention to bounding boxes in this particular image.
I just entered those boxes randomly
Screenshot 2022-11-06 at 16 48 00

@ambujpawar ambujpawar marked this pull request as draft October 7, 2022 12:43
@datumbox datumbox mentioned this pull request Oct 7, 2022
16 tasks
@ambujpawar ambujpawar marked this pull request as ready for review November 6, 2022 15:50
@pmeier pmeier self-assigned this Nov 7, 2022
@pmeier
Copy link
Collaborator

pmeier commented Nov 7, 2022

Hey @ambujpawar and thanks a lot for the PR! I'll try to help you land it in the near future. As you might have noticed, this transform is not straight forward to implement since it requires a batch of detection samples. In this context this means a list of samples, whereas for classification "batch" usually means an extra batch dimension on a tensor. This makes this implementation a lot harder compared to regular MixUp.

Still, we need to be able to support it. I'll look into how we can streamline the process for example by providing a _DetectionBatchTransform or standalone utilities that makes this easier. I'll get back to you when I found a solution or need your input. Is that ok with you?

@ambujpawar
Copy link
Contributor Author

ambujpawar commented Nov 7, 2022

Hi @pmeier, it sounds perfect to me. Looking forward to your suggestions :)

I agree with your comment regarding MixupforDetection taking batches of detection samples. However, shouldn't it be similar to what we do in CopyPaste transform? Because in copyPaste we also expect a batch of images

@pmeier
Copy link
Collaborator

pmeier commented Nov 8, 2022

However, shouldn't it be similar to what we do in CopyPaste transform? Because in copyPaste we also expect a batch of images

Exactly. Before we operated under the assumption that SimpleCopyPaste is the only batch detection transform and thus a one-off solution for it was good enough. With DetectionMixUp in the picture our assumption is no longer true and we need to look how we can provide utilities to ease the implementation of these transforms.

Right now the largest part of the implementation deals with the "infrastructure", i.e. extracting the right inputs and putting them back afterwards. Only a small part is spent on the actual algorithm. In a best case scenario, I find a solution so you can only write the algorithm and the remainder is handled by a base class or some high level utilities.

@ambujpawar
Copy link
Contributor Author

ambujpawar commented Nov 8, 2022

That clears up all the questions for me. Thanks!

Yeah, a base class is perhaps the best solution in those regards. Please let me know if I my help is needed :)

@pmeier pmeier marked this pull request as draft November 8, 2022 12:23
Copy link
Collaborator

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ambujpawar I took the liberty of pushing a patch to your PR. I've added two functions flatten_and_extract as well as unflatten_and_insert that implement what their name implies. I'm actively looking for your feedback so nothing is fixed yet. Two things that I already noticed:

  1. Both SimpleCopyPaste as well MixUpDetection use a "split" layout for images and targets. Is that by design or could we use one container like a dictionary for both of them. Imagine something like sample = {"image": ..., "boxes": ...}.
  2. The old extraction and insertion logic converted to tensor and back for images and (un-)wrapped the other features. Right now, the new logic does not do this. Instead this is moved inside the _mixup function. We could move that back into the logic as well. What do you prefer?

@ambujpawar
Copy link
Contributor Author

ambujpawar commented Nov 8, 2022

Thanks for adding the patch! :)

I think it looks nice. Regarding your questions:

  1. Both SimpleCopyPaste as well MixUpDetection use a "split" layout for images and targets. Is that by design or could we use one container like a dictionary for both of them. Imagine something like sample = {"image": ..., "boxes": ...}.

Yes, they both use a "split" design, but Mixup Detection but Mixup does not use "Masks". Mixup is only used for Detection not Segmentation.

  1. The old extraction and insertion logic converted to tensor and back for images and (un-)wrapped the other features. Right now, the new logic does not do this. Instead this is moved inside the _mixup function. We could move that back into the logic as well. What do you prefer?

If I had to choose one design, I would chose the former design but I dont have any strong arguments for it.

Shall we also include the developers of SimpleCopyPaste transform as well? They might also have some comments regarding these changes

Copy link
Collaborator

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is based on the example listed in the blogpost rergarding Transforms v2.
How should we call this transform instead?

Well, MixupDetection as well as SimpleCopyPaste are detection batch transforms and thus fall outside of the "regular" transforms. This is why we need so extra stuff to implement them properly. If you look at the other, their implementation is much simpler.

Batch transform here means that the input needs to be batched. For image classification transforms like CutMix or MixUp this simply means an extra batch dimension on the input tensors:

if inpt.ndim < expected_ndim:
raise ValueError("The transform expects a batched input")

However, this is not possible for detection tasks. Each sample can have a different number of bounding boxes and thus we cannot put them into a single tensor. Hence, a "detection batch" is just a sequence of individual samples. For you example, this could be

batch = [(img, bbox_1, label), (img2, bbox_2, label)]
transformed_batch = trans(batch)

Of course you can also do

batch = [{"image": img, "boxes": bbox_1, "label": label}, ...]

or something else as long as the outer container is a sequence.

Figured out SimpleCopyPaste still doesn't work. Working on it

You don't have to. Let's make sure DetectionMixup works as we want it to, and I'll fix SimpleCopyPaste afterwards.

Let's make sure we expand the tests a little. Basically we should have three test cases:

  1. a) and b) Make sure that _mixup is a no-op in case the ratio is == 0 or >= 1.0.
  2. Make sure that we get the correct output for a different ratio, e.g. == 0.5. Right now, we are only doing a smoke test that checks the shapes.

I think when that is done, I can take over and fix the rest.

@@ -1436,63 +1437,6 @@ def create_fake_image(self, mocker, image_type):
return PIL.Image.new("RGB", (32, 32), 123)
return mocker.MagicMock(spec=image_type)

def test__extract_image_targets_assertion(self, mocker):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you delete this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was using it to test _extract_image_targets function. However, since we removed those functions I removed them from here as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry! Realized this was for TestSimpleCopyPaste. Undoing the changes, sorry

test/test_prototype_transforms.py Outdated Show resolved Hide resolved
torchvision/prototype/transforms/_augment.py Outdated Show resolved Hide resolved
Comment on lines 345 to 346
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(ratio=float(self._dist.sample()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've opted to sample the ratio in the _get_params method. This has two advantages:

  1. People familiar with the other transforms can see at a glance that we are sampling something and this is not buried deep in the implementation.
  2. _mixup is easier to test since it has no random behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, it looks much tidier this way!
However, I have a question: we dont use flat_inputs? Shall we just remove it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep it for consistency with the other transformations. This is the basic protocol for all Transform._get_params calls. Although we call it manually here, there is some benefit by aligning it. Someone not familiar with this transform, but the others in general might trip over the fact that the parameter is not there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized we are actually not passing anything here. Let's just pass the flat inputs for completeness.

torchvision/prototype/transforms/_augment.py Show resolved Hide resolved
torchvision/prototype/transforms/_augment.py Outdated Show resolved Hide resolved
@pmeier
Copy link
Collaborator

pmeier commented Jan 16, 2023

Hey @ambujpawar 👋 I hope you are all right. I wanted to check in on this PR. Are you planning on finishing it or should I take over?

@ambujpawar
Copy link
Contributor Author

Hi @pmeier, thanks for asking! I'm doing good :) just back from a super long christmas and new year vacation so did not have time to work on this PR. I would still like to work on if we are not running on a deadline or something.

I can work on it this weekend and can request you for re-review :)
Does that work with you?

@pmeier
Copy link
Collaborator

pmeier commented Jan 17, 2023

No rush from my side. I thought I check on you after roughly one month of inactivity. In case you didn't plan to finish this, we still would like to have it and I would have taken over. This weekend sounds good.

@ambujpawar
Copy link
Contributor Author

Thanks! I'll update the PR this weekend then! :)

@ambujpawar
Copy link
Contributor Author

ambujpawar commented Jan 22, 2023

Hi, I added the test cases for when the ratios for mixup are 0 and 1. However, I still think there is some bug when we are mixing the two images. I am not able to exactly point what causes it though. Perhaps after looking at the code something rings a bell for you.

So, this is the expected output (or something similar). Notice the light appearance of cat in the background
Screenshot 2023-01-22 at 14 39 04

However, after our latest changes it look like this. Notice the picture of cat is completely overwriting the picture of dog.
Screenshot 2023-01-22 at 14 42 38

I am not exactly sure but I suspect something is going wrong when we are mixing images in augment.py Line 376-381. I am not able to solve it but perhaps you can have a look at it please?

@pmeier
Copy link
Collaborator

pmeier commented Jan 23, 2023

I am not exactly sure but I suspect something is going wrong when we are mixing images in augment.py Line 376-381. I am not able to solve it but perhaps you can have a look at it please?

Yup, the problem is that we replace the values in the first image with the ones from the second rather than adding them. To demonstrate, let's establish a visual benchmark first that we both can easily reproduce:

import PIL.Image

import torch

from torchvision.io import read_image
from torchvision.prototype import datapoints, transforms
from torchvision.utils import make_grid


def read_sample(path, label):
    image = datapoints.Image(read_image(path))
    bounding_box = datapoints.BoundingBox(
        [[0, 0, *image.spatial_size[::-1]]], format="xyxy", spatial_size=image.spatial_size
    )
    label = datapoints.Label([label])
    return dict(
        path=path,
        image=image,
        bounding_box=bounding_box,
        label=label,
    )


batch = [
    read_sample("test/assets/encode_jpeg/grace_hopper_517x606.jpg", 0),
    read_sample("test/assets/fakedata/logos/rgb_pytorch.png", 1),
]

transform = transforms.MixupDetection()

torch.manual_seed(0)
output = transform(batch)

image = make_grid([sample["image"] for sample in output])
PIL.Image.fromarray(image.permute(1, 2, 0).numpy()).save("mixup_detection.jpg")

Output with the current implementation is

mixup_detection

So, in the left image, the PyTorch logo is the second image and thus we are just pasting it over Grace Hopper. On the right side the PyTorch logo is completely gone, since Grace Hopper is larger and thus completely paints over it.

Applying the first suggestion from below gives us

mixup_detection

And thus the behavior we want.

@pmeier
Copy link
Collaborator

pmeier commented Jan 23, 2023

I'm working on fixing SimpleCopyPaste now.

Co-authored-by: Philip Meier <github.pmeier@posteo.de>
@ambujpawar
Copy link
Contributor Author

And thus the behavior we want.

Yup, exactly! This is the behavior we want.
Thanks for fixing it, my eyes were not able to find it haha

@pmeier
Copy link
Collaborator

pmeier commented Jan 23, 2023

I've pushed an update to SimpleCopyPaste, but so far I have only done visual checks. The tests for it very much relied on the internals and so I'll need to fix them as well. @ambujpawar is there anything left on your side that you want to do? Otherwise, I'm going to finish over the next few days.

@ambujpawar
Copy link
Contributor Author

@ambujpawar is there anything left on your side that you want to do?

Nope. I think everything is done on my side and this mixupDetection feature is ready. :)

@ambujpawar
Copy link
Contributor Author

Hi @pmeier, Congrats on the torchvision v0.15 release.
I just wanted to checkup on the future regarding the MixupDetection transform.
Is it still waiting on the topic of "How to smoothly support "pairwise" transforms" listed in #7319?

Thanks in advance!! :)

@pmeier
Copy link
Collaborator

pmeier commented Apr 4, 2023

Yes, unfortunately we are blocked by this. Sorry for not informing your earlier. We held off the batch transforms for now for the reason you listed above. I'll ping you here when we have figured it out. Thanks a lot for your patience!

@ambujpawar
Copy link
Contributor Author

Ah sure! No worries
Thanks for the update! :)
BTW, the new transforms_v2 really look good. Thanks for them

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

New Feature: Mixup Transform for Object Detection
4 participants