-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add KeyPoint feature for prototype datasets #5326
base: main
Are you sure you want to change the base?
Changes from all commits
20d2c3a
61d1d73
aa1b920
db99617
ea7aafa
8cd0136
8d20406
0abf12f
4bab7c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,7 @@ | |
hint_sharding, | ||
hint_shuffling, | ||
) | ||
from torchvision.prototype.features import BoundingBox, Label, _Feature, EncodedImage | ||
from torchvision.prototype.features import BoundingBox, Label, _Feature, EncodedImage, Keypoint | ||
from torchvision.prototype.utils._internal import FrozenMapping | ||
|
||
|
||
|
@@ -117,6 +117,45 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st | |
ann_ids=[ann["id"] for ann in anns], | ||
) | ||
|
||
def _decode_person_keypoints_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]: | ||
image_size = (image_meta["height"], image_meta["width"]) | ||
|
||
keypoints_meta = torch.tensor([ann["keypoints"] for ann in anns]).view(len(anns), -1, 3) | ||
coordinates = keypoints_meta[..., :2] | ||
visibility = _Feature(keypoints_meta[..., 2]) | ||
|
||
# As of COCO 2017, only the person category has keypoints. | ||
descriptions = ( | ||
"nose", | ||
"left_eye", | ||
"right_eye", | ||
"left_ear", | ||
"right_ear", | ||
"left_shoulder", | ||
"right_shoulder", | ||
"left_elbow", | ||
"right_elbow", | ||
"left_wrist", | ||
"right_wrist", | ||
"left_hip", | ||
"right_hip", | ||
"left_knee", | ||
"right_knee", | ||
"left_ankle", | ||
"right_ankle", | ||
) | ||
symmetries = [ | ||
("vertical", description, description.replace("left", "right")) | ||
for description in descriptions | ||
if description.startswith("left") | ||
] | ||
|
||
return dict( | ||
self._decode_instances_anns(anns, image_meta), | ||
keypoints=Keypoint(coordinates, image_size=image_size, descriptions=descriptions, symmetries=symmetries), | ||
visibility=visibility, | ||
) | ||
|
||
def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]: | ||
return dict( | ||
captions=[ann["caption"] for ann in anns], | ||
|
@@ -126,6 +165,7 @@ def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str, | |
_ANN_DECODERS = OrderedDict( | ||
[ | ||
("instances", _decode_instances_anns), | ||
("person_keypoints", _decode_person_keypoints_anns), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is unrelated to this PR, but I would maybe revisit if we want to have dataset to have a varying return type depending on arguments passed to it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that is a design choice that we should address. cc @NicolasHug |
||
("captions", _decode_captions_ann), | ||
] | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any, Tuple, Optional, Sequence, Union, Collection | ||
|
||
import torch | ||
from torchvision.prototype.utils._internal import StrEnum | ||
|
||
from ._feature import _Feature | ||
|
||
|
||
class KeypointSymmetry(StrEnum): | ||
VERTICAL = "vertical" | ||
HORIZONTAL = "horizontal" | ||
|
||
|
||
class Keypoint(_Feature): | ||
image_size: Tuple[int, int] | ||
descriptions: Sequence[Sequence[str]] | ||
symmetries: Sequence[Tuple[KeypointSymmetry, int, int]] | ||
|
||
def __new__( | ||
cls, | ||
data: Any, | ||
*, | ||
dtype: Optional[torch.dtype] = None, | ||
device: Optional[torch.device] = None, | ||
image_size: Tuple[int, int], | ||
descriptions: Optional[Sequence[str]] = None, | ||
symmetries: Collection[ | ||
Tuple[Union[str, KeypointSymmetry], Union[str, int], Union[str, int]], | ||
] = (), | ||
) -> Keypoint: | ||
keypoint = super().__new__(cls, data, dtype=dtype, device=device) | ||
|
||
parsed_symmetries = [] | ||
for symmetry, first, second in symmetries: | ||
if isinstance(symmetry, str): | ||
symmetry = KeypointSymmetry[symmetry] | ||
|
||
if isinstance(first, str): | ||
if not descriptions: | ||
raise ValueError | ||
|
||
first = descriptions.index(first) | ||
|
||
if isinstance(second, str): | ||
if not descriptions: | ||
raise ValueError | ||
|
||
second = descriptions.index(second) | ||
|
||
parsed_symmetries.append((symmetry, first, second)) | ||
|
||
keypoint._metadata.update(dict(image_size=image_size, descriptions=descriptions, symmetries=parsed_symmetries)) | ||
|
||
return keypoint | ||
|
||
@classmethod | ||
def _to_tensor(cls, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor: | ||
tensor = torch.as_tensor(data, dtype=dtype, device=device) | ||
if tensor.shape[-1] != 2: | ||
raise ValueError | ||
if tensor.ndim == 1: | ||
tensor = tensor.view(1, -1) | ||
return tensor | ||
|
||
@property | ||
def num_keypoints(self) -> int: | ||
return self.shape[-2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you clarify the type of the resulting
symmetries
variable?Below they
Keypoint
expectsSequence[Tuple[KeypointSymmetry, Tuple[int, int]]]
. It's non obvious to me how the current code has the specific type.