From 05a0133720607c36aeed9a9fda5d41bddc5de6df Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Fri, 1 Nov 2024 21:54:36 +0000 Subject: [PATCH 1/4] fix case where there are no good matches for the prompt --- inference/models/owlv2/owlv2.py | 61 +++++++++++-------- .../models_predictions_tests/test_owlv2.py | 22 +++++++ 2 files changed, 58 insertions(+), 25 deletions(-) diff --git a/inference/models/owlv2/owlv2.py b/inference/models/owlv2/owlv2.py index a3d05ff40b..56a4b753bb 100644 --- a/inference/models/owlv2/owlv2.py +++ b/inference/models/owlv2/owlv2.py @@ -171,7 +171,7 @@ def get_class_preds_from_embeds( survival_indices = torchvision.ops.nms( to_corners(pred_boxes), pred_scores, iou_threshold ) - # put on numpy and filter to post-nms + # filter to post-nms pred_boxes = pred_boxes[survival_indices, :] pred_classes = pred_classes[survival_indices] pred_scores = pred_scores[survival_indices] @@ -313,9 +313,10 @@ def embed_image(self, image: np.ndarray) -> Hash: def get_query_embedding( self, query_spec: QuerySpecType, iou_threshold: float - ) -> torch.Tensor: + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: # NOTE: for now we're handling each image seperately query_embeds = [] + has_overlap = [] for image_hash, query_boxes in query_spec.items(): try: _objectness, image_boxes, image_class_embeds, _, _ = ( @@ -329,22 +330,19 @@ def get_query_embedding( ) if image_boxes.numel() == 0 or query_boxes_tensor.numel() == 0: continue - iou, _ = box_iou( - to_corners(image_boxes), to_corners(query_boxes_tensor) - ) # 3000, k + iou, _ = box_iou(to_corners(image_boxes), to_corners(query_boxes_tensor)) ious, indices = torch.max(iou, dim=0) - # filter for only iou > 0.4 - iou_mask = ious > iou_threshold - indices = indices[iou_mask] - if not indices.numel() > 0: - continue embeds = image_class_embeds[indices] + + iou_mask = ious > iou_threshold + + # we don't filter by the mask here so as to maintain parallel structure + # with the metadata in the external calling function query_embeds.append(embeds) - if not query_embeds: - return None - query = torch.cat(query_embeds, dim=0) - return query + has_overlap.append(iou_mask) + + return query_embeds, has_overlap def infer_from_embed( self, @@ -371,6 +369,9 @@ def infer_from_embed( all_predicted_classes.append(classes) all_predicted_scores.append(scores) + if len(all_predicted_boxes) == 0: + return [] + all_predicted_boxes = torch.cat(all_predicted_boxes, dim=0) all_predicted_classes = torch.cat(all_predicted_classes, dim=0) all_predicted_scores = torch.cat(all_predicted_scores, dim=0) @@ -440,12 +441,16 @@ def make_class_embeddings_dict( bool_to_literal = {True: "positive", False: "negative"} for train_image in training_data: + boxes = train_image["boxes"] + if len(boxes) == 0: + # no prompts for this image, so we skip it + continue + # grab and embed image image = load_image_rgb(train_image["image"]) image_hash = self.embed_image(image) - # grab and normalize box prompts for this image - boxes = train_image["boxes"] + # normalize box prompts coords = [[box["x"], box["y"], box["w"], box["h"]] for box in boxes] coords = [ tuple([c / max(image.shape[:2]) for c in coord]) for coord in coords @@ -456,18 +461,24 @@ def make_class_embeddings_dict( # compute the embeddings for the box prompts query_spec = {image_hash: coords} # NOTE: because we just computed the embedding for this image, this should never result in a KeyError - embeddings = self.get_query_embedding(query_spec, iou_threshold) - - if embeddings is None: - continue + batched_embeddings, batched_has_overlap = self.get_query_embedding( + query_spec, iou_threshold + ) + # get_query_embedding is designed to handle multiple images + # so we take the first (and only) element + embeddings = batched_embeddings[0] + has_overlap = batched_has_overlap[0] # add the embeddings to their appropriate class and positive/negative list - for embedding, class_name, is_positive in zip( - embeddings, classes, is_positive + for embedding, class_name, is_positive, fits_the_prompt in zip( + embeddings, classes, is_positive, has_overlap ): - class_embeddings_dict[class_name][bool_to_literal[is_positive]].append( - embedding - ) + # we checked if the found box sufficiently overlaps with the prompt + # so we skip adding it to the class embeddings dict if it doesn't + if fits_the_prompt: + class_embeddings_dict[class_name][ + bool_to_literal[is_positive] + ].append(embedding) # convert the lists of embeddings to tensors diff --git a/tests/inference/models_predictions_tests/test_owlv2.py b/tests/inference/models_predictions_tests/test_owlv2.py index 6bbcbcfd5b..d57651b780 100644 --- a/tests/inference/models_predictions_tests/test_owlv2.py +++ b/tests/inference/models_predictions_tests/test_owlv2.py @@ -124,3 +124,25 @@ def test_owlv2(): response = OwlV2().infer_from_request(request) assert len(response.predictions) == 5 + + # test that it can handle a bad prompt + request = OwlV2InferenceRequest( + image=image, + training_data=[ + { + "image": image, + "boxes": [ + {"x": 1, "y": 1, "w": 1, "h": 1, "cls": "post", "negative": False} + ], + } + ], + visualize_predictions=True, + confidence=0.9, + ) + + response = OwlV2().infer_from_request(request) + assert len(response.predictions) == 0 + + +if __name__ == "__main__": + test_owlv2() From 953d84b05f0936353ed79c69c9acc7f22416609d Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Fri, 1 Nov 2024 22:27:44 +0000 Subject: [PATCH 2/4] additional integration tests --- .../models_predictions_tests/test_owlv2.py | 97 ++++++++++++++++++- 1 file changed, 95 insertions(+), 2 deletions(-) diff --git a/tests/inference/models_predictions_tests/test_owlv2.py b/tests/inference/models_predictions_tests/test_owlv2.py index d57651b780..f68b698a3b 100644 --- a/tests/inference/models_predictions_tests/test_owlv2.py +++ b/tests/inference/models_predictions_tests/test_owlv2.py @@ -1,7 +1,10 @@ +import pytest + from inference.core.entities.requests.owlv2 import OwlV2InferenceRequest from inference.models.owlv2.owlv2 import OwlV2 +@pytest.mark.slow def test_owlv2(): image = { "type": "url", @@ -49,6 +52,14 @@ def test_owlv2(): assert abs(532 - posts[3].x) < 1.5 assert abs(572 - posts[4].x) < 1.5 + +@pytest.mark.slow +def test_owlv2_multiple_prompts(): + image = { + "type": "url", + "value": "https://media.roboflow.com/inference/seawithdock.jpeg", + } + # test we can handle multiple (positive and negative) prompts for the same image request = OwlV2InferenceRequest( image=image, @@ -96,7 +107,15 @@ def test_owlv2(): assert abs(532 - posts[2].x) < 1.5 assert abs(572 - posts[3].x) < 1.5 - # test that we can handle no prompts for an image + +@pytest.mark.slow +def test_owlv2_image_without_prompts(): + image = { + "type": "url", + "value": "https://media.roboflow.com/inference/seawithdock.jpeg", + } + + # test that we can handle an image without any prompts request = OwlV2InferenceRequest( image=image, training_data=[ @@ -125,7 +144,15 @@ def test_owlv2(): response = OwlV2().infer_from_request(request) assert len(response.predictions) == 5 - # test that it can handle a bad prompt + +@pytest.mark.slow +def test_owlv2_bad_prompt(): + image = { + "type": "url", + "value": "https://media.roboflow.com/inference/seawithdock.jpeg", + } + + # test that we can handle a bad prompt request = OwlV2InferenceRequest( image=image, training_data=[ @@ -144,5 +171,71 @@ def test_owlv2(): assert len(response.predictions) == 0 +@pytest.mark.slow +def test_owlv2_no_training_data(): + image = { + "type": "url", + "value": "https://media.roboflow.com/inference/seawithdock.jpeg", + } + + # test that we can handle no training data + request = OwlV2InferenceRequest( + image=image, + training_data=[], + ) + + response = OwlV2().infer_from_request(request) + assert len(response.predictions) == 0 + + +@pytest.mark.slow +def test_owlv2_multiple_training_images(): + image = { + "type": "url", + "value": "https://media.roboflow.com/inference/seawithdock.jpeg", + } + second_image = { + "type": "url", + "value": "https://media.roboflow.com/inference/dock2.jpg", + } + + request = OwlV2InferenceRequest( + image=image, + training_data=[ + { + "image": image, + "boxes": [ + { + "x": 223, + "y": 306, + "w": 40, + "h": 226, + "cls": "post", + "negative": False, + } + ], + }, + { + "image": second_image, + "boxes": [ + { + "x": 3009, + "y": 1873, + "w": 289, + "h": 811, + "cls": "post", + "negative": True, + } + ], + }, + ], + visualize_predictions=True, + confidence=0.9, + ) + + response = OwlV2().infer_from_request(request) + assert len(response.predictions) == 5 + + if __name__ == "__main__": test_owlv2() From 2714e351c147fd842e2653e8a3b160f293143cb6 Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Fri, 1 Nov 2024 22:38:08 +0000 Subject: [PATCH 3/4] no need for such a big refactor --- inference/models/owlv2/owlv2.py | 58 ++++++++----------- .../models_predictions_tests/test_owlv2.py | 41 +++++++++++++ 2 files changed, 66 insertions(+), 33 deletions(-) diff --git a/inference/models/owlv2/owlv2.py b/inference/models/owlv2/owlv2.py index 56a4b753bb..8ff5434531 100644 --- a/inference/models/owlv2/owlv2.py +++ b/inference/models/owlv2/owlv2.py @@ -313,10 +313,9 @@ def embed_image(self, image: np.ndarray) -> Hash: def get_query_embedding( self, query_spec: QuerySpecType, iou_threshold: float - ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + ) -> torch.Tensor: # NOTE: for now we're handling each image seperately query_embeds = [] - has_overlap = [] for image_hash, query_boxes in query_spec.items(): try: _objectness, image_boxes, image_class_embeds, _, _ = ( @@ -330,19 +329,22 @@ def get_query_embedding( ) if image_boxes.numel() == 0 or query_boxes_tensor.numel() == 0: continue - iou, _ = box_iou(to_corners(image_boxes), to_corners(query_boxes_tensor)) + iou, _ = box_iou( + to_corners(image_boxes), to_corners(query_boxes_tensor) + ) # 3000, k ious, indices = torch.max(iou, dim=0) - - embeds = image_class_embeds[indices] - + # filter for only iou > 0.4 iou_mask = ious > iou_threshold + indices = indices[iou_mask] + if not indices.numel() > 0: + continue - # we don't filter by the mask here so as to maintain parallel structure - # with the metadata in the external calling function + embeds = image_class_embeds[indices] query_embeds.append(embeds) - has_overlap.append(iou_mask) - - return query_embeds, has_overlap + if not query_embeds: + return None + query = torch.cat(query_embeds, dim=0) + return query def infer_from_embed( self, @@ -369,7 +371,7 @@ def infer_from_embed( all_predicted_classes.append(classes) all_predicted_scores.append(scores) - if len(all_predicted_boxes) == 0: + if not all_predicted_boxes: return [] all_predicted_boxes = torch.cat(all_predicted_boxes, dim=0) @@ -441,16 +443,12 @@ def make_class_embeddings_dict( bool_to_literal = {True: "positive", False: "negative"} for train_image in training_data: - boxes = train_image["boxes"] - if len(boxes) == 0: - # no prompts for this image, so we skip it - continue - # grab and embed image image = load_image_rgb(train_image["image"]) image_hash = self.embed_image(image) - # normalize box prompts + # grab and normalize box prompts for this image + boxes = train_image["boxes"] coords = [[box["x"], box["y"], box["w"], box["h"]] for box in boxes] coords = [ tuple([c / max(image.shape[:2]) for c in coord]) for coord in coords @@ -461,24 +459,18 @@ def make_class_embeddings_dict( # compute the embeddings for the box prompts query_spec = {image_hash: coords} # NOTE: because we just computed the embedding for this image, this should never result in a KeyError - batched_embeddings, batched_has_overlap = self.get_query_embedding( - query_spec, iou_threshold - ) - # get_query_embedding is designed to handle multiple images - # so we take the first (and only) element - embeddings = batched_embeddings[0] - has_overlap = batched_has_overlap[0] + embeddings = self.get_query_embedding(query_spec, iou_threshold) + + if embeddings is None: + continue # add the embeddings to their appropriate class and positive/negative list - for embedding, class_name, is_positive, fits_the_prompt in zip( - embeddings, classes, is_positive, has_overlap + for embedding, class_name, is_positive in zip( + embeddings, classes, is_positive ): - # we checked if the found box sufficiently overlaps with the prompt - # so we skip adding it to the class embeddings dict if it doesn't - if fits_the_prompt: - class_embeddings_dict[class_name][ - bool_to_literal[is_positive] - ].append(embedding) + class_embeddings_dict[class_name][bool_to_literal[is_positive]].append( + embedding + ) # convert the lists of embeddings to tensors diff --git a/tests/inference/models_predictions_tests/test_owlv2.py b/tests/inference/models_predictions_tests/test_owlv2.py index f68b698a3b..53cbf7639f 100644 --- a/tests/inference/models_predictions_tests/test_owlv2.py +++ b/tests/inference/models_predictions_tests/test_owlv2.py @@ -171,6 +171,47 @@ def test_owlv2_bad_prompt(): assert len(response.predictions) == 0 +@pytest.mark.slow +def test_owlv2_bad_prompt_hidden_among_good_prompts(): + image = { + "type": "url", + "value": "https://media.roboflow.com/inference/seawithdock.jpeg", + } + + # test that we can handle a bad prompt + request = OwlV2InferenceRequest( + image=image, + training_data=[ + { + "image": image, + "boxes": [ + { + "x": 1, + "y": 1, + "w": 1, + "h": 1, + "cls": "post", + "negative": False, + }, + { + "x": 223, + "y": 306, + "w": 40, + "h": 226, + "cls": "post", + "negative": False, + }, + ], + } + ], + visualize_predictions=True, + confidence=0.9, + ) + + response = OwlV2().infer_from_request(request) + assert len(response.predictions) == 5 + + @pytest.mark.slow def test_owlv2_no_training_data(): image = { From be40f2da88748b6b8c4ad2ee158c98badb75a931 Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Fri, 1 Nov 2024 22:39:40 +0000 Subject: [PATCH 4/4] style change --- tests/inference/models_predictions_tests/test_owlv2.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/inference/models_predictions_tests/test_owlv2.py b/tests/inference/models_predictions_tests/test_owlv2.py index 53cbf7639f..33e8b7f443 100644 --- a/tests/inference/models_predictions_tests/test_owlv2.py +++ b/tests/inference/models_predictions_tests/test_owlv2.py @@ -159,7 +159,14 @@ def test_owlv2_bad_prompt(): { "image": image, "boxes": [ - {"x": 1, "y": 1, "w": 1, "h": 1, "cls": "post", "negative": False} + { + "x": 1, + "y": 1, + "w": 1, + "h": 1, + "cls": "post", + "negative": False, + } ], } ],