Skip to content

Change draw_segmentation_masks to accept boolean masks #3820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
35 changes: 0 additions & 35 deletions gallery/plot_visualization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,38 +94,3 @@ def show(imgs):
for dog_int, output in zip((dog1_int, dog2_int), outputs)
]
show(dogs_with_boxes)

#####################################
# Visualizing segmentation masks
# ------------------------------
# The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to
# draw segmentation amasks on images. We can set the colors as well as
# transparency of masks.
#
# Here is demo with torchvision's FCN Resnet-50, loaded with
# :func:`~torchvision.models.segmentation.fcn_resnet50`.
# You can also try using
# DeepLabv3 (:func:`~torchvision.models.segmentation.deeplabv3_resnet50`)
# or lraspp mobilenet models
# (:func:`~torchvision.models.segmentation.lraspp_mobilenet_v3_large`).
#
# Like :func:`~torchvision.utils.draw_bounding_boxes`,
# :func:`~torchvision.utils.draw_segmentation_masks` requires a single RGB image
# of dtype `uint8`.

from torchvision.models.segmentation import fcn_resnet50
from torchvision.utils import draw_segmentation_masks


model = fcn_resnet50(pretrained=True, progress=False)
model = model.eval()

# The model expects the batch to be normalized
batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
outputs = model(batch)

dogs_with_masks = [
draw_segmentation_masks(dog_int, masks=masks, alpha=0.6)
for dog_int, masks in zip((dog1_int, dog2_int), outputs['out'])
]
show(dogs_with_masks)
Binary file removed test/assets/fakedata/draw_segm_masks_colors_util.png
Binary file not shown.
Binary file removed test/assets/fakedata/draw_segm_masks_no_colors_util.png
Binary file not shown.
156 changes: 107 additions & 49 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import numpy as np
import os
import sys
Expand All @@ -7,7 +8,7 @@
import unittest
from io import BytesIO
import torchvision.transforms.functional as F
from PIL import Image, __version__ as PILLOW_VERSION
from PIL import Image, __version__ as PILLOW_VERSION, ImageColor


PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.'))
Expand Down Expand Up @@ -159,55 +160,112 @@ def test_draw_invalid_boxes(self):
self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong1, boxes)
self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong2, boxes)

def test_draw_segmentation_masks_colors(self):
img = torch.full((3, 5, 5), 255, dtype=torch.uint8)
img_cp = img.clone()
masks_cp = masks.clone()
colors = ["#FF00FF", (0, 255, 0), "red"]
result = utils.draw_segmentation_masks(img, masks, colors=colors)

path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
"fakedata", "draw_segm_masks_colors_util.png")

if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)

expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))
# Check if modification is not in place
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item())

def test_draw_segmentation_masks_no_colors(self):
img = torch.full((3, 20, 20), 255, dtype=torch.uint8)
img_cp = img.clone()
masks_cp = masks.clone()
result = utils.draw_segmentation_masks(img, masks, colors=None)

path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
"fakedata", "draw_segm_masks_no_colors_util.png")

if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)

expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))
# Check if modification is not in place
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item())

def test_draw_invalid_masks(self):
img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
img_wrong3 = torch.full((4, 5, 5), 255, dtype=torch.uint8)

self.assertRaises(TypeError, utils.draw_segmentation_masks, img_tp, masks)
self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong1, masks)
self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong2, masks)
self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong3, masks)
@pytest.mark.parametrize('dtype', (torch.float, torch.uint8))
@pytest.mark.parametrize('colors', [
None,
['red', 'blue'],
['#FF00FF', (1, 34, 122)],
])
@pytest.mark.parametrize('alpha', (0, .5, .7, 1))
def test_draw_segmentation_masks(dtype, colors, alpha):
"""This test makes sure that masks draw their corresponding color where they should"""
num_masks, h, w = 2, 100, 100
img = torch.randint(0, 256, size=(3, h, w), dtype=dtype)
masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool)

# For testing we enforce that there's no overlap between the masks. The
# current behaviour is that the last mask's color will take priority when
# masks overlap, but this makes testing slightly harder so we don't really
# care
overlap = masks[0] & masks[1]
masks[:, overlap] = False

