diff --git a/inference/models/owlv2/owlv2.py b/inference/models/owlv2/owlv2.py index a3d05ff40..8ff543453 100644 --- a/inference/models/owlv2/owlv2.py +++ b/inference/models/owlv2/owlv2.py @@ -171,7 +171,7 @@ def get_class_preds_from_embeds( survival_indices = torchvision.ops.nms( to_corners(pred_boxes), pred_scores, iou_threshold ) - # put on numpy and filter to post-nms + # filter to post-nms pred_boxes = pred_boxes[survival_indices, :] pred_classes = pred_classes[survival_indices] pred_scores = pred_scores[survival_indices] @@ -371,6 +371,9 @@ def infer_from_embed( all_predicted_classes.append(classes) all_predicted_scores.append(scores) + if not all_predicted_boxes: + return [] + all_predicted_boxes = torch.cat(all_predicted_boxes, dim=0) all_predicted_classes = torch.cat(all_predicted_classes, dim=0) all_predicted_scores = torch.cat(all_predicted_scores, dim=0) diff --git a/tests/inference/models_predictions_tests/test_owlv2.py b/tests/inference/models_predictions_tests/test_owlv2.py index 6bbcbcfd5..33e8b7f44 100644 --- a/tests/inference/models_predictions_tests/test_owlv2.py +++ b/tests/inference/models_predictions_tests/test_owlv2.py @@ -1,7 +1,10 @@ +import pytest + from inference.core.entities.requests.owlv2 import OwlV2InferenceRequest from inference.models.owlv2.owlv2 import OwlV2 +@pytest.mark.slow def test_owlv2(): image = { "type": "url", @@ -49,6 +52,14 @@ def test_owlv2(): assert abs(532 - posts[3].x) < 1.5 assert abs(572 - posts[4].x) < 1.5 + +@pytest.mark.slow +def test_owlv2_multiple_prompts(): + image = { + "type": "url", + "value": "https://media.roboflow.com/inference/seawithdock.jpeg", + } + # test we can handle multiple (positive and negative) prompts for the same image request = OwlV2InferenceRequest( image=image, @@ -96,7 +107,15 @@ def test_owlv2(): assert abs(532 - posts[2].x) < 1.5 assert abs(572 - posts[3].x) < 1.5 - # test that we can handle no prompts for an image + +@pytest.mark.slow +def test_owlv2_image_without_prompts(): + image = { + "type": "url", + "value": "https://media.roboflow.com/inference/seawithdock.jpeg", + } + + # test that we can handle an image without any prompts request = OwlV2InferenceRequest( image=image, training_data=[ @@ -124,3 +143,147 @@ def test_owlv2(): response = OwlV2().infer_from_request(request) assert len(response.predictions) == 5 + + +@pytest.mark.slow +def test_owlv2_bad_prompt(): + image = { + "type": "url", + "value": "https://media.roboflow.com/inference/seawithdock.jpeg", + } + + # test that we can handle a bad prompt + request = OwlV2InferenceRequest( + image=image, + training_data=[ + { + "image": image, + "boxes": [ + { + "x": 1, + "y": 1, + "w": 1, + "h": 1, + "cls": "post", + "negative": False, + } + ], + } + ], + visualize_predictions=True, + confidence=0.9, + ) + + response = OwlV2().infer_from_request(request) + assert len(response.predictions) == 0 + + +@pytest.mark.slow +def test_owlv2_bad_prompt_hidden_among_good_prompts(): + image = { + "type": "url", + "value": "https://media.roboflow.com/inference/seawithdock.jpeg", + } + + # test that we can handle a bad prompt + request = OwlV2InferenceRequest( + image=image, + training_data=[ + { + "image": image, + "boxes": [ + { + "x": 1, + "y": 1, + "w": 1, + "h": 1, + "cls": "post", + "negative": False, + }, + { + "x": 223, + "y": 306, + "w": 40, + "h": 226, + "cls": "post", + "negative": False, + }, + ], + } + ], + visualize_predictions=True, + confidence=0.9, + ) + + response = OwlV2().infer_from_request(request) + assert len(response.predictions) == 5 + + +@pytest.mark.slow +def test_owlv2_no_training_data(): + image = { + "type": "url", + "value": "https://media.roboflow.com/inference/seawithdock.jpeg", + } + + # test that we can handle no training data + request = OwlV2InferenceRequest( + image=image, + training_data=[], + ) + + response = OwlV2().infer_from_request(request) + assert len(response.predictions) == 0 + + +@pytest.mark.slow +def test_owlv2_multiple_training_images(): + image = { + "type": "url", + "value": "https://media.roboflow.com/inference/seawithdock.jpeg", + } + second_image = { + "type": "url", + "value": "https://media.roboflow.com/inference/dock2.jpg", + } + + request = OwlV2InferenceRequest( + image=image, + training_data=[ + { + "image": image, + "boxes": [ + { + "x": 223, + "y": 306, + "w": 40, + "h": 226, + "cls": "post", + "negative": False, + } + ], + }, + { + "image": second_image, + "boxes": [ + { + "x": 3009, + "y": 1873, + "w": 289, + "h": 811, + "cls": "post", + "negative": True, + } + ], + }, + ], + visualize_predictions=True, + confidence=0.9, + ) + + response = OwlV2().infer_from_request(request) + assert len(response.predictions) == 5 + + +if __name__ == "__main__": + test_owlv2()