Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds image-guided object detection support to OWL-ViT #20136

Merged
merged 11 commits into from
Nov 16, 2022
3 changes: 3 additions & 0 deletions docs/source/en/model_doc/owlvit.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ This model was contributed by [adirik](https://huggingface.co/adirik). The origi

[[autodoc]] OwlViTFeatureExtractor
- __call__
- post_process
- post_process_image_guided_detection

## OwlViTProcessor

Expand All @@ -106,3 +108,4 @@ This model was contributed by [adirik](https://huggingface.co/adirik). The origi

[[autodoc]] OwlViTForObjectDetection
- forward
- image_guided_detection
147 changes: 132 additions & 15 deletions src/transformers/models/owlvit/feature_extraction_owlvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,56 @@
logger = logging.get_logger(__name__)


# Copied from transformers.models.detr.feature_extraction_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(left, top, right, bottom).
(x_0, y_0, x_1, y_1).
"""
x_center, y_center, width, height = x.unbind(-1)
boxes = [(x_center - 0.5 * width), (y_center - 0.5 * height), (x_center + 0.5 * width), (y_center + 0.5 * height)]
return torch.stack(boxes, dim=-1)
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)


# Copied from transformers.models.detr.modeling_detr._upcast
def _upcast(t):
alaradirik marked this conversation as resolved.
Show resolved Hide resolved
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.is_floating_point():
return t if t.dtype in (torch.float32, torch.float64) else t.float()
else:
return t if t.dtype in (torch.int32, torch.int64) else t.int()


def box_area(boxes):
"""
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.

Args:
boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
< x2` and `0 <= y1 < y2`.

Returns:
`torch.FloatTensor`: a tensor containing the area for each box.
"""
boxes = _upcast(boxes)
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])


def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)

left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]

width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]

union = area1[:, None] + area2 - inter

iou = inter / union
return iou, union


class OwlViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
Expand All @@ -56,10 +98,11 @@ class OwlViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin
The size to use for resizing the image. Only has an effect if `do_resize` is set to `True`. If `size` is a
sequence like (h, w), output size will be matched to this. If `size` is an int, then image will be resized
to (size, size).
resample (`int`, *optional*, defaults to `PILImageResampling.BICUBIC`):
An optional resampling filter. This can be one of `PILImageResampling.NEAREST`, `PILImageResampling.BOX`,
`PILImageResampling.BILINEAR`, `PILImageResampling.HAMMING`, `PILImageResampling.BICUBIC` or
`PILImageResampling.LANCZOS`. Only has an effect if `do_resize` is set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_center_crop (`bool`, *optional*, defaults to `False`):
Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
image is padded with 0's and then center cropped.
Expand Down Expand Up @@ -111,10 +154,11 @@ def post_process(self, outputs, target_sizes):
Args:
outputs ([`OwlViTObjectDetectionOutput`]):
Raw outputs of the model.
target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original
image size (before any data augmentation). For visualization, this should be the image size after data
augment, but before padding.
target_sizes (`torch.Tensor`, *optional*):
Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in
the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to
None, predictions will not be unnormalized.

Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model.
Expand Down Expand Up @@ -142,6 +186,82 @@ def post_process(self, outputs, target_sizes):

return results

def post_process_image_guided_detection(self, outputs, threshold=0.6, nms_threshold=0.3, target_sizes=None):
"""
Converts the output of [`OwlViTForObjectDetection.image_guided_detection`] into the format expected by the COCO
api.

Args:
outputs ([`OwlViTImageGuidedObjectDetectionOutput`]):
Raw outputs of the model.
threshold (`float`, *optional*, defaults to 0.6):
Minimum confidence threshold to use to filter out predicted boxes.
nms_threshold (`float`, *optional*, defaults to 0.3):
IoU threshold for non-maximum suppression of overlapping boxes.
target_sizes (`torch.Tensor`, *optional*):
Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in
the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to
None, predictions will not be unnormalized.

Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model. All labels are set to None as
`OwlViTForObjectDetection.image_guided_detection` perform one-shot object detection.
"""
logits, target_boxes = outputs.logits, outputs.target_pred_boxes

if len(logits) != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
if target_sizes.shape[1] != 2:
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")

probs = torch.max(logits, dim=-1)
scores = torch.sigmoid(probs.values)

# Convert to [x0, y0, x1, y1] format
target_boxes = center_to_corners_format(target_boxes)

# Apply non-maximum suppression (NMS)
if nms_threshold < 1.0:
for idx in range(target_boxes.shape[0]):
for i in torch.argsort(-scores[idx]):
if not scores[idx][i]:
continue

ious = box_iou(target_boxes[idx][i, :].unsqueeze(0), target_boxes[idx])[0][0]
ious[i] = -1.0 # Mask self-IoU.
scores[idx][ious > nms_threshold] = 0.0

# Convert from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
target_boxes = target_boxes * scale_fct[:, None, :]

# Compute box display alphas based on prediction scores
results = []
alphas = torch.zeros_like(scores)

for idx in range(target_boxes.shape[0]):
# Select scores for boxes matching the current query:
query_scores = scores[idx]
if not query_scores.nonzero().numel():
continue

# Scale box alpha such that the best box for each query has alpha 1.0 and the worst box has alpha 0.1.
# All other boxes will either belong to a different query, or will not be shown.
max_score = torch.max(query_scores) + 1e-6
query_alphas = (query_scores - (max_score * 0.1)) / (max_score * 0.9)
query_alphas[query_alphas < threshold] = 0.0
query_alphas = torch.clip(query_alphas, 0.0, 1.0)
alphas[idx] = query_alphas

mask = alphas[idx] > 0
box_scores = alphas[idx][mask]
boxes = target_boxes[idx][mask]
results.append({"scores": box_scores, "labels": None, "boxes": boxes})

return results

def __call__(
self,
images: Union[
Expand All @@ -165,18 +285,15 @@ def __call__(
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W) or (H, W, C),
where C is a number of channels, H and W are image height and width.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason those spaces are removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they are removed by make style but it seemed like there were extra blank lines to begin with.

return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:

- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.

Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:

- **pixel_values** -- Pixel values to be fed to a model.
"""
# Input type checking for clearer error
Expand Down
Loading