Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the KeyPoints TVTensor #8817

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions docs/source/tv_tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ info.

Image
Video
KeyPoints
BoundingBoxFormat
BoundingBoxes
Mask
Expand Down
11 changes: 10 additions & 1 deletion gallery/transforms/plot_tv_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# for the input data.
#
# :mod:`torchvision.tv_tensors` supports four types of TVTensors:
# :mod:`torchvision.tv_tensors` supports five types of TVTensors:
#
# * :class:`~torchvision.tv_tensors.Image`
# * :class:`~torchvision.tv_tensors.Video`
# * :class:`~torchvision.tv_tensors.BoundingBoxes`
# * :class:`~torchvision.tv_tensors.KeyPoints`
# * :class:`~torchvision.tv_tensors.Mask`
#
# What can I do with a TVTensor?
Expand Down Expand Up @@ -96,6 +97,7 @@
# :class:`~torchvision.tv_tensors.BoundingBoxes` requires the coordinate format as well as the size of the
# corresponding image (``canvas_size``) alongside the actual values. These
# metadata are required to properly transform the bounding boxes.
# In a similar fashion, :class:`~torchvision.tv_tensors.KeyPoints` also require the ``canvas_size`` metadata to be added.

bboxes = tv_tensors.BoundingBoxes(
[[17, 16, 344, 495], [0, 10, 0, 10]],
Expand All @@ -104,6 +106,13 @@
)
print(bboxes)


keypoints = tv_tensors.KeyPoints(
[[17, 16], [344, 495], [0, 10], [0, 10]],
canvas_size=image.shape[-2:]
)
print(keypoints)

# %%
# Using ``tv_tensors.wrap()``
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
15 changes: 15 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import shutil
import sys
import tempfile
from typing import Sequence, Tuple
import warnings
from subprocess import CalledProcessError, check_output, STDOUT

Expand Down Expand Up @@ -402,6 +403,20 @@ def make_image_pil(*args, **kwargs):
return to_pil_image(make_image(*args, **kwargs))


def make_keypoints(
canvas_size: Tuple[int, int] = DEFAULT_SIZE, *, num_points: int | Sequence[int] = 4, dtype=None, device='cpu'
) -> tv_tensors.KeyPoints:
"""Make the KeyPoints for testing purposes"""
if isinstance(num_points, int):
num_points = [num_points]
half_point: Tuple[int, ...] = tuple(num_points) + (1,)
y = torch.randint(0, canvas_size[0] - 1, half_point, dtype=dtype, device=device)
x = torch.randint(0, canvas_size[1] - 1, half_point, dtype=dtype, device=device)
points = torch.cat((x, y), dim=-1)
keypoints = tv_tensors.KeyPoints(points, canvas_size=canvas_size)
return keypoints


def make_bounding_boxes(
canvas_size=DEFAULT_SIZE,
*,
Expand Down
32 changes: 28 additions & 4 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
make_image,
make_image_pil,
make_image_tensor,
make_keypoints,
make_segmentation_mask,
make_video,
make_video_tensor,
Expand Down Expand Up @@ -223,6 +224,7 @@ def check_functional_kernel_signature_match(functional, *, kernel, input_type):
# explicitly passed to the kernel.
explicit_metadata = {
tv_tensors.BoundingBoxes: {"format", "canvas_size"},
tv_tensors.KeyPoints: {"canvas_size"}
}
kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]

