Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add KeyPoint feature for prototype datasets #5326

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,16 @@ def make_rle_segmentation():
category_id=int(make_scalar(dtype=torch.int64)),
)

@staticmethod
def _make_person_keypoints_data(ann_id, image_meta):
data = CocoMockData._make_instances_data(ann_id, image_meta)
keypoints = [
(*make_tensor((2,), dtype=torch.int, low=0).tolist(), visibility) if visibility else (0, 0, 0)
for visibility in torch.randint(3, (17,)).tolist()
]
data["keypoints"] = list(itertools.chain.from_iterable(keypoints))
return data

@staticmethod
def _make_captions_data(ann_id, image_meta):
return dict(caption=f"Caption {ann_id} describing image {image_meta['id']}.")
Expand All @@ -553,6 +563,7 @@ def _make_annotations(cls, root, name, *, images_meta):
num_anns_per_image = torch.zeros((len(images_meta),), dtype=torch.int64)
for annotations, fn in (
("instances", cls._make_instances_data),
("person_keypoints", cls._make_person_keypoints_data),
("captions", cls._make_captions_data),
):
num_anns_per_image += cls._make_annotations_json(
Expand Down Expand Up @@ -840,7 +851,18 @@ def _make_bounding_boxes_file(cls, root, image_file_names):

@classmethod
def _make_landmarks_file(cls, root, image_file_names):
field_names = ("lefteye_x", "lefteye_y", "rightmouth_x", "rightmouth_y")
field_names = (
"lefteye_x",
"lefteye_y",
"righteye_x",
"righteye_y",
"nose_x",
"nose_y",
"leftmouth_x",
"leftmouth_y",
"rightmouth_x",
"rightmouth_y",
)
data = [
[
name,
Expand Down
43 changes: 30 additions & 13 deletions torchvision/prototype/datasets/_builtin/celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import EncodedImage, _Feature, Label, BoundingBox
from torchvision.prototype.features import EncodedImage, Label, BoundingBox, Keypoint


csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
Expand Down Expand Up @@ -64,6 +64,7 @@ class CelebA(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"celeba",
categories=[str(label + 1) for label in range(10177)],
homepage="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html",
valid_options=dict(split=("train", "val", "test")),
)
Expand Down Expand Up @@ -125,24 +126,40 @@ def _prepare_sample(
split_and_image_data, ann_data = data
_, (_, image_data) = split_and_image_data
path, buffer = image_data
(_, identity), (_, attributes), (_, bounding_box), (_, landmarks) = ann_data

image = EncodedImage.from_file(buffer)
(_, identity), (_, attributes), (_, bounding_box), (_, landmarks) = ann_data

identity = Label(int(identity["identity"]))

attributes = {attr: value == "1" for attr, value in attributes.items()}

bounding_box = BoundingBox(
[int(bounding_box[key]) for key in ("x_1", "y_1", "width", "height")],
format="xywh",
image_size=image.image_size,
)

descriptions = list({key[:-2] for key in landmarks.keys()})
symmetries = [
("vertical", description, description.replace("left", "right"))
for description in descriptions
if description.startswith("left")
]
keypoints = Keypoint(
[(int(landmarks[f"{name}_x"]), int(landmarks[f"{name}_y"])) for name in descriptions],
image_size=image.image_size,
descriptions=descriptions,
symmetries=symmetries,
)

return dict(
path=path,
image=image,
identity=Label(int(identity["identity"])),
attributes={attr: value == "1" for attr, value in attributes.items()},
bounding_box=BoundingBox(
[int(bounding_box[key]) for key in ("x_1", "y_1", "width", "height")],
format="xywh",
image_size=image.image_size,
),
landmarks={
landmark: _Feature((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"])))
for landmark in {key[:-2] for key in landmarks.keys()}
},
identity=identity,
attributes=attributes,
bounding_box=bounding_box,
keypoints=keypoints,
)

def _make_datapipe(
Expand Down
42 changes: 41 additions & 1 deletion torchvision/prototype/datasets/_builtin/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import BoundingBox, Label, _Feature, EncodedImage
from torchvision.prototype.features import BoundingBox, Label, _Feature, EncodedImage, Keypoint
from torchvision.prototype.utils._internal import FrozenMapping


Expand Down Expand Up @@ -117,6 +117,45 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st
ann_ids=[ann["id"] for ann in anns],
)

def _decode_person_keypoints_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]:
image_size = (image_meta["height"], image_meta["width"])

keypoints_meta = torch.tensor([ann["keypoints"] for ann in anns]).view(len(anns), -1, 3)
coordinates = keypoints_meta[..., :2]
visibility = _Feature(keypoints_meta[..., 2])

# As of COCO 2017, only the person category has keypoints.
descriptions = (
"nose",
"left_eye",
"right_eye",
"left_ear",
"right_ear",
"left_shoulder",
"right_shoulder",
"left_elbow",
"right_elbow",
"left_wrist",
"right_wrist",
"left_hip",
"right_hip",
"left_knee",
"right_knee",
"left_ankle",
"right_ankle",
)
symmetries = [
("vertical", description, description.replace("left", "right"))
for description in descriptions
if description.startswith("left")
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify the type of the resulting symmetries variable?

Below they Keypoint expects Sequence[Tuple[KeypointSymmetry, Tuple[int, int]]]. It's non obvious to me how the current code has the specific type.


return dict(
self._decode_instances_anns(anns, image_meta),
keypoints=Keypoint(coordinates, image_size=image_size, descriptions=descriptions, symmetries=symmetries),
visibility=visibility,
)

def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]:
return dict(
captions=[ann["caption"] for ann in anns],
Expand All @@ -126,6 +165,7 @@ def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str,
_ANN_DECODERS = OrderedDict(
[
("instances", _decode_instances_anns),
("person_keypoints", _decode_person_keypoints_anns),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated to this PR, but I would maybe revisit if we want to have dataset to have a varying return type depending on arguments passed to it.
It might be better to have different flavors of the dataset to be different dataset classes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that is a design choice that we should address. cc @NicolasHug

("captions", _decode_captions_ann),
]
)
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from ._encoded import EncodedData, EncodedImage, EncodedVideo
from ._feature import _Feature
from ._image import ColorSpace, Image
from ._keypoint import Keypoint, KeypointSymmetry
from ._label import Label, OneHotLabel
from ._segmentation_mask import SegmentationMask
69 changes: 69 additions & 0 deletions torchvision/prototype/features/_keypoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations

from typing import Any, Tuple, Optional, Sequence, Union, Collection

import torch
from torchvision.prototype.utils._internal import StrEnum

from ._feature import _Feature


class KeypointSymmetry(StrEnum):
VERTICAL = "vertical"
HORIZONTAL = "horizontal"


class Keypoint(_Feature):
image_size: Tuple[int, int]
descriptions: Sequence[Sequence[str]]
symmetries: Sequence[Tuple[KeypointSymmetry, int, int]]

def __new__(
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
image_size: Tuple[int, int],
descriptions: Optional[Sequence[str]] = None,
symmetries: Collection[
Tuple[Union[str, KeypointSymmetry], Union[str, int], Union[str, int]],
] = (),
) -> Keypoint:
keypoint = super().__new__(cls, data, dtype=dtype, device=device)

parsed_symmetries = []
for symmetry, first, second in symmetries:
if isinstance(symmetry, str):
symmetry = KeypointSymmetry[symmetry]

if isinstance(first, str):
if not descriptions:
raise ValueError

first = descriptions.index(first)

if isinstance(second, str):
if not descriptions:
raise ValueError

second = descriptions.index(second)

parsed_symmetries.append((symmetry, first, second))

keypoint._metadata.update(dict(image_size=image_size, descriptions=descriptions, symmetries=parsed_symmetries))

return keypoint

@classmethod
def _to_tensor(cls, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor:
tensor = torch.as_tensor(data, dtype=dtype, device=device)
if tensor.shape[-1] != 2:
raise ValueError
if tensor.ndim == 1:
tensor = tensor.view(1, -1)
return tensor

@property
def num_keypoints(self) -> int:
return self.shape[-2]
12 changes: 12 additions & 0 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,25 @@
PIL.Image.Image: _F.hflip,
features.Image: K.horizontal_flip_image,
features.BoundingBox: None,
features.Keypoint: None,
},
)
def horizontal_flip(input: T, *args: Any, **kwargs: Any) -> T:
"""TODO: add docstring"""
if isinstance(input, features.BoundingBox):
output = K.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size)
return cast(T, features.BoundingBox.new_like(input, output))
elif isinstance(input, features.Keypoint):
coordinates = K.horizontal_flip_keypoint(input, image_size=input.image_size)

idcs = list(range(input.num_keypoints))
for symmetry, first, second in input.symmetries:
if symmetry != "vertical":
continue

idcs[first], idcs[second] = second, first

return cast(T, features.Keypoint.new_like(input, coordinates[..., idcs, :]))

raise RuntimeError

Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ._geometry import (
horizontal_flip_bounding_box,
horizontal_flip_image,
horizontal_flip_keypoint,
resize_bounding_box,
resize_image,
resize_segmentation_mask,
Expand Down
6 changes: 6 additions & 0 deletions torchvision/prototype/transforms/kernels/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ def horizontal_flip_bounding_box(
).view(shape)


def horizontal_flip_keypoint(keypoint: torch.Tensor, *, image_size: Tuple[int, int]) -> torch.Tensor:
keypoint = keypoint.clone()
keypoint[..., 0] = image_size[1] - keypoint[..., 0]
return keypoint


def resize_image(
image: torch.Tensor,
size: List[int],
Expand Down