Skip to content

Commit d6d2d03

Browse files
sbucaillequbvel
authored andcommitted
🚨 [lightglue] fix: matches order changed because of early stopped indices (#40859)
* fix: bug that made early stop change order of matches * fix: applied code suggestion Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * fix: applied code suggestion to modular * fix: integration tests --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
1 parent b164209 commit d6d2d03

File tree

3 files changed

+46
-14
lines changed

3 files changed

+46
-14
lines changed

‎src/transformers/models/lightglue/modeling_lightglue.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,10 @@ def _concat_early_stopped_outputs(
628628
matching_scores,
629629
):
630630
early_stops_indices = torch.stack(early_stops_indices)
631+
# Rearrange tensors to have the same order as the input batch
632+
ids = torch.arange(early_stops_indices.shape[0])
633+
order_indices = early_stops_indices[ids]
634+
early_stops_indices = early_stops_indices[order_indices]
631635
matches, final_pruned_keypoints_indices = (
632636
pad_sequence(tensor, batch_first=True, padding_value=-1)
633637
for tensor in [matches, final_pruned_keypoints_indices]

‎src/transformers/models/lightglue/modular_lightglue.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,10 @@ def _concat_early_stopped_outputs(
786786
matching_scores,
787787
):
788788
early_stops_indices = torch.stack(early_stops_indices)
789+
# Rearrange tensors to have the same order as the input batch
790+
ids = torch.arange(early_stops_indices.shape[0])
791+
order_indices = early_stops_indices[ids]
792+
early_stops_indices = early_stops_indices[order_indices]
789793
matches, final_pruned_keypoints_indices = (
790794
pad_sequence(tensor, batch_first=True, padding_value=-1)
791795
for tensor in [matches, final_pruned_keypoints_indices]

‎tests/models/lightglue/test_modeling_lightglue.py‎

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -331,31 +331,31 @@ def test_inference(self):
331331
predicted_matches_values1 = outputs.matches[1, 0, 10:30]
332332
predicted_matching_scores_values1 = outputs.matching_scores[1, 0, 10:30]
333333

334-
expected_number_of_matches0 = 140
334+
expected_number_of_matches0 = 866
335335
expected_matches_values0 = torch.tensor(
336-
[14, -1, -1, 15, 17, 13, -1, -1, -1, -1, -1, -1, 5, -1, -1, 19, -1, 10, -1, 11],
337-
dtype=torch.int64,
338-
device=torch_device,
339-
)
340-
expected_matching_scores_values0 = torch.tensor(
341-
[0.3796, 0, 0, 0.3772, 0.4439, 0.2411, 0, 0, 0.0032, 0, 0, 0, 0.2997, 0, 0, 0.6762, 0, 0.8826, 0, 0.5583],
342-
device=torch_device,
343-
)
344-
345-
expected_number_of_matches1 = 866
346-
expected_matches_values1 = torch.tensor(
347336
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
348337
dtype=torch.int64,
349338
device=torch_device,
350339
)
351-
expected_matching_scores_values1 = torch.tensor(
340+
expected_matching_scores_values0 = torch.tensor(
352341
[
353342
0.6188,0.7817,0.5686,0.9353,0.9801,0.9193,0.8632,0.9111,0.9821,0.5496,
354343
0.9906,0.8682,0.9679,0.9914,0.9318,0.1910,0.9669,0.3240,0.9971,0.9923,
355344
],
356345
device=torch_device
357346
) # fmt:skip
358347

348+
expected_number_of_matches1 = 140
349+
expected_matches_values1 = torch.tensor(
350+
[14, -1, -1, 15, 17, 13, -1, -1, -1, -1, -1, -1, 5, -1, -1, 19, -1, 10, -1, 11],
351+
dtype=torch.int64,
352+
device=torch_device,
353+
)
354+
expected_matching_scores_values1 = torch.tensor(
355+
[0.3796, 0, 0, 0.3772, 0.4439, 0.2411, 0, 0, 0.0032, 0, 0, 0, 0.2997, 0, 0, 0.6762, 0, 0.8826, 0, 0.5583],
356+
device=torch_device,
357+
)
358+
359359
# expected_early_stopping_layer = 2
360360
# predicted_early_stopping_layer = torch.max(outputs.prune[1]).item()
361361
# self.assertEqual(predicted_early_stopping_layer, expected_early_stopping_layer)
@@ -375,7 +375,6 @@ def test_inference(self):
375375
Such CUDA inconsistencies can be found
376376
[here](https://github.com/huggingface/transformers/pull/33200/files#r1785980300)
377377
"""
378-
379378
self.assertTrue(abs(predicted_number_of_matches0 - expected_number_of_matches0) < 4)
380379
self.assertTrue(abs(predicted_number_of_matches1 - expected_number_of_matches1) < 4)
381380
self.assertTrue(
@@ -590,3 +589,28 @@ def test_inference_without_early_stop_and_keypoint_pruning(self):
590589
)
591590
self.assertTrue(torch.sum(predicted_matches_values0 != expected_matches_values0) < 4)
592591
self.assertTrue(torch.sum(predicted_matches_values1 != expected_matches_values1) < 4)
592+
593+
@slow
594+
def test_inference_order_with_early_stop(self):
595+
model = LightGlueForKeypointMatching.from_pretrained(
596+
"ETH-CVG/lightglue_superpoint", attn_implementation="eager"
597+
).to(torch_device)
598+
preprocessor = self.default_image_processor
599+
images = prepare_imgs()
600+
# [[image2, image0], [image1, image1]] -> [[image2, image0], [image2, image0], [image1, image1]]
601+
images = [images[0]] + images # adding a 3rd pair to test batching with early stopping
602+
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
603+
with torch.no_grad():
604+
outputs = model(**inputs, output_hidden_states=True, output_attentions=True)
605+
606+
predicted_number_of_matches_pair0 = torch.sum(outputs.matches[0][0] != -1).item()
607+
predicted_number_of_matches_pair1 = torch.sum(outputs.matches[1][0] != -1).item()
608+
predicted_number_of_matches_pair2 = torch.sum(outputs.matches[2][0] != -1).item()
609+
610+
# pair 0 and 1 are the same, so should have the same number of matches
611+
# pair 2 is [image1, image1] so should have more matches than first two pairs
612+
# This ensures that early stopping does not affect the order of the outputs
613+
# See : https://huggingface.co/ETH-CVG/lightglue_superpoint/discussions/6
614+
# The bug made the pairs switch order when early stopping was activated
615+
self.assertTrue(predicted_number_of_matches_pair0 == predicted_number_of_matches_pair1)
616+
self.assertTrue(predicted_number_of_matches_pair0 < predicted_number_of_matches_pair2)

0 commit comments

Comments
 (0)