Skip to content

Add utility to draw keypoints #4216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Nov 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
b2f6615
fix
oke-aditya May 20, 2021
4fb038d
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya May 20, 2021
deda5d7
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya May 21, 2021
5490821
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya May 21, 2021
4cfc220
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya May 21, 2021
6306746
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya Jul 24, 2021
e8c93cf
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya Jul 28, 2021
73abf84
Outline Keypoints API
oke-aditya Jul 28, 2021
2ebe0c7
Add utility
oke-aditya Jul 28, 2021
77afa81
make it work :)
oke-aditya Jul 28, 2021
cdeebdd
Fix optional type
oke-aditya Jul 29, 2021
fad0d44
Merge branch 'master' of https://github.com/pytorch/vision into add_kypt
oke-aditya Jul 29, 2021
76af22e
Merge branch 'main' of https://github.com/pytorch/vision into add_kypt
oke-aditya Sep 12, 2021
1291f44
Add connectivity, fmassa's advice :smiley:
oke-aditya Sep 12, 2021
b9af874
Minor code improvement
oke-aditya Sep 12, 2021
28ebcc0
small fix
oke-aditya Sep 13, 2021
8db61f7
Merge branch 'main' of https://github.com/pytorch/vision into add_kypt
oke-aditya Oct 2, 2021
ebe7a25
fix implementation
oke-aditya Oct 2, 2021
e6afa37
Add tests
oke-aditya Oct 2, 2021
c46d9db
Merge branch 'main' of https://github.com/pytorch/vision into add_kypt
oke-aditya Oct 5, 2021
691562c
Fix tests
oke-aditya Oct 5, 2021
4a65900
Merge branch 'main' of https://github.com/pytorch/vision into add_kypt
oke-aditya Oct 5, 2021
d5747d3
Merge branch 'main' of https://github.com/pytorch/vision into add_kypt
oke-aditya Oct 28, 2021
a034137
Update colors
oke-aditya Oct 28, 2021
1f41550
Fix bug and test more robustly
oke-aditya Oct 28, 2021
d9d96cb
Merge branch 'main' of https://github.com/pytorch/vision into add_kypt
oke-aditya Nov 2, 2021
e6e7428
Add a comment, merge stuff
oke-aditya Nov 2, 2021
c8da898
Merge branch 'main' of https://github.com/pytorch/vision into add_kypt
oke-aditya Nov 2, 2021
4385643
Merge branch 'main' of https://github.com/pytorch/vision into add_kypt
oke-aditya Nov 2, 2021
8997b58
Merge branch 'main' of https://github.com/pytorch/vision into add_kypt
oke-aditya Nov 8, 2021
0002677
Fix fmt
oke-aditya Nov 8, 2021
949e42c
Support single str for merging
oke-aditya Nov 8, 2021
693ffdc
Merge branch 'main' of https://github.com/pytorch/vision into add_kypt
oke-aditya Nov 8, 2021
00db80b
Remove unnecessary vars.
datumbox Nov 9, 2021
41ecc44
Merge branch 'main' into add_kypt
datumbox Nov 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ vizualization <sphx_glr_auto_examples_plot_visualization_utils.py>`.

draw_bounding_boxes
draw_segmentation_masks
draw_keypoints
make_grid
save_image
Binary file added test/assets/fakedata/draw_keypoint_vanilla.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
55 changes: 55 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)

keypoints = torch.tensor([[[10, 10], [5, 5], [2, 2]], [[20, 20], [30, 30], [3, 3]]], dtype=torch.float)


def test_make_grid_not_inplace():
t = torch.rand(5, 3, 10, 10)
Expand Down Expand Up @@ -248,5 +250,58 @@ def test_draw_segmentation_masks_errors():
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)


def test_draw_keypoints_vanilla():
# Keypoints is declared on top as global variable
keypoints_cp = keypoints.clone()

img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
result = utils.draw_keypoints(img, keypoints, colors="red", connectivity=((0, 1),))
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)

expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
assert_equal(result, expected)
# Check that keypoints are not modified inplace
assert_equal(keypoints, keypoints_cp)
# Check that image is not modified in place
assert_equal(img, img_cp)


@pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)])
def test_draw_keypoints_colored(colors):
# Keypoints is declared on top as global variable
keypoints_cp = keypoints.clone()

img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
result = utils.draw_keypoints(img, keypoints, colors=colors, connectivity=((0, 1),))
assert result.size(0) == 3
assert_equal(keypoints, keypoints_cp)
assert_equal(img, img_cp)


def test_draw_keypoints_errors():
h, w = 10, 10
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)

with pytest.raises(TypeError, match="The image must be a tensor"):
utils.draw_keypoints(image="Not A Tensor Image", keypoints=keypoints)
with pytest.raises(ValueError, match="The image dtype must be"):
img_bad_dtype = torch.full((3, h, w), 0, dtype=torch.int64)
utils.draw_keypoints(image=img_bad_dtype, keypoints=keypoints)
with pytest.raises(ValueError, match="Pass individual images, not batches"):
batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8)
utils.draw_keypoints(image=batch, keypoints=keypoints)
with pytest.raises(ValueError, match="Pass an RGB image"):
one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8)
utils.draw_keypoints(image=one_channel, keypoints=keypoints)
with pytest.raises(ValueError, match="keypoints must be of shape"):
invalid_keypoints = torch.tensor([[10, 10, 10, 10], [5, 6, 7, 8]], dtype=torch.float)
utils.draw_keypoints(image=img, keypoints=invalid_keypoints)


if __name__ == "__main__":
pytest.main([__file__])
72 changes: 71 additions & 1 deletion torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from PIL import Image, ImageDraw, ImageFont, ImageColor

__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"]
__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks", "draw_keypoints"]


@torch.no_grad()
Expand Down Expand Up @@ -300,6 +300,76 @@ def draw_segmentation_masks(
return out.to(out_dtype)


@torch.no_grad()
def draw_keypoints(
image: torch.Tensor,
keypoints: torch.Tensor,
connectivity: Optional[Tuple[Tuple[int, int]]] = None,
colors: Optional[Union[str, Tuple[int, int, int]]] = None,
radius: int = 2,
width: int = 3,
) -> torch.Tensor:

"""
Draws Keypoints on given RGB image.
The values of the input image should be uint8 between 0 and 255.

Args:
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
in the format [x, y].
connectivity (Tuple[Tuple[int, int]]]): A Tuple of tuple where,
each tuple contains pair of keypoints to be connected.
colors (str, Tuple): The color can be represented as
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
radius (int): Integer denoting radius of keypoint.
width (int): Integer denoting width of line connecting keypoints.

Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
"""

if not isinstance(image, torch.Tensor):
raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size()[0] != 3:
raise ValueError("Pass an RGB image. Other Image formats are not supported")

if keypoints.ndim != 3:
raise ValueError("keypoints must be of shape (num_instances, K, 2)")

ndarr = image.permute(1, 2, 0).numpy()
img_to_draw = Image.fromarray(ndarr)
draw = ImageDraw.Draw(img_to_draw)
img_kpts = keypoints.to(torch.int64).tolist()

for kpt_id, kpt_inst in enumerate(img_kpts):
for inst_id, kpt in enumerate(kpt_inst):
x1 = kpt[0] - radius
x2 = kpt[0] + radius
y1 = kpt[1] - radius
y2 = kpt[1] + radius
draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0)

if connectivity:
for connection in connectivity:
start_pt_x = kpt_inst[connection[0]][0]
start_pt_y = kpt_inst[connection[0]][1]

end_pt_x = kpt_inst[connection[1]][0]
end_pt_y = kpt_inst[connection[1]][1]

draw.line(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What should be the line color? Should there be a parameter for line color? Or should it be same as color of keypoints?
Currently it is white which looks bad.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the interest of getting this merged soon, let's leave it white for now, and in the future create an issue to enable this to be configurable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we have sufficient time before the next release. Let's discuss about this as well.

((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)),
width=width,
)

return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)


def _generate_color_palette(num_masks: int):
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
return [tuple((i * palette) % 255) for i in range(num_masks)]
Expand Down