Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/models/lightglue/modeling_lightglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,10 @@ def _concat_early_stopped_outputs(
matching_scores,
):
early_stops_indices = torch.stack(early_stops_indices)
# Rearrange tensors to have the same order as the input batch
ids = torch.arange(early_stops_indices.shape[0])
order_indices = early_stops_indices[ids]
early_stops_indices = early_stops_indices[order_indices]
matches, final_pruned_keypoints_indices = (
pad_sequence(tensor, batch_first=True, padding_value=-1)
for tensor in [matches, final_pruned_keypoints_indices]
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/lightglue/modular_lightglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,10 @@ def _concat_early_stopped_outputs(
matching_scores,
):
early_stops_indices = torch.stack(early_stops_indices)
# Rearrange tensors to have the same order as the input batch
ids = torch.arange(early_stops_indices.shape[0])
order_indices = early_stops_indices[ids]
early_stops_indices = early_stops_indices[order_indices]
matches, final_pruned_keypoints_indices = (
pad_sequence(tensor, batch_first=True, padding_value=-1)
for tensor in [matches, final_pruned_keypoints_indices]
Expand Down
52 changes: 38 additions & 14 deletions tests/models/lightglue/test_modeling_lightglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,31 +331,31 @@ def test_inference(self):
predicted_matches_values1 = outputs.matches[1, 0, 10:30]
predicted_matching_scores_values1 = outputs.matching_scores[1, 0, 10:30]

expected_number_of_matches0 = 140
expected_number_of_matches0 = 866
expected_matches_values0 = torch.tensor(
[14, -1, -1, 15, 17, 13, -1, -1, -1, -1, -1, -1, 5, -1, -1, 19, -1, 10, -1, 11],
dtype=torch.int64,
device=torch_device,
)
expected_matching_scores_values0 = torch.tensor(
[0.3796, 0, 0, 0.3772, 0.4439, 0.2411, 0, 0, 0.0032, 0, 0, 0, 0.2997, 0, 0, 0.6762, 0, 0.8826, 0, 0.5583],
device=torch_device,
)

expected_number_of_matches1 = 866
expected_matches_values1 = torch.tensor(
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
dtype=torch.int64,
device=torch_device,
)
expected_matching_scores_values1 = torch.tensor(
expected_matching_scores_values0 = torch.tensor(
[
0.6188,0.7817,0.5686,0.9353,0.9801,0.9193,0.8632,0.9111,0.9821,0.5496,
0.9906,0.8682,0.9679,0.9914,0.9318,0.1910,0.9669,0.3240,0.9971,0.9923,
],
device=torch_device
) # fmt:skip

expected_number_of_matches1 = 140
expected_matches_values1 = torch.tensor(
[14, -1, -1, 15, 17, 13, -1, -1, -1, -1, -1, -1, 5, -1, -1, 19, -1, 10, -1, 11],
dtype=torch.int64,
device=torch_device,
)
expected_matching_scores_values1 = torch.tensor(
[0.3796, 0, 0, 0.3772, 0.4439, 0.2411, 0, 0, 0.0032, 0, 0, 0, 0.2997, 0, 0, 0.6762, 0, 0.8826, 0, 0.5583],
device=torch_device,
)

# expected_early_stopping_layer = 2
# predicted_early_stopping_layer = torch.max(outputs.prune[1]).item()
# self.assertEqual(predicted_early_stopping_layer, expected_early_stopping_layer)
Expand All @@ -375,7 +375,6 @@ def test_inference(self):
Such CUDA inconsistencies can be found
[here](https://github.com/huggingface/transformers/pull/33200/files#r1785980300)
"""

self.assertTrue(abs(predicted_number_of_matches0 - expected_number_of_matches0) < 4)
self.assertTrue(abs(predicted_number_of_matches1 - expected_number_of_matches1) < 4)
self.assertTrue(
Expand Down Expand Up @@ -590,3 +589,28 @@ def test_inference_without_early_stop_and_keypoint_pruning(self):
)
self.assertTrue(torch.sum(predicted_matches_values0 != expected_matches_values0) < 4)
self.assertTrue(torch.sum(predicted_matches_values1 != expected_matches_values1) < 4)

@slow
def test_inference_order_with_early_stop(self):
model = LightGlueForKeypointMatching.from_pretrained(
"ETH-CVG/lightglue_superpoint", attn_implementation="eager"
).to(torch_device)
preprocessor = self.default_image_processor
images = prepare_imgs()
# [[image2, image0], [image1, image1]] -> [[image2, image0], [image2, image0], [image1, image1]]
images = [images[0]] + images # adding a 3rd pair to test batching with early stopping
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True, output_attentions=True)

predicted_number_of_matches_pair0 = torch.sum(outputs.matches[0][0] != -1).item()
predicted_number_of_matches_pair1 = torch.sum(outputs.matches[1][0] != -1).item()
predicted_number_of_matches_pair2 = torch.sum(outputs.matches[2][0] != -1).item()

# pair 0 and 1 are the same, so should have the same number of matches
# pair 2 is [image1, image1] so should have more matches than first two pairs
# This ensures that early stopping does not affect the order of the outputs
# See : https://huggingface.co/ETH-CVG/lightglue_superpoint/discussions/6
# The bug made the pairs switch order when early stopping was activated
self.assertTrue(predicted_number_of_matches_pair0 == predicted_number_of_matches_pair1)
self.assertTrue(predicted_number_of_matches_pair0 < predicted_number_of_matches_pair2)