Skip to content

Commit

Permalink
support of float dtypes for draw_segmentation_masks (#8150)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
  • Loading branch information
GsnMithra and NicolasHug authored Dec 18, 2023
1 parent c35d385 commit 6c2e0ae
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 9 deletions.
21 changes: 21 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torchvision.utils as utils
from common_utils import assert_equal, cpu_and_cuda
from PIL import __version__ as PILLOW_VERSION, Image, ImageColor
from torchvision.transforms.v2.functional import to_dtype


PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
Expand Down Expand Up @@ -246,6 +247,26 @@ def test_draw_segmentation_masks(colors, alpha, device):
torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0)


def test_draw_segmentation_masks_dtypes():
num_masks, h, w = 2, 100, 100

masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool)

img_uint8 = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8)
out_uint8 = utils.draw_segmentation_masks(img_uint8, masks)

assert img_uint8 is not out_uint8
assert out_uint8.dtype == torch.uint8

img_float = to_dtype(img_uint8, torch.float, scale=True)
out_float = utils.draw_segmentation_masks(img_float, masks)

assert img_float is not out_float
assert out_float.is_floating_point()

torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)


@pytest.mark.parametrize("device", cpu_and_cuda())
def test_draw_segmentation_masks_errors(device):
h, w = 10, 10
Expand Down
24 changes: 15 additions & 9 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from PIL import Image, ImageColor, ImageDraw, ImageFont


__all__ = [
"make_grid",
"save_image",
Expand Down Expand Up @@ -262,10 +263,10 @@ def draw_segmentation_masks(

"""
Draws segmentation masks on given RGB image.
The values of the input image should be uint8 between 0 and 255.
The image values should be uint8 in [0, 255] or float in [0, 1].
Args:
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float.
masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
0 means full transparency, 1 means no transparency.
Expand All @@ -282,8 +283,8 @@ def draw_segmentation_masks(
_log_api_usage_once(draw_segmentation_masks)
if not isinstance(image, torch.Tensor):
raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
elif not (image.dtype == torch.uint8 or image.is_floating_point()):
raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size()[0] != 3:
Expand All @@ -303,10 +304,10 @@ def draw_segmentation_masks(
warnings.warn("masks doesn't contain any mask. No mask was drawn")
return image

out_dtype = torch.uint8
original_dtype = image.dtype
colors = [
torch.tensor(color, dtype=out_dtype, device=image.device)
for color in _parse_colors(colors, num_objects=num_masks)
torch.tensor(color, dtype=original_dtype, device=image.device)
for color in _parse_colors(colors, num_objects=num_masks, dtype=original_dtype)
]

img_to_draw = image.detach().clone()
Expand All @@ -315,7 +316,8 @@ def draw_segmentation_masks(
img_to_draw[:, mask] = color[:, None]

out = image * (1 - alpha) + img_to_draw * alpha
return out.to(out_dtype)
# Note: at this point, out is a float tensor in [0, 1] or [0, 255] depending on original_dtype
return out.to(original_dtype)


@torch.no_grad()
Expand Down Expand Up @@ -516,6 +518,7 @@ def _parse_colors(
colors: Union[None, str, Tuple[int, int, int], List[Union[str, Tuple[int, int, int]]]],
*,
num_objects: int,
dtype: torch.dtype = torch.uint8,
) -> List[Tuple[int, int, int]]:
"""
Parses a specification of colors for a set of objects.
Expand Down Expand Up @@ -552,7 +555,10 @@ def _parse_colors(
else: # colors specifies a single color for all objects
colors = [colors] * num_objects

return [ImageColor.getrgb(color) if isinstance(color, str) else color for color in colors]
colors = [ImageColor.getrgb(color) if isinstance(color, str) else color for color in colors]
if dtype.is_floating_point: # [0, 255] -> [0, 1]
colors = [tuple(v / 255 for v in color) for color in colors]
return colors


def _log_api_usage_once(obj: Any) -> None:
Expand Down

0 comments on commit 6c2e0ae

Please sign in to comment.