out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha)
assert out.dtype == dtype
assert out is not img

# Make sure the image didn't change where there's no mask
masked_pixels = masks[0] | masks[1]
assert (img[:, ~masked_pixels] == out[:, ~masked_pixels]).all()

if colors is None:
colors = utils._generate_color_palette(num_masks)

# Make sure each mask draws with its own color
for mask, color in zip(masks, colors):
if isinstance(color, str):
color = ImageColor.getrgb(color)
color = torch.tensor(color, dtype=dtype)
if dtype == torch.float:
color /= 255

if alpha == 0:
assert (out[:, mask] == color[:, None]).all()
elif alpha == 1:
assert (out[:, mask] == img[:, mask]).all()

interpolated_color = (img[:, mask] * alpha + color[:, None] * (1 - alpha))
max_diff = (out[:, mask] - interpolated_color).abs().max()
if dtype == torch.uint8:
assert max_diff <= 1
else:
assert max_diff <= 1e-5


def test_draw_segmentation_masks_int_vs_float():
"""Make sure float and uint8 dtypes produce similar images"""
h, w = 100, 100
masks = torch.randint(0, 2, size=(2, h, w), dtype=torch.bool)
img_int = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8)
img_float = F.convert_image_dtype(img_int, torch.float)

out_int = utils.draw_segmentation_masks(image=img_int, masks=masks, colors=['red', 'blue'])
out_float = utils.draw_segmentation_masks(image=img_float, masks=masks, colors=['red', 'blue'])

assert out_int.dtype == img_int.dtype
assert out_float.dtype == img_float.dtype

out_float_int = F.convert_image_dtype(out_float, torch.uint8).int()
out_int = out_int.int()

assert (out_int - out_float_int).abs().max() <= 1


def test_draw_segmentation_masks_errors():
h, w = 10, 10

masks = torch.randint(0, 2, size=(h, w), dtype=torch.bool)
img = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8)

with pytest.raises(TypeError, match="The image must be a tensor"):
utils.draw_segmentation_masks(image="Not A Tensor Image", masks=masks)
with pytest.raises(ValueError, match="The image dtype must be"):
img_bad_dtype = torch.randint(0, 256, size=(3, h, w), dtype=torch.int64)
utils.draw_segmentation_masks(image=img_bad_dtype, masks=masks)
with pytest.raises(ValueError, match="Pass individual images, not batches"):
batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8)
utils.draw_segmentation_masks(image=batch, masks=masks)
with pytest.raises(ValueError, match="Pass an RGB image"):
one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8)
utils.draw_segmentation_masks(image=one_channel, masks=masks)
with pytest.raises(ValueError, match="The masks must be of dtype bool"):
masks_bad_dtype = torch.randint(0, 2, size=(h, w), dtype=torch.float)
utils.draw_segmentation_masks(image=img, masks=masks_bad_dtype)
with pytest.raises(ValueError, match="masks must be of shape"):
masks_bad_shape = torch.randint(0, 2, size=(3, 2, h, w), dtype=torch.bool)
utils.draw_segmentation_masks(image=img, masks=masks_bad_shape)
with pytest.raises(ValueError, match="must have the same height and width"):
masks_bad_shape = torch.randint(0, 2, size=(h + 4, w), dtype=torch.bool)
utils.draw_segmentation_masks(image=img, masks=masks_bad_shape)
with pytest.raises(ValueError, match="There are more masks"):
utils.draw_segmentation_masks(image=img, masks=masks, colors=[])
with pytest.raises(ValueError, match="colors must be a tuple or a string, or a list thereof"):
bad_colors = np.array(['red', 'blue']) # should be a list
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
with pytest.raises(ValueError, match="It seems that you passed a tuple of colors instead of"):
bad_colors = ('red', 'blue') # should be a list
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)


