Skip to content

Commit 218f441

Browse files
jla524amyeroberts
andauthored
Fix image post-processing for OWLv2 (#30686)
* feat: add note about owlv2 * fix: post processing coordinates * remove: workaround document * fix: extra quotes * update: owlv2 docstrings * fix: copies check * feat: add unit test for resize * Update tests/models/owlv2/test_image_processor_owlv2.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
1 parent df53c6e commit 218f441

File tree

3 files changed

+53
-50
lines changed

3 files changed

+53
-50
lines changed

src/transformers/models/owlv2/image_processing_owlv2.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,6 @@ def preprocess(
481481
data = {"pixel_values": images}
482482
return BatchFeature(data=data, tensor_type=return_tensors)
483483

484-
# Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process_object_detection
485484
def post_process_object_detection(
486485
self, outputs, threshold: float = 0.1, target_sizes: Union[TensorType, List[Tuple]] = None
487486
):
@@ -525,6 +524,18 @@ def post_process_object_detection(
525524
else:
526525
img_h, img_w = target_sizes.unbind(1)
527526

527+
# rescale coordinates
528+
width_ratio = 1
529+
height_ratio = 1
530+
531+
if img_w < img_h:
532+
width_ratio = img_w / img_h
533+
elif img_h < img_w:
534+
height_ratio = img_h / img_w
535+
536+
img_w = img_w / width_ratio
537+
img_h = img_h / height_ratio
538+
528539
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
529540
boxes = boxes * scale_fct[:, None, :]
530541

src/transformers/models/owlv2/modeling_owlv2.py

+17-47
Original file line numberDiff line numberDiff line change
@@ -1540,9 +1540,7 @@ def image_guided_detection(
15401540
>>> import requests
15411541
>>> from PIL import Image
15421542
>>> import torch
1543-
>>> import numpy as np
15441543
>>> from transformers import AutoProcessor, Owlv2ForObjectDetection
1545-
>>> from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
15461544
15471545
>>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
15481546
>>> model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
@@ -1557,20 +1555,7 @@ def image_guided_detection(
15571555
>>> with torch.no_grad():
15581556
... outputs = model.image_guided_detection(**inputs)
15591557
1560-
>>> # Note: boxes need to be visualized on the padded, unnormalized image
1561-
>>> # hence we'll set the target image sizes (height, width) based on that
1562-
1563-
>>> def get_preprocessed_image(pixel_values):
1564-
... pixel_values = pixel_values.squeeze().numpy()
1565-
... unnormalized_image = (pixel_values * np.array(OPENAI_CLIP_STD)[:, None, None]) + np.array(OPENAI_CLIP_MEAN)[:, None, None]
1566-
... unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
1567-
... unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
1568-
... unnormalized_image = Image.fromarray(unnormalized_image)
1569-
... return unnormalized_image
1570-
1571-
>>> unnormalized_image = get_preprocessed_image(inputs.pixel_values)
1572-
1573-
>>> target_sizes = torch.Tensor([unnormalized_image.size[::-1]])
1558+
>>> target_sizes = torch.Tensor([image.size[::-1]])
15741559
15751560
>>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
15761561
>>> results = processor.post_process_image_guided_detection(
@@ -1581,19 +1566,19 @@ def image_guided_detection(
15811566
>>> for box, score in zip(boxes, scores):
15821567
... box = [round(i, 2) for i in box.tolist()]
15831568
... print(f"Detected similar object with confidence {round(score.item(), 3)} at location {box}")
1584-
Detected similar object with confidence 0.938 at location [490.96, 109.89, 821.09, 536.11]
1585-
Detected similar object with confidence 0.959 at location [8.67, 721.29, 928.68, 732.78]
1586-
Detected similar object with confidence 0.902 at location [4.27, 720.02, 941.45, 761.59]
1587-
Detected similar object with confidence 0.985 at location [265.46, -58.9, 1009.04, 365.66]
1588-
Detected similar object with confidence 1.0 at location [9.79, 28.69, 937.31, 941.64]
1589-
Detected similar object with confidence 0.998 at location [869.97, 58.28, 923.23, 978.1]
1590-
Detected similar object with confidence 0.985 at location [309.23, 21.07, 371.61, 932.02]
1591-
Detected similar object with confidence 0.947 at location [27.93, 859.45, 969.75, 915.44]
1592-
Detected similar object with confidence 0.996 at location [785.82, 41.38, 880.26, 966.37]
1593-
Detected similar object with confidence 0.998 at location [5.08, 721.17, 925.93, 998.41]
1594-
Detected similar object with confidence 0.969 at location [6.7, 898.1, 921.75, 949.51]
1595-
Detected similar object with confidence 0.966 at location [47.16, 927.29, 981.99, 942.14]
1596-
Detected similar object with confidence 0.924 at location [46.4, 936.13, 953.02, 950.78]
1569+
Detected similar object with confidence 0.938 at location [327.31, 54.94, 547.39, 268.06]
1570+
Detected similar object with confidence 0.959 at location [5.78, 360.65, 619.12, 366.39]
1571+
Detected similar object with confidence 0.902 at location [2.85, 360.01, 627.63, 380.8]
1572+
Detected similar object with confidence 0.985 at location [176.98, -29.45, 672.69, 182.83]
1573+
Detected similar object with confidence 1.0 at location [6.53, 14.35, 624.87, 470.82]
1574+
Detected similar object with confidence 0.998 at location [579.98, 29.14, 615.49, 489.05]
1575+
Detected similar object with confidence 0.985 at location [206.15, 10.53, 247.74, 466.01]
1576+
Detected similar object with confidence 0.947 at location [18.62, 429.72, 646.5, 457.72]
1577+
Detected similar object with confidence 0.996 at location [523.88, 20.69, 586.84, 483.18]
1578+
Detected similar object with confidence 0.998 at location [3.39, 360.59, 617.29, 499.21]
1579+
Detected similar object with confidence 0.969 at location [4.47, 449.05, 614.5, 474.76]
1580+
Detected similar object with confidence 0.966 at location [31.44, 463.65, 654.66, 471.07]
1581+
Detected similar object with confidence 0.924 at location [30.93, 468.07, 635.35, 475.39]
15971582
```"""
15981583
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
15991584
output_hidden_states = (
@@ -1665,10 +1650,8 @@ def forward(
16651650
```python
16661651
>>> import requests
16671652
>>> from PIL import Image
1668-
>>> import numpy as np
16691653
>>> import torch
16701654
>>> from transformers import AutoProcessor, Owlv2ForObjectDetection
1671-
>>> from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
16721655
16731656
>>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
16741657
>>> model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
@@ -1682,20 +1665,7 @@ def forward(
16821665
>>> with torch.no_grad():
16831666
... outputs = model(**inputs)
16841667
1685-
>>> # Note: boxes need to be visualized on the padded, unnormalized image
1686-
>>> # hence we'll set the target image sizes (height, width) based on that
1687-
1688-
>>> def get_preprocessed_image(pixel_values):
1689-
... pixel_values = pixel_values.squeeze().numpy()
1690-
... unnormalized_image = (pixel_values * np.array(OPENAI_CLIP_STD)[:, None, None]) + np.array(OPENAI_CLIP_MEAN)[:, None, None]
1691-
... unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
1692-
... unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
1693-
... unnormalized_image = Image.fromarray(unnormalized_image)
1694-
... return unnormalized_image
1695-
1696-
>>> unnormalized_image = get_preprocessed_image(inputs.pixel_values)
1697-
1698-
>>> target_sizes = torch.Tensor([unnormalized_image.size[::-1]])
1668+
>>> target_sizes = torch.Tensor([image.size[::-1]])
16991669
>>> # Convert outputs (bounding boxes and class logits) to final bounding boxes and scores
17001670
>>> results = processor.post_process_object_detection(
17011671
... outputs=outputs, threshold=0.2, target_sizes=target_sizes
@@ -1708,8 +1678,8 @@ def forward(
17081678
>>> for box, score, label in zip(boxes, scores, labels):
17091679
... box = [round(i, 2) for i in box.tolist()]
17101680
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
1711-
Detected a photo of a cat with confidence 0.614 at location [512.5, 35.08, 963.48, 557.02]
1712-
Detected a photo of a cat with confidence 0.665 at location [10.13, 77.94, 489.93, 709.69]
1681+
Detected a photo of a cat with confidence 0.614 at location [341.67, 23.39, 642.32, 371.35]
1682+
Detected a photo of a cat with confidence 0.665 at location [6.75, 51.96, 326.62, 473.13]
17131683
```"""
17141684
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
17151685
output_hidden_states = (

tests/models/owlv2/test_image_processor_owlv2.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,18 @@
1717
import unittest
1818

1919
from transformers.testing_utils import require_torch, require_vision, slow
20-
from transformers.utils import is_vision_available
20+
from transformers.utils import is_torch_available, is_vision_available
2121

2222
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
2323

2424

2525
if is_vision_available():
2626
from PIL import Image
2727

28-
from transformers import Owlv2ImageProcessor
28+
from transformers import AutoProcessor, Owlv2ForObjectDetection, Owlv2ImageProcessor
29+
30+
if is_torch_available():
31+
import torch
2932

3033

3134
class Owlv2ImageProcessingTester(unittest.TestCase):
@@ -120,6 +123,25 @@ def test_image_processor_integration_test(self):
120123
mean_value = round(pixel_values.mean().item(), 4)
121124
self.assertEqual(mean_value, 0.2353)
122125

126+
@slow
127+
def test_image_processor_integration_test_resize(self):
128+
checkpoint = "google/owlv2-base-patch16-ensemble"
129+
processor = AutoProcessor.from_pretrained(checkpoint)
130+
model = Owlv2ForObjectDetection.from_pretrained(checkpoint)
131+
132+
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
133+
inputs = processor(text=["cat"], images=image, return_tensors="pt")
134+
135+
with torch.no_grad():
136+
outputs = model(**inputs)
137+
138+
target_sizes = torch.tensor([image.size[::-1]])
139+
results = processor.post_process_object_detection(outputs, threshold=0.2, target_sizes=target_sizes)[0]
140+
141+
boxes = results["boxes"].tolist()
142+
self.assertEqual(boxes[0], [341.66656494140625, 23.38756561279297, 642.321044921875, 371.3482971191406])
143+
self.assertEqual(boxes[1], [6.753320693969727, 51.96149826049805, 326.61810302734375, 473.12982177734375])
144+
123145
@unittest.skip("OWLv2 doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
124146
def test_call_numpy_4_channels(self):
125147
pass

0 commit comments

Comments
 (0)