Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ViTPose] Convert more checkpoints #35638

Merged
merged 10 commits into from
Jan 20, 2025
4 changes: 2 additions & 2 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ Flax), PyTorch, and/or TensorFlow.
| [ViTMAE](model_doc/vit_mae) | ✅ | ✅ | ❌ |
| [ViTMatte](model_doc/vitmatte) | ✅ | ❌ | ❌ |
| [ViTMSN](model_doc/vit_msn) | ✅ | ❌ | ❌ |
| [VitPose](model_doc/vitpose) | ✅ | ❌ | ❌ |
| [VitPoseBackbone](model_doc/vitpose_backbone) | ✅ | ❌ | ❌ |
| [ViTPose](model_doc/vitpose) | ✅ | ❌ | ❌ |
| [ViTPoseBackbone](model_doc/vitpose_backbone) | ✅ | ❌ | ❌ |
| [VITS](model_doc/vits) | ✅ | ❌ | ❌ |
| [ViViT](model_doc/vivit) | ✅ | ❌ | ❌ |
| [Wav2Vec2](model_doc/wav2vec2) | ✅ | ✅ | ✅ |
Expand Down
88 changes: 61 additions & 27 deletions docs/source/en/model_doc/vitpose.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,28 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->

# VitPose
# ViTPose

## Overview

