Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Owl v2 speedup #759

Merged
merged 8 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions development/benchmark_scripts/benchmark_owlv2_inference_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import requests
import tempfile
from PIL import Image
import time

from inference.models.owlv2.owlv2 import OwlV2

# run a simple latency test
image_via_url = {
"type": "url",
"value": "https://media.roboflow.com/inference/seawithdock.jpeg",
}

# Download the image
response = requests.get(image_via_url["value"])
response.raise_for_status()

# Create a temporary file
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file:
temp_file.write(response.content)
temp_file_path = temp_file.name

img = Image.open(temp_file_path)
img = img.convert("RGB")

request_dict = dict(
image=img,
training_data=[
{
"image": img,
"boxes": [{"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post"}],
}
],
visualize_predictions=False,
)

model = OwlV2()

for _ in range(10):
print("pre cache fill try")
time_start = time.time()
response = model.infer(**request_dict)
time_end = time.time()
print(f"Time taken: {time_end - time_start} seconds")

print("post cache fill try")
time_start = time.time()
response = model.infer(**request_dict)
time_end = time.time()
print(f"Time taken: {time_end - time_start} seconds")

model.reset_cache()
109 changes: 91 additions & 18 deletions inference/models/owlv2/owlv2.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import hashlib
import os
from collections import defaultdict
from typing import Dict, List, NewType
from typing import Dict, List, NewType, Tuple

import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from PIL import Image
from transformers import Owlv2ForObjectDetection, Owlv2Processor
from transformers.models.owlv2.modeling_owlv2 import box_iou

Expand Down Expand Up @@ -56,15 +56,79 @@ def _check_size_limit(self):
self.popitem(last=False)


def preprocess_image(
np_image: np.ndarray,
image_size: Tuple[int, int],
image_mean: torch.Tensor,
image_std: torch.Tensor,
) -> torch.Tensor:
"""Preprocess an image for OWLv2 by resizing, normalizing, and padding it.
This is much faster than using the Owlv2Processor directly, as we ensure we use GPU if available.

Args:
np_image (np.ndarray): The image to preprocess, with shape (H, W, 3)
image_size (tuple[int, int]): The target size of the image
image_mean (torch.Tensor): The mean of the image, on DEVICE, with shape (1, 3, 1, 1)
image_std (torch.Tensor): The standard deviation of the image, on DEVICE, with shape (1, 3, 1, 1)

Returns:
torch.Tensor: The preprocessed image, on DEVICE, with shape (1, 3, H, W)
"""
current_size = np_image.shape[:2]

r = min(image_size[0] / current_size[0], image_size[1] / current_size[1])
target_size = (int(r * current_size[0]), int(r * current_size[1]))

torch_image = (
torch.tensor(np_image)
.permute(2, 0, 1)
.unsqueeze(0)
.to(DEVICE)
.to(dtype=torch.float32)
/ 255.0
)
torch_image = F.interpolate(
torch_image, size=target_size, mode="bilinear", align_corners=False
)

padded_image_tensor = torch.ones((1, 3, *image_size), device=DEVICE) * 0.5
padded_image_tensor[:, :, : torch_image.shape[2], : torch_image.shape[3]] = (
torch_image
)

padded_image_tensor = (padded_image_tensor - image_mean) / image_std

return padded_image_tensor


class OwlV2(RoboflowCoreModel):
task_type = "object-detection"
box_format = "xywh"

def __init__(self, *args, model_id="owlv2/owlv2-base-patch16-ensemble", **kwargs):
super().__init__(*args, model_id=model_id, **kwargs)
hf_id = os.path.join("google", self.version_id)
self.processor = Owlv2Processor.from_pretrained(hf_id)
processor = Owlv2Processor.from_pretrained(hf_id)
self.image_size = tuple(processor.image_processor.size.values())
self.image_mean = torch.tensor(
processor.image_processor.image_mean, device=DEVICE
).view(1, 3, 1, 1)
self.image_std = torch.tensor(
processor.image_processor.image_std, device=DEVICE
).view(1, 3, 1, 1)
self.model = Owlv2ForObjectDetection.from_pretrained(hf_id).eval().to(DEVICE)
self.reset_cache()

# compile forward pass of the visual backbone of the model
# NOTE that this is able to fix the manual attention implementation used in OWLv2
# so we don't have to force in flash attention by ourselves
# however that is only true if torch version 2.4 or later is used
# for torch < 2.4, this is a LOT slower and using flash attention by ourselves is faster
# this also breaks in torch < 2.1 so we supress torch._dynamo errors
torch._dynamo.config.suppress_errors = True
self.model.owlv2.vision_model = torch.compile(self.model.owlv2.vision_model)

