diff --git a/docs/source/utils.rst b/docs/source/utils.rst index 651dc49a59a..4c2ece97708 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -14,5 +14,6 @@ vizualization `. draw_bounding_boxes draw_segmentation_masks + draw_keypoints make_grid save_image diff --git a/test/assets/fakedata/draw_keypoint_vanilla.png b/test/assets/fakedata/draw_keypoint_vanilla.png new file mode 100644 index 00000000000..6cd6d943b6c Binary files /dev/null and b/test/assets/fakedata/draw_keypoint_vanilla.png differ diff --git a/test/test_utils.py b/test/test_utils.py index 1439a974368..64f45c697c6 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -16,6 +16,8 @@ boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) +keypoints = torch.tensor([[[10, 10], [5, 5], [2, 2]], [[20, 20], [30, 30], [3, 3]]], dtype=torch.float) + def test_make_grid_not_inplace(): t = torch.rand(5, 3, 10, 10) @@ -248,5 +250,58 @@ def test_draw_segmentation_masks_errors(): utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) +def test_draw_keypoints_vanilla(): + # Keypoints is declared on top as global variable + keypoints_cp = keypoints.clone() + + img = torch.full((3, 100, 100), 0, dtype=torch.uint8) + img_cp = img.clone() + result = utils.draw_keypoints(img, keypoints, colors="red", connectivity=((0, 1),)) + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png") + if not os.path.exists(path): + res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) + res.save(path) + + expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) + assert_equal(result, expected) + # Check that keypoints are not modified inplace + assert_equal(keypoints, keypoints_cp) + # Check that image is not modified in place + assert_equal(img, img_cp) + + +@pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)]) +def test_draw_keypoints_colored(colors): + # Keypoints is declared on top as global variable + keypoints_cp = keypoints.clone() + + img = torch.full((3, 100, 100), 0, dtype=torch.uint8) + img_cp = img.clone() + result = utils.draw_keypoints(img, keypoints, colors=colors, connectivity=((0, 1),)) + assert result.size(0) == 3 + assert_equal(keypoints, keypoints_cp) + assert_equal(img, img_cp) + + +def test_draw_keypoints_errors(): + h, w = 10, 10 + img = torch.full((3, 100, 100), 0, dtype=torch.uint8) + + with pytest.raises(TypeError, match="The image must be a tensor"): + utils.draw_keypoints(image="Not A Tensor Image", keypoints=keypoints) + with pytest.raises(ValueError, match="The image dtype must be"): + img_bad_dtype = torch.full((3, h, w), 0, dtype=torch.int64) + utils.draw_keypoints(image=img_bad_dtype, keypoints=keypoints) + with pytest.raises(ValueError, match="Pass individual images, not batches"): + batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8) + utils.draw_keypoints(image=batch, keypoints=keypoints) + with pytest.raises(ValueError, match="Pass an RGB image"): + one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8) + utils.draw_keypoints(image=one_channel, keypoints=keypoints) + with pytest.raises(ValueError, match="keypoints must be of shape"): + invalid_keypoints = torch.tensor([[10, 10, 10, 10], [5, 6, 7, 8]], dtype=torch.float) + utils.draw_keypoints(image=img, keypoints=invalid_keypoints) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchvision/utils.py b/torchvision/utils.py index 9dd3c9c5ab3..6c29767a7ce 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -7,7 +7,7 @@ import torch from PIL import Image, ImageDraw, ImageFont, ImageColor -__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"] +__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks", "draw_keypoints"] @torch.no_grad() @@ -300,6 +300,76 @@ def draw_segmentation_masks( return out.to(out_dtype) +@torch.no_grad() +def draw_keypoints( + image: torch.Tensor, + keypoints: torch.Tensor, + connectivity: Optional[Tuple[Tuple[int, int]]] = None, + colors: Optional[Union[str, Tuple[int, int, int]]] = None, + radius: int = 2, + width: int = 3, +) -> torch.Tensor: + + """ + Draws Keypoints on given RGB image. + The values of the input image should be uint8 between 0 and 255. + + Args: + image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, + in the format [x, y]. + connectivity (Tuple[Tuple[int, int]]]): A Tuple of tuple where, + each tuple contains pair of keypoints to be connected. + colors (str, Tuple): The color can be represented as + PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + radius (int): Integer denoting radius of keypoint. + width (int): Integer denoting width of line connecting keypoints. + + Returns: + img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. + """ + + if not isinstance(image, torch.Tensor): + raise TypeError(f"The image must be a tensor, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"The image dtype must be uint8, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + elif image.size()[0] != 3: + raise ValueError("Pass an RGB image. Other Image formats are not supported") + + if keypoints.ndim != 3: + raise ValueError("keypoints must be of shape (num_instances, K, 2)") + + ndarr = image.permute(1, 2, 0).numpy() + img_to_draw = Image.fromarray(ndarr) + draw = ImageDraw.Draw(img_to_draw) + img_kpts = keypoints.to(torch.int64).tolist() + + for kpt_id, kpt_inst in enumerate(img_kpts): + for inst_id, kpt in enumerate(kpt_inst): + x1 = kpt[0] - radius + x2 = kpt[0] + radius + y1 = kpt[1] - radius + y2 = kpt[1] + radius + draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) + + if connectivity: + for connection in connectivity: + start_pt_x = kpt_inst[connection[0]][0] + start_pt_y = kpt_inst[connection[0]][1] + + end_pt_x = kpt_inst[connection[1]][0] + end_pt_y = kpt_inst[connection[1]][1] + + draw.line( + ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)), + width=width, + ) + + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) + + def _generate_color_palette(num_masks: int): palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) return [tuple((i * palette) % 255) for i in range(num_masks)]