diff --git a/gallery/plot_visualization_utils.py b/gallery/plot_visualization_utils.py index 58788437a28..76e32bc586e 100644 --- a/gallery/plot_visualization_utils.py +++ b/gallery/plot_visualization_utils.py @@ -94,38 +94,3 @@ def show(imgs): for dog_int, output in zip((dog1_int, dog2_int), outputs) ] show(dogs_with_boxes) - -##################################### -# Visualizing segmentation masks -# ------------------------------ -# The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to -# draw segmentation amasks on images. We can set the colors as well as -# transparency of masks. -# -# Here is demo with torchvision's FCN Resnet-50, loaded with -# :func:`~torchvision.models.segmentation.fcn_resnet50`. -# You can also try using -# DeepLabv3 (:func:`~torchvision.models.segmentation.deeplabv3_resnet50`) -# or lraspp mobilenet models -# (:func:`~torchvision.models.segmentation.lraspp_mobilenet_v3_large`). -# -# Like :func:`~torchvision.utils.draw_bounding_boxes`, -# :func:`~torchvision.utils.draw_segmentation_masks` requires a single RGB image -# of dtype `uint8`. - -from torchvision.models.segmentation import fcn_resnet50 -from torchvision.utils import draw_segmentation_masks - - -model = fcn_resnet50(pretrained=True, progress=False) -model = model.eval() - -# The model expects the batch to be normalized -batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) -outputs = model(batch) - -dogs_with_masks = [ - draw_segmentation_masks(dog_int, masks=masks, alpha=0.6) - for dog_int, masks in zip((dog1_int, dog2_int), outputs['out']) -] -show(dogs_with_masks) diff --git a/test/assets/fakedata/draw_segm_masks_colors_util.png b/test/assets/fakedata/draw_segm_masks_colors_util.png deleted file mode 100644 index 454b3555631..00000000000 Binary files a/test/assets/fakedata/draw_segm_masks_colors_util.png and /dev/null differ diff --git a/test/assets/fakedata/draw_segm_masks_no_colors_util.png b/test/assets/fakedata/draw_segm_masks_no_colors_util.png deleted file mode 100644 index f048d2469d2..00000000000 Binary files a/test/assets/fakedata/draw_segm_masks_no_colors_util.png and /dev/null differ diff --git a/test/test_utils.py b/test/test_utils.py index 8c4cc620229..7644915f2cb 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,3 +1,4 @@ +import pytest import numpy as np import os import sys @@ -7,7 +8,7 @@ import unittest from io import BytesIO import torchvision.transforms.functional as F -from PIL import Image, __version__ as PILLOW_VERSION +from PIL import Image, __version__ as PILLOW_VERSION, ImageColor PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.')) @@ -159,55 +160,112 @@ def test_draw_invalid_boxes(self): self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong1, boxes) self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong2, boxes) - def test_draw_segmentation_masks_colors(self): - img = torch.full((3, 5, 5), 255, dtype=torch.uint8) - img_cp = img.clone() - masks_cp = masks.clone() - colors = ["#FF00FF", (0, 255, 0), "red"] - result = utils.draw_segmentation_masks(img, masks, colors=colors) - - path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", - "fakedata", "draw_segm_masks_colors_util.png") - - if not os.path.exists(path): - res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) - res.save(path) - - expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) - self.assertTrue(torch.equal(result, expected)) - # Check if modification is not in place - self.assertTrue(torch.all(torch.eq(img, img_cp)).item()) - self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item()) - - def test_draw_segmentation_masks_no_colors(self): - img = torch.full((3, 20, 20), 255, dtype=torch.uint8) - img_cp = img.clone() - masks_cp = masks.clone() - result = utils.draw_segmentation_masks(img, masks, colors=None) - - path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", - "fakedata", "draw_segm_masks_no_colors_util.png") - - if not os.path.exists(path): - res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) - res.save(path) - - expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) - self.assertTrue(torch.equal(result, expected)) - # Check if modification is not in place - self.assertTrue(torch.all(torch.eq(img, img_cp)).item()) - self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item()) - - def test_draw_invalid_masks(self): - img_tp = ((1, 1, 1), (1, 2, 3)) - img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float) - img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8) - img_wrong3 = torch.full((4, 5, 5), 255, dtype=torch.uint8) - self.assertRaises(TypeError, utils.draw_segmentation_masks, img_tp, masks) - self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong1, masks) - self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong2, masks) - self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong3, masks) +@pytest.mark.parametrize('dtype', (torch.float, torch.uint8)) +@pytest.mark.parametrize('colors', [ + None, + ['red', 'blue'], + ['#FF00FF', (1, 34, 122)], +]) +@pytest.mark.parametrize('alpha', (0, .5, .7, 1)) +def test_draw_segmentation_masks(dtype, colors, alpha): + """This test makes sure that masks draw their corresponding color where they should""" + num_masks, h, w = 2, 100, 100 + img = torch.randint(0, 256, size=(3, h, w), dtype=dtype) + masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool) + + # For testing we enforce that there's no overlap between the masks. The + # current behaviour is that the last mask's color will take priority when + # masks overlap, but this makes testing slightly harder so we don't really + # care + overlap = masks[0] & masks[1] + masks[:, overlap] = False + + out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha) + assert out.dtype == dtype + assert out is not img + + # Make sure the image didn't change where there's no mask + masked_pixels = masks[0] | masks[1] + assert (img[:, ~masked_pixels] == out[:, ~masked_pixels]).all() + + if colors is None: + colors = utils._generate_color_palette(num_masks) + + # Make sure each mask draws with its own color + for mask, color in zip(masks, colors): + if isinstance(color, str): + color = ImageColor.getrgb(color) + color = torch.tensor(color, dtype=dtype) + if dtype == torch.float: + color /= 255 + + if alpha == 0: + assert (out[:, mask] == color[:, None]).all() + elif alpha == 1: + assert (out[:, mask] == img[:, mask]).all() + + interpolated_color = (img[:, mask] * alpha + color[:, None] * (1 - alpha)) + max_diff = (out[:, mask] - interpolated_color).abs().max() + if dtype == torch.uint8: + assert max_diff <= 1 + else: + assert max_diff <= 1e-5 + + +def test_draw_segmentation_masks_int_vs_float(): + """Make sure float and uint8 dtypes produce similar images""" + h, w = 100, 100 + masks = torch.randint(0, 2, size=(2, h, w), dtype=torch.bool) + img_int = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8) + img_float = F.convert_image_dtype(img_int, torch.float) + + out_int = utils.draw_segmentation_masks(image=img_int, masks=masks, colors=['red', 'blue']) + out_float = utils.draw_segmentation_masks(image=img_float, masks=masks, colors=['red', 'blue']) + + assert out_int.dtype == img_int.dtype + assert out_float.dtype == img_float.dtype + + out_float_int = F.convert_image_dtype(out_float, torch.uint8).int() + out_int = out_int.int() + + assert (out_int - out_float_int).abs().max() <= 1 + + +def test_draw_segmentation_masks_errors(): + h, w = 10, 10 + + masks = torch.randint(0, 2, size=(h, w), dtype=torch.bool) + img = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8) + + with pytest.raises(TypeError, match="The image must be a tensor"): + utils.draw_segmentation_masks(image="Not A Tensor Image", masks=masks) + with pytest.raises(ValueError, match="The image dtype must be"): + img_bad_dtype = torch.randint(0, 256, size=(3, h, w), dtype=torch.int64) + utils.draw_segmentation_masks(image=img_bad_dtype, masks=masks) + with pytest.raises(ValueError, match="Pass individual images, not batches"): + batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8) + utils.draw_segmentation_masks(image=batch, masks=masks) + with pytest.raises(ValueError, match="Pass an RGB image"): + one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8) + utils.draw_segmentation_masks(image=one_channel, masks=masks) + with pytest.raises(ValueError, match="The masks must be of dtype bool"): + masks_bad_dtype = torch.randint(0, 2, size=(h, w), dtype=torch.float) + utils.draw_segmentation_masks(image=img, masks=masks_bad_dtype) + with pytest.raises(ValueError, match="masks must be of shape"): + masks_bad_shape = torch.randint(0, 2, size=(3, 2, h, w), dtype=torch.bool) + utils.draw_segmentation_masks(image=img, masks=masks_bad_shape) + with pytest.raises(ValueError, match="must have the same height and width"): + masks_bad_shape = torch.randint(0, 2, size=(h + 4, w), dtype=torch.bool) + utils.draw_segmentation_masks(image=img, masks=masks_bad_shape) + with pytest.raises(ValueError, match="There are more masks"): + utils.draw_segmentation_masks(image=img, masks=masks, colors=[]) + with pytest.raises(ValueError, match="colors must be a tuple or a string, or a list thereof"): + bad_colors = np.array(['red', 'blue']) # should be a list + utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) + with pytest.raises(ValueError, match="It seems that you passed a tuple of colors instead of"): + bad_colors = ('red', 'blue') # should be a list + utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) if __name__ == '__main__': diff --git a/torchvision/utils.py b/torchvision/utils.py index 9d9bbdb3c80..c8e333b5767 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -226,52 +226,73 @@ def draw_segmentation_masks( """ Draws segmentation masks on given RGB image. - The values of the input image should be uint8 between 0 and 255. + The values of the input image should be uint8 between 0 and 255, or float values between 0 and 1. Args: - image (Tensor): Tensor of shape (3 x H x W) and dtype uint8. - masks (Tensor): Tensor of shape (num_masks, H, W). Each containing probability of predicted class. - alpha (float): Float number between 0 and 1 denoting factor of transparency of masks. - colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can - be represented as `str` or `Tuple[int, int, int]`. + image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float. + masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. + alpha (float): Float number between 0 and 1 denoting the transparency of the masks. + colors (list or None): List containing the colors of the masks. The colors can + be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + When ``masks`` has a single entry of shape (H, W), you can pass a single color instead of a list + with one element. By default, random colors are generated for each mask. Returns: - img (Tensor[C, H, W]): Image Tensor of dtype uint8 with segmentation masks plotted. + img (Tensor[C, H, W]): Image Tensor with the same dtype as the input image, with segmentation masks + drawn on top. """ if not isinstance(image, torch.Tensor): - raise TypeError(f"Tensor expected, got {type(image)}") - elif image.dtype != torch.uint8: - raise ValueError(f"Tensor uint8 expected, got {image.dtype}") + raise TypeError(f"The image must be a tensor, got {type(image)}") + elif image.dtype not in (torch.uint8, torch.float): + raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size()[0] != 3: raise ValueError("Pass an RGB image. Other Image formats are not supported") + if masks.ndim == 2: + masks = masks[None, :, :] + if masks.ndim != 3: + raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") + if masks.dtype != torch.bool: + raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") + if masks.shape[-2:] != image.shape[-2:]: + raise ValueError("The image and the masks must have the same height and width") num_masks = masks.size()[0] - masks = masks.argmax(0) + if colors is not None and num_masks > len(colors): + raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") if colors is None: - palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) - colors_t = torch.as_tensor([i for i in range(num_masks)])[:, None] * palette - color_arr = (colors_t % 255).numpy().astype("uint8") - else: - color_list = [] - for color in colors: - if isinstance(color, str): - # This will automatically raise Error if rgb cannot be parsed. - fill_color = ImageColor.getrgb(color) - color_list.append(fill_color) - elif isinstance(color, tuple): - color_list.append(color) - - color_arr = np.array(color_list).astype("uint8") - - _, h, w = image.size() - img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize((w, h)) - img_to_draw.putpalette(color_arr) - - img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGB'))) - img_to_draw = img_to_draw.permute((2, 0, 1)) - - return (image.float() * alpha + img_to_draw.float() * (1.0 - alpha)).to(dtype=torch.uint8) + colors = _generate_color_palette(num_masks) + + if not isinstance(colors, list): + colors = [colors] + if not isinstance(colors[0], (tuple, str)): + raise ValueError("colors must be a tuple or a string, or a list thereof") + if isinstance(colors[0], tuple) and len(colors[0]) != 3: + raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") + + out_dtype = image.dtype + + colors_ = [] + for color in colors: + if isinstance(color, str): + color = ImageColor.getrgb(color) + color = torch.tensor(color, dtype=out_dtype) + if out_dtype == torch.float: + color /= 255 + colors_.append(color) + + img_to_draw = image.detach().clone() + # TODO: There might be a way to vectorize this + for mask, color in zip(masks, colors_): + img_to_draw[:, mask] = color[:, None] + + out = image * alpha + img_to_draw * (1 - alpha) + return out.to(out_dtype) + + +def _generate_color_palette(num_masks): + palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) + return [tuple((i * palette) % 255) for i in range(num_masks)]