if __name__ == '__main__':
Expand Down
89 changes: 55 additions & 34 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,52 +226,73 @@ def draw_segmentation_masks(

"""
Draws segmentation masks on given RGB image.
The values of the input image should be uint8 between 0 and 255.
The values of the input image should be uint8 between 0 and 255, or float values between 0 and 1.

Args:
image (Tensor): Tensor of shape (3 x H x W) and dtype uint8.
masks (Tensor): Tensor of shape (num_masks, H, W). Each containing probability of predicted class.
alpha (float): Float number between 0 and 1 denoting factor of transparency of masks.
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can
be represented as `str` or `Tuple[int, int, int]`.
image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float.
masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
colors (list or None): List containing the colors of the masks. The colors can
be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
When ``masks`` has a single entry of shape (H, W), you can pass a single color instead of a list
with one element. By default, random colors are generated for each mask.

Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with segmentation masks plotted.
img (Tensor[C, H, W]): Image Tensor with the same dtype as the input image, with segmentation masks
drawn on top.
"""

if not isinstance(image, torch.Tensor):
raise TypeError(f"Tensor expected, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype not in (torch.uint8, torch.float):
raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size()[0] != 3:
raise ValueError("Pass an RGB image. Other Image formats are not supported")
if masks.ndim == 2:
masks = masks[None, :, :]
if masks.ndim != 3:
raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)")
if masks.dtype != torch.bool:
raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}")
if masks.shape[-2:] != image.shape[-2:]:
raise ValueError("The image and the masks must have the same height and width")

num_masks = masks.size()[0]
masks = masks.argmax(0)
if colors is not None and num_masks > len(colors):
raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})")

if colors is None:
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors_t = torch.as_tensor([i for i in range(num_masks)])[:, None] * palette
color_arr = (colors_t % 255).numpy().astype("uint8")
else:
color_list = []
for color in colors:
if isinstance(color, str):
# This will automatically raise Error if rgb cannot be parsed.
fill_color = ImageColor.getrgb(color)
color_list.append(fill_color)
elif isinstance(color, tuple):
color_list.append(color)

color_arr = np.array(color_list).astype("uint8")

_, h, w = image.size()
img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize((w, h))
img_to_draw.putpalette(color_arr)

img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGB')))
img_to_draw = img_to_draw.permute((2, 0, 1))

return (image.float() * alpha + img_to_draw.float() * (1.0 - alpha)).to(dtype=torch.uint8)
colors = _generate_color_palette(num_masks)

if not isinstance(colors, list):
colors = [colors]
if not isinstance(colors[0], (tuple, str)):
raise ValueError("colors must be a tuple or a string, or a list thereof")
if isinstance(colors[0], tuple) and len(colors[0]) != 3:
raise ValueError("It seems that you passed a tuple of colors instead of a list of colors")

out_dtype = image.dtype

colors_ = []
for color in colors:
if isinstance(color, str):
color = ImageColor.getrgb(color)
color = torch.tensor(color, dtype=out_dtype)
if out_dtype == torch.float:
color /= 255
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we have to support anything other than uint8 here, given that this method is used for visualization. To be specific, the only issue I have with this is that you are being forced to assume that the colour needs to be between 0-1, hence divide by 255. Thoughts?

Copy link
Contributor

@oke-aditya oke-aditya May 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I too thought (I will have a thorough look tommorrow!). The other utility draw_bounding_boxes return uint8 dtype only. So perhaps we can be consistent on either.

Copy link
Member Author

@NicolasHug NicolasHug May 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current interface takes a uint8 image as input and outputs a float image. This means that you can't call draw_segmentation_masks on that output again, which isn't super practical. EDIT: this isn't the case, in master the function properly returns a uint8 image.

Also, the models accept float images, not uint8 images. If you look at the current example https://pytorch.org/vision/master/auto_examples/plot_visualization_utils.html#sphx-glr-auto-examples-plot-visualization-utils-py we're forced to have 2 copies of each image: one float and one for uint8. This isn't optimal either.

Supporting both floats and uint8 and have the dtypes "pass through" solves both of these issues for little maintenance cost - I wrote tests that ensure both dtypes yield the same results.

Regarding consistency, I'm planning on modifying draw_bounding_boxes to have the same behavior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the models accept float images, not uint8 images.

