diff --git a/docs/source/en/model_doc/owlvit.mdx b/docs/source/en/model_doc/owlvit.mdx index 0b61d7b274a0c7..ddbc2826d7a655 100644 --- a/docs/source/en/model_doc/owlvit.mdx +++ b/docs/source/en/model_doc/owlvit.mdx @@ -57,8 +57,8 @@ OWL-ViT is a zero-shot text-conditioned object detection model. OWL-ViT uses [CL ... box = [round(i, 2) for i in box.tolist()] ... if score >= score_threshold: ... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}") -Detected a photo of a cat with confidence 0.243 at location [1.42, 50.69, 308.58, 370.48] -Detected a photo of a cat with confidence 0.298 at location [348.06, 20.56, 642.33, 372.61] +Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29] +Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17] ``` This model was contributed by [adirik](https://huggingface.co/adirik). The original code can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit). diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index 73ee2597f1b163..c0386ab23d3fba 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -1323,8 +1323,8 @@ def forward( ... box = [round(i, 2) for i in box.tolist()] ... if score >= score_threshold: ... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}") - Detected a photo of a cat with confidence 0.243 at location [1.42, 50.69, 308.58, 370.48] - Detected a photo of a cat with confidence 0.298 at location [348.06, 20.56, 642.33, 372.61] + Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29] + Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17] ```""" output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states diff --git a/tests/models/owlvit/test_modeling_owlvit.py b/tests/models/owlvit/test_modeling_owlvit.py index 7564d192ad9898..e8f615ec8e54f0 100644 --- a/tests/models/owlvit/test_modeling_owlvit.py +++ b/tests/models/owlvit/test_modeling_owlvit.py @@ -733,7 +733,6 @@ def prepare_img(): @require_vision @require_torch -@unittest.skip("These tests are broken, fix me Alara") class OwlViTModelIntegrationTest(unittest.TestCase): @slow def test_inference(self): @@ -763,8 +762,7 @@ def test_inference(self): outputs.logits_per_text.shape, torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])), ) - expected_logits = torch.tensor([[4.4420, 0.6181]], device=torch_device) - + expected_logits = torch.tensor([[3.4613, 0.9403]], device=torch_device) self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3)) @slow @@ -788,7 +786,8 @@ def test_inference_object_detection(self): num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2) self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4))) + expected_slice_boxes = torch.tensor( - [[0.0948, 0.0471, 0.1915], [0.3194, 0.0583, 0.6498], [0.1441, 0.0452, 0.2197]] + [[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]] ).to(torch_device) self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))