Skip to content

Commit

Permalink
return text and vision outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
alaradirik committed Nov 9, 2022
1 parent 5c5fa9f commit 9ebd950
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 30 deletions.
7 changes: 4 additions & 3 deletions src/transformers/models/owlvit/feature_extraction_owlvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def post_process(self, outputs, target_sizes):
Args:
outputs ([`OwlViTObjectDetectionOutput`]):
Raw outputs of the model.
target_sizes (`torch.Tensor`, *optional*, defaults None):
target_sizes (`torch.Tensor`, *optional*):
Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in
the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to
None, predictions will not be unnormalized.
Expand Down Expand Up @@ -200,7 +200,7 @@ def post_process_image_guided_detection(self, outputs, threshold=0.6, nms_thresh
Minimum confidence threshold to use to filter out predicted boxes.
nms_threshold (`float`, *optional*, defaults to 0.3):
IoU threshold for non-maximum suppression of overlapping boxes.
target_sizes (`torch.Tensor`, *optional*, defaults None):
target_sizes (`torch.Tensor`, *optional*):
Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in
the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to
None, predictions will not be unnormalized.
Expand Down Expand Up @@ -249,7 +249,8 @@ def post_process_image_guided_detection(self, outputs, threshold=0.6, nms_thresh
if not query_scores.nonzero().numel():
continue

# Box alpha is scaled such that the best box for a query has alpha 1.0 and the worst box for which this query is still the top query has alpha 0.1. All other boxes will either belong to a different query, or will not be shown.
# Scale box alpha such that the best box for each query has alpha 1.0 and the worst box has alpha 0.1.
# All other boxes will either belong to a different query, or will not be shown.
max_score = torch.max(query_scores) + 1e-6
query_alphas = (query_scores - (max_score * 0.1)) / (max_score * 0.9)
query_alphas[query_alphas < threshold] = 0.0
Expand Down
55 changes: 28 additions & 27 deletions src/transformers/models/owlvit/modeling_owlvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ class OwlViTObjectDetectionOutput(ModelOutput):
vision_model_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_patches + 1, hidden_size)`)):
Last hidden states extracted from the [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image
patches where the total number of patches is (image_size / patch_size)**2.
text_model_output (Tuple[`BaseModelOutputWithPooling`]):
The output of the [`OwlViTTextModel`].
vision_model_output (`BaseModelOutputWithPooling`):
The output of the [`OwlViTVisionModel`].
"""

loss: Optional[torch.FloatTensor] = None
Expand All @@ -235,6 +239,8 @@ class OwlViTObjectDetectionOutput(ModelOutput):
class_embeds: torch.FloatTensor = None
text_model_last_hidden_state: Optional[torch.FloatTensor] = None
vision_model_last_hidden_state: Optional[torch.FloatTensor] = None
text_model_output: BaseModelOutputWithPooling = None
vision_model_output: BaseModelOutputWithPooling = None


@dataclass
Expand Down Expand Up @@ -1132,16 +1138,6 @@ def get_image_features(

return image_features

def get_vision_output(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
# Get vision model outputs
vision_outputs = self.vision_model(pixel_values=pixel_values)

# Apply post_layernorm to last_hidden_state, return non-projected output
last_hidden_state = vision_outputs[0]
image_features = self.vision_model.post_layernorm(last_hidden_state)

return (image_features, last_hidden_state)

@add_start_docstrings_to_model_forward(OWLVIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=OwlViTOutput, config_class=OwlViTConfig)
def forward(
Expand Down Expand Up @@ -1407,6 +1403,7 @@ def image_text_embedder(
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Tuple[torch.FloatTensor]:

# Encode text and image
Expand All @@ -1417,6 +1414,7 @@ def image_text_embedder(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_base_image_embeds=True,
return_dict=return_dict,
)

# Resize class token
Expand All @@ -1438,19 +1436,19 @@ def image_text_embedder(
image_embeds = image_embeds.reshape(new_size)
text_embeds = outputs[-4]

# Last hidden states from text and vision transformers
text_model_last_hidden_state = outputs[-2][0]
vision_model_last_hidden_state = outputs[-1][0]

return (text_embeds, image_embeds, text_model_last_hidden_state, vision_model_last_hidden_state)
return (text_embeds, image_embeds, outputs)

def image_embedder(
self,
pixel_values: torch.FloatTensor,
output_hidden_states: Optional[bool] = None,
) -> Tuple[torch.FloatTensor]:
# Get OwlViTModel vision embeddings (same as CLIP)
image_embeds, last_hidden_state = self.owlvit.get_vision_output(pixel_values=pixel_values)
vision_outputs = self.owlvit.vision_model(pixel_values=pixel_values)

# Apply post_layernorm to last_hidden_state, return non-projected output
last_hidden_state = vision_outputs[0]
image_embeds = self.vision_model.post_layernorm(last_hidden_state)

# Resize class token
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
Expand Down Expand Up @@ -1657,19 +1655,18 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.return_dict

# Embed images and text queries
outputs = self.image_text_embedder(
query_embeds, feature_map, outputs = self.image_text_embedder(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

# Last hidden states of text and vision transformers
text_model_last_hidden_state = outputs[2]
vision_model_last_hidden_state = outputs[3]

query_embeds = outputs[0]
feature_map = outputs[1]
# Text and vision model outputs
text_outputs = outputs[-2]
vision_outputs = outputs[-1]

batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
Expand All @@ -1695,8 +1692,10 @@ def forward(
query_embeds,
feature_map,
class_embeds,
text_model_last_hidden_state,
vision_model_last_hidden_state,
text_outputs[0],
vision_outputs[0],
text_outputs,
vision_outputs,
)
output = tuple(x for x in output if x is not None)
return output
Expand All @@ -1707,6 +1706,8 @@ def forward(
pred_boxes=pred_boxes,
logits=pred_logits,
class_embeds=class_embeds,
text_model_last_hidden_state=text_model_last_hidden_state,
vision_model_last_hidden_state=vision_model_last_hidden_state,
text_model_last_hidden_state=text_outputs[0],
vision_model_last_hidden_state=vision_outputs[0],
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)

0 comments on commit 9ebd950

Please sign in to comment.