Expand Down Expand Up @@ -327,6 +329,18 @@ def _make_transform_sample(transform, *, image_or_video, adapter):
canvas_size=size,
device=device,
),
keypoints=make_keypoints(canvas_size=size), keypoints_degenerate=tv_tensors.KeyPoints(
[
[0, 1], # left edge
[1, 0], # top edge
[0, 0], # top left corner
[size[1], 1], # right edge
[size[1], 0], # top right corner
[1, size[0]], # bottom edge
[0, size[0]], # bottom left corner
[size[1], size[0]] # bottom right corner
], canvas_size=size, device=device
),
detection_mask=make_detection_masks(size, device=device),
segmentation_mask=make_segmentation_mask(size, device=device),
int=0,
Expand Down Expand Up @@ -680,6 +694,7 @@ def test_functional(self, size, make_input):
(F.resize_image, torch.Tensor),
(F._geometry._resize_image_pil, PIL.Image.Image),
(F.resize_image, tv_tensors.Image),
(F.resize_keypoints, tv_tensors.KeyPoints),
(F.resize_bounding_boxes, tv_tensors.BoundingBoxes),
(F.resize_mask, tv_tensors.Mask),
(F.resize_video, tv_tensors.Video),
Expand Down Expand Up @@ -1035,6 +1050,7 @@ def test_functional(self, make_input):
(F.horizontal_flip_image, torch.Tensor),
(F._geometry._horizontal_flip_image_pil, PIL.Image.Image),
(F.horizontal_flip_image, tv_tensors.Image),
(F.horizontal_flip_keypoints, tv_tensors.KeyPoints),
(F.horizontal_flip_bounding_boxes, tv_tensors.BoundingBoxes),
(F.horizontal_flip_mask, tv_tensors.Mask),
(F.horizontal_flip_video, tv_tensors.Video),
Expand Down Expand Up @@ -1203,6 +1219,7 @@ def test_functional(self, make_input):
(F.affine_image, torch.Tensor),
(F._geometry._affine_image_pil, PIL.Image.Image),
(F.affine_image, tv_tensors.Image),
(F.affine_keypoints, tv_tensors.KeyPoints),
(F.affine_bounding_boxes, tv_tensors.BoundingBoxes),
(F.affine_mask, tv_tensors.Mask),
(F.affine_video, tv_tensors.Video),
Expand Down Expand Up @@ -1485,6 +1502,7 @@ def test_functional(self, make_input):
(F.vertical_flip_image, torch.Tensor),
(F._geometry._vertical_flip_image_pil, PIL.Image.Image),
(F.vertical_flip_image, tv_tensors.Image),
(F.vertical_flip_keypoints, tv_tensors.KeyPoints),
(F.vertical_flip_bounding_boxes, tv_tensors.BoundingBoxes),
(F.vertical_flip_mask, tv_tensors.Mask),
(F.vertical_flip_video, tv_tensors.Video),
Expand Down Expand Up @@ -1627,6 +1645,7 @@ def test_functional(self, make_input):
(F.rotate_image, torch.Tensor),
(F._geometry._rotate_image_pil, PIL.Image.Image),
(F.rotate_image, tv_tensors.Image),
(F.rotate_keypoints, tv_tensors.KeyPoints),
(F.rotate_bounding_boxes, tv_tensors.BoundingBoxes),
(F.rotate_mask, tv_tensors.Mask),
(F.rotate_video, tv_tensors.Video),
Expand Down Expand Up @@ -2332,7 +2351,9 @@ def test_error(self, T):
F.to_pil_image(imgs[0]),
tv_tensors.Mask(torch.rand(12, 12)),
tv_tensors.BoundingBoxes(torch.rand(2, 4), format="XYXY", canvas_size=12),
tv_tensors.KeyPoints(torch.rand(2, 2), canvas_size=(12, 12))
):
print(type(input_with_bad_type), cutmix_mixup)
with pytest.raises(ValueError, match="does not support PIL images, "):
cutmix_mixup(input_with_bad_type)

Expand Down Expand Up @@ -2740,8 +2761,9 @@ def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.elastic, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
"make_input", [
make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints
],
)
def test_displacement_error(self, make_input):
input = make_input()
Expand All @@ -2753,8 +2775,10 @@ def test_displacement_error(self, make_input):
F.elastic(input, displacement=torch.rand(F.get_size(input)))

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
"make_input", [
make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video,
make_keypoints
],
)
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
@pytest.mark.parametrize("size", [(163, 163), (72, 333), (313, 95)])
Expand Down
54 changes: 30 additions & 24 deletions test/test_transforms_v2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

import torchvision.transforms.v2._utils
from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_masks, make_image
from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_masks, make_image, make_keypoints

from torchvision import tv_tensors
from torchvision.transforms.v2._utils import has_all, has_any
Expand All @@ -14,29 +14,32 @@
IMAGE = make_image(DEFAULT_SIZE, color_space="RGB")
BOUNDING_BOX = make_bounding_boxes(DEFAULT_SIZE, format=tv_tensors.BoundingBoxFormat.XYXY)
MASK = make_detection_masks(DEFAULT_SIZE)
KEYPOINTS = make_keypoints(DEFAULT_SIZE)


@pytest.mark.parametrize(
("sample", "types", "expected"),
[
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
((MASK,), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
((BOUNDING_BOX,), (tv_tensors.Image, tv_tensors.Mask), False),
((IMAGE,), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image,), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes,), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.KeyPoints,), True),
((MASK,), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints), False),
((BOUNDING_BOX,), (tv_tensors.Image, tv_tensors.Mask, tv_tensors.KeyPoints), False),
((IMAGE,), (tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), False),
((KEYPOINTS,), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
(
(IMAGE, BOUNDING_BOX, MASK),
(tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
(IMAGE, BOUNDING_BOX, MASK, KEYPOINTS),
(tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints),
True,
),
((), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, tv_tensors.Image),), True),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
((), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), False),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (lambda obj: isinstance(obj, tv_tensors.Image),), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (lambda _: True,), True),
((IMAGE,), (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True),
(
(torch.Tensor(IMAGE),),
Expand All @@ -57,15 +60,18 @@ def test_has_any(sample, types, expected):
@pytest.mark.parametrize(
("sample", "types", "expected"),
[
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image,), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes,), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Mask, tv_tensors.KeyPoints), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints), True),
((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), True),
(
(IMAGE, BOUNDING_BOX, MASK),
(tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
(IMAGE, BOUNDING_BOX, MASK, KEYPOINTS),
(tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints),
True,
),
((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
Expand Down
Loading