diff --git a/datumaro/components/annotations/matcher.py b/datumaro/components/annotations/matcher.py index c12875c1c4..a6f6b0307c 100644 --- a/datumaro/components/annotations/matcher.py +++ b/datumaro/components/annotations/matcher.py @@ -59,34 +59,34 @@ def match_segments( matches = [] mispred = [] - for a_idx, a_segm in enumerate(a_segms): - if len(b_segms) == 0: - break - matched_b = -1 - max_dist = -1 - b_indices = np.argsort( - [not label_matcher(a_segm, b_segm) for b_segm in b_segms], kind="stable" - ) # prioritize those with same label, keep score order - for b_idx in b_indices: - if 0 <= b_matches[b_idx]: # assign a_segm with max conf + # It needs len(a_segms) > 0 and len(b_segms) > 0 + if len(b_segms) > 0: + for a_idx, a_segm in enumerate(a_segms): + matched_b = -1 + max_dist = -1 + b_indices = np.argsort( + [not label_matcher(a_segm, b_segm) for b_segm in b_segms], kind="stable" + ) # prioritize those with same label, keep score order + for b_idx in b_indices: + if 0 <= b_matches[b_idx]: # assign a_segm with max conf + continue + d = distances[a_idx, b_idx] + if d < dist_thresh or d <= max_dist: + continue + max_dist = d + matched_b = b_idx + + if matched_b < 0: continue - d = distances[a_idx, b_idx] - if d < dist_thresh or d <= max_dist: - continue - max_dist = d - matched_b = b_idx - - if matched_b < 0: - continue - a_matches[a_idx] = matched_b - b_matches[matched_b] = a_idx + a_matches[a_idx] = matched_b + b_matches[matched_b] = a_idx - b_segm = b_segms[matched_b] + b_segm = b_segms[matched_b] - if label_matcher(a_segm, b_segm): - matches.append((a_segm, b_segm)) - else: - mispred.append((a_segm, b_segm)) + if label_matcher(a_segm, b_segm): + matches.append((a_segm, b_segm)) + else: + mispred.append((a_segm, b_segm)) # *_umatched: boxes of (*) we failed to match a_unmatched = [a_segms[i] for i, m in enumerate(a_matches) if m < 0]