From 53cd51619820f9c868e8c0779609a2b83c929d75 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 12 May 2021 17:37:26 +0100 Subject: [PATCH 1/9] WIP --- gallery/plot_visualization_utils.py | 69 +++++++++---- test/test_utils.py | 153 +++++++++++++++++++--------- torchvision/utils.py | 91 ++++++++++------- 3 files changed, 212 insertions(+), 101 deletions(-) diff --git a/gallery/plot_visualization_utils.py b/gallery/plot_visualization_utils.py index 58788437a28..06f7162d2a5 100644 --- a/gallery/plot_visualization_utils.py +++ b/gallery/plot_visualization_utils.py @@ -24,7 +24,8 @@ def show(imgs): imgs = [imgs] fix, axs = plt.subplots(ncols=len(imgs), squeeze=False) for i, img in enumerate(imgs): - img = F.to_pil_image(img.to('cpu')) + img = img.detach() + img = F.to_pil_image(img) axs[0, i].imshow(np.asarray(img)) axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) @@ -50,9 +51,8 @@ def show(imgs): # Visualizing bounding boxes # -------------------------- # We can use :func:`~torchvision.utils.draw_bounding_boxes` to draw boxes on an -# image. We can set the colors, labels, width as well as font and font size ! -# The boxes are in ``(xmin, ymin, xmax, ymax)`` format -# from torchvision.utils import draw_bounding_boxes +# image. We can set the colors, labels, width as well as font and font size. +# The boxes are in ``(xmin, ymin, xmax, ymax)`` format. from torchvision.utils import draw_bounding_boxes @@ -99,19 +99,17 @@ def show(imgs): # Visualizing segmentation masks # ------------------------------ # The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to -# draw segmentation amasks on images. We can set the colors as well as -# transparency of masks. +# draw segmentation amasks on images. # -# Here is demo with torchvision's FCN Resnet-50, loaded with +# We will see how to use it with torchvision's FCN Resnet-50, loaded with # :func:`~torchvision.models.segmentation.fcn_resnet50`. # You can also try using # DeepLabv3 (:func:`~torchvision.models.segmentation.deeplabv3_resnet50`) # or lraspp mobilenet models # (:func:`~torchvision.models.segmentation.lraspp_mobilenet_v3_large`). # -# Like :func:`~torchvision.utils.draw_bounding_boxes`, -# :func:`~torchvision.utils.draw_segmentation_masks` requires a single RGB image -# of dtype `uint8`. +# Let's start by looking at the ouput of the model. Remember that in general, +# images must be normalized before they're passed to the model. from torchvision.models.segmentation import fcn_resnet50 from torchvision.utils import draw_segmentation_masks @@ -120,12 +118,49 @@ def show(imgs): model = fcn_resnet50(pretrained=True, progress=False) model = model.eval() -# The model expects the batch to be normalized -batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) -outputs = model(batch) +normalized_batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) +output = model(normalized_batch)['out'] +print(output.shape, output.min().item(), output.max().item()) + +##################################### +# As we can see above, the output of the segmentation model is a tensor of shape +# ``(batch_size, num_classes, H, W)``. Each value is a non-normalized score +# and can normalize them into ``[0, 1]`` by using a softmax. After the softmax, +# we can interpret each value as a probability indicating how likely a given +# pixel is to belong to a given class. +# +# Let's plot the masks that have been detected for the dog class and for the +# boat class: + +seg_classes = [ + '__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', + 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', + 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' +] +seg_class_to_idx = {cls: idx for (idx, cls) in enumerate(seg_classes)} + +# We normalize the masks of each image in the batch independently +normalized_masks = torch.stack([torch.nn.Softmax(dim=0)(masks) for masks in output]) -dogs_with_masks = [ - draw_segmentation_masks(dog_int, masks=masks, alpha=0.6) - for dog_int, masks in zip((dog1_int, dog2_int), outputs['out']) +dog_and_boat_masks = [ + normalized_masks[img_idx, seg_class_to_idx[cls]] + for img_idx in range(batch.shape[0]) + for cls in ('dog', 'boat') ] -show(dogs_with_masks) + +show(dog_and_boat_masks) + +##################################### +# As expected, the model is confident about the dog class, but not so much for +# the boat class. +# +# The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to +# plots those masks on top of the original image. This function expects the +# masks to be boolean masks, but our masks above contain probabilities in ``[0, +# 1]``. To get boolean masks, we can do the following: + +# dogs_with_dog_masks = [ +# draw_segmentation_masks(dog_int, masks=output[img_idx, seg_class_to_idx['dog']], alpha=0.6) +# for img_idx, dog_int in enumerate(dog1_int, dog2_int) +# ] +# show(dogs_with_masks) diff --git a/test/test_utils.py b/test/test_utils.py index 8c4cc620229..7d85c0cb1af 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,3 +1,4 @@ +import pytest import numpy as np import os import sys @@ -7,7 +8,7 @@ import unittest from io import BytesIO import torchvision.transforms.functional as F -from PIL import Image, __version__ as PILLOW_VERSION +from PIL import Image, __version__ as PILLOW_VERSION, ImageColor PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.')) @@ -159,55 +160,109 @@ def test_draw_invalid_boxes(self): self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong1, boxes) self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong2, boxes) - def test_draw_segmentation_masks_colors(self): - img = torch.full((3, 5, 5), 255, dtype=torch.uint8) - img_cp = img.clone() - masks_cp = masks.clone() - colors = ["#FF00FF", (0, 255, 0), "red"] - result = utils.draw_segmentation_masks(img, masks, colors=colors) - - path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", - "fakedata", "draw_segm_masks_colors_util.png") - - if not os.path.exists(path): - res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) - res.save(path) - - expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) - self.assertTrue(torch.equal(result, expected)) - # Check if modification is not in place - self.assertTrue(torch.all(torch.eq(img, img_cp)).item()) - self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item()) - - def test_draw_segmentation_masks_no_colors(self): - img = torch.full((3, 20, 20), 255, dtype=torch.uint8) - img_cp = img.clone() - masks_cp = masks.clone() - result = utils.draw_segmentation_masks(img, masks, colors=None) - - path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", - "fakedata", "draw_segm_masks_no_colors_util.png") - - if not os.path.exists(path): - res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) - res.save(path) - - expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) - self.assertTrue(torch.equal(result, expected)) - # Check if modification is not in place - self.assertTrue(torch.all(torch.eq(img, img_cp)).item()) - self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item()) - - def test_draw_invalid_masks(self): - img_tp = ((1, 1, 1), (1, 2, 3)) - img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float) - img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8) - img_wrong3 = torch.full((4, 5, 5), 255, dtype=torch.uint8) - self.assertRaises(TypeError, utils.draw_segmentation_masks, img_tp, masks) - self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong1, masks) - self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong2, masks) - self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong3, masks) +@pytest.mark.parametrize('dtype', (torch.float, torch.uint8)) +@pytest.mark.parametrize('colors', [ + # None, + ['red', 'blue'], + ['#FF00FF', (1, 34, 122)], +]) +@pytest.mark.parametrize('alpha', (0, 1)) +def test_draw_segmentation_masks(dtype, colors, alpha): + """This test makes sure that masks draw their corresponding color where they should""" + num_masks, h, w = 2, 100, 100 + img = torch.randint(0, 256, size=(3, h, w), dtype=dtype) + masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool) + + # For testing we enforce that there's no overlap between the masks. The + # current behaviour is that the last mask's color will take priority when + # masks overlap, but this makes testing slightly harder so we don't really + # care + overlap = masks[0] & masks[1] + masks[:, overlap] = False + + out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha) + assert out.dtype == dtype + assert out is not img + + if dtype == torch.float: + # makes comparisons below easier + img = F.convert_image_dtype(img, torch.uint8) + out = F.convert_image_dtype(out, torch.uint8) + img, out = img.float(), out.float() # avoids underflows etc. + + # Make sure the image didn't change where there's no mask + masked_pixels = masks[0] | masks[1] + assert (img[:, ~masked_pixels] == out[:, ~masked_pixels]).all() + + if colors is None: + colors = utils._generate_color_palette(num_masks) + + # Make sure each mask draws with its own color + for mask, color in zip(masks, colors): + if isinstance(color, str): + color = ImageColor.getrgb(color) + color = torch.tensor(color, dtype=dtype) + + if alpha == 0: + assert (out[:, mask] == color[:, None]).all() + else: + assert (out[:, mask] == img[:, mask]).all() + + +def test_draw_segmentation_masks_int_vs_float(): + """Make sure float and uint8 dtypes produce similar images""" + h, w = 100, 100 + masks = torch.randint(0, 2, size=(2, h, w), dtype=torch.bool) + img_int = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8) + img_float = F.convert_image_dtype(img_int, torch.float) + + out_int = utils.draw_segmentation_masks(image=img_int, masks=masks, colors=['red', 'blue']) + out_float = utils.draw_segmentation_masks(image=img_float, masks=masks, colors=['red', 'blue']) + + assert out_int.dtype == img_int.dtype + assert out_float.dtype == img_float.dtype + + out_float_int = F.convert_image_dtype(out_float, torch.uint8).int() + out_int = out_int.int() + + assert (out_int - out_float_int).abs().max() <= 1 + + +def test_draw_segmentation_masks_errors(): + h, w = 10, 10 + + masks = torch.randint(0, 2, size=(h, w), dtype=torch.bool) + img = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8) + + with pytest.raises(TypeError, match="The image must be a tensor"): + utils.draw_segmentation_masks(image="Not A Tensor Image", masks=masks) + with pytest.raises(ValueError, match="The image dtype must be"): + img_bad_dtype = torch.randint(0, 256, size=(3, h, w), dtype=torch.int64) + utils.draw_segmentation_masks(image=img_bad_dtype, masks=masks) + with pytest.raises(ValueError, match="Pass individual images, not batches"): + batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8) + utils.draw_segmentation_masks(image=batch, masks=masks) + with pytest.raises(ValueError, match="Pass an RGB image"): + one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8) + utils.draw_segmentation_masks(image=one_channel, masks=masks) + with pytest.raises(ValueError, match="The masks must be of dtype bool"): + masks_bad_dtype = torch.randint(0, 2, size=(h, w), dtype=torch.float) + utils.draw_segmentation_masks(image=img, masks=masks_bad_dtype) + with pytest.raises(ValueError, match="masks must be of shape"): + masks_bad_shape = torch.randint(0, 2, size=(3, 2, h, w), dtype=torch.bool) + utils.draw_segmentation_masks(image=img, masks=masks_bad_shape) + with pytest.raises(ValueError, match="must have the same height and width"): + masks_bad_shape = torch.randint(0, 2, size=(h + 4, w), dtype=torch.bool) + utils.draw_segmentation_masks(image=img, masks=masks_bad_shape) + with pytest.raises(ValueError, match="There are more masks"): + utils.draw_segmentation_masks(image=img, masks=masks, colors=[]) + with pytest.raises(ValueError, match="colors must be a tuple or a string, or a list thereof"): + bad_colors = np.array(['red', 'blue']) # should be a list + utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) + with pytest.raises(ValueError, match="It seems that you passed a tuple of colors instead of"): + bad_colors = ('red', 'blue') # should be a list + utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) if __name__ == '__main__': diff --git a/torchvision/utils.py b/torchvision/utils.py index 9d9bbdb3c80..0bbea6391cb 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -216,7 +216,7 @@ def draw_bounding_boxes( return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) -@torch.no_grad() +# @torch.no_grad() def draw_segmentation_masks( image: torch.Tensor, masks: torch.Tensor, @@ -229,49 +229,70 @@ def draw_segmentation_masks( The values of the input image should be uint8 between 0 and 255. Args: - image (Tensor): Tensor of shape (3 x H x W) and dtype uint8. - masks (Tensor): Tensor of shape (num_masks, H, W). Each containing probability of predicted class. - alpha (float): Float number between 0 and 1 denoting factor of transparency of masks. - colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can - be represented as `str` or `Tuple[int, int, int]`. + image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float. + masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. + alpha (float): Float number between 0 and 1 denoting the transparency of the masks. + colors (list or None): List containing the colors of the masks. The colors can + be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + When ``masks`` has a single entry of shape (H, W), you can pass a single color instead of a list + with one element. By default, random colors are generated for each mask. Returns: - img (Tensor[C, H, W]): Image Tensor of dtype uint8 with segmentation masks plotted. + img (Tensor[C, H, W]): Image Tensor with the same dtype as the input image, with segmentation masks + drawn on top. """ if not isinstance(image, torch.Tensor): - raise TypeError(f"Tensor expected, got {type(image)}") - elif image.dtype != torch.uint8: - raise ValueError(f"Tensor uint8 expected, got {image.dtype}") + raise TypeError(f"The image must be a tensor, got {type(image)}") + elif image.dtype not in (torch.uint8, torch.float): + raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size()[0] != 3: raise ValueError("Pass an RGB image. Other Image formats are not supported") + if masks.ndim == 2: + masks = masks[None, :, :] + if masks.ndim != 3: + raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") + if masks.dtype != torch.bool: + raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") + if masks.shape[-2:] != image.shape[-2:]: + raise ValueError(f"The image and the masks must have the same height and width") num_masks = masks.size()[0] - masks = masks.argmax(0) - + if colors is not None and num_masks > len(colors): + raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") + if colors is None: - palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) - colors_t = torch.as_tensor([i for i in range(num_masks)])[:, None] * palette - color_arr = (colors_t % 255).numpy().astype("uint8") - else: - color_list = [] - for color in colors: - if isinstance(color, str): - # This will automatically raise Error if rgb cannot be parsed. - fill_color = ImageColor.getrgb(color) - color_list.append(fill_color) - elif isinstance(color, tuple): - color_list.append(color) - - color_arr = np.array(color_list).astype("uint8") - - _, h, w = image.size() - img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize((w, h)) - img_to_draw.putpalette(color_arr) - - img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGB'))) - img_to_draw = img_to_draw.permute((2, 0, 1)) - - return (image.float() * alpha + img_to_draw.float() * (1.0 - alpha)).to(dtype=torch.uint8) + colors = _generate_color_palette(num_masks) + + if not isinstance(colors, list): + colors = [colors] + if not isinstance(colors[0], (tuple, str)): + raise ValueError("colors must be a tuple or a string, or a list thereof") + if isinstance(colors[0], tuple) and len(colors[0]) != 3: + raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") + + out_dtype = image.dtype + + colors_ = [] + for color in colors: + if isinstance(color, str): + color = ImageColor.getrgb(color) + color = torch.tensor(color, dtype=out_dtype) + if out_dtype == torch.float: + color /= 255 + colors_.append(color) + + img_to_draw = image.detach().clone() + # TODO: There might be a way to vectorize this + for mask, color in zip(masks, colors_): + img_to_draw[:, mask] = color[:, None] + + out = image * alpha + img_to_draw * (1 - alpha) + return out.to(out_dtype) + + +def _generate_color_palette(num_masks): + palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) + return [tuple((i * palette) % 255) for i in range(num_masks)] From 9cf6247cbb0a97ca13712980721d80d3f8fb7dce Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 12 May 2021 17:40:16 +0100 Subject: [PATCH 2/9] rm images --- gallery/plot_visualization_utils.py | 75 +++++------------- .../fakedata/draw_segm_masks_colors_util.png | Bin 88 -> 0 bytes .../draw_segm_masks_no_colors_util.png | Bin 106 -> 0 bytes 3 files changed, 20 insertions(+), 55 deletions(-) delete mode 100644 test/assets/fakedata/draw_segm_masks_colors_util.png delete mode 100644 test/assets/fakedata/draw_segm_masks_no_colors_util.png diff --git a/gallery/plot_visualization_utils.py b/gallery/plot_visualization_utils.py index 06f7162d2a5..f196f0cdd6d 100644 --- a/gallery/plot_visualization_utils.py +++ b/gallery/plot_visualization_utils.py @@ -24,8 +24,7 @@ def show(imgs): imgs = [imgs] fix, axs = plt.subplots(ncols=len(imgs), squeeze=False) for i, img in enumerate(imgs): - img = img.detach() - img = F.to_pil_image(img) + img = F.to_pil_image(img.to('cpu')) axs[0, i].imshow(np.asarray(img)) axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) @@ -51,8 +50,9 @@ def show(imgs): # Visualizing bounding boxes # -------------------------- # We can use :func:`~torchvision.utils.draw_bounding_boxes` to draw boxes on an -# image. We can set the colors, labels, width as well as font and font size. -# The boxes are in ``(xmin, ymin, xmax, ymax)`` format. +# image. We can set the colors, labels, width as well as font and font size ! +# The boxes are in ``(xmin, ymin, xmax, ymax)`` format +# from torchvision.utils import draw_bounding_boxes from torchvision.utils import draw_bounding_boxes @@ -99,68 +99,33 @@ def show(imgs): # Visualizing segmentation masks # ------------------------------ # The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to -# draw segmentation amasks on images. +# draw segmentation amasks on images. We can set the colors as well as +# transparency of masks. # -# We will see how to use it with torchvision's FCN Resnet-50, loaded with +# Here is demo with torchvision's FCN Resnet-50, loaded with # :func:`~torchvision.models.segmentation.fcn_resnet50`. # You can also try using # DeepLabv3 (:func:`~torchvision.models.segmentation.deeplabv3_resnet50`) # or lraspp mobilenet models # (:func:`~torchvision.models.segmentation.lraspp_mobilenet_v3_large`). # -# Let's start by looking at the ouput of the model. Remember that in general, -# images must be normalized before they're passed to the model. +# Like :func:`~torchvision.utils.draw_bounding_boxes`, +# :func:`~torchvision.utils.draw_segmentation_masks` requires a single RGB image +# of dtype `uint8`. -from torchvision.models.segmentation import fcn_resnet50 -from torchvision.utils import draw_segmentation_masks +# from torchvision.models.segmentation import fcn_resnet50 +# from torchvision.utils import draw_segmentation_masks -model = fcn_resnet50(pretrained=True, progress=False) -model = model.eval() - -normalized_batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) -output = model(normalized_batch)['out'] -print(output.shape, output.min().item(), output.max().item()) - -##################################### -# As we can see above, the output of the segmentation model is a tensor of shape -# ``(batch_size, num_classes, H, W)``. Each value is a non-normalized score -# and can normalize them into ``[0, 1]`` by using a softmax. After the softmax, -# we can interpret each value as a probability indicating how likely a given -# pixel is to belong to a given class. -# -# Let's plot the masks that have been detected for the dog class and for the -# boat class: - -seg_classes = [ - '__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', - 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', - 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' -] -seg_class_to_idx = {cls: idx for (idx, cls) in enumerate(seg_classes)} - -# We normalize the masks of each image in the batch independently -normalized_masks = torch.stack([torch.nn.Softmax(dim=0)(masks) for masks in output]) +# model = fcn_resnet50(pretrained=True, progress=False) +# model = model.eval() -dog_and_boat_masks = [ - normalized_masks[img_idx, seg_class_to_idx[cls]] - for img_idx in range(batch.shape[0]) - for cls in ('dog', 'boat') -] - -show(dog_and_boat_masks) - -##################################### -# As expected, the model is confident about the dog class, but not so much for -# the boat class. -# -# The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to -# plots those masks on top of the original image. This function expects the -# masks to be boolean masks, but our masks above contain probabilities in ``[0, -# 1]``. To get boolean masks, we can do the following: +# # The model expects the batch to be normalized +# batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) +# outputs = model(batch) -# dogs_with_dog_masks = [ -# draw_segmentation_masks(dog_int, masks=output[img_idx, seg_class_to_idx['dog']], alpha=0.6) -# for img_idx, dog_int in enumerate(dog1_int, dog2_int) +# dogs_with_masks = [ +# draw_segmentation_masks(dog_int, masks=masks, alpha=0.6) +# for dog_int, masks in zip((dog1_int, dog2_int), outputs['out']) # ] # show(dogs_with_masks) diff --git a/test/assets/fakedata/draw_segm_masks_colors_util.png b/test/assets/fakedata/draw_segm_masks_colors_util.png deleted file mode 100644 index 454b35556317dc1da1707fb234cf8563c1e8c707..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 88 zcmeAS@N?(olHy`uVBq!ia0vp^tRT$61SFYwH*Nw_@}4e^Ar*6yP5$MdX<(7~Fa6*B l;o_6Vh6|XE{XeGhiNULzE&Y{;O%+fngQu&X%Q~loCIFN+8JPe8 diff --git a/test/assets/fakedata/draw_segm_masks_no_colors_util.png b/test/assets/fakedata/draw_segm_masks_no_colors_util.png deleted file mode 100644 index f048d2469d2414d6e1e864111a6117a30a7d210b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 106 zcmeAS@N?(olHy`uVBq!ia0vp^A|TAc1SFYWcSQjyLr)jSkcv6UCi5Pib Date: Wed, 12 May 2021 17:41:56 +0100 Subject: [PATCH 3/9] cleanup --- test/test_utils.py | 4 ++-- torchvision/utils.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 7d85c0cb1af..e0689dc5342 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -173,7 +173,7 @@ def test_draw_segmentation_masks(dtype, colors, alpha): num_masks, h, w = 2, 100, 100 img = torch.randint(0, 256, size=(3, h, w), dtype=dtype) masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool) - + # For testing we enforce that there's no overlap between the masks. The # current behaviour is that the last mask's color will take priority when # masks overlap, but this makes testing slightly harder so we don't really @@ -208,7 +208,7 @@ def test_draw_segmentation_masks(dtype, colors, alpha): assert (out[:, mask] == color[:, None]).all() else: assert (out[:, mask] == img[:, mask]).all() - + def test_draw_segmentation_masks_int_vs_float(): """Make sure float and uint8 dtypes produce similar images""" diff --git a/torchvision/utils.py b/torchvision/utils.py index 0bbea6391cb..51a1f346036 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -216,7 +216,7 @@ def draw_bounding_boxes( return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) -# @torch.no_grad() +@torch.no_grad() def draw_segmentation_masks( image: torch.Tensor, masks: torch.Tensor, @@ -262,17 +262,17 @@ def draw_segmentation_masks( num_masks = masks.size()[0] if colors is not None and num_masks > len(colors): raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") - + if colors is None: colors = _generate_color_palette(num_masks) - + if not isinstance(colors, list): colors = [colors] if not isinstance(colors[0], (tuple, str)): raise ValueError("colors must be a tuple or a string, or a list thereof") if isinstance(colors[0], tuple) and len(colors[0]) != 3: raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") - + out_dtype = image.dtype colors_ = [] @@ -283,7 +283,7 @@ def draw_segmentation_masks( if out_dtype == torch.float: color /= 255 colors_.append(color) - + img_to_draw = image.detach().clone() # TODO: There might be a way to vectorize this for mask, color in zip(masks, colors_): From e52fb082a1a926b82d5fa0ed6bda5cf32a5f4264 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 12 May 2021 18:07:54 +0100 Subject: [PATCH 4/9] pep --- torchvision/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 51a1f346036..949aa3b6406 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -257,7 +257,7 @@ def draw_segmentation_masks( if masks.dtype != torch.bool: raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") if masks.shape[-2:] != image.shape[-2:]: - raise ValueError(f"The image and the masks must have the same height and width") + raise ValueError("The image and the masks must have the same height and width") num_masks = masks.size()[0] if colors is not None and num_masks > len(colors): From 131f1a444e08bfa7ee3874e43c907cda92b26a8d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 13 May 2021 08:30:01 +0100 Subject: [PATCH 5/9] temporarily remove mask example --- gallery/plot_visualization_utils.py | 36 +---------------------------- 1 file changed, 1 insertion(+), 35 deletions(-) diff --git a/gallery/plot_visualization_utils.py b/gallery/plot_visualization_utils.py index f196f0cdd6d..260b92270fd 100644 --- a/gallery/plot_visualization_utils.py +++ b/gallery/plot_visualization_utils.py @@ -77,6 +77,7 @@ def show(imgs): dog1_float = convert_image_dtype(dog1_int, dtype=torch.float) dog2_float = convert_image_dtype(dog2_int, dtype=torch.float) batch = torch.stack([dog1_float, dog2_float]) +batch = torch.stack([dog1_int, dog2_int]) model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False) model = model.eval() @@ -94,38 +95,3 @@ def show(imgs): for dog_int, output in zip((dog1_int, dog2_int), outputs) ] show(dogs_with_boxes) - -##################################### -# Visualizing segmentation masks -# ------------------------------ -# The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to -# draw segmentation amasks on images. We can set the colors as well as -# transparency of masks. -# -# Here is demo with torchvision's FCN Resnet-50, loaded with -# :func:`~torchvision.models.segmentation.fcn_resnet50`. -# You can also try using -# DeepLabv3 (:func:`~torchvision.models.segmentation.deeplabv3_resnet50`) -# or lraspp mobilenet models -# (:func:`~torchvision.models.segmentation.lraspp_mobilenet_v3_large`). -# -# Like :func:`~torchvision.utils.draw_bounding_boxes`, -# :func:`~torchvision.utils.draw_segmentation_masks` requires a single RGB image -# of dtype `uint8`. - -# from torchvision.models.segmentation import fcn_resnet50 -# from torchvision.utils import draw_segmentation_masks - - -# model = fcn_resnet50(pretrained=True, progress=False) -# model = model.eval() - -# # The model expects the batch to be normalized -# batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) -# outputs = model(batch) - -# dogs_with_masks = [ -# draw_segmentation_masks(dog_int, masks=masks, alpha=0.6) -# for dog_int, masks in zip((dog1_int, dog2_int), outputs['out']) -# ] -# show(dogs_with_masks) From 7876d47cb93f68a338e5af974490ee98e1ce3e9d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 13 May 2021 08:31:00 +0100 Subject: [PATCH 6/9] Add comment about float images expected in 0-1 --- torchvision/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 949aa3b6406..c8e333b5767 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -226,7 +226,7 @@ def draw_segmentation_masks( """ Draws segmentation masks on given RGB image. - The values of the input image should be uint8 between 0 and 255. + The values of the input image should be uint8 between 0 and 255, or float values between 0 and 1. Args: image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float. From e0e0fc679365ed03a620380dec93c7563c37ba7d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 13 May 2021 08:51:26 +0100 Subject: [PATCH 7/9] remove debug stuff --- gallery/plot_visualization_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gallery/plot_visualization_utils.py b/gallery/plot_visualization_utils.py index 260b92270fd..76e32bc586e 100644 --- a/gallery/plot_visualization_utils.py +++ b/gallery/plot_visualization_utils.py @@ -77,7 +77,6 @@ def show(imgs): dog1_float = convert_image_dtype(dog1_int, dtype=torch.float) dog2_float = convert_image_dtype(dog2_int, dtype=torch.float) batch = torch.stack([dog1_float, dog2_float]) -batch = torch.stack([dog1_int, dog2_int]) model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False) model = model.eval() From 5d287db8a20b727e5dc3e54903cdccf5bf9706a7 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 13 May 2021 09:11:11 +0100 Subject: [PATCH 8/9] Put back None testing --- test/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_utils.py b/test/test_utils.py index e0689dc5342..a223088121c 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -163,7 +163,7 @@ def test_draw_invalid_boxes(self): @pytest.mark.parametrize('dtype', (torch.float, torch.uint8)) @pytest.mark.parametrize('colors', [ - # None, + None, ['red', 'blue'], ['#FF00FF', (1, 34, 122)], ]) From f91d0709058b0a5acad20d8636b5209efb6004db Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 13 May 2021 09:46:28 +0100 Subject: [PATCH 9/9] more robust test --- test/test_utils.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index a223088121c..7644915f2cb 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -167,7 +167,7 @@ def test_draw_invalid_boxes(self): ['red', 'blue'], ['#FF00FF', (1, 34, 122)], ]) -@pytest.mark.parametrize('alpha', (0, 1)) +@pytest.mark.parametrize('alpha', (0, .5, .7, 1)) def test_draw_segmentation_masks(dtype, colors, alpha): """This test makes sure that masks draw their corresponding color where they should""" num_masks, h, w = 2, 100, 100 @@ -185,12 +185,6 @@ def test_draw_segmentation_masks(dtype, colors, alpha): assert out.dtype == dtype assert out is not img - if dtype == torch.float: - # makes comparisons below easier - img = F.convert_image_dtype(img, torch.uint8) - out = F.convert_image_dtype(out, torch.uint8) - img, out = img.float(), out.float() # avoids underflows etc. - # Make sure the image didn't change where there's no mask masked_pixels = masks[0] | masks[1] assert (img[:, ~masked_pixels] == out[:, ~masked_pixels]).all() @@ -203,12 +197,21 @@ def test_draw_segmentation_masks(dtype, colors, alpha): if isinstance(color, str): color = ImageColor.getrgb(color) color = torch.tensor(color, dtype=dtype) + if dtype == torch.float: + color /= 255 if alpha == 0: assert (out[:, mask] == color[:, None]).all() - else: + elif alpha == 1: assert (out[:, mask] == img[:, mask]).all() + interpolated_color = (img[:, mask] * alpha + color[:, None] * (1 - alpha)) + max_diff = (out[:, mask] - interpolated_color).abs().max() + if dtype == torch.uint8: + assert max_diff <= 1 + else: + assert max_diff <= 1e-5 + def test_draw_segmentation_masks_int_vs_float(): """Make sure float and uint8 dtypes produce similar images"""