Skip to content

Commit

Permalink
update keypoint convert
Browse files Browse the repository at this point in the history
  • Loading branch information
xiexinch committed Sep 19, 2023
1 parent 748432d commit b296a2a
Showing 1 changed file with 13 additions and 26 deletions.
39 changes: 13 additions & 26 deletions mmpose/datasets/transforms/converting.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,10 @@ def transform(self, results: dict) -> dict:
num_instances = results['keypoints'].shape[0]

# Initialize output arrays
keypoints = np.zeros((num_instances, self.num_keypoints, 2))
keypoints = np.zeros((num_instances, self.num_keypoints, 3))
keypoints_visible = np.zeros((num_instances, self.num_keypoints))
key = 'keypoints_3d' if 'keypoints_3d' in results else 'keypoints'

if 'keypoints_3d' in results:
keypoints_3d = np.zeros((num_instances, self.num_keypoints, 3),
dtype=np.float32)
flip_indices = results.get('flip_indices', None)

# Create a mask to weight visibility loss
Expand All @@ -106,43 +104,32 @@ def transform(self, results: dict) -> dict:
# Interpolate keypoints if pairs of source indexes provided
if self.interpolation:
keypoints[:, self.target_index] = 0.5 * (
results['keypoints'][:, self.source_index] +
results['keypoints'][:, self.source_index2])

results[key][:, self.source_index] +
results[key][:, self.source_index2])
keypoints_visible[:, self.target_index] = results[
'keypoints_visible'][:, self.source_index] * \
results['keypoints_visible'][:, self.source_index2]

if 'keypoints_3d' in results:
keypoints_3d[:, self.target_index] = 0.5 * (
results['keypoints_3d'][:, self.source_index] +
results['keypoints_3d'][:, self.source_index2])

'keypoints_visible'][:, self.source_index] * results[
'keypoints_visible'][:, self.source_index2]
# Flip keypoints if flip_indices provided
if flip_indices is not None:
for i, (x1, x2) in enumerate(
zip(self.source_index, self.source_index2)):
id = flip_indices[x1] if x1 == x2 else i
flip_indices[i] = id if id < self.num_keypoints else i
idx = flip_indices[x1] if x1 == x2 else i
flip_indices[i] = idx if idx < self.num_keypoints else i
flip_indices = flip_indices[:len(self.source_index)]
# Otherwise just copy from the source index
else:
keypoints[:,
self.target_index] = results['keypoints'][:, self.
source_index]
keypoints[:, self.target_index] = results[key][:,
self.source_index]
keypoints_visible[:, self.target_index] = results[
'keypoints_visible'][:, self.source_index]
if 'keypoints_3d' in results:
keypoints_3d[:, self.target_index] = results[
'keypoints_3d'][:, self.source_index]

# Update the results dict
results['keypoints'] = keypoints
results['keypoints'] = keypoints[..., :2]
results['keypoints_visible'] = np.stack(
[keypoints_visible, keypoints_visible_weights], axis=2)
if 'keypoints_3d' in results:
results['keypoints_3d'] = keypoints_3d
results['lifting_target'] = keypoints_3d[results['target_idx']]
results['keypoints_3d'] = keypoints
results['lifting_target'] = keypoints[results['target_idx']]
results['lifting_target_visible'] = keypoints_visible[
results['target_idx']]
results['flip_indices'] = flip_indices
Expand Down

0 comments on commit b296a2a

Please sign in to comment.