Skip to content

Commit

Permalink
Merge pull request #57 from SangbumChoi/vitpose
Browse files Browse the repository at this point in the history
Scipy transformation
  • Loading branch information
NielsRogge authored Jul 23, 2024
2 parents 6238277 + b0a488e commit c20462e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 16 deletions.
18 changes: 10 additions & 8 deletions src/transformers/models/vitpose/convert_vitpose_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,10 @@ def prepare_img():


name_to_path = {
"vitpose-base-simple": "/Users/nielsrogge/Documents/ViTPose/vitpose-b-simple.pth",
"vitpose-base": "/Users/nielsrogge/Documents/ViTPose/vitpose-b.pth",
"vitpose-base-coco-aic-mpii": "/Users/nielsrogge/Documents/ViTPose/vitpose_base_coco_aic_mpii.pth",
"vitpose+-base": "/Users/nielsrogge/Documents/ViTPose/vitpose+_base.pth",
"vitpose-base-simple": "vitpose-b-simple.pth",
"vitpose-base": "vitpose-b.pth",
"vitpose-base-coco-aic-mpii": "vitpose_base_coco_aic_mpii.pth",
"vitpose+-base": "vitpose+_base.pth",
}


Expand All @@ -200,9 +200,6 @@ def convert_vitpose_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub
checkpoint_path = name_to_path[model_name]
state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]

# for name, param in state_dict.items():
# print(name, param.shape)

# rename some keys
new_state_dict = convert_state_dict(state_dict, dim=config.backbone_config.hidden_size, config=config)
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
Expand All @@ -226,7 +223,7 @@ def convert_vitpose_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub

filepath = hf_hub_download(repo_id="nielsr/test-image", filename="vitpose_batch_data.pt", repo_type="dataset")
original_pixel_values = torch.load(filepath, map_location="cpu")["img"]
assert torch.allclose(pixel_values, original_pixel_values)
assert torch.allclose(pixel_values, original_pixel_values, atol=1e-1)

img_metas = torch.load(filepath, map_location="cpu")["img_metas"]
dataset_index = torch.tensor([0])
Expand Down Expand Up @@ -263,21 +260,25 @@ def convert_vitpose_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub
assert torch.allclose(
torch.from_numpy(pose_results[1]["keypoints"][0, :3]),
torch.tensor([3.98180511e02, 1.81808380e02, 8.66642594e-01]),
atol=5e-2,
)
elif model_name == "vitpose-base":
assert torch.allclose(
torch.from_numpy(pose_results[1]["keypoints"][0, :3]),
torch.tensor([3.9807913e02, 1.8182812e02, 8.8235235e-01]),
atol=5e-2,
)
elif model_name == "vitpose-base-coco-aic-mpii":
assert torch.allclose(
torch.from_numpy(pose_results[1]["keypoints"][0, :3]),
torch.tensor([3.98305542e02, 1.81741592e02, 8.69966745e-01]),
atol=5e-2,
)
elif model_name == "vitpose+-base":
assert torch.allclose(
torch.from_numpy(pose_results[1]["keypoints"][0, :3]),
torch.tensor([3.98201294e02, 1.81728302e02, 8.75046968e-01]),
atol=5e-2,
)
else:
raise ValueError("Model not supported")
Expand All @@ -290,6 +291,7 @@ def convert_vitpose_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub
assert torch.allclose(
torch.tensor(hf_pose_results[1]["keypoints"][0, :3]),
torch.tensor([3.9813846e02, 1.8180725e02, 8.7446749e-01]),
atol=5e-2,
)
assert hf_pose_results[0]["keypoints"].shape == (17, 3)
assert hf_pose_results[1]["keypoints"].shape == (17, 3)
Expand Down
42 changes: 34 additions & 8 deletions src/transformers/models/vitpose/image_processing_vitpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,15 @@
to_numpy_array,
valid_images,
)
from ...utils import TensorType, is_cv2_available, is_vision_available, logging
from ...utils import TensorType, is_scipy_available, is_vision_available, logging


if is_vision_available():
import PIL

if is_cv2_available():
# TODO get rid of cv2?
import cv2

if is_scipy_available():
from scipy.linalg import inv
from scipy.ndimage import affine_transform, gaussian_filter

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -127,9 +126,10 @@ def post_dark_udp(coords, batch_heatmaps, kernel=3):
num_coords = coords.shape[0]
if not (batch_size == 1 or batch_size == num_coords):
raise ValueError("The batch size of heatmaps should be 1 or equal to the batch size of coordinates.")
radius = int((kernel - 1) // 2)
for heatmaps in batch_heatmaps:
for heatmap in heatmaps:
cv2.GaussianBlur(heatmap, (kernel, kernel), 0, heatmap)
gaussian_filter(heatmap, sigma=0.8, output=heatmap, radius=(radius, radius), axes=(0, 1))
np.clip(batch_heatmaps, 0.001, 50, batch_heatmaps)
np.log(batch_heatmaps, batch_heatmaps)

Expand Down Expand Up @@ -247,6 +247,32 @@ def get_warp_matrix(theta: float, size_input: np.ndarray, size_dst: np.ndarray,
return matrix


def scipy_warp_affine(src, M, size):
"""
This function implements cv2.warpAffine used in the original implementation using scipy.
Note: the original implementation uses cv2.INTER_LINEAR.
"""
channels = [src[..., i] for i in range(src.shape[-1])]

# Convert to a 3x3 matrix used by SciPy
M_scipy = np.vstack([M, [0, 0, 1]])
# If you have a matrix for the ‘push’ transformation, use its inverse (numpy.linalg.inv) in this function.
M_inv = inv(M_scipy)
M_inv[0, 0], M_inv[0, 1], M_inv[1, 0], M_inv[1, 1], M_inv[0, 2], M_inv[1, 2] = (
M_inv[1, 1],
M_inv[1, 0],
M_inv[0, 1],
M_inv[0, 0],
M_inv[1, 2],
M_inv[0, 2],
)

new_src = [affine_transform(channel, M_inv, output_shape=size, order=1) for channel in channels]
new_src = np.stack(new_src, axis=-1)
return new_src


class ViTPoseImageProcessor(BaseImageProcessor):
r"""
Constructs a ViTPose image processor.
Expand Down Expand Up @@ -330,12 +356,12 @@ def affine_transform(
transformation = get_warp_matrix(rotation, center * 2.0, np.array(size) - 1.0, scale * 200.0)

# cv2 requires channels last format
cv2_image = (
image = (
image
if input_data_format == ChannelDimension.LAST
else to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format)
)
image = cv2.warpAffine(cv2_image, transformation, size, flags=cv2.INTER_LINEAR)
image = scipy_warp_affine(src=image, M=transformation, size=(size[1], size[0]))

image = to_channel_dimension_format(image, data_format, ChannelDimension.LAST)

Expand Down

0 comments on commit c20462e

Please sign in to comment.