Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change draw_segmentation_masks to accept boolean masks #3820

Closed
wants to merge 11 commits into from

Conversation

NicolasHug
Copy link
Member

@NicolasHug NicolasHug commented May 12, 2021

This PR changes the behaviour of draw_segmentation_masks so that it accepts boolean masks

In master, draw_segmentation_masks supports masks as returned by the semantic segmentation models, i.e. batches of masks of shape num_masks x H x W with scores inside, and it internally does masks = masks.argmax(0) to get the most likely class for each pixel. This is too restrictive and specific to segmentation models, and it doesn't play well with instance segmentation models: #3774 (comment).

This PR thus proposes to accept only boolean masks which are more general: it's up to the users to properly convert score masks into boolean masks. This was noted by @fmassa in the original PR #3330 (comment). I will follow up with another PR updating the visualization example to properly explain how to do such things (I'm still working on it, I hope I can get it done by tomorrow)

This PR also adds a bunch of tests and support for drawing on float images, not just uint8. This allows to call the function once to draw a mask, and then call it again on the ouput to draw another mask. This is currently not supported since the function takes uint8 as input but outputs float images. (EDIT: I'm wrong here, the function properly returns uint8). This allows to just work with float images instead of keeping the uint8 images around.

Note: this is not BC-breaking since the function hasn't been released yet.

CC @oke-aditya :)

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's still a WIP, I only checked the changes at draw_segmentation_masks() and they look good to me. I left only one comment, let me know your thoughts.

color = ImageColor.getrgb(color)
color = torch.tensor(color, dtype=out_dtype)
if out_dtype == torch.float:
color /= 255
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we have to support anything other than uint8 here, given that this method is used for visualization. To be specific, the only issue I have with this is that you are being forced to assume that the colour needs to be between 0-1, hence divide by 255. Thoughts?

Copy link
Contributor

@oke-aditya oke-aditya May 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I too thought (I will have a thorough look tommorrow!). The other utility draw_bounding_boxes return uint8 dtype only. So perhaps we can be consistent on either.

Copy link
Member Author

@NicolasHug NicolasHug May 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current interface takes a uint8 image as input and outputs a float image. This means that you can't call draw_segmentation_masks on that output again, which isn't super practical. EDIT: this isn't the case, in master the function properly returns a uint8 image.

Also, the models accept float images, not uint8 images. If you look at the current example https://pytorch.org/vision/master/auto_examples/plot_visualization_utils.html#sphx-glr-auto-examples-plot-visualization-utils-py we're forced to have 2 copies of each image: one float and one for uint8. This isn't optimal either.

Supporting both floats and uint8 and have the dtypes "pass through" solves both of these issues for little maintenance cost - I wrote tests that ensure both dtypes yield the same results.

Regarding consistency, I'm planning on modifying draw_bounding_boxes to have the same behavior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the models accept float images, not uint8 images.

It is true the models receive float images that are also typically Normalized. These are not the versions of the images that you will use for visualization though. Most models scale to [0-1] while others recently require scaling to [-1,1] (see the SSD models). So at the end the user will need to keep around the non-normalized uint8 version to visualize it.

An alternative course of action might be to continue handling all types of images but avoid any assumption over the color scales. You could do that by creating a palette only if the image is uint8 and expect the user to provide a valid list of colors on their right scales otherwise. Happy to discuss this more to unblock your work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is true the models receive float images that are also typically Normalized. These are not the versions of the images that you will use for visualization though

According to #3774 (comment), Mask-RCNN and Faster-RCNN don't require normalization, so one could use the same non-normalized float images in this case I think.

IIUC, the main concern seems to be the assumption that if an image is a float, then we expect it to be in [0, 1]. According to our docs, this assumption holds for the transforms already:

The expected range of the values of a tensor image is implicitely defined by the tensor dtype. Tensor images with a float dtype are expected to have values in [0, 1). Tensor images with an integer dtype are expected to have values in [0, MAX_DTYPE] where MAX_DTYPE is the largest value that can be represented in that dtype.

It also holds when calling F.convert_image_dtype(img, torch.float) and all our conversion utilities as far as I can tell, so it seems to be a reasonable assumption.

I'll add a comment in the docstring to clarify that. The worst thing that can happen if the user passes a float image that's not in [0, 1] is that the colors will be slightly off

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other words: as far as I can tell, removing the assumption that float images are in [0, 1] for draw_segmentation_masks will not facilitate any use-case, and it won't make anything worse either.

OTOH, removing this assumption forces users to keep the uint8 image around, or it forces them to manually pass colors. By seamlessly allowing float images in [0, 1] users only need to keep the float image in [0, 1], and possibly the normalized image to pass to a model. They can discard the uint8 image.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, this idea that float images are implicitly expected to be in 0-1 is largely adopted throughout the ecosystem. Matplotlib assumes this too: see below how both images are rendered exactly the same. By allowing this, we're making our tools easily usable by users that are used to the existing state of things

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there plans to change the behaviour of ToTensor() and all of the transforms w.r.t. this assumption with float inputs?

My point here is about keeping this assumption contained and not exposing it to other places of the library. Whether we should change the behaviour of ToTensor to remove the silent rescaling is something we could discuss for TorchVision v1.0.

plotting the result masks an image that has been normalized (i.e. that isn't in 0-1) leads to terrible visualizations anyway

This is precisely why the original versions of these methods supported only uint8 images and that's why we can't just throw away the uint8 versions of the images. The Segmentation models rescale and normalize the images outside of the models meaning you would have to undo them before vizualizing.

At this point I think it would be best not to merge the PR without containing the Gallery examples that would give us a clear view of the usage and the how corner-cases are handled. Perhaps @fmassa can weight-in here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Segmentation models rescale and normalize the images outside of the models meaning you would have to undo them before vizualizing.

Users have to convert their images to float in [0, 1] at some point.

All pre-trained models expect input images normalized in the same way,
i.e. mini-batches of 3-channel RGB images of shape (3 x H x W),
where H and W are expected to be at least 224.
The images have to be loaded in to a range of [0, 1] and then normalized
using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].

