Skip to content

Commit

Permalink
fix owlvit tests, update docstring examples (huggingface#18586)
Browse files Browse the repository at this point in the history
  • Loading branch information
alaradirik authored and oneraghavan committed Sep 26, 2022
1 parent 6b7f271 commit f99a2fc
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
4 changes: 2 additions & 2 deletions docs/source/en/model_doc/owlvit.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ OWL-ViT is a zero-shot text-conditioned object detection model. OWL-ViT uses [CL
... box = [round(i, 2) for i in box.tolist()]
... if score >= score_threshold:
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
Detected a photo of a cat with confidence 0.243 at location [1.42, 50.69, 308.58, 370.48]
Detected a photo of a cat with confidence 0.298 at location [348.06, 20.56, 642.33, 372.61]
Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
```
This model was contributed by [adirik](https://huggingface.co/adirik). The original code can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit).
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/owlvit/modeling_owlvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,8 +1323,8 @@ def forward(
... box = [round(i, 2) for i in box.tolist()]
... if score >= score_threshold:
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
Detected a photo of a cat with confidence 0.243 at location [1.42, 50.69, 308.58, 370.48]
Detected a photo of a cat with confidence 0.298 at location [348.06, 20.56, 642.33, 372.61]
Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
```"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down
7 changes: 3 additions & 4 deletions tests/models/owlvit/test_modeling_owlvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,6 @@ def prepare_img():

@require_vision
@require_torch
@unittest.skip("These tests are broken, fix me Alara")
class OwlViTModelIntegrationTest(unittest.TestCase):
@slow
def test_inference(self):
Expand Down Expand Up @@ -763,8 +762,7 @@ def test_inference(self):
outputs.logits_per_text.shape,
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
)
expected_logits = torch.tensor([[4.4420, 0.6181]], device=torch_device)

expected_logits = torch.tensor([[3.4613, 0.9403]], device=torch_device)
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))

@slow
Expand All @@ -788,7 +786,8 @@ def test_inference_object_detection(self):

num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))

expected_slice_boxes = torch.tensor(
[[0.0948, 0.0471, 0.1915], [0.3194, 0.0583, 0.6498], [0.1441, 0.0452, 0.2197]]
[[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))

0 comments on commit f99a2fc

Please sign in to comment.