It is true the models receive float images that are also typically Normalized. These are not the versions of the images that you will use for visualization though. Most models scale to [0-1] while others recently require scaling to [-1,1] (see the SSD models). So at the end the user will need to keep around the non-normalized uint8 version to visualize it.

An alternative course of action might be to continue handling all types of images but avoid any assumption over the color scales. You could do that by creating a palette only if the image is uint8 and expect the user to provide a valid list of colors on their right scales otherwise. Happy to discuss this more to unblock your work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is true the models receive float images that are also typically Normalized. These are not the versions of the images that you will use for visualization though

According to #3774 (comment), Mask-RCNN and Faster-RCNN don't require normalization, so one could use the same non-normalized float images in this case I think.

IIUC, the main concern seems to be the assumption that if an image is a float, then we expect it to be in [0, 1]. According to our docs, this assumption holds for the transforms already:

The expected range of the values of a tensor image is implicitely defined by the tensor dtype. Tensor images with a float dtype are expected to have values in [0, 1). Tensor images with an integer dtype are expected to have values in [0, MAX_DTYPE] where MAX_DTYPE is the largest value that can be represented in that dtype.

It also holds when calling F.convert_image_dtype(img, torch.float) and all our conversion utilities as far as I can tell, so it seems to be a reasonable assumption.

I'll add a comment in the docstring to clarify that. The worst thing that can happen if the user passes a float image that's not in [0, 1] is that the colors will be slightly off

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other words: as far as I can tell, removing the assumption that float images are in [0, 1] for draw_segmentation_masks will not facilitate any use-case, and it won't make anything worse either.

OTOH, removing this assumption forces users to keep the uint8 image around, or it forces them to manually pass colors. By seamlessly allowing float images in [0, 1] users only need to keep the float image in [0, 1], and possibly the normalized image to pass to a model. They can discard the uint8 image.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, this idea that float images are implicitly expected to be in 0-1 is largely adopted throughout the ecosystem. Matplotlib assumes this too: see below how both images are rendered exactly the same. By allowing this, we're making our tools easily usable by users that are used to the existing state of things

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there plans to change the behaviour of ToTensor() and all of the transforms w.r.t. this assumption with float inputs?

My point here is about keeping this assumption contained and not exposing it to other places of the library. Whether we should change the behaviour of ToTensor to remove the silent rescaling is something we could discuss for TorchVision v1.0.

plotting the result masks an image that has been normalized (i.e. that isn't in 0-1) leads to terrible visualizations anyway

This is precisely why the original versions of these methods supported only uint8 images and that's why we can't just throw away the uint8 versions of the images. The Segmentation models rescale and normalize the images outside of the models meaning you would have to undo them before vizualizing.

At this point I think it would be best not to merge the PR without containing the Gallery examples that would give us a clear view of the usage and the how corner-cases are handled. Perhaps @fmassa can weight-in here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Segmentation models rescale and normalize the images outside of the models meaning you would have to undo them before vizualizing.

Users have to convert their images to float in [0, 1] at some point.

All pre-trained models expect input images normalized in the same way,
i.e. mini-batches of 3-channel RGB images of shape (3 x H x W),
where H and W are expected to be at least 224.
The images have to be loaded in to a range of [0, 1] and then normalized
using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].

So it's not just something contained in the transforms. It's pervasive throughout the library, and throughout the entire ecosystem. By supporting those images directly, we allow users to drop the uint8 images and only keep the float images (normalized and non-normalized versions).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still WIP but see also https://github.com/pytorch/vision/pull/3824/files#r631748970 to illustrate why and how we can drop references to uint8 images once we support floats in 0-1

colors_.append(color)

img_to_draw = image.detach().clone()
# TODO: There might be a way to vectorize this
for mask, color in zip(masks, colors_):
img_to_draw[:, mask] = color[:, None]

out = image * alpha + img_to_draw * (1 - alpha)
return out.to(out_dtype)


def _generate_color_palette(num_masks):
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
return [tuple((i * palette) % 255) for i in range(num_masks)]