Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add KeyPoint feature for prototype datasets #5326

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Feb 1, 2022

We currently have two different APIs for keypoints:

  • draw_keypoints requires the keypoints to be in shape (num_instances, num_keypoints_per_instance, 3) where the last channel contains the x and y coordinate.
  • KeyKeypointRCNN's such as keypointrcnn_resnet50_fpn require the keypoints to be in shape (num_instances, num_keypoints_per_instance, 3) where the last channel contains the x and y coordinate and the visibility.

We currently also have two datasets that provide keypoints: COCO and CelebA. Of those COCO provides a visibility flag while CelebA doesn't. Skimming through other datasets, it seems the visibility is not a regular part of the annotations. Thus, for the KeyPoint feature that this PR adds, I went for only the x and y coordinates.

Below you can find example implementations that call both APIs for COCO and CelebA:

COCO

import torch
from torchvision.models.detection import keypointrcnn_resnet50_fpn
from torchvision.prototype import datasets
from torchvision.utils import draw_keypoints
from torchvision.transforms.functional import to_pil_image

name = "coco"

dataset = datasets.load(name, annotations="person_keypoints")
sample = next(iter(dataset))

annotated_image = draw_keypoints(
    sample["image"],
    sample["keypoints"],
    radius=int(2e-2 * min(sample["image"].shape[-2:])),
    colors="red",
)
to_pil_image(annotated_image).save(f"{name}.jpg")


model = keypointrcnn_resnet50_fpn(
    num_classes=len(datasets.info(name).categories),
    num_keypoints=sample["keypoints"].shape[-2],
)

# this conversion will happen in a transform later
image = sample["image"].to(torch.float).div(255.0)

target = sample

# the model requires the bounding boxes as 3d tensor and in XYXY format at the 'boxes' key
target["boxes"] = target["bounding_boxes"].convert("xyxy")

keypoints_without_visibility = target["keypoints"]
keypoints_visibility = target["visibility"].unsqueeze(-1)
target["keypoints"] = torch.cat((keypoints_without_visibility, keypoints_visibility), dim=-1).to(torch.float)

# images and targets need to be list of images and annotation dictionaries
loss = model([image], [target])

coco

Example for CelebA

import torch
from torchvision.models.detection import keypointrcnn_resnet50_fpn
from torchvision.prototype import datasets
from torchvision.utils import draw_keypoints
from torchvision.transforms.functional import to_pil_image

name = "celeba"

dataset = datasets.load(name)
sample = next(iter(dataset))

annotated_image = draw_keypoints(
    sample["image"],
    sample["keypoints"].unsqueeze(0),
    radius=int(2e-2 * min(sample["image"].shape[-2:])),
    colors="red",
)
to_pil_image(annotated_image).save(f"{name}.jpg")


model = keypointrcnn_resnet50_fpn(
    num_classes=len(datasets.info(name).categories),
    num_keypoints=sample["keypoints"].shape[-2],
)

# this conversion will happen in a transform later
image = sample["image"].to(torch.float).div(255.0)

target = sample

# the model requires the label as 1d tensor at the 'labels' key
target["labels"] = target["identity"].unsqueeze(0)
# the model requires the bounding boxes as 3d tensor and in XYXY format at the 'boxes' key
target["boxes"] = target["bbox"].convert("xyxy").unsqueeze(0)

keypoints_without_visibility = target["keypoints"]
# visbility == 2 denotes visible in COCO annotations
keypoints_visibility = torch.full((*target["keypoints"].shape[:-1], 1), 2)
target["keypoints"] = torch.cat((keypoints_without_visibility, keypoints_visibility), dim=-1).to(torch.float)

# images and targets need to be list of images and annotation dictionaries
loss = model([image], [target])

celeba

@facebook-github-bot
Copy link

facebook-github-bot commented Feb 1, 2022

💊 CI failures summary and remediations

As of commit 4bab7c3 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@pmeier pmeier changed the title add keypoints annotations to COCO prototype dataset add KeyPoint feature for prototype datasets Feb 1, 2022
@pmeier pmeier requested a review from fmassa February 1, 2022 12:48
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I've made some comments, let me know what you think

torchvision/prototype/transforms/_geometry.py Outdated Show resolved Hide resolved
torchvision/prototype/transforms/_geometry.py Outdated Show resolved Hide resolved
torchvision/prototype/transforms/_geometry.py Outdated Show resolved Hide resolved
torchvision/prototype/features/_keypoint.py Outdated Show resolved Hide resolved
@@ -129,6 +159,7 @@ def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str,
_ANN_DECODERS = OrderedDict(
[
("instances", _decode_instances_anns),
("person_keypoints", _decode_person_keypoints_anns),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated to this PR, but I would maybe revisit if we want to have dataset to have a varying return type depending on arguments passed to it.
It might be better to have different flavors of the dataset to be different dataset classes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that is a design choice that we should address. cc @NicolasHug

@pmeier
Copy link
Collaborator Author

pmeier commented Feb 3, 2022

In response to #5326 (comment) I've added symmetries meta data to the KeyPoint feature. With this each dataset can specify what symmetries exists between the keypoints.

Here is an example with HorizontalFlip:

from torchvision.prototype import datasets, transforms
from torchvision.utils import draw_keypoints

from torchvision.transforms.functional import to_pil_image


def draw_non_right_keypoints(sample, file):
    image = sample["image"]
    keypoints = sample["keypoints"]
    non_right_keypoints = keypoints[[not descr.startswith("right") for descr in keypoints.descriptions], :]

    annotated_image = draw_keypoints(
        image,
        non_right_keypoints.unsqueeze(0),
        radius=int(2e-2 * min(sample["image"].shape[-2:])),
        colors="red",
    )
    to_pil_image(annotated_image).save(file)


dataset = datasets.load("celeba")
sample = next(iter(dataset))

draw_non_right_keypoints(sample, "before.jpg")

transform = transforms.HorizontalFlip()
transformed_sample = transform(sample)

draw_non_right_keypoints(transformed_sample, "after.jpg")

before after

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proposal for the symmetry LGTM, thanks!

I only have one more comment regarding the resize transform, otherwise good to merge


class Resize(Transform):
NO_OP_FEATURE_TYPES = {Label}
NO_OP_FEATURE_TYPES = {Label, Keypoint}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok if we leave this as a TODO, but we should change the coordinates of the keypoints following the rescaling factor

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in general all geometric transforms need to support Keypoint's. We can probably share a lot of functionality with the bounding box kernels.

torchvision/prototype/datasets/_builtin/celeba.py Outdated Show resolved Hide resolved
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I got some questions below

("vertical", description, description.replace("left", "right"))
for description in descriptions
if description.startswith("left")
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify the type of the resulting symmetries variable?

Below they Keypoint expects Sequence[Tuple[KeypointSymmetry, Tuple[int, int]]]. It's non obvious to me how the current code has the specific type.

torchvision/prototype/datasets/_builtin/celeba.py Outdated Show resolved Hide resolved
torchvision/prototype/features/_keypoint.py Outdated Show resolved Hide resolved
torchvision/prototype/features/_keypoint.py Outdated Show resolved Hide resolved
@pmeier
Copy link
Collaborator Author

pmeier commented Feb 7, 2022

As discussed offline with @datumbox, we'll put this PR on hold until we have full support for bounding boxes and segmentation masks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants