diff --git a/docs/source/en/model_doc/blip-2.md b/docs/source/en/model_doc/blip-2.md index 0890e612561a69..b6045d2b4f4c6e 100644 --- a/docs/source/en/model_doc/blip-2.md +++ b/docs/source/en/model_doc/blip-2.md @@ -87,4 +87,9 @@ If you're interested in submitting a resource to be included here, please feel f [[autodoc]] Blip2ForConditionalGeneration - forward - - generate \ No newline at end of file + - generate + +## Blip2ForImageTextRetrieval + +[[autodoc]] Blip2ForImageTextRetrieval + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index db097b0ba63b9b..1e2be47de4e030 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1334,6 +1334,7 @@ [ "BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST", "Blip2ForConditionalGeneration", + "Blip2ForImageTextRetrieval", "Blip2Model", "Blip2PreTrainedModel", "Blip2QFormerModel", @@ -5346,6 +5347,7 @@ from .models.blip_2 import ( BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST, Blip2ForConditionalGeneration, + Blip2ForImageTextRetrieval, Blip2Model, Blip2PreTrainedModel, Blip2QFormerModel, diff --git a/src/transformers/models/blip_2/__init__.py b/src/transformers/models/blip_2/__init__.py index 6fbfd53b3703fd..cd2396d3288837 100644 --- a/src/transformers/models/blip_2/__init__.py +++ b/src/transformers/models/blip_2/__init__.py @@ -38,6 +38,7 @@ "Blip2QFormerModel", "Blip2PreTrainedModel", "Blip2ForConditionalGeneration", + "Blip2ForImageTextRetrieval", "Blip2VisionModel", ] @@ -59,6 +60,7 @@ from .modeling_blip_2 import ( BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST, Blip2ForConditionalGeneration, + Blip2ForImageTextRetrieval, Blip2Model, Blip2PreTrainedModel, Blip2QFormerModel, diff --git a/src/transformers/models/blip_2/configuration_blip_2.py b/src/transformers/models/blip_2/configuration_blip_2.py index 1b375e147f780b..ef246b129cad16 100644 --- a/src/transformers/models/blip_2/configuration_blip_2.py +++ b/src/transformers/models/blip_2/configuration_blip_2.py @@ -209,6 +209,7 @@ def __init__( position_embedding_type="absolute", cross_attention_frequency=2, encoder_hidden_size=1408, + qformer_text_input=False, **kwargs, ): super().__init__(pad_token_id=pad_token_id, **kwargs) @@ -227,6 +228,7 @@ def __init__( self.position_embedding_type = position_embedding_type self.cross_attention_frequency = cross_attention_frequency self.encoder_hidden_size = encoder_hidden_size + self.qformer_text_input = qformer_text_input @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": @@ -302,7 +304,15 @@ class Blip2Config(PretrainedConfig): model_type = "blip-2" - def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs): + def __init__( + self, + vision_config=None, + qformer_config=None, + text_config=None, + num_query_tokens=32, + image_text_hidden_size=256, + **kwargs, + ): super().__init__(**kwargs) if vision_config is None: @@ -326,6 +336,7 @@ def __init__(self, vision_config=None, qformer_config=None, text_config=None, nu self.is_encoder_decoder = self.text_config.is_encoder_decoder self.num_query_tokens = num_query_tokens + self.image_text_hidden_size = image_text_hidden_size self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES self.initializer_factor = 1.0 diff --git a/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py b/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py index c2e6eceae53273..d2e1b27dfb0869 100644 --- a/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py +++ b/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py @@ -31,9 +31,12 @@ from transformers import ( AutoTokenizer, + BertTokenizer, Blip2Config, Blip2ForConditionalGeneration, + Blip2ForImageTextRetrieval, Blip2Processor, + Blip2QFormerConfig, Blip2VisionConfig, BlipImageProcessor, OPTConfig, @@ -51,7 +54,7 @@ def load_demo_image(): # here we list all keys to be renamed (original name on the left, our name on the right) -def create_rename_keys(config): +def create_rename_keys(config, model_name): rename_keys = [] # fmt: off @@ -77,8 +80,9 @@ def create_rename_keys(config): rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias")) # QFormer - rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight")) - rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias")) + if "itm" not in model_name: + rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight")) + rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias")) # fmt: on return rename_keys @@ -114,8 +118,19 @@ def get_blip2_config(model_name, eos_token_id): text_config = T5Config.from_pretrained("google/flan-t5-xl", dense_act_fn="gelu", bos_token_id=1).to_dict() elif "t5-xxl" in model_name: text_config = T5Config.from_pretrained("google/flan-t5-xxl", dense_act_fn="gelu", bos_token_id=1).to_dict() - - config = Blip2Config(vision_config=vision_config, text_config=text_config) + elif "itm" in model_name: + text_config = {} + else: + raise ValueError("Model name not supported") + + if "itm" in model_name: + config = Blip2Config( + vision_config=vision_config, + text_config=text_config, + qformer_config=Blip2QFormerConfig(vocab_size=30523, qformer_text_input=True).to_dict(), + ) + else: + config = Blip2Config(vision_config=vision_config, text_config=text_config) return config, image_size @@ -125,15 +140,24 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ """ Copy/paste/tweak model's weights to Transformers design. """ - tokenizer = ( - AutoTokenizer.from_pretrained("facebook/opt-2.7b") - if "opt" in model_name - else AutoTokenizer.from_pretrained("google/flan-t5-xl") - ) - eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0] + if "opt" in model_name: + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b") + elif "itm" in model_name: + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right") + tokenizer.add_special_tokens({"bos_token": "[DEC]"}) + else: + tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl") + + if "itm" in model_name: + eos_token_id = None + else: + eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0] config, image_size = get_blip2_config(model_name, eos_token_id=eos_token_id) - hf_model = Blip2ForConditionalGeneration(config).eval() + if "itm" in model_name: + hf_model = Blip2ForImageTextRetrieval(config).eval() + else: + hf_model = Blip2ForConditionalGeneration(config).eval() model_name_to_original = { "blip2-opt-2.7b": ("blip2_opt", "pretrain_opt2.7b"), @@ -143,6 +167,9 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ "blip2-flan-t5-xl": ("blip2_t5", "pretrain_flant5xl"), "blip2-flan-t5-xl-coco": ("blip2_t5", "caption_coco_flant5xl"), "blip2-flan-t5-xxl": ("blip2_t5", "pretrain_flant5xxl"), + "blip2-itm-vit-g": ("blip2_image_text_matching", "pretrain"), + # "blip2-itm-vit-large": ("blip2_image_text_matching", "pretrain_vitL"), + "blip2-itm-vit-g-coco": ("blip2_image_text_matching", "coco"), } name, type = model_name_to_original[model_name] @@ -163,13 +190,15 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ # update state dict keys state_dict = original_model.state_dict() - rename_keys = create_rename_keys(config) + rename_keys = create_rename_keys(config, model_name) for src, dest in rename_keys: rename_key(state_dict, src, dest) # some keys can be renamed efficiently for key, val in state_dict.copy().items(): val = state_dict.pop(key) + if key.startswith("Qformer.cls"): + key = key.replace("Qformer.cls", "cls") if key.startswith("Qformer.bert"): key = key.replace("Qformer.bert", "qformer") if "attention.self" in key: @@ -193,7 +222,6 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ image = load_demo_image() original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device) - input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device) # create processor image_processor = BlipImageProcessor( @@ -207,50 +235,100 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ original_model.to(lavis_device) hf_model.to(hf_model_device) - with torch.no_grad(): - if "opt" in model_name: - original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits - logits = hf_model(pixel_values, input_ids).logits - else: - original_logits = original_model( - {"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]} - ).logits - labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100) - logits = hf_model(pixel_values, input_ids, labels=labels).logits - assert original_logits.shape == logits.shape - print("First values of original logits:", original_logits[0, :3, :3]) - print("First values of HF logits:", logits[0, :3, :3]) + if "itm" in model_name: + caption = "a large fountain spewing water into the air" + input_ids = tokenizer([caption], return_tensors="pt").input_ids.to(hf_model_device) + attention_mask = processor(text=caption, return_tensors="pt").attention_mask.to(hf_model_device) + + with torch.no_grad(): + original_logits = original_model( + {"image": original_pixel_values, "text_input": [caption]}, match_head="itm" + ) + logits = hf_model(pixel_values=original_pixel_values, input_ids=input_ids, attention_mask=attention_mask) - # assert values - assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4) - print("Looks ok!") + assert original_logits.shape == logits.itm_score.shape + print("First values of original logits:", original_logits[0, :3]) + print("First values of HF logits:", logits.itm_score[0, :3]) - print("Generating a caption...") - prompt = "Question: what object is in this image? Answer:" - input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device) + # assert values + # cast to same type + target_dtype = logits.itm_score.dtype + assert torch.allclose(original_logits.to(target_dtype), logits.itm_score, atol=1e-2) - set_seed(42) + original_itm_scores = torch.nn.functional.softmax(original_logits, dim=1) + itm_scores = torch.nn.functional.softmax(logits.itm_score, dim=1) + assert torch.allclose(original_itm_scores.to(target_dtype), itm_scores, atol=1e-2) + print("Looks ok!") - original_outputs = original_model.generate( - {"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True - ) - outputs = hf_model.generate( - pixel_values, - input_ids, - do_sample=True, - num_beams=5, - max_length=30, - min_length=1, - top_p=0.9, - repetition_penalty=1.0, - length_penalty=1.0, - temperature=1, - ) - output_text = processor.batch_decode(outputs, skip_special_tokens=True) - output_text = [text.strip() for text in output_text] - print("Original generation:", original_outputs) - print("HF generation:", output_text) + with torch.no_grad(): + original_logits = original_model( + {"image": original_pixel_values, "text_input": [caption]}, match_head="itc" + ) + logits = hf_model( + pixel_values=original_pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + use_itm_head=False, + ) + + assert original_logits.shape == logits.itm_score.shape + print("First values of original logits:", original_logits[0, :3]) + print("First values of HF logits:", logits.itm_score[0, :3]) + + # assert values + # cast to same type + target_dtype = logits.itm_score.dtype + assert torch.allclose(original_logits.to(target_dtype), logits.itm_score, atol=1e-2) + print("Looks ok!") + + else: + input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device) + + with torch.no_grad(): + if "opt" in model_name: + original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits + logits = hf_model(pixel_values, input_ids).logits + else: + original_logits = original_model( + {"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]} + ).logits + labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100) + logits = hf_model(pixel_values, input_ids, labels=labels).logits + + assert original_logits.shape == logits.shape + print("First values of original logits:", original_logits[0, :3, :3]) + print("First values of HF logits:", logits[0, :3, :3]) + + # assert values + assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4) + print("Looks ok!") + + print("Generating a caption...") + prompt = "Question: what object is in this image? Answer:" + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device) + + set_seed(42) + + original_outputs = original_model.generate( + {"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True + ) + outputs = hf_model.generate( + pixel_values, + input_ids, + do_sample=True, + num_beams=5, + max_length=30, + min_length=1, + top_p=0.9, + repetition_penalty=1.0, + length_penalty=1.0, + temperature=1, + ) + output_text = processor.batch_decode(outputs, skip_special_tokens=True) + output_text = [text.strip() for text in output_text] + print("Original generation:", original_outputs) + print("HF generation:", output_text) if pytorch_dump_folder_path is not None: processor.save_pretrained(pytorch_dump_folder_path) @@ -271,6 +349,9 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ "blip2-flan-t5-xl", "blip2-flan-t5-xl-coco", "blip2-flan-t5-xxl", + "blip2-itm-vit-g", + # "blip2-itm-vit-large", + "blip2-itm-vit-g-coco", ] parser.add_argument( "--model_name", diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 87c8132ff4fd86..b71724d8db0865 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -22,6 +22,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss +from torch.nn.functional import normalize from ...activations import ACT2FN from ...modeling_outputs import ( @@ -86,6 +87,50 @@ def to_tuple(self) -> Tuple[Any]: ) +@dataclass +# Copied from transformers.models.blip.modeling_blip.BlipImageTextMatchingModelOutput with Blip->Blip2 +class Blip2ImageTextMatchingModelOutput(ModelOutput): + """ + Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the + last hidden states. This class also adds the loss term from the text decoder as well as the image-text similarity + scores. + + Args: + itm_score (`torch.FloatTensor`): + The image-text similarity scores. + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Languge modeling loss from the text decoder. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + vision_pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*): + Last layer hidden-state of the vision of the vision-only branch of the model. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + question_embeds (`torch.FloatTensor`): + The question embeddings obtained by the text projection layer. + """ + + itm_score: Optional[torch.FloatTensor] = None + loss: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + vision_pooler_output: Optional[torch.FloatTensor] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + question_embeds: Optional[Tuple[torch.FloatTensor]] = None + + # Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2 class Blip2VisionEmbeddings(nn.Module): def __init__(self, config: Blip2VisionConfig): @@ -816,6 +861,10 @@ def __init__(self, config, layer_idx): else: self.has_cross_attention = False + if config.qformer_text_input: + self.intermediate = Blip2QFormerIntermediate(config) + self.output = Blip2QFormerOutput(config) + self.intermediate_query = Blip2QFormerIntermediate(config) self.output_query = Blip2QFormerOutput(config) @@ -1003,6 +1052,59 @@ def custom_forward(*inputs): ) +# Adapted from https://github.com/salesforce/LAVIS/blob/main/lavis/models/blip2_models/Qformer.py#L51 +class Blip2TextEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, + input_ids: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if input_ids is not None: + input_ids = input_ids.to(self.word_embeddings.weight.device) + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + class Blip2QFormerModel(Blip2PreTrainedModel): """ Querying Transformer (Q-Former), used in BLIP-2. @@ -1012,8 +1114,11 @@ def __init__(self, config: Blip2QFormerConfig): super().__init__(config) self.config = config - self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + if config.qformer_text_input: + self.embeddings = Blip2TextEmbeddings(config) + else: + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) self.encoder = Blip2QFormerEncoder(config) @@ -1080,7 +1185,7 @@ def get_extended_attention_mask( def forward( self, - query_embeds: torch.FloatTensor, + query_embeds: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, @@ -1090,6 +1195,8 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + input_ids: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`): @@ -1116,6 +1223,9 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if input_ids is None and query_embeds is None: + raise ValueError("You have to specify query_embeds when input_ids is None") + # past_key_values_length past_key_values_length = ( past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0 @@ -1123,15 +1233,25 @@ def forward( query_length = query_embeds.shape[1] if query_embeds is not None else 0 - embedding_output = self.layernorm(query_embeds) - embedding_output = self.dropout(embedding_output) + if hasattr(self, "embeddings"): + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = self.layernorm(query_embeds) + embedding_output = self.dropout(embedding_output) input_shape = embedding_output.size()[:-1] batch_size, seq_length = input_shape device = embedding_output.device if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device, dtype=torch.long + ) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. @@ -1149,7 +1269,7 @@ def forward( if type(encoder_attention_mask) == list: encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] elif encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device, dtype=torch.long) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) @@ -1192,6 +1312,56 @@ def forward( ) +# Copied from transformers.models.blip.modeling_blip_text.BlipTextPredictionHeadTransform with Blip->Blip2 +class Blip2TextPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# Copied from transformers.models.blip.modeling_blip_text.BlipTextLMPredictionHead with Blip->Blip2 +class Blip2TextLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = Blip2TextPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Copied from transformers.models.blip.modeling_blip_text.BlipTextOnlyMLMHead with Blip->Blip2 +class Blip2TextOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = Blip2TextLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + @add_start_docstrings( """ BLIP-2 Model for generating text and image features. The model consists of a vision encoder, Querying Transformer @@ -1211,17 +1381,27 @@ def __init__(self, config: Blip2Config): self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) self.qformer = Blip2QFormerModel(config.qformer_config) - self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) - if config.use_decoder_only_language_model: - language_model = AutoModelForCausalLM.from_config(config.text_config) + if self.config.qformer_config.qformer_text_input: + self._keep_in_fp32_modules = [] # fp16 compatibility + + # vision projection layer + self.vision_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) + + # text projection layer + self.text_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) + else: - language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) + self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) + if config.use_decoder_only_language_model: + language_model = AutoModelForCausalLM.from_config(config.text_config) + else: + language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) - # Update _tied_weights_keys using the base model used. - if language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] + # Update _tied_weights_keys using the base model used. + if language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] - self.language_model = language_model + self.language_model = language_model # Initialize weights and apply final processing self.post_init() @@ -1236,7 +1416,10 @@ def set_output_embeddings(self, new_embeddings): self.language_model.set_output_embeddings(new_embeddings) def get_output_embeddings(self) -> nn.Module: - return self.language_model.get_output_embeddings() + if hasattr(self, "language_model"): + return self.language_model.get_output_embeddings() + else: + return self.qformer.get_output_embeddings() def get_encoder(self): return self.language_model.get_encoder() @@ -1245,7 +1428,7 @@ def get_decoder(self): return self.language_model.get_decoder() def _tie_weights(self): - if not self.config.use_decoder_only_language_model: + if not self.config.use_decoder_only_language_model and hasattr(self, "language_model"): self.language_model.encoder.embed_tokens = self.language_model.shared self.language_model.decoder.embed_tokens = self.language_model.shared @@ -1288,29 +1471,42 @@ def get_text_features( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if self.config.use_decoder_only_language_model: - text_outputs = self.language_model( + if hasattr(self, "language_model"): + if self.config.use_decoder_only_language_model: + text_outputs = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + else: + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + + text_outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + labels=labels, + ) + + return text_outputs + else: + query_outputs = self.qformer( input_ids=input_ids, + # query_embeds=query_tokens, attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, return_dict=return_dict, ) - else: - inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + embeds = query_outputs.last_hidden_state + text_features = self.text_proj(embeds) + text_features = normalize(text_features, dim=-1) - text_outputs = self.language_model( - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - labels=labels, - ) - - return text_outputs + return text_features @add_start_docstrings_to_model_forward(BLIP_2_VISION_INPUTS_DOCSTRING) def get_image_features( @@ -1358,7 +1554,26 @@ def get_image_features( return_dict=return_dict, ) - return vision_outputs + if not hasattr(self, "vision_proj"): + return vision_outputs + + else: + image_embeds = vision_outputs[0] + + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_outputs = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=return_dict, + ) + + embeds = query_outputs.last_hidden_state + image_features = normalize(self.vision_proj(embeds), dim=-1) + + return image_features @add_start_docstrings_to_model_forward(BLIP_2_INPUTS_DOCSTRING) def get_qformer_features( @@ -1884,3 +2099,141 @@ def generate( ) return outputs + + +@add_start_docstrings( + """ + BLIP-2 Model with a vision and text projector, and a classification head on top. The model is used in the context + of image-text retrieval. Given an image and a text, the model returns the probability of the text being relevant to + the image. + """, + BLIP_2_START_DOCSTRING, +) +class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): + main_input_name = "pixel_values" + config_class = Blip2Config + + def __init__(self, config: Blip2Config): + super().__init__(config) + + self.vision_model = Blip2VisionModel(config.vision_config) + + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + + self.qformer = Blip2QFormerModel(config.qformer_config) + self.cls = Blip2TextOnlyMLMHead(config.qformer_config) + + # vision projection layer + self.vision_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) + + # text projection layer + self.text_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) + + # image text matching head + self.itm_head = nn.Linear(config.qformer_config.hidden_size, 2) + + self.temp = nn.Parameter(0.07 * torch.ones([])) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BLIP_2_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Blip2ImageTextMatchingModelOutput, config_class=Blip2VisionConfig) + def forward( + self, + pixel_values: torch.FloatTensor, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, + use_itm_head: Optional[bool] = True, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Blip2ImageTextMatchingModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Blip2ForImageTextRetrieval + + >>> model = Blip2ForImageTextRetrieval.from_pretrained("jpizarrom/blip2-itm-vit-g") + >>> processor = AutoProcessor.from_pretrained("jpizarrom/blip2-itm-vit-g") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text = "an image of a cat" + + >>> inputs = processor(images=image, text=text, return_tensors="pt") + >>> outputs = model(**inputs) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[0] + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + if use_itm_head: + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(query_tokens.device) + attention_mask = torch.cat([query_attention_mask, attention_mask], dim=1) + + question_embeds = self.qformer( + input_ids=input_ids, + query_embeds=query_tokens, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=return_dict, + ) + question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state + + output = self.itm_head(question_embeds[:, : query_tokens.size(1), :]) + output = output.mean(dim=1) + else: + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_outputs = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=return_dict, + ) + image_embeds = query_outputs[0] if not return_dict else query_outputs.last_hidden_state + + question_embeds = self.qformer( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=return_dict, + ) + question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state + + image_feat = normalize(self.vision_proj(image_embeds), dim=-1) + text_feat = normalize(self.text_proj(question_embeds[:, 0, :]), dim=-1) + + outputs = torch.bmm(image_feat, text_feat.unsqueeze(-1)) + output, _ = torch.max(outputs, dim=1) + + if not return_dict: + outputs = (output, vision_outputs[0]) + vision_outputs[2:] + (question_embeds,) + return tuple(output for output in outputs if output is not None) + + return Blip2ImageTextMatchingModelOutput( + itm_score=output, + last_hidden_state=vision_outputs.last_hidden_state, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + question_embeds=question_embeds, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index f30b04e0e75e70..21ce7b760f8761 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1503,6 +1503,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Blip2ForImageTextRetrieval(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Blip2Model(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index c5bdb70791eb56..812b90fb2ca473 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -23,7 +23,14 @@ import requests from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig -from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device +from transformers.testing_utils import ( + require_torch, + require_torch_gpu, + require_torch_multi_gpu, + require_vision, + slow, + torch_device, +) from transformers.utils import is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -41,7 +48,7 @@ import torch from torch import nn - from transformers import Blip2ForConditionalGeneration, Blip2Model, Blip2VisionModel + from transformers import Blip2ForConditionalGeneration, Blip2ForImageTextRetrieval, Blip2Model, Blip2VisionModel from transformers.models.blip_2.modeling_blip_2 import BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST @@ -224,6 +231,7 @@ def __init__( initializer_range=0.02, bos_token_id=0, scope=None, + qformer_text_input=False, ): self.parent = parent self.batch_size = batch_size @@ -243,6 +251,7 @@ def __init__( self.initializer_range = initializer_range self.scope = scope self.bos_token_id = bos_token_id + self.qformer_text_input = qformer_text_input def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -275,6 +284,7 @@ def get_config(self): max_position_embeddings=self.max_position_embeddings, initializer_range=self.initializer_range, bos_token_id=self.bos_token_id, + qformer_text_input=self.qformer_text_input, ) @@ -813,6 +823,211 @@ def test_initialization(self): ) +class Blip2TextRetrievalModelTester: + def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training=True): + if vision_kwargs is None: + vision_kwargs = {} + if qformer_kwargs is None: + qformer_kwargs = {"qformer_text_input": True} + text_kwargs = {} + + self.parent = parent + self.vision_model_tester = Blip2VisionModelTester(parent, **vision_kwargs) + self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs) + self.text_model_tester = Blip2TextModelTester(parent, **text_kwargs) + self.is_training = is_training + + def prepare_config_and_inputs(self): + _, input_ids, attention_mask = self.qformer_model_tester.prepare_config_and_inputs() + _, pixel_values = self.vision_model_tester.prepare_config_and_inputs() + + config = self.get_config() + + return config, input_ids, attention_mask, pixel_values + + def get_config(self): + return Blip2Config.from_vision_qformer_text_configs( + vision_config=self.vision_model_tester.get_config(), + qformer_config=self.qformer_model_tester.get_config(), + text_config=self.text_model_tester.get_config(), + ) + + def create_and_check_model(self, config, input_ids, attention_mask, pixel_values): + model = Blip2ForImageTextRetrieval(config).to(torch_device).eval() + with torch.no_grad(): + result = model(pixel_values, input_ids, attention_mask) + + image_size = (self.vision_model_tester.image_size, self.vision_model_tester.image_size) + patch_size = (self.vision_model_tester.patch_size, self.vision_model_tester.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + + self.parent.assertEqual( + result.itm_score.shape, + (self.vision_model_tester.batch_size, 2), + ) + self.parent.assertEqual( + result.last_hidden_state.shape, + (self.vision_model_tester.batch_size, num_patches + 1, self.qformer_model_tester.hidden_size), + ) + self.parent.assertEqual( + result.question_embeds.shape, + ( + self.text_model_tester.batch_size, + self.vision_model_tester.hidden_size + input_ids.shape[1], + self.qformer_model_tester.hidden_size, + ), + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, attention_mask, pixel_values = config_and_inputs + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + } + return config, inputs_dict + + +@require_torch +class Blip2TextRetrievalModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (Blip2ForImageTextRetrieval,) if is_torch_available() else () + fx_compatible = False + test_head_masking = False + test_pruning = False + test_resize_embeddings = False + test_attention_outputs = False + test_torchscript = False + + # TODO add or skip tests + test_model_outputs_equivalence = False + test_tied_weights_keys = False + test_hidden_states_output = False + test_inputs_embeds = False + test_model_common_attributes = False + test_retain_grad_hidden_states_attentions = False + + def setUp(self): + self.model_tester = Blip2TextRetrievalModelTester(self) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + if model.config.is_encoder_decoder: + expected_arg_names = [ + "pixel_values", + "input_ids", + "attention_mask", + ] + + expected_arg_names.extend(["use_itm_head"] if "use_itm_head" in arg_names else ["decoder_input_ids"]) + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + else: + # TODO + raise NotImplementedError + + def test_load_vision_qformer_text_config(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + # Save Blip2Config and check if we can load Blip2VisionConfig from it + with tempfile.TemporaryDirectory() as tmp_dir_name: + config.save_pretrained(tmp_dir_name) + vision_config = Blip2VisionConfig.from_pretrained(tmp_dir_name) + self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict()) + + # Save Blip2Config and check if we can load Blip2QFormerConfig from it + with tempfile.TemporaryDirectory() as tmp_dir_name: + config.save_pretrained(tmp_dir_name) + qformer_config = Blip2QFormerConfig.from_pretrained(tmp_dir_name) + self.assertDictEqual(config.qformer_config.to_dict(), qformer_config.to_dict()) + + @slow + def test_model_from_pretrained(self): + for model_name in ["jpizarrom/blip2-itm-vit-g"]: + for model_class in self.all_model_classes + (Blip2Model,): + model = model_class.from_pretrained(model_name) + self.assertIsNotNone(model) + + def test_get_text_features(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + inputs_dict = { + "input_ids": torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]).to(torch_device), + "attention_mask": torch.LongTensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]).to(torch_device), + } + + model = Blip2Model(config).to(torch_device) + model.eval() + text_features = model.get_text_features(**inputs_dict) + self.assertEqual(text_features[0].shape, (10, config.image_text_hidden_size)) + + def test_get_image_features(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + keys_to_pop = ["input_ids", "attention_mask"] + + for key in keys_to_pop: + inputs_dict.pop(key) + + model = Blip2Model(config).to(torch_device) + model.eval() + image_features = model.get_image_features(**inputs_dict) + self.assertEqual( + image_features.shape, # [12, 32, 256] + ( + self.model_tester.vision_model_tester.batch_size, + config.vision_config.hidden_size, + config.image_text_hidden_size, + ), + ) + + def test_training(self): + pass + + def test_training_gradient_checkpointing(self): + pass + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + # check if `logit_scale` is initilized as per the original implementation + if name == "logit_scale": + self.assertAlmostEqual( + param.data.item(), + np.log(1 / 0.07), + delta=1e-3, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + elif name == "temp": + self.assertAlmostEqual( + param.data.item(), + 0.07, + delta=1e-3, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + else: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + # We will verify our results on an image of cute cats def prepare_img(): url = "https://huggingface.co/hf-internal-testing/blip-test-image/resolve/main/demo.jpg" @@ -989,3 +1204,87 @@ def test_inference_t5_multi_gpu(self): [0, 3, 7, 152, 67, 839, 1], ) self.assertEqual(generated_text, "san diego") + + @require_torch_gpu + def test_inference_itm_features(self): + processor = Blip2Processor.from_pretrained("jpizarrom/blip2-itm-vit-g") + model = Blip2Model.from_pretrained( + "jpizarrom/blip2-itm-vit-g", + ).to(torch_device) + + # image features + image = prepare_img() + image_inputs = processor(images=image, return_tensors="pt").to(torch_device) + image_features = model.get_image_features(**image_inputs) + expected_image_features = torch.tensor( + [ + -0.0946953147649765, + -0.07541415840387344, + 0.03312666341662407, + 0.053536128252744675, + 0.03368198126554489, + -0.013867943547666073, + ] + ).to(torch_device) + self.assertTrue(torch.allclose(image_features[0][0][:6], expected_image_features, atol=1e-4)) + + # text features + text_inputs = processor(text="a woman sitting on the beach with a dog", padding=True, return_tensors="pt").to( + torch_device + ) + text_features = model.get_text_features(**text_inputs) + expected_text_features = torch.tensor( + [ + -0.10836730897426605, + 0.05315554141998291, + -0.028310950845479965, + 0.016979066655039787, + 0.0865054652094841, + -0.046645939350128174, + ] + ).to(torch_device) + self.assertTrue(torch.allclose(text_features[0][0][:6], expected_text_features, atol=1e-4)) + + # check similarity + similarity = (image_features @ text_features[:, 0, :].t()).max() + expected_similarity = torch.tensor(0.44385525584220886).to(torch_device) + self.assertTrue(torch.allclose(similarity, expected_similarity, atol=1e-4)) + + @require_torch_gpu + def test_inference_itm_features_fp16(self): + processor = Blip2Processor.from_pretrained("jpizarrom/blip2-itm-vit-g") + model = Blip2Model.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16).to(torch_device) + + # image features + image = prepare_img() + image_inputs = processor(images=image, return_tensors="pt").to(torch_device, dtype=torch.float16) + image_features = model.get_image_features(**image_inputs) + expected_image_features = [ + -0.093994140625, + -0.075927734375, + 0.031890869140625, + 0.053009033203125, + 0.0352783203125, + -0.01190185546875, + ] + self.assertTrue(np.allclose(image_features[0][0][:6].tolist(), expected_image_features, atol=1e-4)) + + # text features + text_inputs = processor(text="a woman sitting on the beach with a dog", padding=True, return_tensors="pt").to( + torch_device + ) + text_features = model.get_text_features(**text_inputs) + expected_text_features = [ + -0.1082763671875, + 0.053192138671875, + -0.02825927734375, + 0.0169830322265625, + 0.08648681640625, + -0.04656982421875, + ] + self.assertTrue(np.allclose(text_features[0][0][:6].tolist(), expected_text_features, atol=1e-4)) + + # check similarity + similarity = (image_features @ text_features[:, 0, :].t()).max() + expected_similarity = 0.44384765625 + self.assertTrue(np.allclose(similarity.item(), expected_similarity, atol=1e-4)) diff --git a/utils/check_repo.py b/utils/check_repo.py index c8bd228eaa776e..792501b9d540d2 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -146,6 +146,7 @@ "ClapAudioModel", "ClapAudioModelWithProjection", "Blip2ForConditionalGeneration", + "Blip2ForImageTextRetrieval", "Blip2QFormerModel", "Blip2VisionModel", "ErnieMForInformationExtraction",