def reset_cache(self):
self.image_embed_cache = LimitedSizeDict(
size_limit=50
) # NOTE: this should have a max size
Expand Down Expand Up @@ -100,22 +164,31 @@ def download_weights(self) -> None:
pass

@torch.no_grad()
def embed_image(self, image: Image.Image) -> Hash:
image_hash = hashlib.sha256(np.array(image).tobytes()).hexdigest()
def embed_image(self, image: np.ndarray) -> Hash:
image_hash = hashlib.sha256(image.tobytes()).hexdigest()

if (image_embeds := self.image_embed_cache.get(image_hash)) is not None:
return image_hash

pixel_values = self.processor(
images=image, return_tensors="pt"
).pixel_values.to(DEVICE)
image_embeds, _ = self.model.image_embedder(pixel_values=pixel_values)
batch_size, h, w, dim = image_embeds.shape
image_features = image_embeds.reshape(batch_size, h * w, dim)
objectness = self.model.objectness_predictor(image_features)
boxes = self.model.box_predictor(image_features, feature_map=image_embeds)
pixel_values = preprocess_image(
image, self.image_size, self.image_mean, self.image_std
)

# torch 2.4 lets you use "cuda:0" as device_type
# but this crashes in 2.3
# so we parse DEVICE as a string to make it work in both 2.3 and 2.4
# as we don't know a priori our torch version
device_str = "cuda" if str(DEVICE).startswith("cuda") else "cpu"
# we disable autocast on CPU for stability, although it's possible using bfloat16 would work
with torch.autocast(
device_type=device_str, dtype=torch.float16, enabled=device_str == "cuda"
):
image_embeds, _ = self.model.image_embedder(pixel_values=pixel_values)
batch_size, h, w, dim = image_embeds.shape
image_features = image_embeds.reshape(batch_size, h * w, dim)
objectness = self.model.objectness_predictor(image_features)
boxes = self.model.box_predictor(image_features, feature_map=image_embeds)

# class_embeddings = model.class_predictor(image_features)[1]
image_class_embeds = self.model.class_head.dense0(image_features)
image_class_embeds /= (
torch.linalg.norm(image_class_embeds, ord=2, dim=-1, keepdim=True) + 1e-6
Expand Down Expand Up @@ -149,7 +222,7 @@ def get_query_embedding(self, query_spec: Dict[Hash, List[List[int]]]):
raise KeyError("We didn't embed the image first!") from error

query_boxes_tensor = torch.tensor(
query_boxes, dtype=torch.float, device=image_boxes.device
query_boxes, dtype=image_boxes.dtype, device=image_boxes.device
)
iou, union = box_iou(
to_corners(image_boxes), to_corners(query_boxes_tensor)
Expand Down Expand Up @@ -192,9 +265,9 @@ def infer_from_embed(self, image_hash: Hash, query_embeddings, confidence):
class_ind = class_map[class_name]
predicted_classes.append(class_ind * torch.ones_like(scores))

all_boxes = torch.cat(predicted_boxes, dim=0)
all_classes = torch.cat(predicted_classes, dim=0)
all_scores = torch.cat(predicted_scores, dim=0)
all_boxes = torch.cat(predicted_boxes, dim=0).float()
all_classes = torch.cat(predicted_classes, dim=0).float()
all_scores = torch.cat(predicted_scores, dim=0).float()
survival_indices = torchvision.ops.nms(to_corners(all_boxes), all_scores, 0.3)
pred_boxes = all_boxes[survival_indices].detach().cpu().numpy()
pred_classes = all_classes[survival_indices].detach().cpu().numpy()
Expand Down
7 changes: 5 additions & 2 deletions tests/inference/models_predictions_tests/test_owlv2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from inference.core.entities.requests.owlv2 import OwlV2InferenceRequest
from inference.core.entities.responses.inference import ObjectDetectionInferenceResponse
from inference.models.owlv2.owlv2 import OwlV2


Expand All @@ -20,4 +19,8 @@ def test_owlv2():
)

response = OwlV2().infer_from_request(request)
assert abs(221.4 - response.predictions[0].x) < 0.1
# the exact value here is highly sensitive to the image interpolation mode used
# as well as the data type used in the model, ie bfloat16 vs float16 vs float32
# and of course the size of the model itself, ie base vs large
# we set a tolerance of 1.5 pixels from the expected value, which should cover most of the cases
assert abs(223 - response.predictions[0].x) < 1.5