Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds image-guided object detection support to OWL-ViT #18891

Closed
wants to merge 49 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
423ece2
init commit
unography Sep 4, 2022
665ae95
fix review comments
unography Sep 6, 2022
ef10934
add main_input_name for OwlViTForObjectDetection
unography Sep 6, 2022
5906273
indexing fix
unography Sep 7, 2022
3795987
modeling tests
unography Sep 7, 2022
37dedeb
Merge branch 'main' of https://github.com/unography/transformers into…
unography Sep 7, 2022
8d96c8e
Update tests/models/owlvit/test_modeling_owlvit.py
unography Sep 7, 2022
3f772dd
choose query embeddings based on iou threshold
unography Sep 7, 2022
23c09ff
remove img guided tests for now
unography Sep 8, 2022
780ed99
return_projected remove
unography Sep 8, 2022
efdca2e
objdet trace fix input passing
unography Sep 8, 2022
bb61e30
style fixes
unography Sep 8, 2022
e31bc98
init commit
unography Sep 4, 2022
0f1e7ef
fix review comments
unography Sep 6, 2022
27a291e
add main_input_name for OwlViTForObjectDetection
unography Sep 6, 2022
55faa87
indexing fix
unography Sep 7, 2022
31c60f2
modeling tests
unography Sep 7, 2022
7669c06
Update tests/models/owlvit/test_modeling_owlvit.py
unography Sep 7, 2022
b442379
choose query embeddings based on iou threshold
unography Sep 7, 2022
380ec76
remove img guided tests for now
unography Sep 8, 2022
7f8a8d5
return_projected remove
unography Sep 8, 2022
409a97b
objdet trace fix input passing
unography Sep 8, 2022
b1d076e
style fixes
unography Sep 8, 2022
2483e8b
Update src/transformers/models/owlvit/modeling_owlvit.py
unography Sep 25, 2022
29e6e2a
Update src/transformers/models/owlvit/processing_owlvit.py
unography Sep 25, 2022
6a943f3
iou_threshold as param, updated docstrings
unography Sep 25, 2022
816f243
types, docstrings
unography Sep 25, 2022
0da80b9
types, docstrings
unography Sep 25, 2022
00e7b79
comment
unography Sep 25, 2022
8127026
var name change, types
unography Sep 25, 2022
3bcb2c2
var name change, types, docstrings
unography Sep 25, 2022
8f3aff0
no need of use_hidden_state anymore
unography Sep 25, 2022
09ebd53
add copied from statements
unography Oct 10, 2022
4822fc0
add copied from statements
unography Oct 10, 2022
8f6c3fd
squeeze fix
unography Oct 10, 2022
688fa99
access only when nonempty
unography Oct 10, 2022
e48506f
rebase to main, add tests
alaradirik Oct 21, 2022
8d67b42
fix inconsistencies, None pred_logits
alaradirik Oct 21, 2022
8834113
undo changes to copied function input signature
alaradirik Oct 21, 2022
03ff0b2
update docstrings
alaradirik Oct 21, 2022
8c74600
add OwlViTForImageGuidedObjectDetection
alaradirik Nov 7, 2022
ca21141
add image_guided_detection method
alaradirik Nov 7, 2022
02500ac
add image-guided detection and its postprocessing method
alaradirik Nov 7, 2022
dfb21e2
address reviews
alaradirik Nov 8, 2022
08804bf
address reviews
alaradirik Nov 8, 2022
5c5fa9f
address reviews
alaradirik Nov 8, 2022
9ebd950
return text and vision outputs
alaradirik Nov 8, 2022
a55f39c
fix merge conflict
alaradirik Nov 9, 2022
60aa449
run make fixup
alaradirik Nov 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 121 additions & 21 deletions src/transformers/models/owlvit/modeling_owlvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,8 @@ def get_image_features(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
return_projected: Optional[bool] = True,
alaradirik marked this conversation as resolved.
Show resolved Hide resolved
normalized: Optional[bool] = None,
return_base_image_embeds: Optional[bool] = None,
unography marked this conversation as resolved.
Show resolved Hide resolved
alaradirik marked this conversation as resolved.
Show resolved Hide resolved
) -> torch.FloatTensor:
r"""
Returns:
Expand Down Expand Up @@ -1004,6 +1006,12 @@ def get_image_features(
image_features = self.visual_projection(pooled_output)
else:
image_features = pooled_output
alaradirik marked this conversation as resolved.
Show resolved Hide resolved
if normalized:
image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
unography marked this conversation as resolved.
Show resolved Hide resolved
if return_base_image_embeds:
last_hidden_state = vision_outputs[0]
image_features = self.vision_model.post_layernorm(last_hidden_state)

return image_features

@add_start_docstrings_to_model_forward(OWLVIT_INPUTS_DOCSTRING)
Expand Down Expand Up @@ -1137,9 +1145,17 @@ def forward(
image_embeds: torch.FloatTensor,
query_embeds: torch.FloatTensor,
query_mask: torch.Tensor,
unography marked this conversation as resolved.
Show resolved Hide resolved
query_image_embeds: torch.FloatTensor,
unography marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[torch.FloatTensor]:

image_class_embeds = self.dense0(image_embeds)
if query_image_embeds is not None:
# Get the most dissimilar embedding
# https://github.com/google-research/scenic/blob/3506cf25e012b17b81179d8abd256b2fd1b81638/scenic/projects/owl_vit/notebooks/inference.py#L158
query_all_embeds = self.dense0(query_image_embeds)
query_mean_embed = torch.mean(query_all_embeds, axis=0)
query_mean_sim = torch.einsum("bd,bid->bi", query_mean_embed, query_all_embeds)
query_embeds = query_all_embeds[:, torch.argmin(query_mean_sim, axis=1)]
unography marked this conversation as resolved.
Show resolved Hide resolved

unography marked this conversation as resolved.
Show resolved Hide resolved
# Normalize image and text features
image_class_embeds /= torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6
Expand Down Expand Up @@ -1243,6 +1259,7 @@ def class_predictor(
image_feats: torch.FloatTensor,
query_embeds: torch.FloatTensor,
query_mask: torch.Tensor,
unography marked this conversation as resolved.
Show resolved Hide resolved
query_image_embeds: torch.FloatTensor,
unography marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[torch.FloatTensor]:
"""
Args:
Expand All @@ -1253,7 +1270,7 @@ def class_predictor(
query_mask:
Must be provided with query_embeddings. A mask indicating which query embeddings are valid.
"""
(pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask)
(pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask, query_image_embeds)
unography marked this conversation as resolved.
Show resolved Hide resolved

return (pred_logits, image_class_embeds)

Expand Down Expand Up @@ -1301,12 +1318,76 @@ def image_text_embedder(

return (text_embeds, image_embeds, text_model_last_hidden_state, vision_model_last_hidden_state)

def image_image_embedder(
self,
query_pixel_values: torch.FloatTensor,
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> Tuple[torch.FloatTensor]:

unography marked this conversation as resolved.
Show resolved Hide resolved
query_image_embeds = self.owlvit.get_image_features(
pixel_values=query_pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_projected=True,
normalized=True,
unography marked this conversation as resolved.
Show resolved Hide resolved
return_base_image_embeds=True,
)
image_embeds = self.owlvit.get_image_features(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_projected=True,
unography marked this conversation as resolved.
Show resolved Hide resolved
normalized=True,
unography marked this conversation as resolved.
Show resolved Hide resolved
return_base_image_embeds=True,
)

# Resize class token
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)

# Merge image embedding with class tokens
image_embeds = image_embeds[:, 1:, :] * class_token_out
image_embeds = self.layer_norm(image_embeds)

# Resize to [batch_size, num_patches, num_patches, hidden_size]
new_size = (
image_embeds.shape[0],
int(np.sqrt(image_embeds.shape[1])),
int(np.sqrt(image_embeds.shape[1])),
image_embeds.shape[-1],
)
image_embeds = image_embeds.reshape(new_size)

# Similar for query

# Resize class token
new_size = tuple(np.array(query_image_embeds.shape) - np.array((0, 1, 0)))
class_token_out = torch.broadcast_to(query_image_embeds[:, :1, :], new_size)

# Merge image embedding with class tokens
query_image_embeds = query_image_embeds[:, 1:, :] * class_token_out
query_image_embeds = self.layer_norm(query_image_embeds)

# Resize to [batch_size, num_patches, num_patches, hidden_size]
new_size = (
query_image_embeds.shape[0],
int(np.sqrt(query_image_embeds.shape[1])),
int(np.sqrt(query_image_embeds.shape[1])),
query_image_embeds.shape[-1],
)
query_image_embeds = query_image_embeds.reshape(new_size)

return (query_image_embeds, image_embeds)

@add_start_docstrings_to_model_forward(OWLVIT_OBJECT_DETECTION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=OwlViTObjectDetectionOutput, config_class=OwlViTConfig)
def forward(
self,
input_ids: torch.Tensor,
pixel_values: torch.FloatTensor,
input_ids: torch.Tensor = None,
query_pixel_values: torch.FloatTensor = None,
unography marked this conversation as resolved.
Show resolved Hide resolved
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
alaradirik marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -1353,35 +1434,54 @@ def forward(
)
return_dict = return_dict if return_dict is not None else self.config.return_dict

# Embed images and text queries
outputs = self.image_text_embedder(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this removed?

Copy link
Contributor

@alaradirik alaradirik Nov 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgugger I removed this because OwlViTObjectDetectionOutput doesn't include text and vision model outputs, but re-adding it and returning model outputs make more sense indeed.

output_hidden_states=output_hidden_states,
)
query_embeds = None
query_mask = None
query_image_embeds = None
if input_ids is None and query_pixel_values is not None:
unography marked this conversation as resolved.
Show resolved Hide resolved
outputs = self.image_image_embedder(
query_pixel_values=query_pixel_values,
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
text_model_last_hidden_state = None
vision_model_last_hidden_state = None
query_feature_map = outputs[0]

# Last hidden states of text and vision transformers
text_model_last_hidden_state = outputs[2]
vision_model_last_hidden_state = outputs[3]
else:
# Embed images and text queries
outputs = self.image_text_embedder(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)

# Last hidden states of text and vision transformers
text_model_last_hidden_state = outputs[2]
vision_model_last_hidden_state = outputs[3]
query_embeds = outputs[0]

query_embeds = outputs[0]
feature_map = outputs[1]

batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))

# Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim]
max_text_queries = input_ids.shape[0] // batch_size
query_embeds = query_embeds.reshape(batch_size, max_text_queries, query_embeds.shape[-1])
if input_ids is not None:
# Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim]
max_text_queries = input_ids.shape[0] // batch_size
query_embeds = query_embeds.reshape(batch_size, max_text_queries, query_embeds.shape[-1])

# If first token is 0, then this is a padded query [batch_size, num_queries].
input_ids = input_ids.reshape(batch_size, max_text_queries, input_ids.shape[-1])
query_mask = input_ids[..., 0] > 0
# If first token is 0, then this is a padded query [batch_size, num_queries].
input_ids = input_ids.reshape(batch_size, max_text_queries, input_ids.shape[-1])
query_mask = input_ids[..., 0] > 0
else:
batch_size, num_patches, num_patches, hidden_dim = query_feature_map.shape
query_image_embeds = torch.reshape(query_feature_map, (batch_size, num_patches * num_patches, hidden_dim))

# Predict object classes [batch_size, num_patches, num_queries+1]
(pred_logits, class_embeds) = self.class_predictor(image_feats, query_embeds, query_mask)
(pred_logits, class_embeds) = self.class_predictor(image_feats, query_embeds, query_mask, query_image_embeds)

# Predict object boxes
pred_boxes = self.box_predictor(image_feats, feature_map)
Expand Down
18 changes: 15 additions & 3 deletions src/transformers/models/owlvit/processing_owlvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class OwlViTProcessor(ProcessorMixin):
def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer)

def __call__(self, text=None, images=None, padding="max_length", return_tensors="np", **kwargs):
def __call__(self, text=None, query_image=None, images=None, padding="max_length", return_tensors="np", **kwargs):
"""
Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and
`kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode:
Expand Down Expand Up @@ -76,8 +76,10 @@ def __call__(self, text=None, images=None, padding="max_length", return_tensors=
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""

if text is None and images is None:
raise ValueError("You have to specify at least one text or image. Both cannot be none.")
if text is None and query_image is None and images is None:
unography marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"You have to specify at least one text or query image or image. All three cannot be none."
)

if text is not None:
if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)):
Expand Down Expand Up @@ -127,15 +129,25 @@ def __call__(self, text=None, images=None, padding="max_length", return_tensors=
encoding = BatchEncoding()
encoding["input_ids"] = input_ids
encoding["attention_mask"] = attention_mask
unography marked this conversation as resolved.
Show resolved Hide resolved
elif query_image is not None:
unography marked this conversation as resolved.
Show resolved Hide resolved
encoding = BatchEncoding()
encoding["query_pixel_values"] = self.feature_extractor(
query_image, return_tensors=return_tensors, **kwargs
).pixel_values
unography marked this conversation as resolved.
Show resolved Hide resolved

if images is not None:
image_features = self.feature_extractor(images, return_tensors=return_tensors, **kwargs)

if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
return encoding
elif query_image is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
return encoding
elif text is not None:
return encoding
elif query_image is not None:
return encoding
unography marked this conversation as resolved.
Show resolved Hide resolved
else:
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)

Expand Down