diff --git a/src/transformers/models/owlvit/feature_extraction_owlvit.py b/src/transformers/models/owlvit/feature_extraction_owlvit.py index d5a38206052f..a39e2f91dc56 100644 --- a/src/transformers/models/owlvit/feature_extraction_owlvit.py +++ b/src/transformers/models/owlvit/feature_extraction_owlvit.py @@ -156,7 +156,7 @@ def post_process(self, outputs, target_sizes): Args: outputs ([`OwlViTObjectDetectionOutput`]): Raw outputs of the model. - target_sizes (`torch.Tensor`, *optional*, defaults None): + target_sizes (`torch.Tensor`, *optional*): Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to None, predictions will not be unnormalized. @@ -200,7 +200,7 @@ def post_process_image_guided_detection(self, outputs, threshold=0.6, nms_thresh Minimum confidence threshold to use to filter out predicted boxes. nms_threshold (`float`, *optional*, defaults to 0.3): IoU threshold for non-maximum suppression of overlapping boxes. - target_sizes (`torch.Tensor`, *optional*, defaults None): + target_sizes (`torch.Tensor`, *optional*): Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to None, predictions will not be unnormalized. @@ -249,7 +249,8 @@ def post_process_image_guided_detection(self, outputs, threshold=0.6, nms_thresh if not query_scores.nonzero().numel(): continue - # Box alpha is scaled such that the best box for a query has alpha 1.0 and the worst box for which this query is still the top query has alpha 0.1. All other boxes will either belong to a different query, or will not be shown. + # Scale box alpha such that the best box for each query has alpha 1.0 and the worst box has alpha 0.1. + # All other boxes will either belong to a different query, or will not be shown. max_score = torch.max(query_scores) + 1e-6 query_alphas = (query_scores - (max_score * 0.1)) / (max_score * 0.9) query_alphas[query_alphas < threshold] = 0.0 diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index e3bd87ed0d53..6546845ed636 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -224,6 +224,10 @@ class OwlViTObjectDetectionOutput(ModelOutput): vision_model_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_patches + 1, hidden_size)`)): Last hidden states extracted from the [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches where the total number of patches is (image_size / patch_size)**2. + text_model_output (Tuple[`BaseModelOutputWithPooling`]): + The output of the [`OwlViTTextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`OwlViTVisionModel`]. """ loss: Optional[torch.FloatTensor] = None @@ -235,6 +239,8 @@ class OwlViTObjectDetectionOutput(ModelOutput): class_embeds: torch.FloatTensor = None text_model_last_hidden_state: Optional[torch.FloatTensor] = None vision_model_last_hidden_state: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None @dataclass @@ -1132,16 +1138,6 @@ def get_image_features( return image_features - def get_vision_output(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor: - # Get vision model outputs - vision_outputs = self.vision_model(pixel_values=pixel_values) - - # Apply post_layernorm to last_hidden_state, return non-projected output - last_hidden_state = vision_outputs[0] - image_features = self.vision_model.post_layernorm(last_hidden_state) - - return (image_features, last_hidden_state) - @add_start_docstrings_to_model_forward(OWLVIT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=OwlViTOutput, config_class=OwlViTConfig) def forward( @@ -1407,6 +1403,7 @@ def image_text_embedder( attention_mask: torch.Tensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> Tuple[torch.FloatTensor]: # Encode text and image @@ -1417,6 +1414,7 @@ def image_text_embedder( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_base_image_embeds=True, + return_dict=return_dict, ) # Resize class token @@ -1438,11 +1436,7 @@ def image_text_embedder( image_embeds = image_embeds.reshape(new_size) text_embeds = outputs[-4] - # Last hidden states from text and vision transformers - text_model_last_hidden_state = outputs[-2][0] - vision_model_last_hidden_state = outputs[-1][0] - - return (text_embeds, image_embeds, text_model_last_hidden_state, vision_model_last_hidden_state) + return (text_embeds, image_embeds, outputs) def image_embedder( self, @@ -1450,7 +1444,11 @@ def image_embedder( output_hidden_states: Optional[bool] = None, ) -> Tuple[torch.FloatTensor]: # Get OwlViTModel vision embeddings (same as CLIP) - image_embeds, last_hidden_state = self.owlvit.get_vision_output(pixel_values=pixel_values) + vision_outputs = self.owlvit.vision_model(pixel_values=pixel_values) + + # Apply post_layernorm to last_hidden_state, return non-projected output + last_hidden_state = vision_outputs[0] + image_embeds = self.vision_model.post_layernorm(last_hidden_state) # Resize class token new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0))) @@ -1657,19 +1655,18 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.return_dict # Embed images and text queries - outputs = self.image_text_embedder( + query_embeds, feature_map, outputs = self.image_text_embedder( input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, + output_attentions=output_attentions, output_hidden_states=output_hidden_states, + return_dict=return_dict, ) - # Last hidden states of text and vision transformers - text_model_last_hidden_state = outputs[2] - vision_model_last_hidden_state = outputs[3] - - query_embeds = outputs[0] - feature_map = outputs[1] + # Text and vision model outputs + text_outputs = outputs[-2] + vision_outputs = outputs[-1] batch_size, num_patches, num_patches, hidden_dim = feature_map.shape image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim)) @@ -1695,8 +1692,10 @@ def forward( query_embeds, feature_map, class_embeds, - text_model_last_hidden_state, - vision_model_last_hidden_state, + text_outputs[0], + vision_outputs[0], + text_outputs, + vision_outputs, ) output = tuple(x for x in output if x is not None) return output @@ -1707,6 +1706,8 @@ def forward( pred_boxes=pred_boxes, logits=pred_logits, class_embeds=class_embeds, - text_model_last_hidden_state=text_model_last_hidden_state, - vision_model_last_hidden_state=vision_model_last_hidden_state, + text_model_last_hidden_state=text_outputs[0], + vision_model_last_hidden_state=vision_outputs[0], + text_model_output=text_outputs, + vision_model_output=vision_outputs, )