So it's not just something contained in the transforms. It's pervasive throughout the library, and throughout the entire ecosystem. By supporting those images directly, we allow users to drop the uint8 images and only keep the float images (normalized and non-normalized versions).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still WIP but see also https://github.com/pytorch/vision/pull/3824/files#r631748970 to illustrate why and how we can drop references to uint8 images once we support floats in 0-1

@NicolasHug
Copy link
Member Author

Thanks for the review ! This isn't WIP :)

I had to comment out the example because it's not compatible anymore but I can also just remove that part if you'd prefer.

I'll re-write it in another PR to keep things self contained

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First of all I see the value of this PR in terms of adding support for more segmentation use-cases. The part I'm concerned with is the direction of supporting float images as this forces us to expose unnecessarily internals of transforms such as the [0,1] normalization.

I think it might be beneficial to take a step back and consider:

  1. Is supporting floats absolutely necessary for this PR? Is it something can be removed from the PR to unblock it and facilitate further discussions on the topic?
  2. Is the target of removing the integer version of the images worth it, given that in the examples we are still forced to keep a normalized and non-normalized version of the image? Also why try to remove the integer version if it's needed by some other methods?

Note that this PR needs to be assessed together with #3827 and #3824 to get a clearer picture. A potential better way to split this work into multiple PRs is to separate the change of behaviour of draw_segmentation_masks() and the proposal of adding float support into 2 separate PRs so that they don't block each other. Given the above, I can't merge this PR as is but I'm happy to reconsider if changes are made to address the concerns I raised.

@NicolasHug
Copy link
Member Author

NicolasHug commented May 14, 2021

The part I'm concerned with is the direction of supporting float images as this forces us to expose unnecessarily internals of transforms such as the [0,1] normalization.

It doesn't expose anything. Float images in [0, 1] is a standard. We explicitly rely on it in many parts of the library (it's not an internal detail), tensorflow relies on it, matplotlib relies on it, PIL relies on it (I haven't double checked that one, so I'm not sure).

Is the target of removing the integer version of the images worth it, given that in the examples we are still forced to keep a normalized and non-normalized version of the image?

Yes: having to keep track of 2 versions of an image is better than having to keep track of 3 versions.

Also why try to remove the integer version if it's needed by some other methods?

We're only forced to keep the uint8 version because we illustrate the use of make_grid. Otherwise we wouldn't need it. As mentionned in other threads, I think make_grid is not worth illustrating (it's a confusing util with a confusing interface) so I'm more than happy to remove it from the example so that we can get rid of the uint8 images altogether.

Is supporting floats absolutely necessary for this PR? Is it something can be removed from the PR to unblock it and facilitate further discussions on the topic?

I would rather not. I really think that supporting floats in [0, 1] is useful, simple, and has zero risk of needing a breaking change in the future. What other scale would make sense anyway, for float images? It doesn't rely on implicit assumptions of torchvision: again, float images are always expected to be in [0, 1] when it comes to plotting. I'm more than happy to raise an error in the case where dtype == float and (min < 0 or max > 1), if we want to prevent users to call this with a float image that's not in [0 ,1].

@datumbox, would you mind answering this: what is the risk of introducing support for [0, 1] images? I'm sorry but so far I can't identify any.

@fmassa
Copy link
Member

fmassa commented May 17, 2021

We have precedence in utils.py to support float images (and even assume it in 0-1).

I think the discussion around float support in 0-1 or only uint8 is a very valid one and I think we should start a discussion around what to do in the future, so that we are consistent.
The sources of confusion that Vasilis are mentioning is that some of our functions might be applied not only on the input image, but also on intermediate features of the model (which does not have the 0-1 assumption).

In order to move forward to merging this PR for the next release, I think it's easiest for now to assume the input tensor is uint8 and raise an error if not. @NicolasHug can you do this

@NicolasHug @datumbox let's setup some time to discuss about what will be the input formats assumed within the codebase.

@NicolasHug
Copy link
Member Author

The sources of confusion that Vasilis are mentioning is that some of our functions might be applied not only on the input image, but also on intermediate features of the model (which does not have the 0-1 assumption).

I understand that, and my point is that draw_segmentation_mask is useless/invalid in this case, i.e. without the 0-1 assumption. There is no floating values apart from the 0-1 scale where draw_segmentation_mask makes sense. So I don't think there's any risk here.

I will update the PR and the 2 next ones to remove 0-1 support, despite the extra work. In the mean time, I would appreciate this key question to be addressed:

what is the risk of introducing support for [0, 1] images?

@NicolasHug
Copy link
Member Author

To avoid pushing changes to many PRs, I'm closing this one in favor of #3824

@NicolasHug NicolasHug closed this May 17, 2021
@fmassa
Copy link
Member

fmassa commented May 17, 2021

what is the risk of introducing support for [0, 1] images?

For this particular case, I don't think we might have any risks in most cases.

But I think it's time for us to rethink what are the conventions we will want to keep with torchvision and which ones we want to start dropping.

The fact that to_tensor always scales the uint8 images by dividing by 255 has been a constant source of issues in the past.
That's the reason why we decided to decouple it in two functions: pil_to_tensor (which just convert a PIL.Image into a torch tensor of the same dtype) and convert_image_dtype, which performs conversions between uint8 / float / etc, see #2060 (comment) for some more discussion

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants