From b2f6615a59c1047a91affef2ed893d38839c7e27 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Thu, 20 May 2021 12:47:27 +0530 Subject: [PATCH 01/17] fix --- docs/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Makefile b/docs/Makefile index 58daa471f5c..ef4cdc0cee0 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -2,7 +2,7 @@ # # You can set these variables from the command line. -SPHINXOPTS = -W # turn warnings into errors +SPHINXOPTS = # turn warnings into errors SPHINXBUILD = sphinx-build SPHINXPROJ = torchvision SOURCEDIR = source From 73abf840adb49d9f71a19785d1bb9b8d56e1735a Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Wed, 28 Jul 2021 16:15:15 +0530 Subject: [PATCH 02/17] Outline Keypoints API --- docs/source/utils.rst | 2 ++ torchvision/utils.py | 54 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/docs/source/utils.rst b/docs/source/utils.rst index acaf785d817..8dadd7465b3 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -10,3 +10,5 @@ torchvision.utils .. autofunction:: draw_bounding_boxes .. autofunction:: draw_segmentation_masks + +.. autofunction:: draw_keypoints diff --git a/torchvision/utils.py b/torchvision/utils.py index 494661e6ad8..f392d6fa3dc 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -6,7 +6,8 @@ import numpy as np from PIL import Image, ImageDraw, ImageFont, ImageColor -__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"] +__all__ = ["make_grid", "save_image", "draw_bounding_boxes", + "draw_segmentation_masks", "draw_keypoints"] @torch.no_grad() @@ -300,6 +301,57 @@ def draw_segmentation_masks( return out.to(out_dtype) +@torch.no_grad() +def draw_keypoints( + image: torch.Tensor, + keypoints: torch.Tensor, + labels: Optional[List[str]] = None, + colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, + radius: Optional[int] = 5, + connect: Optional[bool] = False, + font: Optional[str] = None, + font_size: int = 10 +) -> torch.Tensor: + + """ + Draws Keypoints on given RGB image. + The values of the input image should be uint8 between 0 and 255. + + Args: + image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + keypoints (Tensor): Tensor of shape (num_keypoints, H, W) or (H, W) and dtype bool. + labels(List[str]): List containing the labels for each Keypoint. + colors (Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]): List containing the colors + or a single color for all the keypoints. + The colors can be represented as `str` or `Tuple[int, int, int]`. + radius (int): Integer denoting radius of keypoint. + connect (bool): If True. It connects all the visible keypints. + font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may + also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, + `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. + font_size (int): The requested font size in points. + + Returns: + img (Tensor[C, H, W]): Image Tensor, with keypoints drawn. + """ + + if not isinstance(image, torch.Tensor): + raise TypeError(f"Tensor expected, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"Tensor uint8 expected, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + elif image.size()[0] != 3: + raise ValueError("Pass an RGB image. Other Image formats are not supported") + + ndarr = image.permute(1, 2, 0).numpy() + img_to_draw = Image.fromarray(ndarr) + + # txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) + + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) + + def _generate_color_palette(num_masks): palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) return [tuple((i * palette) % 255) for i in range(num_masks)] From 2ebe0c72ce0530548a9eaa823109b4ff39ef7ca0 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Wed, 28 Jul 2021 18:36:36 +0530 Subject: [PATCH 03/17] Add utility --- torchvision/utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index f392d6fa3dc..a37fbae211e 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -319,8 +319,9 @@ def draw_keypoints( Args: image (Tensor): Tensor of shape (3, H, W) and dtype uint8. - keypoints (Tensor): Tensor of shape (num_keypoints, H, W) or (H, W) and dtype bool. - labels(List[str]): List containing the labels for each Keypoint. + keypoints (Tensor): Tensor of shape (num_keypoints, K, 3) the K keypoints location for each of the N instances, + in the format [x, y, visibility], where visibility=0 means that the keypoint is not visible. + labels (List[str]): List containing the labels for each Keypoint. colors (Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]): List containing the colors or a single color for all the keypoints. The colors can be represented as `str` or `Tuple[int, int, int]`. @@ -332,7 +333,7 @@ def draw_keypoints( font_size (int): The requested font size in points. Returns: - img (Tensor[C, H, W]): Image Tensor, with keypoints drawn. + img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. """ if not isinstance(image, torch.Tensor): @@ -346,10 +347,11 @@ def draw_keypoints( ndarr = image.permute(1, 2, 0).numpy() img_to_draw = Image.fromarray(ndarr) + out_dtype = torch.uint8 # txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) - return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=out_dtype) def _generate_color_palette(num_masks): From 77afa81502030e7904f72156f81b840640a5fded Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Wed, 28 Jul 2021 22:49:15 +0530 Subject: [PATCH 04/17] make it work :) --- torchvision/utils.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index a37fbae211e..dc0457effdd 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -307,7 +307,7 @@ def draw_keypoints( keypoints: torch.Tensor, labels: Optional[List[str]] = None, colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, - radius: Optional[int] = 5, + radius: Optional[int] = 2, connect: Optional[bool] = False, font: Optional[str] = None, font_size: int = 10 @@ -347,8 +347,19 @@ def draw_keypoints( ndarr = image.permute(1, 2, 0).numpy() img_to_draw = Image.fromarray(ndarr) + draw = ImageDraw.Draw(img_to_draw) out_dtype = torch.uint8 + img_kpts = keypoints.to(torch.int64).tolist() + + for i, kpt_inst in enumerate(img_kpts): + for kpt in kpt_inst: + x1 = kpt[0] - radius + x2 = kpt[0] + radius + y1 = kpt[1] - radius + y2 = kpt[1] + radius + draw.ellipse([x1, y1, x2, y2], fill="red", outline=None, width=0) + # txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=out_dtype) From cdeebdde681b5a227c0d96eb4ae07ed3fbd49e46 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Thu, 29 Jul 2021 14:12:53 +0530 Subject: [PATCH 05/17] Fix optional type --- torchvision/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index dc0457effdd..3bb3374b448 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -307,8 +307,8 @@ def draw_keypoints( keypoints: torch.Tensor, labels: Optional[List[str]] = None, colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, - radius: Optional[int] = 2, - connect: Optional[bool] = False, + radius: int = 2, + connect: bool = False, font: Optional[str] = None, font_size: int = 10 ) -> torch.Tensor: From 1291f447a139afb433e69c9a2652b88374181402 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Sun, 12 Sep 2021 21:50:12 +0530 Subject: [PATCH 06/17] Add connectivity, fmassa's advice :smiley: --- torchvision/utils.py | 46 +++++++++++++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 3bb3374b448..f7cd8b2eff6 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -303,14 +303,15 @@ def draw_segmentation_masks( @torch.no_grad() def draw_keypoints( - image: torch.Tensor, - keypoints: torch.Tensor, - labels: Optional[List[str]] = None, - colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, - radius: int = 2, - connect: bool = False, - font: Optional[str] = None, - font_size: int = 10 + image: torch.Tensor, + keypoints: torch.Tensor, + labels: Optional[List[str]] = None, + colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, + radius: int = 2, + width: int = 3, + connectivity: Optional[Tuple[Tuple[int, int]]] = None, + font: Optional[str] = None, + font_size: int = 10 ) -> torch.Tensor: """ @@ -319,14 +320,16 @@ def draw_keypoints( Args: image (Tensor): Tensor of shape (3, H, W) and dtype uint8. - keypoints (Tensor): Tensor of shape (num_keypoints, K, 3) the K keypoints location for each of the N instances, - in the format [x, y, visibility], where visibility=0 means that the keypoint is not visible. + keypoints (Tensor): Tensor of shape (num_instances, K, 3) the K keypoints location for each of the N instances, + in the format [x, y, visibility], where `visibility=0` means that the keypoint is not visible. labels (List[str]): List containing the labels for each Keypoint. colors (Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]): List containing the colors or a single color for all the keypoints. The colors can be represented as `str` or `Tuple[int, int, int]`. radius (int): Integer denoting radius of keypoint. - connect (bool): If True. It connects all the visible keypints. + width (int): Integer denoting width of line connecting keypoints. + connectivity (Tuple[Tuple[int, int]]]): A Tuple of tuple where, + each tuple contains pair of keypoints to be connected. font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. @@ -352,13 +355,30 @@ def draw_keypoints( img_kpts = keypoints.to(torch.int64).tolist() - for i, kpt_inst in enumerate(img_kpts): + for kpt_inst in img_kpts: for kpt in kpt_inst: x1 = kpt[0] - radius x2 = kpt[0] + radius y1 = kpt[1] - radius y2 = kpt[1] + radius - draw.ellipse([x1, y1, x2, y2], fill="red", outline=None, width=0) + + if isinstance(colors, str) or isinstance(colors, tuple): + draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) + + for kpt_inst in img_kpts: + for connection in connectivity: + start_pt_x = kpt_inst[connection[0]][0] + start_pt_y = kpt_inst[connection[0]][1] + + end_pt_x = kpt_inst[connection[1]][0] + end_pt_y = kpt_inst[connection[1]][1] + + draw.line( + ( + (start_pt_x, start_pt_y), (end_pt_x, end_pt_y) + ), + width=width, + ) # txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) From b9af8740fcbb98c161d790aab1cf3a085ad5d699 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Sun, 12 Sep 2021 22:11:03 +0530 Subject: [PATCH 07/17] Minor code improvement --- torchvision/utils.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index f7cd8b2eff6..26ebc830200 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -305,11 +305,11 @@ def draw_segmentation_masks( def draw_keypoints( image: torch.Tensor, keypoints: torch.Tensor, + connectivity: Optional[Tuple[Tuple[int, int]]] = None, labels: Optional[List[str]] = None, colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, radius: int = 2, width: int = 3, - connectivity: Optional[Tuple[Tuple[int, int]]] = None, font: Optional[str] = None, font_size: int = 10 ) -> torch.Tensor: @@ -365,20 +365,20 @@ def draw_keypoints( if isinstance(colors, str) or isinstance(colors, tuple): draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) - for kpt_inst in img_kpts: - for connection in connectivity: - start_pt_x = kpt_inst[connection[0]][0] - start_pt_y = kpt_inst[connection[0]][1] - - end_pt_x = kpt_inst[connection[1]][0] - end_pt_y = kpt_inst[connection[1]][1] - - draw.line( - ( - (start_pt_x, start_pt_y), (end_pt_x, end_pt_y) - ), - width=width, - ) + if connectivity: + for connection in connectivity: + start_pt_x = kpt_inst[connection[0]][0] + start_pt_y = kpt_inst[connection[0]][1] + + end_pt_x = kpt_inst[connection[1]][0] + end_pt_y = kpt_inst[connection[1]][1] + + draw.line( + ( + (start_pt_x, start_pt_y), (end_pt_x, end_pt_y) + ), + width=width, + ) # txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) From 28ebcc0ed2ccddcd84c387b4a6d12ad11ac66770 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Mon, 13 Sep 2021 20:06:55 +0530 Subject: [PATCH 08/17] small fix --- torchvision/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 26ebc830200..4987312c88f 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -322,14 +322,14 @@ def draw_keypoints( image (Tensor): Tensor of shape (3, H, W) and dtype uint8. keypoints (Tensor): Tensor of shape (num_instances, K, 3) the K keypoints location for each of the N instances, in the format [x, y, visibility], where `visibility=0` means that the keypoint is not visible. + connectivity (Tuple[Tuple[int, int]]]): A Tuple of tuple where, + each tuple contains pair of keypoints to be connected. labels (List[str]): List containing the labels for each Keypoint. colors (Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]): List containing the colors or a single color for all the keypoints. The colors can be represented as `str` or `Tuple[int, int, int]`. radius (int): Integer denoting radius of keypoint. width (int): Integer denoting width of line connecting keypoints. - connectivity (Tuple[Tuple[int, int]]]): A Tuple of tuple where, - each tuple contains pair of keypoints to be connected. font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. From ebe7a25764fa78896320604008afcbac70420fd3 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Sat, 2 Oct 2021 21:22:13 +0530 Subject: [PATCH 09/17] fix implementation --- torchvision/utils.py | 33 +++++++++++++-------------------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index 4987312c88f..a714ee265ab 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -303,15 +303,12 @@ def draw_segmentation_masks( @torch.no_grad() def draw_keypoints( - image: torch.Tensor, - keypoints: torch.Tensor, - connectivity: Optional[Tuple[Tuple[int, int]]] = None, - labels: Optional[List[str]] = None, - colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, - radius: int = 2, - width: int = 3, - font: Optional[str] = None, - font_size: int = 10 + image: torch.Tensor, + keypoints: torch.Tensor, + connectivity: Optional[Tuple[Tuple[int, int]]] = None, + colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, + radius: int = 2, + width: int = 3, ) -> torch.Tensor: """ @@ -320,34 +317,32 @@ def draw_keypoints( Args: image (Tensor): Tensor of shape (3, H, W) and dtype uint8. - keypoints (Tensor): Tensor of shape (num_instances, K, 3) the K keypoints location for each of the N instances, - in the format [x, y, visibility], where `visibility=0` means that the keypoint is not visible. + keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, + in the format [x, y]. connectivity (Tuple[Tuple[int, int]]]): A Tuple of tuple where, each tuple contains pair of keypoints to be connected. - labels (List[str]): List containing the labels for each Keypoint. colors (Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]): List containing the colors or a single color for all the keypoints. The colors can be represented as `str` or `Tuple[int, int, int]`. radius (int): Integer denoting radius of keypoint. width (int): Integer denoting width of line connecting keypoints. - font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may - also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, - `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. - font_size (int): The requested font size in points. Returns: img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. """ if not isinstance(image, torch.Tensor): - raise TypeError(f"Tensor expected, got {type(image)}") + raise TypeError(f"The image must be a tensor, got {type(image)}") elif image.dtype != torch.uint8: - raise ValueError(f"Tensor uint8 expected, got {image.dtype}") + raise ValueError(f"The image dtype must be uint8, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size()[0] != 3: raise ValueError("Pass an RGB image. Other Image formats are not supported") + if keypoints.ndim != 3: + raise ValueError("keypoints must be of shape (num_instances, K, 2)") + ndarr = image.permute(1, 2, 0).numpy() img_to_draw = Image.fromarray(ndarr) draw = ImageDraw.Draw(img_to_draw) @@ -380,8 +375,6 @@ def draw_keypoints( width=width, ) - # txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) - return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=out_dtype) From e6afa37af2c91debf673256cb5c997c65a8bda38 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Sat, 2 Oct 2021 21:24:46 +0530 Subject: [PATCH 10/17] Add tests --- .../assets/fakedata/draw_keypoint_vanilla.png | Bin 0 -> 283 bytes test/test_utils.py | 69 ++++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 test/assets/fakedata/draw_keypoint_vanilla.png diff --git a/test/assets/fakedata/draw_keypoint_vanilla.png b/test/assets/fakedata/draw_keypoint_vanilla.png new file mode 100644 index 0000000000000000000000000000000000000000..8cd34f84539c9b5a054274f21c176b0e2bd9ee60 GIT binary patch literal 283 zcmeAS@N?(olHy`uVBq!ia0vp^DIm(xHTrHeIoMt$43>Tzjm zZ>Z9uH|{gys(S;fWk*s!7=8`MNLb>0sV0cpS-*Nhf z$e(|0Uk9DDzhj=fSM~ODn~tgZ^NWqc&efi~{IVv)F_RoWq% zS@pG3)Uv8{-tpx3?|08#bMwaU-{&?az4%-EKSJ>EyZ7~9tA*aF0R4jw5}vaj=CxOO SzHhq#NXpaI&t;ucLK6Ui9D7p$ literal 0 HcmV?d00001 diff --git a/test/test_utils.py b/test/test_utils.py index 37829b906f1..46c5dc3da57 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -17,6 +17,16 @@ boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) +keypoints = torch.tensor( + [ + [ + [10, 10, 1.0], [5, 5, 1.0] + ], + [ + [20, 20, 1.0], [30, 30, 1.0] + ] + ], dtype=torch.float) + def test_make_grid_not_inplace(): t = torch.rand(5, 3, 10, 10) @@ -248,5 +258,64 @@ def test_draw_segmentation_masks_errors(): utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) +def test_draw_keypoints_vanilla(): + # Keypoints is declared on top as global variable + keypoints_cp = keypoints.clone() + + img = torch.full((3, 100, 100), 0, dtype=torch.uint8) + img_cp = img.clone() + result = utils.draw_keypoints(img, keypoints, colors="red", connectivity=((0, 1), )) + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png") + if not os.path.exists(path): + res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) + res.save(path) + + expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) + assert_equal(result, expected) + # Check that keypoints are not modified inplace + assert_equal(keypoints, keypoints_cp) + # Check that image is not modified in place + assert_equal(img, img_cp) + + +@pytest.mark.parametrize('colors', [ + 'red', + '#FF00FF', + (1, 34, 122) +]) +def test_draw_keypoints_colored(colors): + # Keypoints is declared on top as global variable + keypoints_cp = keypoints.clone() + + img = torch.full((3, 100, 100), 0, dtype=torch.uint8) + img_cp = img.clone() + result = utils.draw_keypoints(img, keypoints, colors=colors, connectivity=((0, 1), )) + assert result.size(0) == 3 + assert_equal(keypoints, keypoints_cp) + assert_equal(img, img_cp) + + +def test_draw_keypoints_errors(): + h, w = 10, 10 + + img = torch.full((3, 100, 100), 0, dtype=torch.uint8) + keypoints = torch.tensor([[[10, 10, 1.0], [5, 5, 1.0]], [[20, 20, 1.0], [30, 30, 1.0]]], dtype=torch.float) + + with pytest.raises(TypeError, match="The image must be a tensor"): + utils.draw_keypoints(image="Not A Tensor Image", keypoints=keypoints) + with pytest.raises(ValueError, match="The image dtype must be"): + img_bad_dtype = torch.full((3, h, w), 0, dtype=torch.int64) + utils.draw_keypoints(image=img_bad_dtype, keypoints=keypoints) + with pytest.raises(ValueError, match="Pass individual images, not batches"): + batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8) + utils.draw_keypoints(image=batch, keypoints=keypoints) + with pytest.raises(ValueError, match="Pass an RGB image"): + one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8) + utils.draw_keypoints(image=one_channel, keypoints=keypoints) + with pytest.raises(ValueError, match="keypoints must be of shape"): + invalid_keypoints = torch.tensor([[10, 10, 10, 10], [5, 6, 7, 8]], dtype=torch.float) + utils.draw_keypoints(image=img, keypoints=invalid_keypoints) + + if __name__ == "__main__": pytest.main([__file__]) From 691562ccf2d3814aacd985eabe5e9adde7522527 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 5 Oct 2021 23:29:24 +0530 Subject: [PATCH 11/17] Fix tests --- test/test_utils.py | 22 ++++------------------ torchvision/utils.py | 13 +++++-------- 2 files changed, 9 insertions(+), 26 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 7552e346861..b1ad8fdb888 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -16,15 +16,7 @@ boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) -keypoints = torch.tensor( - [ - [ - [10, 10, 1.0], [5, 5, 1.0] - ], - [ - [20, 20, 1.0], [30, 30, 1.0] - ] - ], dtype=torch.float) +keypoints = torch.tensor([[[10, 10], [5, 5]], [[20, 20], [30, 30]]], dtype=torch.float) def test_make_grid_not_inplace(): @@ -259,7 +251,7 @@ def test_draw_keypoints_vanilla(): img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img_cp = img.clone() - result = utils.draw_keypoints(img, keypoints, colors="red", connectivity=((0, 1), )) + result = utils.draw_keypoints(img, keypoints, colors="red", connectivity=((0, 1),)) path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png") if not os.path.exists(path): res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) @@ -273,18 +265,14 @@ def test_draw_keypoints_vanilla(): assert_equal(img, img_cp) -@pytest.mark.parametrize('colors', [ - 'red', - '#FF00FF', - (1, 34, 122) -]) +@pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)]) def test_draw_keypoints_colored(colors): # Keypoints is declared on top as global variable keypoints_cp = keypoints.clone() img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img_cp = img.clone() - result = utils.draw_keypoints(img, keypoints, colors=colors, connectivity=((0, 1), )) + result = utils.draw_keypoints(img, keypoints, colors=colors, connectivity=((0, 1),)) assert result.size(0) == 3 assert_equal(keypoints, keypoints_cp) assert_equal(img, img_cp) @@ -292,9 +280,7 @@ def test_draw_keypoints_colored(colors): def test_draw_keypoints_errors(): h, w = 10, 10 - img = torch.full((3, 100, 100), 0, dtype=torch.uint8) - keypoints = torch.tensor([[[10, 10, 1.0], [5, 5, 1.0]], [[20, 20, 1.0], [30, 30, 1.0]]], dtype=torch.float) with pytest.raises(TypeError, match="The image must be a tensor"): utils.draw_keypoints(image="Not A Tensor Image", keypoints=keypoints) diff --git a/torchvision/utils.py b/torchvision/utils.py index 06c03343672..f0588108ab6 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -7,8 +7,7 @@ import torch from PIL import Image, ImageDraw, ImageFont, ImageColor -__all__ = ["make_grid", "save_image", "draw_bounding_boxes", - "draw_segmentation_masks", "draw_keypoints"] +__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks", "draw_keypoints"] @torch.no_grad() @@ -322,9 +321,9 @@ def draw_keypoints( in the format [x, y]. connectivity (Tuple[Tuple[int, int]]]): A Tuple of tuple where, each tuple contains pair of keypoints to be connected. - colors (Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]): List containing the colors - or a single color for all the keypoints. - The colors can be represented as `str` or `Tuple[int, int, int]`. + colors (color or list of colors, optional): List containing the colors + of the keypoints or single color for all keypoints. The color can be represented as + PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. radius (int): Integer denoting radius of keypoint. width (int): Integer denoting width of line connecting keypoints. @@ -370,9 +369,7 @@ def draw_keypoints( end_pt_y = kpt_inst[connection[1]][1] draw.line( - ( - (start_pt_x, start_pt_y), (end_pt_x, end_pt_y) - ), + ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)), width=width, ) From a0341377b586fd9bcae6d6f1227f66f62f1aedfc Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Thu, 28 Oct 2021 22:01:55 +0530 Subject: [PATCH 12/17] Update colors --- test/test_utils.py | 4 +++- torchvision/utils.py | 29 ++++++++++++++++++++++++----- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 8b0db317ee8..79d0e9f564d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -270,7 +270,9 @@ def test_draw_keypoints_vanilla(): assert_equal(img, img_cp) -@pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)]) +@pytest.mark.parametrize( + "colors", ["red", "#FF00FF", (1, 34, 122), ["red", "blue"], [["red", "blue"], ["orange", "green"]]] +) def test_draw_keypoints_colored(colors): # Keypoints is declared on top as global variable keypoints_cp = keypoints.clone() diff --git a/torchvision/utils.py b/torchvision/utils.py index 98624da2de6..b3f4e7624ce 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -305,7 +305,14 @@ def draw_keypoints( image: torch.Tensor, keypoints: torch.Tensor, connectivity: Optional[Tuple[Tuple[int, int]]] = None, - colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, + colors: Optional[ + Union[ + List[Union[str, Tuple[int, int, int]]], + str, + Tuple[int, int, int], + List[List[Union[str, Tuple[int, int, int]]]], + ] + ] = None, radius: int = 2, width: int = 3, ) -> torch.Tensor: @@ -320,8 +327,9 @@ def draw_keypoints( in the format [x, y]. connectivity (Tuple[Tuple[int, int]]]): A Tuple of tuple where, each tuple contains pair of keypoints to be connected. - colors (color or list of colors, optional): List containing the colors - of the keypoints or single color for all keypoints. The color can be represented as + colors (color or list of colors or list of list containing colors, optional): List containing the colors + of each instance of keypoints or a nested list containing color for every keypoint id + or single color for all keypoints. The color can be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. radius (int): Integer denoting radius of keypoint. width (int): Integer denoting width of line connecting keypoints. @@ -349,8 +357,16 @@ def draw_keypoints( img_kpts = keypoints.to(torch.int64).tolist() - for kpt_inst in img_kpts: - for kpt in kpt_inst: + # Specifying color for every keypoint id + if isinstance(colors[0], list): + keypoints_id_color = True + else: + keypoints_id_color = False + + for i, kpt_inst in enumerate(img_kpts): + if keypoints_id_color: + colors = colors[i] + for inst_id, kpt in enumerate(kpt_inst): x1 = kpt[0] - radius x2 = kpt[0] + radius y1 = kpt[1] - radius @@ -359,6 +375,9 @@ def draw_keypoints( if isinstance(colors, str) or isinstance(colors, tuple): draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) + if isinstance(colors, list): + draw.ellipse([x1, y1, x2, y2], fill=colors[inst_id], outline=None, width=0) + if connectivity: for connection in connectivity: start_pt_x = kpt_inst[connection[0]][0] From 1f41550942d36f25d1c4fe3c2300e4363891c14a Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Thu, 28 Oct 2021 23:59:00 +0530 Subject: [PATCH 13/17] Fix bug and test more robustly --- test/assets/fakedata/draw_keypoint_vanilla.png | Bin 283 -> 300 bytes test/test_utils.py | 4 ++-- torchvision/utils.py | 9 ++++++--- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/test/assets/fakedata/draw_keypoint_vanilla.png b/test/assets/fakedata/draw_keypoint_vanilla.png index 8cd34f84539c9b5a054274f21c176b0e2bd9ee60..6cd6d943b6c02f9e439410a0a8f9a43081d216d5 100644 GIT binary patch delta 257 zcmV+c0sj7*0;~d%B!BZsL_t(|obA{_j)O1|MZs_Ge`jU^XozFb2s|6RYKN2!Pfi0} zlC0U4aA#Rh7whHENRp+x*6T(cD|%g98gpu0^=evD&GjvmTrK<Q|di9BwIeK-T z$7ApByPd4f*!%m=u`=}jzJIJ_*!}%~-Oq$O36b#=7x?f49?L>1_u6Z-00000NkvXX Hu0mjfxnzT8 delta 255 zcmZ3(G@EIHO8p~G7srr_Id5;A=4&z#aJZ=X|Nr^Y#uKYIHAqc9XQ{1{wzqYS~&e`9oFi+mAdi%Le$JG4! z#l~UhYR_GMS(9P%&FcBhGgen6f7jUdiq5Pm?GVkZ`r0XKSyejkc=G%AyXUUCdE@u* wa~qRh{H^^TA$a)R`}(idLhn= Date: Tue, 2 Nov 2021 20:02:55 +0530 Subject: [PATCH 14/17] Add a comment, merge stuff --- torchvision/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/utils.py b/torchvision/utils.py index 9c76232924b..826b694f6bd 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -365,6 +365,7 @@ def draw_keypoints( for kpt_id, kpt_inst in enumerate(img_kpts): if keypoints_id_color: + # Get the color from nested list. colors_draw = colors[kpt_id] for inst_id, kpt in enumerate(kpt_inst): x1 = kpt[0] - radius From 0002677a8996c22338f9af8526c47d9247ef29cf Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Mon, 8 Nov 2021 18:23:45 +0530 Subject: [PATCH 15/17] Fix fmt --- test/test_utils.py | 2 +- torchvision/utils.py | 34 +++++++++++++++++----------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 4518e18828f..1e79c6934c0 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -271,7 +271,7 @@ def test_draw_keypoints_vanilla(): @pytest.mark.parametrize( - "colors", ["red", "#FF00FF", (1, 34, 122), ["red", "blue"], [["red", "blue", "pink"], ["orange", "green", "red"]]] + "colors", ["red", "#FF00FF", (1, 34, 122), [["red", "blue", "pink"], ["orange", "green", "red"]]] ) def test_draw_keypoints_colored(colors): # Keypoints is declared on top as global variable diff --git a/torchvision/utils.py b/torchvision/utils.py index 826b694f6bd..7cc6e97384a 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -307,7 +307,6 @@ def draw_keypoints( connectivity: Optional[Tuple[Tuple[int, int]]] = None, colors: Optional[ Union[ - List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int], List[List[Union[str, Tuple[int, int, int]]]], @@ -327,8 +326,7 @@ def draw_keypoints( in the format [x, y]. connectivity (Tuple[Tuple[int, int]]]): A Tuple of tuple where, each tuple contains pair of keypoints to be connected. - colors (color or list of colors or list of list containing colors, optional): List containing the colors - of each instance of keypoints or a nested list containing color for every keypoint id + colors (color list of list containing colors, optional): A nested list containing color for every keypoint id or single color for all keypoints. The color can be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. radius (int): Integer denoting radius of keypoint. @@ -355,32 +353,34 @@ def draw_keypoints( draw = ImageDraw.Draw(img_to_draw) out_dtype = torch.uint8 + num_instances = keypoints.shape[0] + num_keypoints = keypoints.shape[1] + img_kpts = keypoints.to(torch.int64).tolist() - # Specifying color for every keypoint id - if isinstance(colors[0], list): - keypoints_id_color = True + # need to use np.array to handle strings as well + colors = np.array(colors) + + # if colors is specified as a string, the number of elements n = 1 + if issubclass(colors.dtype.type, np.integer): + shape = (num_instances, num_keypoints, 3) else: - keypoints_id_color = False + shape = (num_instances, num_keypoints) + + colors = np.broadcast_to(colors, shape) for kpt_id, kpt_inst in enumerate(img_kpts): - if keypoints_id_color: - # Get the color from nested list. - colors_draw = colors[kpt_id] for inst_id, kpt in enumerate(kpt_inst): + color = colors[kpt_id, inst_id] x1 = kpt[0] - radius x2 = kpt[0] + radius y1 = kpt[1] - radius y2 = kpt[1] + radius - if isinstance(colors, str) or isinstance(colors, tuple): - draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) + if isinstance(color, np.ndarray): + color = tuple(color) - if isinstance(colors, list): - if keypoints_id_color: - draw.ellipse([x1, y1, x2, y2], fill=colors_draw[inst_id], outline=None, width=0) - else: - draw.ellipse([x1, y1, x2, y2], fill=colors[kpt_id], outline=None, width=0) + draw.ellipse([x1, y1, x2, y2], fill=color, outline=None, width=0) if connectivity: for connection in connectivity: From 949e42cd45321327f155aa6bb8568df85da8affb Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Mon, 8 Nov 2021 19:09:59 +0530 Subject: [PATCH 16/17] Support single str for merging --- test/test_utils.py | 4 +--- torchvision/utils.py | 33 +++------------------------------ 2 files changed, 4 insertions(+), 33 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 1e79c6934c0..64f45c697c6 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -270,9 +270,7 @@ def test_draw_keypoints_vanilla(): assert_equal(img, img_cp) -@pytest.mark.parametrize( - "colors", ["red", "#FF00FF", (1, 34, 122), [["red", "blue", "pink"], ["orange", "green", "red"]]] -) +@pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)]) def test_draw_keypoints_colored(colors): # Keypoints is declared on top as global variable keypoints_cp = keypoints.clone() diff --git a/torchvision/utils.py b/torchvision/utils.py index 7cc6e97384a..fceda99a813 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -305,13 +305,7 @@ def draw_keypoints( image: torch.Tensor, keypoints: torch.Tensor, connectivity: Optional[Tuple[Tuple[int, int]]] = None, - colors: Optional[ - Union[ - str, - Tuple[int, int, int], - List[List[Union[str, Tuple[int, int, int]]]], - ] - ] = None, + colors: Optional[Union[str, Tuple[int, int, int]]] = None, radius: int = 2, width: int = 3, ) -> torch.Tensor: @@ -326,8 +320,7 @@ def draw_keypoints( in the format [x, y]. connectivity (Tuple[Tuple[int, int]]]): A Tuple of tuple where, each tuple contains pair of keypoints to be connected. - colors (color list of list containing colors, optional): A nested list containing color for every keypoint id - or single color for all keypoints. The color can be represented as + colors (str, Tuple): The color can be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. radius (int): Integer denoting radius of keypoint. width (int): Integer denoting width of line connecting keypoints. @@ -352,35 +345,15 @@ def draw_keypoints( img_to_draw = Image.fromarray(ndarr) draw = ImageDraw.Draw(img_to_draw) out_dtype = torch.uint8 - - num_instances = keypoints.shape[0] - num_keypoints = keypoints.shape[1] - img_kpts = keypoints.to(torch.int64).tolist() - # need to use np.array to handle strings as well - colors = np.array(colors) - - # if colors is specified as a string, the number of elements n = 1 - if issubclass(colors.dtype.type, np.integer): - shape = (num_instances, num_keypoints, 3) - else: - shape = (num_instances, num_keypoints) - - colors = np.broadcast_to(colors, shape) - for kpt_id, kpt_inst in enumerate(img_kpts): for inst_id, kpt in enumerate(kpt_inst): - color = colors[kpt_id, inst_id] x1 = kpt[0] - radius x2 = kpt[0] + radius y1 = kpt[1] - radius y2 = kpt[1] + radius - - if isinstance(color, np.ndarray): - color = tuple(color) - - draw.ellipse([x1, y1, x2, y2], fill=color, outline=None, width=0) + draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) if connectivity: for connection in connectivity: From 00db80b7a89698c43aa987807765b4a3b59f9c23 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 9 Nov 2021 11:48:14 +0000 Subject: [PATCH 17/17] Remove unnecessary vars. --- torchvision/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index fceda99a813..6c29767a7ce 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -344,7 +344,6 @@ def draw_keypoints( ndarr = image.permute(1, 2, 0).numpy() img_to_draw = Image.fromarray(ndarr) draw = ImageDraw.Draw(img_to_draw) - out_dtype = torch.uint8 img_kpts = keypoints.to(torch.int64).tolist() for kpt_id, kpt_inst in enumerate(img_kpts): @@ -368,7 +367,7 @@ def draw_keypoints( width=width, ) - return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=out_dtype) + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) def _generate_color_palette(num_masks: int):