1
1
# Copyright (c) OpenMMLab. All rights reserved.
2
- from collections import namedtuple
3
2
from itertools import product
4
3
from typing import Any , List , Optional , Tuple
5
4
16
15
refine_keypoints_dark_udp )
17
16
18
17
18
+ def _py_max_match (scores ):
19
+ """Apply munkres algorithm to get the best match.
20
+
21
+ Args:
22
+ scores(np.ndarray): cost matrix.
23
+
24
+ Returns:
25
+ np.ndarray: best match.
26
+ """
27
+ m = Munkres ()
28
+ tmp = m .compute (scores )
29
+ tmp = np .array (tmp ).astype (int )
30
+ return tmp
31
+
32
+
19
33
def _group_keypoints_by_tags (vals : np .ndarray ,
20
34
tags : np .ndarray ,
21
35
locs : np .ndarray ,
@@ -54,89 +68,78 @@ def _group_keypoints_by_tags(vals: np.ndarray,
54
68
np.ndarray: grouped keypoints in shape (G, K, D+1), where the last
55
69
dimenssion is the concatenated keypoint coordinates and scores.
56
70
"""
71
+
72
+ tag_k , loc_k , val_k = tags , locs , vals
57
73
K , M , D = locs .shape
58
74
assert vals .shape == tags .shape [:2 ] == (K , M )
59
75
assert len (keypoint_order ) == K
60
76
61
- # Build Munkres instance
62
- munkres = Munkres ()
63
-
64
- # Build a group pool, each group contains the keypoints of an instance
65
- groups = []
77
+ default_ = np .zeros ((K , 3 + tag_k .shape [2 ]), dtype = np .float32 )
66
78
67
- Group = namedtuple ('Group' , field_names = ['kpts' , 'scores' , 'tag_list' ])
79
+ joint_dict = {}
80
+ tag_dict = {}
81
+ for i in range (K ):
82
+ idx = keypoint_order [i ]
68
83
69
- def _init_group ():
70
- """Initialize a group, which is composed of the keypoints, keypoint
71
- scores and the tag of each keypoint."""
72
- _group = Group (
73
- kpts = np .zeros ((K , D ), dtype = np .float32 ),
74
- scores = np .zeros (K , dtype = np .float32 ),
75
- tag_list = [])
76
- return _group
84
+ tags = tag_k [idx ]
85
+ joints = np .concatenate ((loc_k [idx ], val_k [idx , :, None ], tags ), 1 )
86
+ mask = joints [:, 2 ] > val_thr
87
+ tags = tags [mask ] # shape: [M, L]
88
+ joints = joints [mask ] # shape: [M, 3 + L], 3: x, y, val
77
89
78
- for i in keypoint_order :
79
- # Get all valid candidate of the i-th keypoints
80
- valid = vals [i ] > val_thr
81
- if not valid .any ():
90
+ if joints .shape [0 ] == 0 :
82
91
continue
83
92
84
- tags_i = tags [i , valid ] # (M', L)
85
- vals_i = vals [i , valid ] # (M',)
86
- locs_i = locs [i , valid ] # (M', D)
87
-
88
- if len (groups ) == 0 : # Initialize the group pool
89
- for tag , val , loc in zip (tags_i , vals_i , locs_i ):
90
- group = _init_group ()
91
- group .kpts [i ] = loc
92
- group .scores [i ] = val
93
- group .tag_list .append (tag )
94
-
95
- groups .append (group )
96
-
97
- else : # Match keypoints to existing groups
98
- groups = groups [:max_groups ]
99
- group_tags = [np .mean (g .tag_list , axis = 0 ) for g in groups ]
100
-
101
- # Calculate distance matrix between group tags and tag candidates
102
- # of the i-th keypoint
103
- # Shape: (M', 1, L) , (1, G, L) -> (M', G, L)
104
- diff = tags_i [:, None ] - np .array (group_tags )[None ]
105
- dists = np .linalg .norm (diff , ord = 2 , axis = 2 )
106
- num_kpts , num_groups = dists .shape [:2 ]
107
-
108
- # Experimental cost function for keypoint-group matching
109
- costs = np .round (dists ) * 100 - vals_i [..., None ]
110
- if num_kpts > num_groups :
111
- padding = np .full ((num_kpts , num_kpts - num_groups ),
112
- 1e10 ,
113
- dtype = np .float32 )
114
- costs = np .concatenate ((costs , padding ), axis = 1 )
115
-
116
- # Match keypoints and groups by Munkres algorithm
117
- matches = munkres .compute (costs )
118
- for kpt_idx , group_idx in matches :
119
- if group_idx < num_groups and dists [kpt_idx ,
120
- group_idx ] < tag_thr :
121
- # Add the keypoint to the matched group
122
- group = groups [group_idx ]
93
+ if i == 0 or len (joint_dict ) == 0 :
94
+ for tag , joint in zip (tags , joints ):
95
+ key = tag [0 ]
96
+ joint_dict .setdefault (key , np .copy (default_ ))[idx ] = joint
97
+ tag_dict [key ] = [tag ]
98
+ else :
99
+ # shape: [M]
100
+ grouped_keys = list (joint_dict .keys ())
101
+ # shape: [M, L]
102
+ grouped_tags = [np .mean (tag_dict [i ], axis = 0 ) for i in grouped_keys ]
103
+
104
+ # shape: [M, M, L]
105
+ diff = joints [:, None , 3 :] - np .array (grouped_tags )[None , :, :]
106
+ # shape: [M, M]
107
+ diff_normed = np .linalg .norm (diff , ord = 2 , axis = 2 )
108
+ diff_saved = np .copy (diff_normed )
109
+ diff_normed = np .round (diff_normed ) * 100 - joints [:, 2 :3 ]
110
+
111
+ num_added = diff .shape [0 ]
112
+ num_grouped = diff .shape [1 ]
113
+
114
+ if num_added > num_grouped :
115
+ diff_normed = np .concatenate (
116
+ (diff_normed ,
117
+ np .zeros ((num_added , num_added - num_grouped ),
118
+ dtype = np .float32 ) + 1e10 ),
119
+ axis = 1 )
120
+
121
+ pairs = _py_max_match (diff_normed )
122
+ for row , col in pairs :
123
+ if (row < num_added and col < num_grouped
124
+ and diff_saved [row ][col ] < tag_thr ):
125
+ key = grouped_keys [col ]
126
+ joint_dict [key ][idx ] = joints [row ]
127
+ tag_dict [key ].append (tags [row ])
123
128
else :
124
- # Initialize a new group with unmatched keypoint
125
- group = _init_group ()
126
- groups .append (group )
127
-
128
- group .kpts [i ] = locs_i [kpt_idx ]
129
- group .scores [i ] = vals_i [kpt_idx ]
130
- group .tag_list .append (tags_i [kpt_idx ])
131
-
132
- groups = groups [:max_groups ]
133
- if groups :
134
- grouped_keypoints = np .stack (
135
- [np .r_ ['1' , g .kpts , g .scores [:, None ]] for g in groups ])
136
- else :
137
- grouped_keypoints = np .empty ((0 , K , D + 1 ))
129
+ key = tags [row ][0 ]
130
+ joint_dict .setdefault (key , np .copy (default_ ))[idx ] = \
131
+ joints [row ]
132
+ tag_dict [key ] = [tags [row ]]
138
133
139
- return grouped_keypoints
134
+ joint_dict_keys = list (joint_dict .keys ())[:max_groups ]
135
+
136
+ if joint_dict_keys :
137
+ results = np .array ([joint_dict [i ]
138
+ for i in joint_dict_keys ]).astype (np .float32 )
139
+ results = results [..., :D + 1 ]
140
+ else :
141
+ results = np .empty ((0 , K , D + 1 ), dtype = np .float32 )
142
+ return results
140
143
141
144
142
145
@KEYPOINT_CODECS .register_module ()
@@ -210,7 +213,8 @@ def __init__(
210
213
decode_gaussian_kernel : int = 3 ,
211
214
decode_keypoint_thr : float = 0.1 ,
212
215
decode_tag_thr : float = 1.0 ,
213
- decode_topk : int = 20 ,
216
+ decode_topk : int = 30 ,
217
+ decode_center_shift = 0.0 ,
214
218
decode_max_instances : Optional [int ] = None ,
215
219
) -> None :
216
220
super ().__init__ ()
@@ -222,8 +226,9 @@ def __init__(
222
226
self .decode_keypoint_thr = decode_keypoint_thr
223
227
self .decode_tag_thr = decode_tag_thr
224
228
self .decode_topk = decode_topk
229
+ self .decode_center_shift = decode_center_shift
225
230
self .decode_max_instances = decode_max_instances
226
- self .dedecode_keypoint_order = decode_keypoint_order .copy ()
231
+ self .decode_keypoint_order = decode_keypoint_order .copy ()
227
232
228
233
if self .use_udp :
229
234
self .scale_factor = ((np .array (input_size ) - 1 ) /
@@ -376,7 +381,7 @@ def _group_func(inputs: Tuple):
376
381
vals ,
377
382
tags ,
378
383
locs ,
379
- keypoint_order = self .dedecode_keypoint_order ,
384
+ keypoint_order = self .decode_keypoint_order ,
380
385
val_thr = self .decode_keypoint_thr ,
381
386
tag_thr = self .decode_tag_thr ,
382
387
max_groups = self .decode_max_instances )
@@ -463,13 +468,13 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
463
468
f'tagging map ({ batch_tags .shape } )' )
464
469
465
470
# Heatmap NMS
466
- batch_heatmaps = batch_heatmap_nms (batch_heatmaps ,
467
- self .decode_nms_kernel )
471
+ batch_heatmaps_peak = batch_heatmap_nms (batch_heatmaps ,
472
+ self .decode_nms_kernel )
468
473
469
474
# Get top-k in each heatmap and and convert to numpy
470
475
batch_topk_vals , batch_topk_tags , batch_topk_locs = to_numpy (
471
476
self ._get_batch_topk (
472
- batch_heatmaps , batch_tags , k = self .decode_topk ))
477
+ batch_heatmaps_peak , batch_tags , k = self .decode_topk ))
473
478
474
479
# Group keypoint candidates into groups (instances)
475
480
batch_groups = self ._group_keypoints (batch_topk_vals , batch_topk_tags ,
@@ -482,16 +487,14 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
482
487
# Refine the keypoint prediction
483
488
batch_keypoints = []
484
489
batch_keypoint_scores = []
490
+ batch_instance_scores = []
485
491
for i , (groups , heatmaps , tags ) in enumerate (
486
492
zip (batch_groups , batch_heatmaps_np , batch_tags_np )):
487
493
488
494
keypoints , scores = groups [..., :- 1 ], groups [..., - 1 ]
495
+ instance_scores = scores .mean (axis = - 1 )
489
496
490
497
if keypoints .size > 0 :
491
- # identify missing keypoints
492
- keypoints , scores = self ._fill_missing_keypoints (
493
- keypoints , scores , heatmaps , tags )
494
-
495
498
# refine keypoint coordinates according to heatmap distribution
496
499
if self .use_udp :
497
500
keypoints = refine_keypoints_dark_udp (
@@ -500,13 +503,20 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
500
503
blur_kernel_size = self .decode_gaussian_kernel )
501
504
else :
502
505
keypoints = refine_keypoints (keypoints , heatmaps )
506
+ keypoints += self .decode_center_shift * \
507
+ (scores > 0 ).astype (keypoints .dtype )[..., None ]
508
+
509
+ # identify missing keypoints
510
+ keypoints , scores = self ._fill_missing_keypoints (
511
+ keypoints , scores , heatmaps , tags )
503
512
504
513
batch_keypoints .append (keypoints )
505
514
batch_keypoint_scores .append (scores )
515
+ batch_instance_scores .append (instance_scores )
506
516
507
517
# restore keypoint scale
508
518
batch_keypoints = [
509
519
kpts * self .scale_factor for kpts in batch_keypoints
510
520
]
511
521
512
- return batch_keypoints , batch_keypoint_scores
522
+ return batch_keypoints , batch_keypoint_scores , batch_instance_scores
0 commit comments