diff --git a/test/test_utils.py b/test/test_utils.py index 32b3db59631..b13bd0f0f5b 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -9,7 +9,7 @@ import torch import torchvision.transforms.functional as F import torchvision.utils as utils -from common_utils import assert_equal +from common_utils import assert_equal, cpu_and_cuda from PIL import __version__ as PILLOW_VERSION, Image, ImageColor @@ -203,12 +203,13 @@ def test_draw_no_boxes(): ], ) @pytest.mark.parametrize("alpha", (0, 0.5, 0.7, 1)) -def test_draw_segmentation_masks(colors, alpha): +@pytest.mark.parametrize("device", cpu_and_cuda()) +def test_draw_segmentation_masks(colors, alpha, device): """This test makes sure that masks draw their corresponding color where they should""" num_masks, h, w = 2, 100, 100 dtype = torch.uint8 - img = torch.randint(0, 256, size=(3, h, w), dtype=dtype) - masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool) + img = torch.randint(0, 256, size=(3, h, w), dtype=dtype, device=device) + masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool, device=device) # For testing we enforce that there's no overlap between the masks. The # current behaviour is that the last mask's color will take priority when @@ -234,7 +235,7 @@ def test_draw_segmentation_masks(colors, alpha): for mask, color in zip(masks, colors): if isinstance(color, str): color = ImageColor.getrgb(color) - color = torch.tensor(color, dtype=dtype) + color = torch.tensor(color, dtype=dtype, device=device) if alpha == 1: assert (out[:, mask] == color[:, None]).all() @@ -245,11 +246,12 @@ def test_draw_segmentation_masks(colors, alpha): torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0) -def test_draw_segmentation_masks_errors(): +@pytest.mark.parametrize("device", cpu_and_cuda()) +def test_draw_segmentation_masks_errors(device): h, w = 10, 10 - masks = torch.randint(0, 2, size=(h, w), dtype=torch.bool) - img = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8) + masks = torch.randint(0, 2, size=(h, w), dtype=torch.bool, device=device) + img = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8, device=device) with pytest.raises(TypeError, match="The image must be a tensor"): utils.draw_segmentation_masks(image="Not A Tensor Image", masks=masks) @@ -281,9 +283,10 @@ def test_draw_segmentation_masks_errors(): utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) -def test_draw_no_segmention_mask(): - img = torch.full((3, 100, 100), 0, dtype=torch.uint8) - masks = torch.full((0, 100, 100), 0, dtype=torch.bool) +@pytest.mark.parametrize("device", cpu_and_cuda()) +def test_draw_no_segmention_mask(device): + img = torch.full((3, 100, 100), 0, dtype=torch.uint8, device=device) + masks = torch.full((0, 100, 100), 0, dtype=torch.bool, device=device) with pytest.warns(UserWarning, match=re.escape("masks doesn't contain any mask. No mask was drawn")): res = utils.draw_segmentation_masks(img, masks) # Check that the function didn't change the image diff --git a/torchvision/utils.py b/torchvision/utils.py index 1418656a7f2..6ec19a0e0a1 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -304,7 +304,10 @@ def draw_segmentation_masks( return image out_dtype = torch.uint8 - colors = [torch.tensor(color, dtype=out_dtype) for color in _parse_colors(colors, num_objects=num_masks)] + colors = [ + torch.tensor(color, dtype=out_dtype, device=image.device) + for color in _parse_colors(colors, num_objects=num_masks) + ] img_to_draw = image.detach().clone() # TODO: There might be a way to vectorize this