The VitPose model was proposed in [ViTPose: Simple Vision Transformer Baselines for Human Pose Estimation](https://arxiv.org/abs/2204.12484) by Yufei Xu, Jing Zhang, Qiming Zhang, Dacheng Tao. VitPose employs a standard, non-hierarchical [Vision Transformer](https://arxiv.org/pdf/2010.11929v2) as backbone for the task of keypoint estimation. A simple decoder head is added on top to predict the heatmaps from a given image. Despite its simplicity, the model gets state-of-the-art results on the challenging MS COCO Keypoint Detection benchmark.
The ViTPose model was proposed in [ViTPose: Simple Vision Transformer Baselines for Human Pose Estimation](https://arxiv.org/abs/2204.12484) by Yufei Xu, Jing Zhang, Qiming Zhang, Dacheng Tao. ViTPose employs a standard, non-hierarchical [Vision Transformer](vit) as backbone for the task of keypoint estimation. A simple decoder head is added on top to predict the heatmaps from a given image. Despite its simplicity, the model gets state-of-the-art results on the challenging MS COCO Keypoint Detection benchmark. The model was further improved in [ViTPose++: Vision Transformer for Generic Body Pose Estimation](https://arxiv.org/abs/2212.04246) where the authors employ
a mixture-of-experts (MoE) module in the ViT backbone along with pre-training on more data, which further enhances the performance.

The abstract from the paper is the following:

*Although no specific domain knowledge is considered in the design, plain vision transformers have shown excellent performance in visual recognition tasks. However, little effort has been made to reveal the potential of such simple structures for pose estimation tasks. In this paper, we show the surprisingly good capabilities of plain vision transformers for pose estimation from various aspects, namely simplicity in model structure, scalability in model size, flexibility in training paradigm, and transferability of knowledge between models, through a simple baseline model called ViTPose. Specifically, ViTPose employs plain and non-hierarchical vision transformers as backbones to extract features for a given person instance and a lightweight decoder for pose estimation. It can be scaled up from 100M to 1B parameters by taking the advantages of the scalable model capacity and high parallelism of transformers, setting a new Pareto front between throughput and performance. Besides, ViTPose is very flexible regarding the attention type, input resolution, pre-training and finetuning strategy, as well as dealing with multiple pose tasks. We also empirically demonstrate that the knowledge of large ViTPose models can be easily transferred to small ones via a simple knowledge token. Experimental results show that our basic ViTPose model outperforms representative methods on the challenging MS COCO Keypoint Detection benchmark, while the largest model sets a new state-of-the-art.*

![vitpose-architecture](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/vitpose-architecture.png)
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/vitpose-architecture.png"
alt="drawing" width="600"/>

<small> ViTPose architecture. Taken from the <a href="https://arxiv.org/abs/2204.12484">original paper.</a> </small>

This model was contributed by [nielsr](https://huggingface.co/nielsr) and [sangbumchoi](https://github.com/SangbumChoi).
The original code can be found [here](https://github.com/ViTAE-Transformer/ViTPose).

## Usage Tips

ViTPose is a so-called top-down keypoint detection model. This means that one first uses an object detector, like [RT-DETR](rt_detr.md), to detect people (or other instances) in an image. Next, ViTPose takes the cropped images as input and predicts the keypoints.
ViTPose is a so-called top-down keypoint detection model. This means that one first uses an object detector, like [RT-DETR](rt_detr.md), to detect people (or other instances) in an image. Next, ViTPose takes the cropped images as input and predicts the keypoints for each of them.

```py
import torch
Expand All @@ -36,11 +40,7 @@ import numpy as np

from PIL import Image

from transformers import (
AutoProcessor,
RTDetrForObjectDetection,
VitPoseForPoseEstimation,
)
from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation

device = "cuda" if torch.cuda.is_available() else "cpu"

Expand All @@ -51,7 +51,7 @@ image = Image.open(requests.get(url, stream=True).raw)
# Stage 1. Detect humans on the image
# ------------------------------------------------------------------------

# You can choose detector by your choice
# You can choose any detector of your choice
person_image_processor = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
person_model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365", device_map=device)

Expand Down Expand Up @@ -89,9 +89,50 @@ pose_results = image_processor.post_process_pose_estimation(outputs, boxes=[pers
image_pose_result = pose_results[0] # results for first image
```

### ViTPose++ models

### Visualization for supervision user
```py
The best [checkpoints](https://huggingface.co/collections/usyd-community/vitpose-677fcfd0a0b2b5c8f79c4335) are those of the [ViTPose++ paper](https://arxiv.org/abs/2212.04246). ViTPose++ models employ a so-called [Mixture-of-Experts (MoE)](https://huggingface.co/blog/moe) architecture for the ViT backbone, resulting in better performance.

The ViTPose+ checkpoints use 6 experts, hence 6 different dataset indices can be passed.
An overview of the various dataset indices is provided below:

- 0: [COCO validation 2017](https://cocodataset.org/#overview) dataset, using an object detector that gets 56 AP on the "person" class
- 1: [AiC](https://github.com/fabbrimatteo/AiC-Dataset) dataset
- 2: [MPII](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/software-and-datasets/mpii-human-pose-dataset) dataset
- 3: [AP-10K](https://github.com/AlexTheBad/AP-10K) dataset
- 4: [APT-36K](https://github.com/pandorgan/APT-36K) dataset
- 5: [COCO-WholeBody](https://github.com/jin-s13/COCO-WholeBody) dataset

Pass the `dataset_index` argument in the forward of the model to indicate which experts to use for each example in the batch. Example usage is shown below:

```python
image_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-plus-base")
model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-plus-base", device=device)

inputs = image_processor(image, boxes=[person_boxes], return_tensors="pt").to(device)

dataset_index = torch.tensor([0], device=device) # must be a tensor of shape (batch_size,)

with torch.no_grad():
outputs = model(**inputs, dataset_index=dataset_index)
```

The ViTPose+ checkpoints use 6 experts, hence 6 different dataset indices can be passed.
An overview of the various dataset indices is provided below:

- 0: [COCO validation 2017](https://cocodataset.org/#overview) dataset, using an object detector that gets 56 AP on the "person" class
- 1: [AiC](https://github.com/fabbrimatteo/AiC-Dataset) dataset
- 2: [MPII](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/software-and-datasets/mpii-human-pose-dataset) dataset
- 3: [AP-10K](https://github.com/AlexTheBad/AP-10K) dataset
- 4: [APT-36K](https://github.com/pandorgan/APT-36K) dataset
- 5: [COCO-WholeBody](https://github.com/jin-s13/COCO-WholeBody) dataset


### Visualization

To visualize the various keypoints, one can either leverage the `supervision` [library](https://github.com/roboflow/supervision (requires `pip install supervision`):

```python
import supervision as sv

xy = torch.stack([pose_result['keypoints'] for pose_result in image_pose_result]).cpu().numpy()
Expand Down Expand Up @@ -119,8 +160,9 @@ annotated_frame = vertex_annotator.annotate(
)
```

### Visualization for advanced user
```py
Alternatively, one can also visualize the keypoints using [OpenCV](https://opencv.org/) (requires `pip install opencv-python`):

```python
import math
import cv2

Expand Down Expand Up @@ -223,26 +265,18 @@ pose_image
```
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/vitpose-coco.jpg" alt="drawing" width="600"/>

### MoE backbone

To enable MoE (Mixture of Experts) function in the backbone, user has to give appropriate configuration such as `num_experts` and input value `dataset_index` to the backbone model. However, it is not used in default parameters. Below is the code snippet for usage of MoE function.
## Resources

```py
>>> from transformers import VitPoseBackboneConfig, VitPoseBackbone
>>> import torch

>>> config = VitPoseBackboneConfig(num_experts=3, out_indices=[-1])
>>> model = VitPoseBackbone(config)
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViTPose. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.

>>> pixel_values = torch.randn(3, 3, 256, 192)
>>> dataset_index = torch.tensor([1, 2, 3])
>>> outputs = model(pixel_values, dataset_index)
```
- A demo of ViTPose on images and video can be found [here](https://huggingface.co/spaces/hysts/ViTPose-transformers).
- A notebook illustrating inference and visualization can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/ViTPose/Inference_with_ViTPose_for_human_pose_estimation.ipynb).

## VitPoseImageProcessor

[[autodoc]] VitPoseImageProcessor
- preprocess
- post_process_pose_estimation

## VitPoseConfig

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,8 +650,8 @@
("vit_msn", "ViTMSN"),
("vitdet", "VitDet"),
("vitmatte", "ViTMatte"),
("vitpose", "VitPose"),
("vitpose_backbone", "VitPoseBackbone"),
("vitpose", "ViTPose"),
("vitpose_backbone", "ViTPoseBackbone"),
("vits", "VITS"),
("vivit", "ViViT"),
("wav2vec2", "Wav2Vec2"),
Expand Down
Loading