@@ -331,31 +331,31 @@ def test_inference(self):
331331 predicted_matches_values1 = outputs .matches [1 , 0 , 10 :30 ]
332332 predicted_matching_scores_values1 = outputs .matching_scores [1 , 0 , 10 :30 ]
333333
334- expected_number_of_matches0 = 140
334+ expected_number_of_matches0 = 866
335335 expected_matches_values0 = torch .tensor (
336- [14 , - 1 , - 1 , 15 , 17 , 13 , - 1 , - 1 , - 1 , - 1 , - 1 , - 1 , 5 , - 1 , - 1 , 19 , - 1 , 10 , - 1 , 11 ],
337- dtype = torch .int64 ,
338- device = torch_device ,
339- )
340- expected_matching_scores_values0 = torch .tensor (
341- [0.3796 , 0 , 0 , 0.3772 , 0.4439 , 0.2411 , 0 , 0 , 0.0032 , 0 , 0 , 0 , 0.2997 , 0 , 0 , 0.6762 , 0 , 0.8826 , 0 , 0.5583 ],
342- device = torch_device ,
343- )
344-
345- expected_number_of_matches1 = 866
346- expected_matches_values1 = torch .tensor (
347336 [10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 ],
348337 dtype = torch .int64 ,
349338 device = torch_device ,
350339 )
351- expected_matching_scores_values1 = torch .tensor (
340+ expected_matching_scores_values0 = torch .tensor (
352341 [
353342 0.6188 ,0.7817 ,0.5686 ,0.9353 ,0.9801 ,0.9193 ,0.8632 ,0.9111 ,0.9821 ,0.5496 ,
354343 0.9906 ,0.8682 ,0.9679 ,0.9914 ,0.9318 ,0.1910 ,0.9669 ,0.3240 ,0.9971 ,0.9923 ,
355344 ],
356345 device = torch_device
357346 ) # fmt:skip
358347
348+ expected_number_of_matches1 = 140
349+ expected_matches_values1 = torch .tensor (
350+ [14 , - 1 , - 1 , 15 , 17 , 13 , - 1 , - 1 , - 1 , - 1 , - 1 , - 1 , 5 , - 1 , - 1 , 19 , - 1 , 10 , - 1 , 11 ],
351+ dtype = torch .int64 ,
352+ device = torch_device ,
353+ )
354+ expected_matching_scores_values1 = torch .tensor (
355+ [0.3796 , 0 , 0 , 0.3772 , 0.4439 , 0.2411 , 0 , 0 , 0.0032 , 0 , 0 , 0 , 0.2997 , 0 , 0 , 0.6762 , 0 , 0.8826 , 0 , 0.5583 ],
356+ device = torch_device ,
357+ )
358+
359359 # expected_early_stopping_layer = 2
360360 # predicted_early_stopping_layer = torch.max(outputs.prune[1]).item()
361361 # self.assertEqual(predicted_early_stopping_layer, expected_early_stopping_layer)
@@ -375,7 +375,6 @@ def test_inference(self):
375375 Such CUDA inconsistencies can be found
376376 [here](https://github.com/huggingface/transformers/pull/33200/files#r1785980300)
377377 """
378-
379378 self .assertTrue (abs (predicted_number_of_matches0 - expected_number_of_matches0 ) < 4 )
380379 self .assertTrue (abs (predicted_number_of_matches1 - expected_number_of_matches1 ) < 4 )
381380 self .assertTrue (
@@ -590,3 +589,28 @@ def test_inference_without_early_stop_and_keypoint_pruning(self):
590589 )
591590 self .assertTrue (torch .sum (predicted_matches_values0 != expected_matches_values0 ) < 4 )
592591 self .assertTrue (torch .sum (predicted_matches_values1 != expected_matches_values1 ) < 4 )
592+
593+ @slow
594+ def test_inference_order_with_early_stop (self ):
595+ model = LightGlueForKeypointMatching .from_pretrained (
596+ "ETH-CVG/lightglue_superpoint" , attn_implementation = "eager"
597+ ).to (torch_device )
598+ preprocessor = self .default_image_processor
599+ images = prepare_imgs ()
600+ # [[image2, image0], [image1, image1]] -> [[image2, image0], [image2, image0], [image1, image1]]
601+ images = [images [0 ]] + images # adding a 3rd pair to test batching with early stopping
602+ inputs = preprocessor (images = images , return_tensors = "pt" ).to (torch_device )
603+ with torch .no_grad ():
604+ outputs = model (** inputs , output_hidden_states = True , output_attentions = True )
605+
606+ predicted_number_of_matches_pair0 = torch .sum (outputs .matches [0 ][0 ] != - 1 ).item ()
607+ predicted_number_of_matches_pair1 = torch .sum (outputs .matches [1 ][0 ] != - 1 ).item ()
608+ predicted_number_of_matches_pair2 = torch .sum (outputs .matches [2 ][0 ] != - 1 ).item ()
609+
610+ # pair 0 and 1 are the same, so should have the same number of matches
611+ # pair 2 is [image1, image1] so should have more matches than first two pairs
612+ # This ensures that early stopping does not affect the order of the outputs
613+ # See : https://huggingface.co/ETH-CVG/lightglue_superpoint/discussions/6
614+ # The bug made the pairs switch order when early stopping was activated
615+ self .assertTrue (predicted_number_of_matches_pair0 == predicted_number_of_matches_pair1 )
616+ self .assertTrue (predicted_number_of_matches_pair0 < predicted_number_of_matches_pair2 )
0 commit comments