From 402373213070fc36e4a4248a00b6fce1fcb72eab Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Sat, 19 Aug 2023 16:45:17 +0200 Subject: [PATCH 01/25] Add Blip2ForImageTextRetrieval --- docs/source/en/model_doc/blip-2.md | 7 +- src/transformers/__init__.py | 2 + src/transformers/models/blip_2/__init__.py | 2 + .../models/blip_2/configuration_blip_2.py | 13 +- .../convert_blip_2_original_to_pytorch.py | 187 +++++--- .../models/blip_2/modeling_blip_2.py | 425 ++++++++++++++++-- src/transformers/utils/dummy_pt_objects.py | 7 + tests/models/blip_2/test_modeling_blip_2.py | 303 ++++++++++++- utils/check_repo.py | 1 + 9 files changed, 854 insertions(+), 93 deletions(-) diff --git a/docs/source/en/model_doc/blip-2.md b/docs/source/en/model_doc/blip-2.md index 0890e612561a69..b6045d2b4f4c6e 100644 --- a/docs/source/en/model_doc/blip-2.md +++ b/docs/source/en/model_doc/blip-2.md @@ -87,4 +87,9 @@ If you're interested in submitting a resource to be included here, please feel f [[autodoc]] Blip2ForConditionalGeneration - forward - - generate \ No newline at end of file + - generate + +## Blip2ForImageTextRetrieval + +[[autodoc]] Blip2ForImageTextRetrieval + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 5c679164157575..3b9d6bc87137db 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1338,6 +1338,7 @@ [ "BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST", "Blip2ForConditionalGeneration", + "Blip2ForImageTextRetrieval", "Blip2Model", "Blip2PreTrainedModel", "Blip2QFormerModel", @@ -5357,6 +5358,7 @@ from .models.blip_2 import ( BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST, Blip2ForConditionalGeneration, + Blip2ForImageTextRetrieval, Blip2Model, Blip2PreTrainedModel, Blip2QFormerModel, diff --git a/src/transformers/models/blip_2/__init__.py b/src/transformers/models/blip_2/__init__.py index 6fbfd53b3703fd..cd2396d3288837 100644 --- a/src/transformers/models/blip_2/__init__.py +++ b/src/transformers/models/blip_2/__init__.py @@ -38,6 +38,7 @@ "Blip2QFormerModel", "Blip2PreTrainedModel", "Blip2ForConditionalGeneration", + "Blip2ForImageTextRetrieval", "Blip2VisionModel", ] @@ -59,6 +60,7 @@ from .modeling_blip_2 import ( BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST, Blip2ForConditionalGeneration, + Blip2ForImageTextRetrieval, Blip2Model, Blip2PreTrainedModel, Blip2QFormerModel, diff --git a/src/transformers/models/blip_2/configuration_blip_2.py b/src/transformers/models/blip_2/configuration_blip_2.py index 1b375e147f780b..ef246b129cad16 100644 --- a/src/transformers/models/blip_2/configuration_blip_2.py +++ b/src/transformers/models/blip_2/configuration_blip_2.py @@ -209,6 +209,7 @@ def __init__( position_embedding_type="absolute", cross_attention_frequency=2, encoder_hidden_size=1408, + qformer_text_input=False, **kwargs, ): super().__init__(pad_token_id=pad_token_id, **kwargs) @@ -227,6 +228,7 @@ def __init__( self.position_embedding_type = position_embedding_type self.cross_attention_frequency = cross_attention_frequency self.encoder_hidden_size = encoder_hidden_size + self.qformer_text_input = qformer_text_input @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": @@ -302,7 +304,15 @@ class Blip2Config(PretrainedConfig): model_type = "blip-2" - def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs): + def __init__( + self, + vision_config=None, + qformer_config=None, + text_config=None, + num_query_tokens=32, + image_text_hidden_size=256, + **kwargs, + ): super().__init__(**kwargs) if vision_config is None: @@ -326,6 +336,7 @@ def __init__(self, vision_config=None, qformer_config=None, text_config=None, nu self.is_encoder_decoder = self.text_config.is_encoder_decoder self.num_query_tokens = num_query_tokens + self.image_text_hidden_size = image_text_hidden_size self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES self.initializer_factor = 1.0 diff --git a/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py b/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py index c2e6eceae53273..d2e1b27dfb0869 100644 --- a/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py +++ b/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py @@ -31,9 +31,12 @@ from transformers import ( AutoTokenizer, + BertTokenizer, Blip2Config, Blip2ForConditionalGeneration, + Blip2ForImageTextRetrieval, Blip2Processor, + Blip2QFormerConfig, Blip2VisionConfig, BlipImageProcessor, OPTConfig, @@ -51,7 +54,7 @@ def load_demo_image(): # here we list all keys to be renamed (original name on the left, our name on the right) -def create_rename_keys(config): +def create_rename_keys(config, model_name): rename_keys = [] # fmt: off @@ -77,8 +80,9 @@ def create_rename_keys(config): rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias")) # QFormer - rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight")) - rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias")) + if "itm" not in model_name: + rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight")) + rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias")) # fmt: on return rename_keys @@ -114,8 +118,19 @@ def get_blip2_config(model_name, eos_token_id): text_config = T5Config.from_pretrained("google/flan-t5-xl", dense_act_fn="gelu", bos_token_id=1).to_dict() elif "t5-xxl" in model_name: text_config = T5Config.from_pretrained("google/flan-t5-xxl", dense_act_fn="gelu", bos_token_id=1).to_dict() - - config = Blip2Config(vision_config=vision_config, text_config=text_config) + elif "itm" in model_name: + text_config = {} + else: + raise ValueError("Model name not supported") + + if "itm" in model_name: + config = Blip2Config( + vision_config=vision_config, + text_config=text_config, + qformer_config=Blip2QFormerConfig(vocab_size=30523, qformer_text_input=True).to_dict(), + ) + else: + config = Blip2Config(vision_config=vision_config, text_config=text_config) return config, image_size @@ -125,15 +140,24 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ """ Copy/paste/tweak model's weights to Transformers design. """ - tokenizer = ( - AutoTokenizer.from_pretrained("facebook/opt-2.7b") - if "opt" in model_name - else AutoTokenizer.from_pretrained("google/flan-t5-xl") - ) - eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0] + if "opt" in model_name: + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b") + elif "itm" in model_name: + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right") + tokenizer.add_special_tokens({"bos_token": "[DEC]"}) + else: + tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl") + + if "itm" in model_name: + eos_token_id = None + else: + eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0] config, image_size = get_blip2_config(model_name, eos_token_id=eos_token_id) - hf_model = Blip2ForConditionalGeneration(config).eval() + if "itm" in model_name: + hf_model = Blip2ForImageTextRetrieval(config).eval() + else: + hf_model = Blip2ForConditionalGeneration(config).eval() model_name_to_original = { "blip2-opt-2.7b": ("blip2_opt", "pretrain_opt2.7b"), @@ -143,6 +167,9 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ "blip2-flan-t5-xl": ("blip2_t5", "pretrain_flant5xl"), "blip2-flan-t5-xl-coco": ("blip2_t5", "caption_coco_flant5xl"), "blip2-flan-t5-xxl": ("blip2_t5", "pretrain_flant5xxl"), + "blip2-itm-vit-g": ("blip2_image_text_matching", "pretrain"), + # "blip2-itm-vit-large": ("blip2_image_text_matching", "pretrain_vitL"), + "blip2-itm-vit-g-coco": ("blip2_image_text_matching", "coco"), } name, type = model_name_to_original[model_name] @@ -163,13 +190,15 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ # update state dict keys state_dict = original_model.state_dict() - rename_keys = create_rename_keys(config) + rename_keys = create_rename_keys(config, model_name) for src, dest in rename_keys: rename_key(state_dict, src, dest) # some keys can be renamed efficiently for key, val in state_dict.copy().items(): val = state_dict.pop(key) + if key.startswith("Qformer.cls"): + key = key.replace("Qformer.cls", "cls") if key.startswith("Qformer.bert"): key = key.replace("Qformer.bert", "qformer") if "attention.self" in key: @@ -193,7 +222,6 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ image = load_demo_image() original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device) - input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device) # create processor image_processor = BlipImageProcessor( @@ -207,50 +235,100 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ original_model.to(lavis_device) hf_model.to(hf_model_device) - with torch.no_grad(): - if "opt" in model_name: - original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits - logits = hf_model(pixel_values, input_ids).logits - else: - original_logits = original_model( - {"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]} - ).logits - labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100) - logits = hf_model(pixel_values, input_ids, labels=labels).logits - assert original_logits.shape == logits.shape - print("First values of original logits:", original_logits[0, :3, :3]) - print("First values of HF logits:", logits[0, :3, :3]) + if "itm" in model_name: + caption = "a large fountain spewing water into the air" + input_ids = tokenizer([caption], return_tensors="pt").input_ids.to(hf_model_device) + attention_mask = processor(text=caption, return_tensors="pt").attention_mask.to(hf_model_device) + + with torch.no_grad(): + original_logits = original_model( + {"image": original_pixel_values, "text_input": [caption]}, match_head="itm" + ) + logits = hf_model(pixel_values=original_pixel_values, input_ids=input_ids, attention_mask=attention_mask) - # assert values - assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4) - print("Looks ok!") + assert original_logits.shape == logits.itm_score.shape + print("First values of original logits:", original_logits[0, :3]) + print("First values of HF logits:", logits.itm_score[0, :3]) - print("Generating a caption...") - prompt = "Question: what object is in this image? Answer:" - input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device) + # assert values + # cast to same type + target_dtype = logits.itm_score.dtype + assert torch.allclose(original_logits.to(target_dtype), logits.itm_score, atol=1e-2) - set_seed(42) + original_itm_scores = torch.nn.functional.softmax(original_logits, dim=1) + itm_scores = torch.nn.functional.softmax(logits.itm_score, dim=1) + assert torch.allclose(original_itm_scores.to(target_dtype), itm_scores, atol=1e-2) + print("Looks ok!") - original_outputs = original_model.generate( - {"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True - ) - outputs = hf_model.generate( - pixel_values, - input_ids, - do_sample=True, - num_beams=5, - max_length=30, - min_length=1, - top_p=0.9, - repetition_penalty=1.0, - length_penalty=1.0, - temperature=1, - ) - output_text = processor.batch_decode(outputs, skip_special_tokens=True) - output_text = [text.strip() for text in output_text] - print("Original generation:", original_outputs) - print("HF generation:", output_text) + with torch.no_grad(): + original_logits = original_model( + {"image": original_pixel_values, "text_input": [caption]}, match_head="itc" + ) + logits = hf_model( + pixel_values=original_pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + use_itm_head=False, + ) + + assert original_logits.shape == logits.itm_score.shape + print("First values of original logits:", original_logits[0, :3]) + print("First values of HF logits:", logits.itm_score[0, :3]) + + # assert values + # cast to same type + target_dtype = logits.itm_score.dtype + assert torch.allclose(original_logits.to(target_dtype), logits.itm_score, atol=1e-2) + print("Looks ok!") + + else: + input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device) + + with torch.no_grad(): + if "opt" in model_name: + original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits + logits = hf_model(pixel_values, input_ids).logits + else: + original_logits = original_model( + {"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]} + ).logits + labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100) + logits = hf_model(pixel_values, input_ids, labels=labels).logits + + assert original_logits.shape == logits.shape + print("First values of original logits:", original_logits[0, :3, :3]) + print("First values of HF logits:", logits[0, :3, :3]) + + # assert values + assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4) + print("Looks ok!") + + print("Generating a caption...") + prompt = "Question: what object is in this image? Answer:" + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device) + + set_seed(42) + + original_outputs = original_model.generate( + {"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True + ) + outputs = hf_model.generate( + pixel_values, + input_ids, + do_sample=True, + num_beams=5, + max_length=30, + min_length=1, + top_p=0.9, + repetition_penalty=1.0, + length_penalty=1.0, + temperature=1, + ) + output_text = processor.batch_decode(outputs, skip_special_tokens=True) + output_text = [text.strip() for text in output_text] + print("Original generation:", original_outputs) + print("HF generation:", output_text) if pytorch_dump_folder_path is not None: processor.save_pretrained(pytorch_dump_folder_path) @@ -271,6 +349,9 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ "blip2-flan-t5-xl", "blip2-flan-t5-xl-coco", "blip2-flan-t5-xxl", + "blip2-itm-vit-g", + # "blip2-itm-vit-large", + "blip2-itm-vit-g-coco", ] parser.add_argument( "--model_name", diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 87c8132ff4fd86..b71724d8db0865 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -22,6 +22,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss +from torch.nn.functional import normalize from ...activations import ACT2FN from ...modeling_outputs import ( @@ -86,6 +87,50 @@ def to_tuple(self) -> Tuple[Any]: ) +@dataclass +# Copied from transformers.models.blip.modeling_blip.BlipImageTextMatchingModelOutput with Blip->Blip2 +class Blip2ImageTextMatchingModelOutput(ModelOutput): + """ + Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the + last hidden states. This class also adds the loss term from the text decoder as well as the image-text similarity + scores. + + Args: + itm_score (`torch.FloatTensor`): + The image-text similarity scores. + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Languge modeling loss from the text decoder. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + vision_pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*): + Last layer hidden-state of the vision of the vision-only branch of the model. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + question_embeds (`torch.FloatTensor`): + The question embeddings obtained by the text projection layer. + """ + + itm_score: Optional[torch.FloatTensor] = None + loss: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + vision_pooler_output: Optional[torch.FloatTensor] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + question_embeds: Optional[Tuple[torch.FloatTensor]] = None + + # Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2 class Blip2VisionEmbeddings(nn.Module): def __init__(self, config: Blip2VisionConfig): @@ -816,6 +861,10 @@ def __init__(self, config, layer_idx): else: self.has_cross_attention = False + if config.qformer_text_input: + self.intermediate = Blip2QFormerIntermediate(config) + self.output = Blip2QFormerOutput(config) + self.intermediate_query = Blip2QFormerIntermediate(config) self.output_query = Blip2QFormerOutput(config) @@ -1003,6 +1052,59 @@ def custom_forward(*inputs): ) +# Adapted from https://github.com/salesforce/LAVIS/blob/main/lavis/models/blip2_models/Qformer.py#L51 +class Blip2TextEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, + input_ids: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if input_ids is not None: + input_ids = input_ids.to(self.word_embeddings.weight.device) + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + class Blip2QFormerModel(Blip2PreTrainedModel): """ Querying Transformer (Q-Former), used in BLIP-2. @@ -1012,8 +1114,11 @@ def __init__(self, config: Blip2QFormerConfig): super().__init__(config) self.config = config - self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + if config.qformer_text_input: + self.embeddings = Blip2TextEmbeddings(config) + else: + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) self.encoder = Blip2QFormerEncoder(config) @@ -1080,7 +1185,7 @@ def get_extended_attention_mask( def forward( self, - query_embeds: torch.FloatTensor, + query_embeds: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, @@ -1090,6 +1195,8 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + input_ids: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`): @@ -1116,6 +1223,9 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if input_ids is None and query_embeds is None: + raise ValueError("You have to specify query_embeds when input_ids is None") + # past_key_values_length past_key_values_length = ( past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0 @@ -1123,15 +1233,25 @@ def forward( query_length = query_embeds.shape[1] if query_embeds is not None else 0 - embedding_output = self.layernorm(query_embeds) - embedding_output = self.dropout(embedding_output) + if hasattr(self, "embeddings"): + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = self.layernorm(query_embeds) + embedding_output = self.dropout(embedding_output) input_shape = embedding_output.size()[:-1] batch_size, seq_length = input_shape device = embedding_output.device if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device, dtype=torch.long + ) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. @@ -1149,7 +1269,7 @@ def forward( if type(encoder_attention_mask) == list: encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] elif encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device, dtype=torch.long) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) @@ -1192,6 +1312,56 @@ def forward( ) +# Copied from transformers.models.blip.modeling_blip_text.BlipTextPredictionHeadTransform with Blip->Blip2 +class Blip2TextPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# Copied from transformers.models.blip.modeling_blip_text.BlipTextLMPredictionHead with Blip->Blip2 +class Blip2TextLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = Blip2TextPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Copied from transformers.models.blip.modeling_blip_text.BlipTextOnlyMLMHead with Blip->Blip2 +class Blip2TextOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = Blip2TextLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + @add_start_docstrings( """ BLIP-2 Model for generating text and image features. The model consists of a vision encoder, Querying Transformer @@ -1211,17 +1381,27 @@ def __init__(self, config: Blip2Config): self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) self.qformer = Blip2QFormerModel(config.qformer_config) - self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) - if config.use_decoder_only_language_model: - language_model = AutoModelForCausalLM.from_config(config.text_config) + if self.config.qformer_config.qformer_text_input: + self._keep_in_fp32_modules = [] # fp16 compatibility + + # vision projection layer + self.vision_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) + + # text projection layer + self.text_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) + else: - language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) + self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) + if config.use_decoder_only_language_model: + language_model = AutoModelForCausalLM.from_config(config.text_config) + else: + language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) - # Update _tied_weights_keys using the base model used. - if language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] + # Update _tied_weights_keys using the base model used. + if language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] - self.language_model = language_model + self.language_model = language_model # Initialize weights and apply final processing self.post_init() @@ -1236,7 +1416,10 @@ def set_output_embeddings(self, new_embeddings): self.language_model.set_output_embeddings(new_embeddings) def get_output_embeddings(self) -> nn.Module: - return self.language_model.get_output_embeddings() + if hasattr(self, "language_model"): + return self.language_model.get_output_embeddings() + else: + return self.qformer.get_output_embeddings() def get_encoder(self): return self.language_model.get_encoder() @@ -1245,7 +1428,7 @@ def get_decoder(self): return self.language_model.get_decoder() def _tie_weights(self): - if not self.config.use_decoder_only_language_model: + if not self.config.use_decoder_only_language_model and hasattr(self, "language_model"): self.language_model.encoder.embed_tokens = self.language_model.shared self.language_model.decoder.embed_tokens = self.language_model.shared @@ -1288,29 +1471,42 @@ def get_text_features( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if self.config.use_decoder_only_language_model: - text_outputs = self.language_model( + if hasattr(self, "language_model"): + if self.config.use_decoder_only_language_model: + text_outputs = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + else: + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + + text_outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + labels=labels, + ) + + return text_outputs + else: + query_outputs = self.qformer( input_ids=input_ids, + # query_embeds=query_tokens, attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, return_dict=return_dict, ) - else: - inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + embeds = query_outputs.last_hidden_state + text_features = self.text_proj(embeds) + text_features = normalize(text_features, dim=-1) - text_outputs = self.language_model( - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - labels=labels, - ) - - return text_outputs + return text_features @add_start_docstrings_to_model_forward(BLIP_2_VISION_INPUTS_DOCSTRING) def get_image_features( @@ -1358,7 +1554,26 @@ def get_image_features( return_dict=return_dict, ) - return vision_outputs + if not hasattr(self, "vision_proj"): + return vision_outputs + + else: + image_embeds = vision_outputs[0] + + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_outputs = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=return_dict, + ) + + embeds = query_outputs.last_hidden_state + image_features = normalize(self.vision_proj(embeds), dim=-1) + + return image_features @add_start_docstrings_to_model_forward(BLIP_2_INPUTS_DOCSTRING) def get_qformer_features( @@ -1884,3 +2099,141 @@ def generate( ) return outputs + + +@add_start_docstrings( + """ + BLIP-2 Model with a vision and text projector, and a classification head on top. The model is used in the context + of image-text retrieval. Given an image and a text, the model returns the probability of the text being relevant to + the image. + """, + BLIP_2_START_DOCSTRING, +) +class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): + main_input_name = "pixel_values" + config_class = Blip2Config + + def __init__(self, config: Blip2Config): + super().__init__(config) + + self.vision_model = Blip2VisionModel(config.vision_config) + + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + + self.qformer = Blip2QFormerModel(config.qformer_config) + self.cls = Blip2TextOnlyMLMHead(config.qformer_config) + + # vision projection layer + self.vision_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) + + # text projection layer + self.text_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) + + # image text matching head + self.itm_head = nn.Linear(config.qformer_config.hidden_size, 2) + + self.temp = nn.Parameter(0.07 * torch.ones([])) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BLIP_2_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Blip2ImageTextMatchingModelOutput, config_class=Blip2VisionConfig) + def forward( + self, + pixel_values: torch.FloatTensor, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, + use_itm_head: Optional[bool] = True, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Blip2ImageTextMatchingModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Blip2ForImageTextRetrieval + + >>> model = Blip2ForImageTextRetrieval.from_pretrained("jpizarrom/blip2-itm-vit-g") + >>> processor = AutoProcessor.from_pretrained("jpizarrom/blip2-itm-vit-g") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text = "an image of a cat" + + >>> inputs = processor(images=image, text=text, return_tensors="pt") + >>> outputs = model(**inputs) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[0] + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + if use_itm_head: + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(query_tokens.device) + attention_mask = torch.cat([query_attention_mask, attention_mask], dim=1) + + question_embeds = self.qformer( + input_ids=input_ids, + query_embeds=query_tokens, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=return_dict, + ) + question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state + + output = self.itm_head(question_embeds[:, : query_tokens.size(1), :]) + output = output.mean(dim=1) + else: + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_outputs = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=return_dict, + ) + image_embeds = query_outputs[0] if not return_dict else query_outputs.last_hidden_state + + question_embeds = self.qformer( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=return_dict, + ) + question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state + + image_feat = normalize(self.vision_proj(image_embeds), dim=-1) + text_feat = normalize(self.text_proj(question_embeds[:, 0, :]), dim=-1) + + outputs = torch.bmm(image_feat, text_feat.unsqueeze(-1)) + output, _ = torch.max(outputs, dim=1) + + if not return_dict: + outputs = (output, vision_outputs[0]) + vision_outputs[2:] + (question_embeds,) + return tuple(output for output in outputs if output is not None) + + return Blip2ImageTextMatchingModelOutput( + itm_score=output, + last_hidden_state=vision_outputs.last_hidden_state, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + question_embeds=question_embeds, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index aad1018e193752..e9e6b2c6c748d3 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1503,6 +1503,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Blip2ForImageTextRetrieval(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Blip2Model(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index 66d59465a7c58e..d8fc2690302564 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -23,7 +23,14 @@ import requests from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig -from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device +from transformers.testing_utils import ( + require_torch, + require_torch_gpu, + require_torch_multi_gpu, + require_vision, + slow, + torch_device, +) from transformers.utils import is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -41,7 +48,7 @@ import torch from torch import nn - from transformers import Blip2ForConditionalGeneration, Blip2Model, Blip2VisionModel + from transformers import Blip2ForConditionalGeneration, Blip2ForImageTextRetrieval, Blip2Model, Blip2VisionModel from transformers.models.blip_2.modeling_blip_2 import BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST @@ -224,6 +231,7 @@ def __init__( initializer_range=0.02, bos_token_id=0, scope=None, + qformer_text_input=False, ): self.parent = parent self.batch_size = batch_size @@ -243,6 +251,7 @@ def __init__( self.initializer_range = initializer_range self.scope = scope self.bos_token_id = bos_token_id + self.qformer_text_input = qformer_text_input def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -275,6 +284,7 @@ def get_config(self): max_position_embeddings=self.max_position_embeddings, initializer_range=self.initializer_range, bos_token_id=self.bos_token_id, + qformer_text_input=self.qformer_text_input, ) @@ -817,6 +827,211 @@ def test_initialization(self): ) +class Blip2TextRetrievalModelTester: + def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training=True): + if vision_kwargs is None: + vision_kwargs = {} + if qformer_kwargs is None: + qformer_kwargs = {"qformer_text_input": True} + text_kwargs = {} + + self.parent = parent + self.vision_model_tester = Blip2VisionModelTester(parent, **vision_kwargs) + self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs) + self.text_model_tester = Blip2TextModelTester(parent, **text_kwargs) + self.is_training = is_training + + def prepare_config_and_inputs(self): + _, input_ids, attention_mask = self.qformer_model_tester.prepare_config_and_inputs() + _, pixel_values = self.vision_model_tester.prepare_config_and_inputs() + + config = self.get_config() + + return config, input_ids, attention_mask, pixel_values + + def get_config(self): + return Blip2Config.from_vision_qformer_text_configs( + vision_config=self.vision_model_tester.get_config(), + qformer_config=self.qformer_model_tester.get_config(), + text_config=self.text_model_tester.get_config(), + ) + + def create_and_check_model(self, config, input_ids, attention_mask, pixel_values): + model = Blip2ForImageTextRetrieval(config).to(torch_device).eval() + with torch.no_grad(): + result = model(pixel_values, input_ids, attention_mask) + + image_size = (self.vision_model_tester.image_size, self.vision_model_tester.image_size) + patch_size = (self.vision_model_tester.patch_size, self.vision_model_tester.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + + self.parent.assertEqual( + result.itm_score.shape, + (self.vision_model_tester.batch_size, 2), + ) + self.parent.assertEqual( + result.last_hidden_state.shape, + (self.vision_model_tester.batch_size, num_patches + 1, self.qformer_model_tester.hidden_size), + ) + self.parent.assertEqual( + result.question_embeds.shape, + ( + self.text_model_tester.batch_size, + self.vision_model_tester.hidden_size + input_ids.shape[1], + self.qformer_model_tester.hidden_size, + ), + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, attention_mask, pixel_values = config_and_inputs + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + } + return config, inputs_dict + + +@require_torch +class Blip2TextRetrievalModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (Blip2ForImageTextRetrieval,) if is_torch_available() else () + fx_compatible = False + test_head_masking = False + test_pruning = False + test_resize_embeddings = False + test_attention_outputs = False + test_torchscript = False + + # TODO add or skip tests + test_model_outputs_equivalence = False + test_tied_weights_keys = False + test_hidden_states_output = False + test_inputs_embeds = False + test_model_common_attributes = False + test_retain_grad_hidden_states_attentions = False + + def setUp(self): + self.model_tester = Blip2TextRetrievalModelTester(self) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + if model.config.is_encoder_decoder: + expected_arg_names = [ + "pixel_values", + "input_ids", + "attention_mask", + ] + + expected_arg_names.extend(["use_itm_head"] if "use_itm_head" in arg_names else ["decoder_input_ids"]) + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + else: + # TODO + raise NotImplementedError + + def test_load_vision_qformer_text_config(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + # Save Blip2Config and check if we can load Blip2VisionConfig from it + with tempfile.TemporaryDirectory() as tmp_dir_name: + config.save_pretrained(tmp_dir_name) + vision_config = Blip2VisionConfig.from_pretrained(tmp_dir_name) + self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict()) + + # Save Blip2Config and check if we can load Blip2QFormerConfig from it + with tempfile.TemporaryDirectory() as tmp_dir_name: + config.save_pretrained(tmp_dir_name) + qformer_config = Blip2QFormerConfig.from_pretrained(tmp_dir_name) + self.assertDictEqual(config.qformer_config.to_dict(), qformer_config.to_dict()) + + @slow + def test_model_from_pretrained(self): + for model_name in ["jpizarrom/blip2-itm-vit-g"]: + for model_class in self.all_model_classes + (Blip2Model,): + model = model_class.from_pretrained(model_name) + self.assertIsNotNone(model) + + def test_get_text_features(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + inputs_dict = { + "input_ids": torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]).to(torch_device), + "attention_mask": torch.LongTensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]).to(torch_device), + } + + model = Blip2Model(config).to(torch_device) + model.eval() + text_features = model.get_text_features(**inputs_dict) + self.assertEqual(text_features[0].shape, (10, config.image_text_hidden_size)) + + def test_get_image_features(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + keys_to_pop = ["input_ids", "attention_mask"] + + for key in keys_to_pop: + inputs_dict.pop(key) + + model = Blip2Model(config).to(torch_device) + model.eval() + image_features = model.get_image_features(**inputs_dict) + self.assertEqual( + image_features.shape, # [12, 32, 256] + ( + self.model_tester.vision_model_tester.batch_size, + config.vision_config.hidden_size, + config.image_text_hidden_size, + ), + ) + + def test_training(self): + pass + + def test_training_gradient_checkpointing(self): + pass + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + # check if `logit_scale` is initilized as per the original implementation + if name == "logit_scale": + self.assertAlmostEqual( + param.data.item(), + np.log(1 / 0.07), + delta=1e-3, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + elif name == "temp": + self.assertAlmostEqual( + param.data.item(), + 0.07, + delta=1e-3, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + else: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + # We will verify our results on an image of cute cats def prepare_img(): url = "https://huggingface.co/hf-internal-testing/blip-test-image/resolve/main/demo.jpg" @@ -993,3 +1208,87 @@ def test_inference_t5_multi_gpu(self): [0, 3, 7, 152, 67, 839, 1], ) self.assertEqual(generated_text, "san diego") + + @require_torch_gpu + def test_inference_itm_features(self): + processor = Blip2Processor.from_pretrained("jpizarrom/blip2-itm-vit-g") + model = Blip2Model.from_pretrained( + "jpizarrom/blip2-itm-vit-g", + ).to(torch_device) + + # image features + image = prepare_img() + image_inputs = processor(images=image, return_tensors="pt").to(torch_device) + image_features = model.get_image_features(**image_inputs) + expected_image_features = torch.tensor( + [ + -0.0946953147649765, + -0.07541415840387344, + 0.03312666341662407, + 0.053536128252744675, + 0.03368198126554489, + -0.013867943547666073, + ] + ).to(torch_device) + self.assertTrue(torch.allclose(image_features[0][0][:6], expected_image_features, atol=1e-4)) + + # text features + text_inputs = processor(text="a woman sitting on the beach with a dog", padding=True, return_tensors="pt").to( + torch_device + ) + text_features = model.get_text_features(**text_inputs) + expected_text_features = torch.tensor( + [ + -0.10836730897426605, + 0.05315554141998291, + -0.028310950845479965, + 0.016979066655039787, + 0.0865054652094841, + -0.046645939350128174, + ] + ).to(torch_device) + self.assertTrue(torch.allclose(text_features[0][0][:6], expected_text_features, atol=1e-4)) + + # check similarity + similarity = (image_features @ text_features[:, 0, :].t()).max() + expected_similarity = torch.tensor(0.44385525584220886).to(torch_device) + self.assertTrue(torch.allclose(similarity, expected_similarity, atol=1e-4)) + + @require_torch_gpu + def test_inference_itm_features_fp16(self): + processor = Blip2Processor.from_pretrained("jpizarrom/blip2-itm-vit-g") + model = Blip2Model.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16).to(torch_device) + + # image features + image = prepare_img() + image_inputs = processor(images=image, return_tensors="pt").to(torch_device, dtype=torch.float16) + image_features = model.get_image_features(**image_inputs) + expected_image_features = [ + -0.093994140625, + -0.075927734375, + 0.031890869140625, + 0.053009033203125, + 0.0352783203125, + -0.01190185546875, + ] + self.assertTrue(np.allclose(image_features[0][0][:6].tolist(), expected_image_features, atol=1e-4)) + + # text features + text_inputs = processor(text="a woman sitting on the beach with a dog", padding=True, return_tensors="pt").to( + torch_device + ) + text_features = model.get_text_features(**text_inputs) + expected_text_features = [ + -0.1082763671875, + 0.053192138671875, + -0.02825927734375, + 0.0169830322265625, + 0.08648681640625, + -0.04656982421875, + ] + self.assertTrue(np.allclose(text_features[0][0][:6].tolist(), expected_text_features, atol=1e-4)) + + # check similarity + similarity = (image_features @ text_features[:, 0, :].t()).max() + expected_similarity = 0.44384765625 + self.assertTrue(np.allclose(similarity.item(), expected_similarity, atol=1e-4)) diff --git a/utils/check_repo.py b/utils/check_repo.py index c8bd228eaa776e..792501b9d540d2 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -146,6 +146,7 @@ "ClapAudioModel", "ClapAudioModelWithProjection", "Blip2ForConditionalGeneration", + "Blip2ForImageTextRetrieval", "Blip2QFormerModel", "Blip2VisionModel", "ErnieMForInformationExtraction", From 8eb718a6cc291705d325fd96035039acfa8fa19f Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Sun, 24 Sep 2023 10:59:30 +0200 Subject: [PATCH 02/25] add Blip2ModelWithProjection --- docs/source/en/model_doc/blip-2.md | 4 + src/transformers/__init__.py | 2 + src/transformers/models/blip_2/__init__.py | 2 + .../models/blip_2/modeling_blip_2.py | 398 ++++++++++++++---- src/transformers/utils/dummy_pt_objects.py | 7 + tests/models/blip_2/test_modeling_blip_2.py | 29 +- utils/check_repo.py | 1 + 7 files changed, 360 insertions(+), 83 deletions(-) diff --git a/docs/source/en/model_doc/blip-2.md b/docs/source/en/model_doc/blip-2.md index b6045d2b4f4c6e..c160af37764651 100644 --- a/docs/source/en/model_doc/blip-2.md +++ b/docs/source/en/model_doc/blip-2.md @@ -93,3 +93,7 @@ If you're interested in submitting a resource to be included here, please feel f [[autodoc]] Blip2ForImageTextRetrieval - forward + +## Blip2ModelWithProjection + +[[autodoc]] Blip2ModelWithProjection diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3b9d6bc87137db..ce17207d51b75c 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1340,6 +1340,7 @@ "Blip2ForConditionalGeneration", "Blip2ForImageTextRetrieval", "Blip2Model", + "Blip2ModelWithProjection", "Blip2PreTrainedModel", "Blip2QFormerModel", "Blip2VisionModel", @@ -5360,6 +5361,7 @@ Blip2ForConditionalGeneration, Blip2ForImageTextRetrieval, Blip2Model, + Blip2ModelWithProjection, Blip2PreTrainedModel, Blip2QFormerModel, Blip2VisionModel, diff --git a/src/transformers/models/blip_2/__init__.py b/src/transformers/models/blip_2/__init__.py index cd2396d3288837..d2d0b0dc0a5603 100644 --- a/src/transformers/models/blip_2/__init__.py +++ b/src/transformers/models/blip_2/__init__.py @@ -35,6 +35,7 @@ _import_structure["modeling_blip_2"] = [ "BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST", "Blip2Model", + "Blip2ModelWithProjection", "Blip2QFormerModel", "Blip2PreTrainedModel", "Blip2ForConditionalGeneration", @@ -62,6 +63,7 @@ Blip2ForConditionalGeneration, Blip2ForImageTextRetrieval, Blip2Model, + Blip2ModelWithProjection, Blip2PreTrainedModel, Blip2QFormerModel, Blip2VisionModel, diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index b71724d8db0865..13651a51eb2545 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -54,6 +54,18 @@ ] +# Copied from transformers.models.clip.modeling_clip.contrastive_loss +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) + + +# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->blip2 +def blip2_loss(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(similarity.t()) + return (caption_loss + image_loss) / 2.0 + + @dataclass class Blip2ForConditionalGenerationModelOutput(ModelOutput): """ @@ -87,6 +99,45 @@ def to_tuple(self) -> Tuple[Any]: ) +@dataclass +# Copied from transformers.models.blip.modeling_blip.BlipOutput with Blip->Blip2 +class Blip2Output(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`Blip2TextModel`]. + image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of + [`Blip2VisionModel`]. + text_model_output(`BaseModelOutputWithPooling`): + The output of the [`Blip2TextModel`]. + vision_model_output(`BaseModelOutputWithPooling`): + The output of the [`Blip2VisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: torch.FloatTensor = None + logits_per_text: torch.FloatTensor = None + text_embeds: torch.FloatTensor = None + image_embeds: torch.FloatTensor = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + @dataclass # Copied from transformers.models.blip.modeling_blip.BlipImageTextMatchingModelOutput with Blip->Blip2 class Blip2ImageTextMatchingModelOutput(ModelOutput): @@ -1381,27 +1432,17 @@ def __init__(self, config: Blip2Config): self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) self.qformer = Blip2QFormerModel(config.qformer_config) - if self.config.qformer_config.qformer_text_input: - self._keep_in_fp32_modules = [] # fp16 compatibility - - # vision projection layer - self.vision_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) - - # text projection layer - self.text_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) - + self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) + if config.use_decoder_only_language_model: + language_model = AutoModelForCausalLM.from_config(config.text_config) else: - self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) - if config.use_decoder_only_language_model: - language_model = AutoModelForCausalLM.from_config(config.text_config) - else: - language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) + language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) - # Update _tied_weights_keys using the base model used. - if language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] + # Update _tied_weights_keys using the base model used. + if language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] - self.language_model = language_model + self.language_model = language_model # Initialize weights and apply final processing self.post_init() @@ -1416,10 +1457,7 @@ def set_output_embeddings(self, new_embeddings): self.language_model.set_output_embeddings(new_embeddings) def get_output_embeddings(self) -> nn.Module: - if hasattr(self, "language_model"): - return self.language_model.get_output_embeddings() - else: - return self.qformer.get_output_embeddings() + return self.language_model.get_output_embeddings() def get_encoder(self): return self.language_model.get_encoder() @@ -1428,7 +1466,7 @@ def get_decoder(self): return self.language_model.get_decoder() def _tie_weights(self): - if not self.config.use_decoder_only_language_model and hasattr(self, "language_model"): + if not self.config.use_decoder_only_language_model: self.language_model.encoder.embed_tokens = self.language_model.shared self.language_model.decoder.embed_tokens = self.language_model.shared @@ -1471,42 +1509,29 @@ def get_text_features( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if hasattr(self, "language_model"): - if self.config.use_decoder_only_language_model: - text_outputs = self.language_model( - input_ids=input_ids, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - else: - inputs_embeds = self.language_model.get_input_embeddings()(input_ids) - - text_outputs = self.language_model( - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - labels=labels, - ) - - return text_outputs - else: - query_outputs = self.qformer( + if self.config.use_decoder_only_language_model: + text_outputs = self.language_model( input_ids=input_ids, - # query_embeds=query_tokens, attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, return_dict=return_dict, ) - embeds = query_outputs.last_hidden_state - text_features = self.text_proj(embeds) - text_features = normalize(text_features, dim=-1) + else: + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) - return text_features + text_outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + labels=labels, + ) + + return text_outputs @add_start_docstrings_to_model_forward(BLIP_2_VISION_INPUTS_DOCSTRING) def get_image_features( @@ -1554,26 +1579,7 @@ def get_image_features( return_dict=return_dict, ) - if not hasattr(self, "vision_proj"): - return vision_outputs - - else: - image_embeds = vision_outputs[0] - - image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) - - query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) - query_outputs = self.qformer( - query_embeds=query_tokens, - encoder_hidden_states=image_embeds, - encoder_attention_mask=image_attention_mask, - return_dict=return_dict, - ) - - embeds = query_outputs.last_hidden_state - image_features = normalize(self.vision_proj(embeds), dim=-1) - - return image_features + return vision_outputs @add_start_docstrings_to_model_forward(BLIP_2_INPUTS_DOCSTRING) def get_qformer_features( @@ -1764,6 +1770,248 @@ def forward( ) +@add_start_docstrings( + """ + BLIP-2 Model for generating text and image features. The model consists of a vision encoder, Querying Transformer + (Q-Former) and a language model. + """, + BLIP_2_START_DOCSTRING, +) +class Blip2ModelWithProjection(Blip2PreTrainedModel): + config_class = Blip2Config + main_input_name = "pixel_values" + _keep_in_fp32_modules = [] + + def __init__(self, config: Blip2Config): + super().__init__(config) + + self.vision_model = Blip2VisionModel(config.vision_config) + + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + self.qformer = Blip2QFormerModel(config.qformer_config) + + # vision projection layer + self.vision_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) + + # text projection layer + self.text_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) + + self.temp = nn.Parameter(0.07 * torch.ones([])) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BLIP_2_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Returns: + text_outputs (`CausalLMOutputWithPast`, or `tuple(torch.FloatTensor)` if `return_dict=False`): + The language model outputs. If `return_dict=True`, the output is a [`CausalLMOutputWithPast`] that + contains the language model logits, the past key values and the hidden states if + `output_hidden_states=True`. + Examples: + ```python + >>> import torch + >>> from transformers import AutoProcessor, Blip2ModelWithProjection + + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> model = Blip2ModelWithProjection.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16) + + >>> model.to(device) # doctest: +IGNORE_RESULT + + >>> processor = AutoProcessor.from_pretrained("jpizarrom/blip2-itm-vit-g") + >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], return_tensors="pt").to(device) + >>> text_features = model.get_text_features(**inputs) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + query_outputs = self.qformer( + input_ids=input_ids, + # query_embeds=query_tokens, + attention_mask=attention_mask, + return_dict=return_dict, + ) + embeds = query_outputs.last_hidden_state + text_features = self.text_proj(embeds) + text_features = normalize(text_features, dim=-1) + + return text_features + + @add_start_docstrings_to_model_forward(BLIP_2_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Returns: + vision_outputs (`BaseModelOutputWithPooling` or tuple of `torch.FloatTensor`): + The vision model outputs. If `return_dict=True`, the output is a [`BaseModelOutputWithPooling`] that + contains the image features, the pooled image features and the hidden states if + `output_hidden_states=True`. + Examples: + ```python + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Blip2ModelWithProjection + + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> model = Blip2ModelWithProjection.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16) + + >>> model.to(device) # doctest: +IGNORE_RESULT + + >>> processor = AutoProcessor.from_pretrained("jpizarrom/blip2-itm-vit-g") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) + >>> image_outputs = model.get_image_features(**inputs) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[0] + + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_outputs = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=return_dict, + ) + + embeds = query_outputs.last_hidden_state + image_features = normalize(self.vision_proj(embeds), dim=-1) + + return image_features + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Blip2Output]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Blip2ModelWithProjection + + >>> model = Blip2ModelWithProjection.from_pretrained("jpizarrom/blip2-itm-vit-g") + >>> processor = AutoProcessor.from_pretrained("jpizarrom/blip2-itm-vit-g") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor( + ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True + ... ) + + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + # Use BLIP2 model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[0] + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_outputs = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=return_dict, + ) + + image_embeds = query_outputs[0] if not return_dict else query_outputs.last_hidden_state + image_embeds = self.vision_proj(image_embeds) + + question_outputs = self.qformer( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=return_dict, + ) + question_embeds = question_outputs[0] if not return_dict else question_outputs.last_hidden_state + question_embeds = self.text_proj(question_embeds[:, 0, :]) + + image_feat = normalize(image_embeds, dim=-1) + text_feat = normalize(question_embeds, dim=-1) + + # text-query similarity + sim_t2q = torch.matmul(text_feat.unsqueeze(1).unsqueeze(1), image_feat.permute(0, 2, 1)).squeeze(2) + + # text-image similarity: aggregate across all query tokens + logits_per_text, _ = sim_t2q.max(-1) + logits_per_text = logits_per_text / self.temp + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + loss = blip2_loss(logits_per_text) + + if not return_dict: + output = (logits_per_image, logits_per_text, question_embeds, image_embeds, query_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return Blip2Output( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=question_embeds, + image_embeds=image_embeds, + text_model_output=query_outputs, + vision_model_output=vision_outputs, + ) + + @add_start_docstrings( """ BLIP-2 Model for generating text given an image and an optional text prompt. The model consists of a vision @@ -2132,8 +2380,6 @@ def __init__(self, config: Blip2Config): # image text matching head self.itm_head = nn.Linear(config.qformer_config.hidden_size, 2) - self.temp = nn.Parameter(0.07 * torch.ones([])) - # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index e9e6b2c6c748d3..387f3a95bee6e6 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1517,6 +1517,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Blip2ModelWithProjection(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Blip2PreTrainedModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index d8fc2690302564..74591ceb68f288 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -48,7 +48,13 @@ import torch from torch import nn - from transformers import Blip2ForConditionalGeneration, Blip2ForImageTextRetrieval, Blip2Model, Blip2VisionModel + from transformers import ( + Blip2ForConditionalGeneration, + Blip2ForImageTextRetrieval, + Blip2Model, + Blip2ModelWithProjection, + Blip2VisionModel, + ) from transformers.models.blip_2.modeling_blip_2 import BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST @@ -895,7 +901,14 @@ def prepare_config_and_inputs_for_common(self): @require_torch class Blip2TextRetrievalModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (Blip2ForImageTextRetrieval,) if is_torch_available() else () + all_model_classes = ( + ( + Blip2ForImageTextRetrieval, + Blip2ModelWithProjection, + ) + if is_torch_available() + else () + ) fx_compatible = False test_head_masking = False test_pruning = False @@ -934,7 +947,7 @@ def test_forward_signature(self): "attention_mask", ] - expected_arg_names.extend(["use_itm_head"] if "use_itm_head" in arg_names else ["decoder_input_ids"]) + expected_arg_names.extend(["use_itm_head"] if "use_itm_head" in arg_names else []) self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) else: # TODO @@ -970,7 +983,7 @@ def test_get_text_features(self): "attention_mask": torch.LongTensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]).to(torch_device), } - model = Blip2Model(config).to(torch_device) + model = Blip2ModelWithProjection(config).to(torch_device) model.eval() text_features = model.get_text_features(**inputs_dict) self.assertEqual(text_features[0].shape, (10, config.image_text_hidden_size)) @@ -983,7 +996,7 @@ def test_get_image_features(self): for key in keys_to_pop: inputs_dict.pop(key) - model = Blip2Model(config).to(torch_device) + model = Blip2ModelWithProjection(config).to(torch_device) model.eval() image_features = model.get_image_features(**inputs_dict) self.assertEqual( @@ -1212,7 +1225,7 @@ def test_inference_t5_multi_gpu(self): @require_torch_gpu def test_inference_itm_features(self): processor = Blip2Processor.from_pretrained("jpizarrom/blip2-itm-vit-g") - model = Blip2Model.from_pretrained( + model = Blip2ModelWithProjection.from_pretrained( "jpizarrom/blip2-itm-vit-g", ).to(torch_device) @@ -1257,7 +1270,9 @@ def test_inference_itm_features(self): @require_torch_gpu def test_inference_itm_features_fp16(self): processor = Blip2Processor.from_pretrained("jpizarrom/blip2-itm-vit-g") - model = Blip2Model.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16).to(torch_device) + model = Blip2ModelWithProjection.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16).to( + torch_device + ) # image features image = prepare_img() diff --git a/utils/check_repo.py b/utils/check_repo.py index 792501b9d540d2..446c7e97968f23 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -147,6 +147,7 @@ "ClapAudioModelWithProjection", "Blip2ForConditionalGeneration", "Blip2ForImageTextRetrieval", + "Blip2ModelWithProjection", "Blip2QFormerModel", "Blip2VisionModel", "ErnieMForInformationExtraction", From 786da89cd2d4824967ec016e9fe4eaac5a65c97c Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Sun, 24 Sep 2023 11:41:19 +0200 Subject: [PATCH 03/25] use gpu on Blip2ForImageTextRetrieval.forward doctest --- src/transformers/models/blip_2/modeling_blip_2.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 13651a51eb2545..fe06b96f936865 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -2401,18 +2401,23 @@ def forward( Examples: ```python + >>> import torch >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, Blip2ForImageTextRetrieval + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + >>> model = Blip2ForImageTextRetrieval.from_pretrained("jpizarrom/blip2-itm-vit-g") >>> processor = AutoProcessor.from_pretrained("jpizarrom/blip2-itm-vit-g") + >>> model.to(device) # doctest: +IGNORE_RESULT + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> text = "an image of a cat" - >>> inputs = processor(images=image, text=text, return_tensors="pt") + >>> inputs = processor(images=image, text=text, return_tensors="pt").to(device) >>> outputs = model(**inputs) ``` """ From 188e3a733b1f997782fb7b0a31fc123df8ccdb68 Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Sun, 24 Sep 2023 11:50:05 +0200 Subject: [PATCH 04/25] use gpu on Blip2ModelWithProjection.forward doctest --- src/transformers/models/blip_2/modeling_blip_2.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index fe06b96f936865..64c28d0e90a6ab 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -1928,19 +1928,24 @@ def forward( Examples: ```python + >>> import torch >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, Blip2ModelWithProjection - >>> model = Blip2ModelWithProjection.from_pretrained("jpizarrom/blip2-itm-vit-g") + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> model = Blip2ModelWithProjection.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16) >>> processor = AutoProcessor.from_pretrained("jpizarrom/blip2-itm-vit-g") + >>> model.to(device) # doctest: +IGNORE_RESULT + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor( ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True - ... ) + ... ).to(device, torch.float16) >>> outputs = model(**inputs) >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score From d1cc037a8eeabcdb63ade14d1fc18e789ea11683 Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Sun, 24 Sep 2023 12:01:51 +0200 Subject: [PATCH 05/25] use float16 on Blip2ForImageTextRetrieval.forward doctest --- src/transformers/models/blip_2/modeling_blip_2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 64c28d0e90a6ab..7f10b6edce2a18 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -2365,6 +2365,7 @@ def generate( class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): main_input_name = "pixel_values" config_class = Blip2Config + _keep_in_fp32_modules = [] def __init__(self, config: Blip2Config): super().__init__(config) @@ -2413,7 +2414,7 @@ def forward( >>> device = "cuda" if torch.cuda.is_available() else "cpu" - >>> model = Blip2ForImageTextRetrieval.from_pretrained("jpizarrom/blip2-itm-vit-g") + >>> model = Blip2ForImageTextRetrieval.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16) >>> processor = AutoProcessor.from_pretrained("jpizarrom/blip2-itm-vit-g") >>> model.to(device) # doctest: +IGNORE_RESULT @@ -2422,7 +2423,7 @@ def forward( >>> image = Image.open(requests.get(url, stream=True).raw) >>> text = "an image of a cat" - >>> inputs = processor(images=image, text=text, return_tensors="pt").to(device) + >>> inputs = processor(images=image, text=text, return_tensors="pt").to(device, torch.float16) >>> outputs = model(**inputs) ``` """ From 5f72231fdb6dac5c8c5a21b8b6b3ae48e9cce101 Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Sun, 24 Sep 2023 12:25:09 +0200 Subject: [PATCH 06/25] add _tied_weights_keys to Blip2ForImageTextRetrieval --- .../models/blip_2/modeling_blip_2.py | 1 + tests/models/blip_2/test_modeling_blip_2.py | 26 ++++++++++++------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 7f10b6edce2a18..e3aa6d2e0a2657 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -2366,6 +2366,7 @@ class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): main_input_name = "pixel_values" config_class = Blip2Config _keep_in_fp32_modules = [] + _tied_weights_keys = ["cls.predictions.decoder.bias"] def __init__(self, config: Blip2Config): super().__init__(config) diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index 74591ceb68f288..5affd2e3793be0 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -916,14 +916,6 @@ class Blip2TextRetrievalModelTest(ModelTesterMixin, unittest.TestCase): test_attention_outputs = False test_torchscript = False - # TODO add or skip tests - test_model_outputs_equivalence = False - test_tied_weights_keys = False - test_hidden_states_output = False - test_inputs_embeds = False - test_model_common_attributes = False - test_retain_grad_hidden_states_attentions = False - def setUp(self): self.model_tester = Blip2TextRetrievalModelTester(self) @@ -931,6 +923,22 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + @unittest.skip(reason="Hidden_states is tested in individual model tests") + def test_hidden_states_output(self): + pass + + @unittest.skip(reason="Inputs_embeds is tested in individual model tests") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Retain_grad is tested in individual model tests") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="Blip2Model does not have input/output embeddings") + def test_model_common_attributes(self): + pass + def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -971,7 +979,7 @@ def test_load_vision_qformer_text_config(self): @slow def test_model_from_pretrained(self): for model_name in ["jpizarrom/blip2-itm-vit-g"]: - for model_class in self.all_model_classes + (Blip2Model,): + for model_class in self.all_model_classes: model = model_class.from_pretrained(model_name) self.assertIsNotNone(model) From e099caaad7dfce8f530efd5fb10a66821fe46307 Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Sun, 24 Sep 2023 18:17:10 +0200 Subject: [PATCH 07/25] add temp param to Blip2ForImageTextRetrieval --- src/transformers/models/blip_2/modeling_blip_2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index e3aa6d2e0a2657..65b23c3746db7a 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -2387,6 +2387,8 @@ def __init__(self, config: Blip2Config): # image text matching head self.itm_head = nn.Linear(config.qformer_config.hidden_size, 2) + self.temp = nn.Parameter(0.07 * torch.ones([])) + # Initialize weights and apply final processing self.post_init() From a1ab97fa44612fbc8dbf9f0823b6b6d369b6c673 Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Fri, 6 Oct 2023 14:47:57 +0200 Subject: [PATCH 08/25] add Blip2TextModelWithProjection and Blip2VisionModelWithProjection --- docs/source/en/model_doc/blip-2.md | 8 + src/transformers/__init__.py | 4 + src/transformers/models/blip_2/__init__.py | 4 + .../models/blip_2/modeling_blip_2.py | 232 +++++++++++++ src/transformers/utils/dummy_pt_objects.py | 14 + tests/models/blip_2/test_modeling_blip_2.py | 310 +++++++++++++++++- utils/check_repo.py | 2 + 7 files changed, 572 insertions(+), 2 deletions(-) diff --git a/docs/source/en/model_doc/blip-2.md b/docs/source/en/model_doc/blip-2.md index c160af37764651..828d12ac43990d 100644 --- a/docs/source/en/model_doc/blip-2.md +++ b/docs/source/en/model_doc/blip-2.md @@ -97,3 +97,11 @@ If you're interested in submitting a resource to be included here, please feel f ## Blip2ModelWithProjection [[autodoc]] Blip2ModelWithProjection + +## Blip2TextModelWithProjection + +[[autodoc]] Blip2TextModelWithProjection + +## Blip2VisionModelWithProjection + +[[autodoc]] Blip2VisionModelWithProjection diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ce17207d51b75c..66a5959fcde55a 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1343,7 +1343,9 @@ "Blip2ModelWithProjection", "Blip2PreTrainedModel", "Blip2QFormerModel", + "Blip2TextModelWithProjection", "Blip2VisionModel", + "Blip2VisionModelWithProjection", ] ) _import_structure["models.bloom"].extend( @@ -5364,7 +5366,9 @@ Blip2ModelWithProjection, Blip2PreTrainedModel, Blip2QFormerModel, + Blip2TextModelWithProjection, Blip2VisionModel, + Blip2VisionModelWithProjection, ) from .models.bloom import ( BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/models/blip_2/__init__.py b/src/transformers/models/blip_2/__init__.py index d2d0b0dc0a5603..226c52a36f4419 100644 --- a/src/transformers/models/blip_2/__init__.py +++ b/src/transformers/models/blip_2/__init__.py @@ -36,11 +36,13 @@ "BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST", "Blip2Model", "Blip2ModelWithProjection", + "Blip2VisionModelWithProjection", "Blip2QFormerModel", "Blip2PreTrainedModel", "Blip2ForConditionalGeneration", "Blip2ForImageTextRetrieval", "Blip2VisionModel", + "Blip2TextModelWithProjection", ] if TYPE_CHECKING: @@ -66,7 +68,9 @@ Blip2ModelWithProjection, Blip2PreTrainedModel, Blip2QFormerModel, + Blip2TextModelWithProjection, Blip2VisionModel, + Blip2VisionModelWithProjection, ) else: diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 65b23c3746db7a..dbca2c7df8c355 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -182,6 +182,66 @@ class Blip2ImageTextMatchingModelOutput(ModelOutput): question_embeds: Optional[Tuple[torch.FloatTensor]] = None +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Blip2 +class Blip2TextModelOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Blip2 +class Blip2VisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + # Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2 class Blip2VisionEmbeddings(nn.Module): def __init__(self, config: Blip2VisionConfig): @@ -2017,6 +2077,178 @@ def forward( ) +class Blip2TextModelWithProjection(Blip2PreTrainedModel): + config_class = Blip2Config + supports_gradient_checkpointing = False + _keep_in_fp32_modules = [] + + def __init__(self, config: Blip2Config): + super().__init__(config) + + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + self.qformer = Blip2QFormerModel(config.qformer_config) + + # text projection layer + self.text_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BLIP_2_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Blip2TextModelOutput, config_class=Blip2Config) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Blip2TextModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, Blip2TextModelWithProjection + + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> model = Blip2TextModelWithProjection.from_pretrained( + ... "jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16 + ... ) + + >>> model.to(device) # doctest: +IGNORE_RESULT + + >>> processor = AutoProcessor.from_pretrained("jpizarrom/blip2-itm-vit-g") + + >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], return_tensors="pt").to(device) + + >>> outputs = model(**inputs) + >>> text_embeds = outputs.text_embeds + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.qformer( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[0] if not return_dict else text_outputs.last_hidden_state + + text_embeds = self.text_proj(pooled_output) + text_embeds = normalize(text_embeds, dim=-1) + + if not return_dict: + outputs = (text_embeds, text_outputs[0]) + text_outputs[2:] + return tuple(output for output in outputs if output is not None) + + return Blip2TextModelOutput( + text_embeds=text_embeds, + last_hidden_state=text_outputs.last_hidden_state, + hidden_states=text_outputs.hidden_states, + attentions=text_outputs.attentions, + ) + + +# Adapted from transformers.models.clip.modeling_clip.CLIPVisionModelWithProjection +class Blip2VisionModelWithProjection(Blip2PreTrainedModel): + config_class = Blip2Config + main_input_name = "pixel_values" + + def __init__(self, config: Blip2VisionConfig): + super().__init__(config) + + self.vision_model = Blip2VisionModel(config.vision_config) + + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + self.qformer = Blip2QFormerModel(config.qformer_config) + + # vision projection layer + self.vision_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BLIP_2_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Blip2VisionModelOutput, config_class=Blip2VisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Blip2VisionModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Blip2VisionModelWithProjection + + >>> model = Blip2VisionModelWithProjection.from_pretrained("jpizarrom/blip2-itm-vit-g") + >>> processor = AutoProcessor.from_pretrained("jpizarrom/blip2-itm-vit-g") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> image_embeds = outputs.image_embeds + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[0] if not return_dict else vision_outputs.last_hidden_state + + image_attention_mask = torch.ones(pooled_output.size()[:-1], dtype=torch.long, device=pooled_output.device) + + query_tokens = self.query_tokens.expand(pooled_output.shape[0], -1, -1) + + query_outputs = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=pooled_output, + encoder_attention_mask=image_attention_mask, + return_dict=return_dict, + ) + + embeds = query_outputs[0] if not return_dict else query_outputs.last_hidden_state + image_embeds = self.vision_proj(embeds) + image_embeds = normalize(image_embeds, dim=-1) + + if not return_dict: + outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:] + return tuple(output for output in outputs if output is not None) + + return Blip2VisionModelOutput( + image_embeds=image_embeds, + last_hidden_state=vision_outputs.last_hidden_state, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + ) + + @add_start_docstrings( """ BLIP-2 Model for generating text given an image and an optional text prompt. The model consists of a vision diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 387f3a95bee6e6..bb75c37c5af577 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1538,6 +1538,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Blip2TextModelWithProjection(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Blip2VisionModel(metaclass=DummyObject): _backends = ["torch"] @@ -1545,6 +1552,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Blip2VisionModelWithProjection(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index 5affd2e3793be0..6b41eb604c8edb 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -53,7 +53,9 @@ Blip2ForImageTextRetrieval, Blip2Model, Blip2ModelWithProjection, + Blip2TextModelWithProjection, Blip2VisionModel, + Blip2VisionModelWithProjection, ) from transformers.models.blip_2.modeling_blip_2 import BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST @@ -484,7 +486,7 @@ def test_forward_signature(self): self.assertListEqual(arg_names[:1], expected_arg_names) def test_load_vision_qformer_text_config(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config, _ = self.model_tester.prepare_config_and_inputs_for_common() # Save Blip2Config and check if we can load Blip2VisionConfig from it with tempfile.TemporaryDirectory() as tmp_dir_name: @@ -745,7 +747,7 @@ def test_forward_signature(self): self.assertListEqual(arg_names[:1], expected_arg_names) def test_load_vision_qformer_text_config(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config, _ = self.model_tester.prepare_config_and_inputs_for_common() # Save Blip2Config and check if we can load Blip2VisionConfig from it with tempfile.TemporaryDirectory() as tmp_dir_name: @@ -833,6 +835,310 @@ def test_initialization(self): ) +class Blip2TextModelWithProjectionTester: + def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training=True): + if vision_kwargs is None: + vision_kwargs = {} + if qformer_kwargs is None: + qformer_kwargs = {"qformer_text_input": True} + text_kwargs = {} + + self.parent = parent + self.vision_model_tester = Blip2VisionModelTester(parent, **vision_kwargs) + self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs) + self.text_model_tester = Blip2TextModelTester(parent, **text_kwargs) + self.is_training = is_training + + def prepare_config_and_inputs(self): + _, input_ids, attention_mask = self.qformer_model_tester.prepare_config_and_inputs() + # _, pixel_values = self.vision_model_tester.prepare_config_and_inputs() + + config = self.get_config() + + return config, input_ids, attention_mask + + def get_config(self): + return Blip2Config.from_vision_qformer_text_configs( + vision_config=self.vision_model_tester.get_config(), + qformer_config=self.qformer_model_tester.get_config(), + text_config=self.text_model_tester.get_config(), + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, attention_mask = config_and_inputs + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + # "pixel_values": pixel_values, + } + return config, inputs_dict + + def create_and_check_model(self, config, input_ids, attention_mask): + model = Blip2TextModelWithProjection(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(input_ids, attention_mask=attention_mask, output_attentions=True, output_hidden_states=True) + + self.parent.assertEqual( + result.last_hidden_state.shape, + (self.vision_model_tester.batch_size, input_ids.shape[1], self.qformer_model_tester.hidden_size), + ) + self.parent.assertEqual( + result.text_embeds.shape, + ( + self.text_model_tester.batch_size, + input_ids.shape[1], + config.image_text_hidden_size, + ), + ) + + with torch.no_grad(): + result2 = model( + input_ids, + attention_mask=attention_mask, + return_dict=not config.use_return_dict, + output_attentions=True, + output_hidden_states=True, + ) + + self.parent.assertTrue(torch.allclose(result.text_embeds, result2[0])) + self.parent.assertTrue(torch.allclose(result.last_hidden_state, result2[1])) + self.parent.assertTrue(torch.allclose(result.hidden_states[0], result2[2][0])) + self.parent.assertTrue(torch.allclose(result.hidden_states[1], result2[2][1])) + self.parent.assertTrue(torch.allclose(result.attentions[0], result2[3][0])) + self.parent.assertTrue(torch.allclose(result.attentions[1], result2[3][1])) + + +@require_torch +class Blip2TextModelWithProjectionTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (Blip2TextModelWithProjection,) if is_torch_available() else () + fx_compatible = False + test_pruning = False + test_head_masking = False + + test_resize_embeddings = False + test_attention_outputs = False + test_torchscript = False + + def setUp(self): + self.model_tester = Blip2TextModelWithProjectionTester(self) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_training(self): + pass + + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="Hidden_states is tested in individual model tests") + def test_hidden_states_output(self): + pass + + @unittest.skip(reason="Blip2TextModelWithProjection does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Retain_grad is tested in individual model tests") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="Blip2TextModelWithProjection does not have input/output embeddings") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="Blip2TextModelWithProjection has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="Blip2TextModelWithProjection has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_to_base(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + if model.config.is_encoder_decoder: + expected_arg_names = [ + "input_ids", + "attention_mask", + "position_ids", + ] + + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + else: + # TODO + raise NotImplementedError + + @slow + def test_model_from_pretrained(self): + for model_name in ["jpizarrom/blip2-itm-vit-g"]: + model = Blip2TextModelWithProjection.from_pretrained(model_name) + self.assertIsNotNone(model) + self.assertTrue(hasattr(model, "text_proj")) + + +class Blip2VisionModelWithProjectionTester: + def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training=True): + if vision_kwargs is None: + vision_kwargs = {} + if qformer_kwargs is None: + qformer_kwargs = {"qformer_text_input": True} + text_kwargs = {} + + self.parent = parent + self.vision_model_tester = Blip2VisionModelTester(parent, **vision_kwargs) + self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs) + self.text_model_tester = Blip2TextModelTester(parent, **text_kwargs) + self.is_training = is_training + + def prepare_config_and_inputs(self): + _, pixel_values = self.vision_model_tester.prepare_config_and_inputs() + + config = self.get_config() + + return config, pixel_values + + def get_config(self): + return Blip2Config.from_vision_qformer_text_configs( + vision_config=self.vision_model_tester.get_config(), + qformer_config=self.qformer_model_tester.get_config(), + text_config=self.text_model_tester.get_config(), + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = { + "pixel_values": pixel_values, + } + return config, inputs_dict + + def create_and_check_model(self, config, pixel_values): + model = Blip2VisionModelWithProjection(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(pixel_values, output_attentions=True, output_hidden_states=True) + + self.parent.assertEqual( + result.last_hidden_state.shape, + ( + self.vision_model_tester.batch_size, + self.vision_model_tester.seq_length, + self.qformer_model_tester.hidden_size, + ), + ) + self.parent.assertEqual( + result.image_embeds.shape, + ( + self.text_model_tester.batch_size, + config.vision_config.hidden_size, + config.image_text_hidden_size, + ), + ) + + with torch.no_grad(): + result2 = model( + pixel_values, + return_dict=not config.use_return_dict, + output_attentions=True, + output_hidden_states=True, + ) + + self.parent.assertTrue(torch.allclose(result.image_embeds, result2[0])) + self.parent.assertTrue(torch.allclose(result.last_hidden_state, result2[1])) + self.parent.assertTrue(torch.allclose(result.hidden_states[0], result2[2][0])) + self.parent.assertTrue(torch.allclose(result.hidden_states[1], result2[2][1])) + self.parent.assertTrue(torch.allclose(result.attentions[0], result2[3][0])) + self.parent.assertTrue(torch.allclose(result.attentions[1], result2[3][1])) + + +@require_torch +class Blip2VisionModelWithProjectionTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (Blip2VisionModelWithProjection,) if is_torch_available() else () + fx_compatible = False + test_pruning = False + test_head_masking = False + + test_resize_embeddings = False + test_attention_outputs = False + test_torchscript = False + + def setUp(self): + self.model_tester = Blip2VisionModelWithProjectionTester(self) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_training(self): + pass + + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="Hidden_states is tested in individual model tests") + def test_hidden_states_output(self): + pass + + @unittest.skip(reason="Blip2TextModelWithProjection does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Retain_grad is tested in individual model tests") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="Blip2TextModelWithProjection does not have input/output embeddings") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="Blip2TextModelWithProjection has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="Blip2TextModelWithProjection has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_to_base(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + if model.config.is_encoder_decoder: + expected_arg_names = [ + "pixel_values", + ] + + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + else: + # TODO + raise NotImplementedError + + @slow + def test_model_from_pretrained(self): + for model_name in ["jpizarrom/blip2-itm-vit-g"]: + model = Blip2VisionModelWithProjection.from_pretrained(model_name) + self.assertIsNotNone(model) + self.assertTrue(hasattr(model, "vision_proj")) + + class Blip2TextRetrievalModelTester: def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training=True): if vision_kwargs is None: diff --git a/utils/check_repo.py b/utils/check_repo.py index 446c7e97968f23..1fabe6eff69e6f 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -148,6 +148,8 @@ "Blip2ForConditionalGeneration", "Blip2ForImageTextRetrieval", "Blip2ModelWithProjection", + "Blip2TextModelWithProjection", + "Blip2VisionModelWithProjection", "Blip2QFormerModel", "Blip2VisionModel", "ErnieMForInformationExtraction", From 18d534084dd6bc5364016061553270d880506171 Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Fri, 6 Oct 2023 15:34:37 +0200 Subject: [PATCH 09/25] use cuda and float16 in doctest Blip2VisionModelWithProjection --- src/transformers/models/blip_2/modeling_blip_2.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index dbca2c7df8c355..8efbd2968b14c0 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -2161,6 +2161,7 @@ def forward( class Blip2VisionModelWithProjection(Blip2PreTrainedModel): config_class = Blip2Config main_input_name = "pixel_values" + _keep_in_fp32_modules = [] def __init__(self, config: Blip2VisionConfig): super().__init__(config) @@ -2196,13 +2197,18 @@ def forward( >>> import requests >>> from transformers import AutoProcessor, Blip2VisionModelWithProjection - >>> model = Blip2VisionModelWithProjection.from_pretrained("jpizarrom/blip2-itm-vit-g") + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + >>> processor = AutoProcessor.from_pretrained("jpizarrom/blip2-itm-vit-g") + >>> model = Blip2VisionModelWithProjection.from_pretrained( + ... "jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16 + ... ) + >>> model.to(device) # doctest: +IGNORE_RESULT >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) - >>> inputs = processor(images=image, return_tensors="pt") + >>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) >>> outputs = model(**inputs) >>> image_embeds = outputs.image_embeds From 0a227d03eabba0b6b9aec88f33e962e5e16cb136 Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Fri, 6 Oct 2023 19:29:36 +0200 Subject: [PATCH 10/25] rename Blip2ModelWithProjection to Blip2ModelWithoutLM --- docs/source/en/model_doc/blip-2.md | 8 +- src/transformers/__init__.py | 6 +- src/transformers/models/blip_2/__init__.py | 6 +- .../models/blip_2/configuration_blip_2.py | 108 +++++++++++++++++- .../models/blip_2/modeling_blip_2.py | 39 ++++--- src/transformers/utils/dummy_pt_objects.py | 2 +- tests/models/blip_2/test_modeling_blip_2.py | 41 +++---- utils/check_repo.py | 2 +- 8 files changed, 161 insertions(+), 51 deletions(-) diff --git a/docs/source/en/model_doc/blip-2.md b/docs/source/en/model_doc/blip-2.md index 828d12ac43990d..250049d86a28a1 100644 --- a/docs/source/en/model_doc/blip-2.md +++ b/docs/source/en/model_doc/blip-2.md @@ -61,6 +61,10 @@ If you're interested in submitting a resource to be included here, please feel f [[autodoc]] Blip2QFormerConfig +## Blip2ModelWithoutLMConfig + +[[autodoc]] Blip2ModelWithoutLMConfig + ## Blip2Processor [[autodoc]] Blip2Processor @@ -94,9 +98,9 @@ If you're interested in submitting a resource to be included here, please feel f [[autodoc]] Blip2ForImageTextRetrieval - forward -## Blip2ModelWithProjection +## Blip2ModelWithoutLM -[[autodoc]] Blip2ModelWithProjection +[[autodoc]] Blip2ModelWithoutLM ## Blip2TextModelWithProjection diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 66a5959fcde55a..25ae59e570c656 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -209,6 +209,7 @@ "models.blip_2": [ "BLIP_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Blip2Config", + "Blip2ModelWithoutLMConfig", "Blip2Processor", "Blip2QFormerConfig", "Blip2VisionConfig", @@ -1340,7 +1341,7 @@ "Blip2ForConditionalGeneration", "Blip2ForImageTextRetrieval", "Blip2Model", - "Blip2ModelWithProjection", + "Blip2ModelWithoutLM", "Blip2PreTrainedModel", "Blip2QFormerModel", "Blip2TextModelWithProjection", @@ -4347,6 +4348,7 @@ from .models.blip_2 import ( BLIP_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Blip2Config, + Blip2ModelWithoutLMConfig, Blip2Processor, Blip2QFormerConfig, Blip2VisionConfig, @@ -5363,7 +5365,7 @@ Blip2ForConditionalGeneration, Blip2ForImageTextRetrieval, Blip2Model, - Blip2ModelWithProjection, + Blip2ModelWithoutLM, Blip2PreTrainedModel, Blip2QFormerModel, Blip2TextModelWithProjection, diff --git a/src/transformers/models/blip_2/__init__.py b/src/transformers/models/blip_2/__init__.py index 226c52a36f4419..3defe7b923a413 100644 --- a/src/transformers/models/blip_2/__init__.py +++ b/src/transformers/models/blip_2/__init__.py @@ -20,6 +20,7 @@ "configuration_blip_2": [ "BLIP_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Blip2Config", + "Blip2ModelWithoutLMConfig", "Blip2QFormerConfig", "Blip2VisionConfig", ], @@ -35,7 +36,7 @@ _import_structure["modeling_blip_2"] = [ "BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST", "Blip2Model", - "Blip2ModelWithProjection", + "Blip2ModelWithoutLM", "Blip2VisionModelWithProjection", "Blip2QFormerModel", "Blip2PreTrainedModel", @@ -49,6 +50,7 @@ from .configuration_blip_2 import ( BLIP_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Blip2Config, + Blip2ModelWithoutLMConfig, Blip2QFormerConfig, Blip2VisionConfig, ) @@ -65,7 +67,7 @@ Blip2ForConditionalGeneration, Blip2ForImageTextRetrieval, Blip2Model, - Blip2ModelWithProjection, + Blip2ModelWithoutLM, Blip2PreTrainedModel, Blip2QFormerModel, Blip2TextModelWithProjection, diff --git a/src/transformers/models/blip_2/configuration_blip_2.py b/src/transformers/models/blip_2/configuration_blip_2.py index ef246b129cad16..ac0359c55a6ee1 100644 --- a/src/transformers/models/blip_2/configuration_blip_2.py +++ b/src/transformers/models/blip_2/configuration_blip_2.py @@ -117,7 +117,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) # get the vision config dict if we are loading from Blip2Config - if config_dict.get("model_type") == "blip-2": + if config_dict.get("model_type") in ["blip-2", "blip-2-without-lm"]: config_dict = config_dict["vision_config"] if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: @@ -237,7 +237,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) # get the qformer config dict if we are loading from Blip2Config - if config_dict.get("model_type") == "blip-2": + if config_dict.get("model_type") in ["blip-2", "blip-2-without-lm"]: config_dict = config_dict["qformer_config"] if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: @@ -364,3 +364,107 @@ def from_vision_qformer_text_configs( text_config=text_config.to_dict(), **kwargs, ) + + +class Blip2ModelWithoutLMConfig(PretrainedConfig): + r""" + [`Blip2ModelWithoutLMConfig`] is the configuration class to store the configuration of a + [`Blip2ForImageTextRetrieval`, `Blip2ModelWithoutLM`]. It is used to instantiate a BLIP-2 model according to the + specified arguments, defining the vision model and Q-Former model. Instantiating a configuration with the defaults + will yield a similar configuration to that of the BLIP-2 + [jpizarrom/blip2-itm-vit-g](https://huggingface.co/jpizarrom/blip2-itm-vit-g) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Blip2VisionConfig`]. + qformer_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Blip2QFormerConfig`]. + num_query_tokens (`int`, *optional*, defaults to 32): + The number of query tokens passed through the Transformer. + + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import ( + ... Blip2VisionConfig, + ... Blip2QFormerConfig, + ... Blip2ModelWithoutLMConfig, + ... Blip2ForImageTextRetrieval, + ... ) + + >>> # Initializing a Blip2ModelWithoutLMConfig with jpizarrom/blip2-itm-vit-g style configuration + >>> configuration = Blip2ModelWithoutLMConfig() + + >>> # Initializing a Blip2ForImageTextRetrieval (with random weights) from the jpizarrom/blip2-itm-vit-g style configuration + >>> model = Blip2ForImageTextRetrieval(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a Blip2Config from a Blip2VisionConfig, Blip2QFormerConfig and any PretrainedConfig + + >>> # Initializing BLIP-2 vision and BLIP-2 Q-Former model configurations + >>> vision_config = Blip2VisionConfig() + >>> qformer_config = Blip2QFormerConfig() + + >>> config = Blip2ModelWithoutLMConfig.from_vision_qformer_text_configs(vision_config, qformer_config) + ```""" + + model_type = "blip-2-without-lm" + + def __init__( + self, + vision_config=None, + qformer_config=None, + num_query_tokens=32, + image_text_hidden_size=256, + **kwargs, + ): + super().__init__(**kwargs) + + if vision_config is None: + vision_config = {} + logger.info("vision_config is None. initializing the Blip2VisionConfig with default values.") + + if qformer_config is None: + qformer_config = {} + logger.info("qformer_config is None. Initializing the Blip2QFormerConfig with default values.") + + self.vision_config = Blip2VisionConfig(**vision_config) + self.qformer_config = Blip2QFormerConfig(**qformer_config) + + self.is_encoder_decoder = self.qformer_config.is_encoder_decoder + + self.num_query_tokens = num_query_tokens + self.image_text_hidden_size = image_text_hidden_size + self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size + # self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + self.initializer_factor = 1.0 + self.initializer_range = 0.02 + + @classmethod + def from_vision_qformer_text_configs( + cls, + vision_config: Blip2VisionConfig, + qformer_config: Blip2QFormerConfig, + **kwargs, + ): + r""" + Instantiate a [`Blip2Config`] (or a derived class) from a BLIP-2 vision model, Q-Former and language model + configurations. + + Returns: + [`Blip2Config`]: An instance of a configuration object + """ + + return cls( + vision_config=vision_config.to_dict(), + qformer_config=qformer_config.to_dict(), + **kwargs, + ) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 8efbd2968b14c0..5fd49f170e811a 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -41,7 +41,7 @@ replace_return_docstrings, ) from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM -from .configuration_blip_2 import Blip2Config, Blip2QFormerConfig, Blip2VisionConfig +from .configuration_blip_2 import Blip2Config, Blip2ModelWithoutLMConfig, Blip2QFormerConfig, Blip2VisionConfig logger = logging.get_logger(__name__) @@ -1837,12 +1837,12 @@ def forward( """, BLIP_2_START_DOCSTRING, ) -class Blip2ModelWithProjection(Blip2PreTrainedModel): - config_class = Blip2Config +class Blip2ModelWithoutLM(Blip2PreTrainedModel): + config_class = Blip2ModelWithoutLMConfig main_input_name = "pixel_values" _keep_in_fp32_modules = [] - def __init__(self, config: Blip2Config): + def __init__(self, config: Blip2ModelWithoutLMConfig): super().__init__(config) self.vision_model = Blip2VisionModel(config.vision_config) @@ -1879,11 +1879,11 @@ def get_text_features( Examples: ```python >>> import torch - >>> from transformers import AutoProcessor, Blip2ModelWithProjection + >>> from transformers import AutoProcessor, Blip2ModelWithoutLM >>> device = "cuda" if torch.cuda.is_available() else "cpu" - >>> model = Blip2ModelWithProjection.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16) + >>> model = Blip2ModelWithoutLM.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16) >>> model.to(device) # doctest: +IGNORE_RESULT @@ -1928,11 +1928,11 @@ def get_image_features( >>> import torch >>> from PIL import Image >>> import requests - >>> from transformers import AutoProcessor, Blip2ModelWithProjection + >>> from transformers import AutoProcessor, Blip2ModelWithoutLM >>> device = "cuda" if torch.cuda.is_available() else "cpu" - >>> model = Blip2ModelWithProjection.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16) + >>> model = Blip2ModelWithoutLM.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16) >>> model.to(device) # doctest: +IGNORE_RESULT @@ -1991,11 +1991,11 @@ def forward( >>> import torch >>> from PIL import Image >>> import requests - >>> from transformers import AutoProcessor, Blip2ModelWithProjection + >>> from transformers import AutoProcessor, Blip2ModelWithoutLM >>> device = "cuda" if torch.cuda.is_available() else "cpu" - >>> model = Blip2ModelWithProjection.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16) + >>> model = Blip2ModelWithoutLM.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16) >>> processor = AutoProcessor.from_pretrained("jpizarrom/blip2-itm-vit-g") >>> model.to(device) # doctest: +IGNORE_RESULT @@ -2061,6 +2061,7 @@ def forward( loss = None if return_loss: loss = blip2_loss(logits_per_text) + # TODO add Image-text Matching and Image Captioning loss computation if not return_dict: output = (logits_per_image, logits_per_text, question_embeds, image_embeds, query_outputs, vision_outputs) @@ -2078,11 +2079,11 @@ def forward( class Blip2TextModelWithProjection(Blip2PreTrainedModel): - config_class = Blip2Config + config_class = Blip2ModelWithoutLMConfig supports_gradient_checkpointing = False _keep_in_fp32_modules = [] - def __init__(self, config: Blip2Config): + def __init__(self, config: Blip2ModelWithoutLMConfig): super().__init__(config) self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) @@ -2095,7 +2096,7 @@ def __init__(self, config: Blip2Config): self.post_init() @add_start_docstrings_to_model_forward(BLIP_2_TEXT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Blip2TextModelOutput, config_class=Blip2Config) + @replace_return_docstrings(output_type=Blip2TextModelOutput, config_class=Blip2ModelWithoutLMConfig) def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -2159,11 +2160,11 @@ def forward( # Adapted from transformers.models.clip.modeling_clip.CLIPVisionModelWithProjection class Blip2VisionModelWithProjection(Blip2PreTrainedModel): - config_class = Blip2Config + config_class = Blip2ModelWithoutLMConfig main_input_name = "pixel_values" _keep_in_fp32_modules = [] - def __init__(self, config: Blip2VisionConfig): + def __init__(self, config: Blip2ModelWithoutLMConfig): super().__init__(config) self.vision_model = Blip2VisionModel(config.vision_config) @@ -2178,7 +2179,7 @@ def __init__(self, config: Blip2VisionConfig): self.post_init() @add_start_docstrings_to_model_forward(BLIP_2_VISION_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Blip2VisionModelOutput, config_class=Blip2VisionConfig) + @replace_return_docstrings(output_type=Blip2VisionModelOutput, config_class=Blip2ModelWithoutLMConfig) def forward( self, pixel_values: Optional[torch.FloatTensor] = None, @@ -2602,11 +2603,11 @@ def generate( ) class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): main_input_name = "pixel_values" - config_class = Blip2Config + config_class = Blip2ModelWithoutLMConfig _keep_in_fp32_modules = [] _tied_weights_keys = ["cls.predictions.decoder.bias"] - def __init__(self, config: Blip2Config): + def __init__(self, config: Blip2ModelWithoutLMConfig): super().__init__(config) self.vision_model = Blip2VisionModel(config.vision_config) @@ -2631,7 +2632,7 @@ def __init__(self, config: Blip2Config): self.post_init() @add_start_docstrings_to_model_forward(BLIP_2_VISION_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Blip2ImageTextMatchingModelOutput, config_class=Blip2VisionConfig) + @replace_return_docstrings(output_type=Blip2ImageTextMatchingModelOutput, config_class=Blip2ModelWithoutLMConfig) def forward( self, pixel_values: torch.FloatTensor, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index bb75c37c5af577..a605ddd62e1ae9 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1517,7 +1517,7 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class Blip2ModelWithProjection(metaclass=DummyObject): +class Blip2ModelWithoutLM(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index 6b41eb604c8edb..a7bd3095bbe23e 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -22,7 +22,7 @@ import numpy as np import requests -from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig +from transformers import CONFIG_MAPPING, Blip2Config, Blip2ModelWithoutLMConfig, Blip2QFormerConfig, Blip2VisionConfig from transformers.testing_utils import ( require_torch, require_torch_gpu, @@ -52,7 +52,7 @@ Blip2ForConditionalGeneration, Blip2ForImageTextRetrieval, Blip2Model, - Blip2ModelWithProjection, + Blip2ModelWithoutLM, Blip2TextModelWithProjection, Blip2VisionModel, Blip2VisionModelWithProjection, @@ -858,10 +858,9 @@ def prepare_config_and_inputs(self): return config, input_ids, attention_mask def get_config(self): - return Blip2Config.from_vision_qformer_text_configs( + return Blip2ModelWithoutLMConfig.from_vision_qformer_text_configs( vision_config=self.vision_model_tester.get_config(), qformer_config=self.qformer_model_tester.get_config(), - text_config=self.text_model_tester.get_config(), ) def prepare_config_and_inputs_for_common(self): @@ -969,6 +968,9 @@ def test_forward_signature(self): arg_names = [*signature.parameters.keys()] if model.config.is_encoder_decoder: + # TODO + raise NotImplementedError + else: expected_arg_names = [ "input_ids", "attention_mask", @@ -976,9 +978,6 @@ def test_forward_signature(self): ] self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) - else: - # TODO - raise NotImplementedError @slow def test_model_from_pretrained(self): @@ -1010,10 +1009,9 @@ def prepare_config_and_inputs(self): return config, pixel_values def get_config(self): - return Blip2Config.from_vision_qformer_text_configs( + return Blip2ModelWithoutLMConfig.from_vision_qformer_text_configs( vision_config=self.vision_model_tester.get_config(), qformer_config=self.qformer_model_tester.get_config(), - text_config=self.text_model_tester.get_config(), ) def prepare_config_and_inputs_for_common(self): @@ -1122,14 +1120,14 @@ def test_forward_signature(self): arg_names = [*signature.parameters.keys()] if model.config.is_encoder_decoder: + # TODO + raise NotImplementedError + else: expected_arg_names = [ "pixel_values", ] self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) - else: - # TODO - raise NotImplementedError @slow def test_model_from_pretrained(self): @@ -1162,10 +1160,9 @@ def prepare_config_and_inputs(self): return config, input_ids, attention_mask, pixel_values def get_config(self): - return Blip2Config.from_vision_qformer_text_configs( + return Blip2ModelWithoutLMConfig.from_vision_qformer_text_configs( vision_config=self.vision_model_tester.get_config(), qformer_config=self.qformer_model_tester.get_config(), - text_config=self.text_model_tester.get_config(), ) def create_and_check_model(self, config, input_ids, attention_mask, pixel_values): @@ -1210,7 +1207,7 @@ class Blip2TextRetrievalModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = ( ( Blip2ForImageTextRetrieval, - Blip2ModelWithProjection, + Blip2ModelWithoutLM, ) if is_torch_available() else () @@ -1255,6 +1252,9 @@ def test_forward_signature(self): arg_names = [*signature.parameters.keys()] if model.config.is_encoder_decoder: + # TODO + raise NotImplementedError + else: expected_arg_names = [ "pixel_values", "input_ids", @@ -1263,9 +1263,6 @@ def test_forward_signature(self): expected_arg_names.extend(["use_itm_head"] if "use_itm_head" in arg_names else []) self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) - else: - # TODO - raise NotImplementedError def test_load_vision_qformer_text_config(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -1297,7 +1294,7 @@ def test_get_text_features(self): "attention_mask": torch.LongTensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]).to(torch_device), } - model = Blip2ModelWithProjection(config).to(torch_device) + model = Blip2ModelWithoutLM(config).to(torch_device) model.eval() text_features = model.get_text_features(**inputs_dict) self.assertEqual(text_features[0].shape, (10, config.image_text_hidden_size)) @@ -1310,7 +1307,7 @@ def test_get_image_features(self): for key in keys_to_pop: inputs_dict.pop(key) - model = Blip2ModelWithProjection(config).to(torch_device) + model = Blip2ModelWithoutLM(config).to(torch_device) model.eval() image_features = model.get_image_features(**inputs_dict) self.assertEqual( @@ -1539,7 +1536,7 @@ def test_inference_t5_multi_gpu(self): @require_torch_gpu def test_inference_itm_features(self): processor = Blip2Processor.from_pretrained("jpizarrom/blip2-itm-vit-g") - model = Blip2ModelWithProjection.from_pretrained( + model = Blip2ModelWithoutLM.from_pretrained( "jpizarrom/blip2-itm-vit-g", ).to(torch_device) @@ -1584,7 +1581,7 @@ def test_inference_itm_features(self): @require_torch_gpu def test_inference_itm_features_fp16(self): processor = Blip2Processor.from_pretrained("jpizarrom/blip2-itm-vit-g") - model = Blip2ModelWithProjection.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16).to( + model = Blip2ModelWithoutLM.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16).to( torch_device ) diff --git a/utils/check_repo.py b/utils/check_repo.py index 1fabe6eff69e6f..6c82795e180561 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -147,7 +147,7 @@ "ClapAudioModelWithProjection", "Blip2ForConditionalGeneration", "Blip2ForImageTextRetrieval", - "Blip2ModelWithProjection", + "Blip2ModelWithoutLM", "Blip2TextModelWithProjection", "Blip2VisionModelWithProjection", "Blip2QFormerModel", From 43fb263c79466085fb9377fec731768c2b7836fe Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Fri, 6 Oct 2023 19:55:04 +0200 Subject: [PATCH 11/25] add image_text_hidden_size to docstring --- src/transformers/models/blip_2/configuration_blip_2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/blip_2/configuration_blip_2.py b/src/transformers/models/blip_2/configuration_blip_2.py index ac0359c55a6ee1..ca17984f3ff486 100644 --- a/src/transformers/models/blip_2/configuration_blip_2.py +++ b/src/transformers/models/blip_2/configuration_blip_2.py @@ -268,7 +268,8 @@ class Blip2Config(PretrainedConfig): Dictionary of configuration options used to initialize any [`PretrainedConfig`]. num_query_tokens (`int`, *optional*, defaults to 32): The number of query tokens passed through the Transformer. - + image_text_hidden_size (`int`, *optional*, defaults to 256): + Dimentionality of the hidden state of the image-text fusion layer. kwargs (*optional*): Dictionary of keyword arguments. @@ -384,7 +385,8 @@ class Blip2ModelWithoutLMConfig(PretrainedConfig): Dictionary of configuration options used to initialize [`Blip2QFormerConfig`]. num_query_tokens (`int`, *optional*, defaults to 32): The number of query tokens passed through the Transformer. - + image_text_hidden_size (`int`, *optional*, defaults to 256): + Dimentionality of the hidden state of the image-text fusion layer. kwargs (*optional*): Dictionary of keyword arguments. From f8b0ed50fdb7343eafee3ffd54176c0d08222f13 Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Fri, 6 Oct 2023 20:08:47 +0200 Subject: [PATCH 12/25] remove image_text_hidden_size from BlipConfig --- src/transformers/models/blip/configuration_blip.py | 4 ---- .../models/blip_2/configuration_blip_2.py | 13 +------------ 2 files changed, 1 insertion(+), 16 deletions(-) diff --git a/src/transformers/models/blip/configuration_blip.py b/src/transformers/models/blip/configuration_blip.py index 39760a7e22a96d..e96f37748bd80d 100644 --- a/src/transformers/models/blip/configuration_blip.py +++ b/src/transformers/models/blip/configuration_blip.py @@ -295,8 +295,6 @@ class BlipConfig(PretrainedConfig): Dimentionality of text and vision projection layers. logit_scale_init_value (`float`, *optional*, defaults to 2.6592): The inital value of the *logit_scale* paramter. Default is used as per the original BLIP implementation. - image_text_hidden_size (`int`, *optional*, defaults to 256): - Dimentionality of the hidden state of the image-text fusion layer. kwargs (*optional*): Dictionary of keyword arguments. @@ -331,7 +329,6 @@ def __init__( vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, - image_text_hidden_size=256, **kwargs, ): super().__init__(**kwargs) @@ -353,7 +350,6 @@ def __init__( self.logit_scale_init_value = logit_scale_init_value self.initializer_factor = 1.0 self.initializer_range = 0.02 - self.image_text_hidden_size = image_text_hidden_size @classmethod def from_text_vision_configs(cls, text_config: BlipTextConfig, vision_config: BlipVisionConfig, **kwargs): diff --git a/src/transformers/models/blip_2/configuration_blip_2.py b/src/transformers/models/blip_2/configuration_blip_2.py index ca17984f3ff486..b5803d809e0cd0 100644 --- a/src/transformers/models/blip_2/configuration_blip_2.py +++ b/src/transformers/models/blip_2/configuration_blip_2.py @@ -268,8 +268,6 @@ class Blip2Config(PretrainedConfig): Dictionary of configuration options used to initialize any [`PretrainedConfig`]. num_query_tokens (`int`, *optional*, defaults to 32): The number of query tokens passed through the Transformer. - image_text_hidden_size (`int`, *optional*, defaults to 256): - Dimentionality of the hidden state of the image-text fusion layer. kwargs (*optional*): Dictionary of keyword arguments. @@ -305,15 +303,7 @@ class Blip2Config(PretrainedConfig): model_type = "blip-2" - def __init__( - self, - vision_config=None, - qformer_config=None, - text_config=None, - num_query_tokens=32, - image_text_hidden_size=256, - **kwargs, - ): + def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs): super().__init__(**kwargs) if vision_config is None: @@ -337,7 +327,6 @@ def __init__( self.is_encoder_decoder = self.text_config.is_encoder_decoder self.num_query_tokens = num_query_tokens - self.image_text_hidden_size = image_text_hidden_size self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES self.initializer_factor = 1.0 From 401b8b8b0146f3abaf9d1d86dc1dad4016ec375c Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Fri, 6 Oct 2023 20:48:43 +0200 Subject: [PATCH 13/25] use Blip2ModelWithoutLMConfig in convert script --- .../models/blip_2/convert_blip_2_original_to_pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py b/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py index d2e1b27dfb0869..54e6add8de50a8 100644 --- a/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py +++ b/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py @@ -35,6 +35,7 @@ Blip2Config, Blip2ForConditionalGeneration, Blip2ForImageTextRetrieval, + Blip2ModelWithoutLMConfig, Blip2Processor, Blip2QFormerConfig, Blip2VisionConfig, @@ -124,9 +125,8 @@ def get_blip2_config(model_name, eos_token_id): raise ValueError("Model name not supported") if "itm" in model_name: - config = Blip2Config( + config = Blip2ModelWithoutLMConfig( vision_config=vision_config, - text_config=text_config, qformer_config=Blip2QFormerConfig(vocab_size=30523, qformer_text_input=True).to_dict(), ) else: From a0f714244b3e55447d0c97b36da0bb049baecd2d Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Fri, 6 Oct 2023 21:12:15 +0200 Subject: [PATCH 14/25] remove not used text_model_tester --- tests/models/blip_2/test_modeling_blip_2.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index a7bd3095bbe23e..f1869dbaa02be9 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -841,17 +841,14 @@ def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training= vision_kwargs = {} if qformer_kwargs is None: qformer_kwargs = {"qformer_text_input": True} - text_kwargs = {} self.parent = parent self.vision_model_tester = Blip2VisionModelTester(parent, **vision_kwargs) self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs) - self.text_model_tester = Blip2TextModelTester(parent, **text_kwargs) self.is_training = is_training def prepare_config_and_inputs(self): _, input_ids, attention_mask = self.qformer_model_tester.prepare_config_and_inputs() - # _, pixel_values = self.vision_model_tester.prepare_config_and_inputs() config = self.get_config() @@ -869,7 +866,6 @@ def prepare_config_and_inputs_for_common(self): inputs_dict = { "input_ids": input_ids, "attention_mask": attention_mask, - # "pixel_values": pixel_values, } return config, inputs_dict @@ -887,7 +883,7 @@ def create_and_check_model(self, config, input_ids, attention_mask): self.parent.assertEqual( result.text_embeds.shape, ( - self.text_model_tester.batch_size, + self.vision_model_tester.batch_size, input_ids.shape[1], config.image_text_hidden_size, ), @@ -993,12 +989,10 @@ def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training= vision_kwargs = {} if qformer_kwargs is None: qformer_kwargs = {"qformer_text_input": True} - text_kwargs = {} self.parent = parent self.vision_model_tester = Blip2VisionModelTester(parent, **vision_kwargs) self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs) - self.text_model_tester = Blip2TextModelTester(parent, **text_kwargs) self.is_training = is_training def prepare_config_and_inputs(self): @@ -1040,7 +1034,7 @@ def create_and_check_model(self, config, pixel_values): self.parent.assertEqual( result.image_embeds.shape, ( - self.text_model_tester.batch_size, + self.vision_model_tester.batch_size, config.vision_config.hidden_size, config.image_text_hidden_size, ), @@ -1143,12 +1137,10 @@ def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training= vision_kwargs = {} if qformer_kwargs is None: qformer_kwargs = {"qformer_text_input": True} - text_kwargs = {} self.parent = parent self.vision_model_tester = Blip2VisionModelTester(parent, **vision_kwargs) self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs) - self.text_model_tester = Blip2TextModelTester(parent, **text_kwargs) self.is_training = is_training def prepare_config_and_inputs(self): @@ -1185,7 +1177,7 @@ def create_and_check_model(self, config, input_ids, attention_mask, pixel_values self.parent.assertEqual( result.question_embeds.shape, ( - self.text_model_tester.batch_size, + self.vision_model_tester.batch_size, self.vision_model_tester.hidden_size + input_ids.shape[1], self.qformer_model_tester.hidden_size, ), From 46adfd5f8feb74dc4cf620b84c8bad3230b41f6f Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Fri, 6 Oct 2023 21:15:57 +0200 Subject: [PATCH 15/25] restore image_text_hidden_size in BlipConfig --- src/transformers/models/blip/configuration_blip.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/models/blip/configuration_blip.py b/src/transformers/models/blip/configuration_blip.py index e96f37748bd80d..39760a7e22a96d 100644 --- a/src/transformers/models/blip/configuration_blip.py +++ b/src/transformers/models/blip/configuration_blip.py @@ -295,6 +295,8 @@ class BlipConfig(PretrainedConfig): Dimentionality of text and vision projection layers. logit_scale_init_value (`float`, *optional*, defaults to 2.6592): The inital value of the *logit_scale* paramter. Default is used as per the original BLIP implementation. + image_text_hidden_size (`int`, *optional*, defaults to 256): + Dimentionality of the hidden state of the image-text fusion layer. kwargs (*optional*): Dictionary of keyword arguments. @@ -329,6 +331,7 @@ def __init__( vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, + image_text_hidden_size=256, **kwargs, ): super().__init__(**kwargs) @@ -350,6 +353,7 @@ def __init__( self.logit_scale_init_value = logit_scale_init_value self.initializer_factor = 1.0 self.initializer_range = 0.02 + self.image_text_hidden_size = image_text_hidden_size @classmethod def from_text_vision_configs(cls, text_config: BlipTextConfig, vision_config: BlipVisionConfig, **kwargs): From a2c098efcdd88c3e03b8b67f21cccac4721b91bd Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Tue, 17 Oct 2023 15:52:30 +0200 Subject: [PATCH 16/25] rename Blip2ModelWithoutLMConfig.from_vision_qformer_configs --- src/transformers/models/blip_2/configuration_blip_2.py | 10 +++++----- tests/models/blip_2/test_modeling_blip_2.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/blip_2/configuration_blip_2.py b/src/transformers/models/blip_2/configuration_blip_2.py index b5803d809e0cd0..e30e2c12cfc3d8 100644 --- a/src/transformers/models/blip_2/configuration_blip_2.py +++ b/src/transformers/models/blip_2/configuration_blip_2.py @@ -398,13 +398,13 @@ class Blip2ModelWithoutLMConfig(PretrainedConfig): >>> # Accessing the model configuration >>> configuration = model.config - >>> # We can also initialize a Blip2Config from a Blip2VisionConfig, Blip2QFormerConfig and any PretrainedConfig + >>> # We can also initialize a Blip2ModelWithoutLMConfig from a Blip2VisionConfig, Blip2QFormerConfig and any PretrainedConfig >>> # Initializing BLIP-2 vision and BLIP-2 Q-Former model configurations >>> vision_config = Blip2VisionConfig() >>> qformer_config = Blip2QFormerConfig() - >>> config = Blip2ModelWithoutLMConfig.from_vision_qformer_text_configs(vision_config, qformer_config) + >>> config = Blip2ModelWithoutLMConfig.from_vision_qformer_configs(vision_config, qformer_config) ```""" model_type = "blip-2-without-lm" @@ -440,18 +440,18 @@ def __init__( self.initializer_range = 0.02 @classmethod - def from_vision_qformer_text_configs( + def from_vision_qformer_configs( cls, vision_config: Blip2VisionConfig, qformer_config: Blip2QFormerConfig, **kwargs, ): r""" - Instantiate a [`Blip2Config`] (or a derived class) from a BLIP-2 vision model, Q-Former and language model + Instantiate a [`Blip2ModelWithoutLMConfig`] (or a derived class) from a BLIP-2 vision and Q-Former model configurations. Returns: - [`Blip2Config`]: An instance of a configuration object + [`Blip2ModelWithoutLMConfig`]: An instance of a configuration object """ return cls( diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index f1869dbaa02be9..7d996bf514b611 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -855,7 +855,7 @@ def prepare_config_and_inputs(self): return config, input_ids, attention_mask def get_config(self): - return Blip2ModelWithoutLMConfig.from_vision_qformer_text_configs( + return Blip2ModelWithoutLMConfig.from_vision_qformer_configs( vision_config=self.vision_model_tester.get_config(), qformer_config=self.qformer_model_tester.get_config(), ) @@ -1003,7 +1003,7 @@ def prepare_config_and_inputs(self): return config, pixel_values def get_config(self): - return Blip2ModelWithoutLMConfig.from_vision_qformer_text_configs( + return Blip2ModelWithoutLMConfig.from_vision_qformer_configs( vision_config=self.vision_model_tester.get_config(), qformer_config=self.qformer_model_tester.get_config(), ) @@ -1152,7 +1152,7 @@ def prepare_config_and_inputs(self): return config, input_ids, attention_mask, pixel_values def get_config(self): - return Blip2ModelWithoutLMConfig.from_vision_qformer_text_configs( + return Blip2ModelWithoutLMConfig.from_vision_qformer_configs( vision_config=self.vision_model_tester.get_config(), qformer_config=self.qformer_model_tester.get_config(), ) From 532f5ae3caf8f86720402ebcc60b9e38f9272e5e Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Thu, 26 Oct 2023 19:40:49 +0200 Subject: [PATCH 17/25] remove Blip2ModelWithoutLMConfig --- docs/source/en/model_doc/blip-2.md | 4 - src/transformers/__init__.py | 2 - src/transformers/models/blip_2/__init__.py | 2 - .../models/blip_2/configuration_blip_2.py | 106 +++--------------- .../convert_blip_2_original_to_pytorch.py | 3 +- .../models/blip_2/modeling_blip_2.py | 30 +++-- tests/models/blip_2/test_modeling_blip_2.py | 20 ++-- 7 files changed, 40 insertions(+), 127 deletions(-) diff --git a/docs/source/en/model_doc/blip-2.md b/docs/source/en/model_doc/blip-2.md index 250049d86a28a1..85f263b4b9e83c 100644 --- a/docs/source/en/model_doc/blip-2.md +++ b/docs/source/en/model_doc/blip-2.md @@ -61,10 +61,6 @@ If you're interested in submitting a resource to be included here, please feel f [[autodoc]] Blip2QFormerConfig -## Blip2ModelWithoutLMConfig - -[[autodoc]] Blip2ModelWithoutLMConfig - ## Blip2Processor [[autodoc]] Blip2Processor diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 25ae59e570c656..5c5d55ea15ce4c 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -209,7 +209,6 @@ "models.blip_2": [ "BLIP_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Blip2Config", - "Blip2ModelWithoutLMConfig", "Blip2Processor", "Blip2QFormerConfig", "Blip2VisionConfig", @@ -4348,7 +4347,6 @@ from .models.blip_2 import ( BLIP_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Blip2Config, - Blip2ModelWithoutLMConfig, Blip2Processor, Blip2QFormerConfig, Blip2VisionConfig, diff --git a/src/transformers/models/blip_2/__init__.py b/src/transformers/models/blip_2/__init__.py index 3defe7b923a413..60105fb8539194 100644 --- a/src/transformers/models/blip_2/__init__.py +++ b/src/transformers/models/blip_2/__init__.py @@ -20,7 +20,6 @@ "configuration_blip_2": [ "BLIP_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Blip2Config", - "Blip2ModelWithoutLMConfig", "Blip2QFormerConfig", "Blip2VisionConfig", ], @@ -50,7 +49,6 @@ from .configuration_blip_2 import ( BLIP_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Blip2Config, - Blip2ModelWithoutLMConfig, Blip2QFormerConfig, Blip2VisionConfig, ) diff --git a/src/transformers/models/blip_2/configuration_blip_2.py b/src/transformers/models/blip_2/configuration_blip_2.py index e30e2c12cfc3d8..942dbaed91fc13 100644 --- a/src/transformers/models/blip_2/configuration_blip_2.py +++ b/src/transformers/models/blip_2/configuration_blip_2.py @@ -117,7 +117,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) # get the vision config dict if we are loading from Blip2Config - if config_dict.get("model_type") in ["blip-2", "blip-2-without-lm"]: + if config_dict.get("model_type") == "blip-2": config_dict = config_dict["vision_config"] if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: @@ -237,7 +237,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) # get the qformer config dict if we are loading from Blip2Config - if config_dict.get("model_type") in ["blip-2", "blip-2-without-lm"]: + if config_dict.get("model_type") == "blip-2": config_dict = config_dict["qformer_config"] if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: @@ -268,6 +268,8 @@ class Blip2Config(PretrainedConfig): Dictionary of configuration options used to initialize any [`PretrainedConfig`]. num_query_tokens (`int`, *optional*, defaults to 32): The number of query tokens passed through the Transformer. + image_text_hidden_size (`int`, *optional*, defaults to 256): + Dimentionality of the hidden state of the image-text fusion layer. kwargs (*optional*): Dictionary of keyword arguments. @@ -303,7 +305,15 @@ class Blip2Config(PretrainedConfig): model_type = "blip-2" - def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs): + def __init__( + self, + vision_config=None, + qformer_config=None, + text_config=None, + num_query_tokens=32, + image_text_hidden_size=256, + **kwargs, + ): super().__init__(**kwargs) if vision_config is None: @@ -327,6 +337,7 @@ def __init__(self, vision_config=None, qformer_config=None, text_config=None, nu self.is_encoder_decoder = self.text_config.is_encoder_decoder self.num_query_tokens = num_query_tokens + self.image_text_hidden_size = image_text_hidden_size self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES self.initializer_factor = 1.0 @@ -355,90 +366,6 @@ def from_vision_qformer_text_configs( **kwargs, ) - -class Blip2ModelWithoutLMConfig(PretrainedConfig): - r""" - [`Blip2ModelWithoutLMConfig`] is the configuration class to store the configuration of a - [`Blip2ForImageTextRetrieval`, `Blip2ModelWithoutLM`]. It is used to instantiate a BLIP-2 model according to the - specified arguments, defining the vision model and Q-Former model. Instantiating a configuration with the defaults - will yield a similar configuration to that of the BLIP-2 - [jpizarrom/blip2-itm-vit-g](https://huggingface.co/jpizarrom/blip2-itm-vit-g) architecture. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - vision_config (`dict`, *optional*): - Dictionary of configuration options used to initialize [`Blip2VisionConfig`]. - qformer_config (`dict`, *optional*): - Dictionary of configuration options used to initialize [`Blip2QFormerConfig`]. - num_query_tokens (`int`, *optional*, defaults to 32): - The number of query tokens passed through the Transformer. - image_text_hidden_size (`int`, *optional*, defaults to 256): - Dimentionality of the hidden state of the image-text fusion layer. - kwargs (*optional*): - Dictionary of keyword arguments. - - Example: - - ```python - >>> from transformers import ( - ... Blip2VisionConfig, - ... Blip2QFormerConfig, - ... Blip2ModelWithoutLMConfig, - ... Blip2ForImageTextRetrieval, - ... ) - - >>> # Initializing a Blip2ModelWithoutLMConfig with jpizarrom/blip2-itm-vit-g style configuration - >>> configuration = Blip2ModelWithoutLMConfig() - - >>> # Initializing a Blip2ForImageTextRetrieval (with random weights) from the jpizarrom/blip2-itm-vit-g style configuration - >>> model = Blip2ForImageTextRetrieval(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - - >>> # We can also initialize a Blip2ModelWithoutLMConfig from a Blip2VisionConfig, Blip2QFormerConfig and any PretrainedConfig - - >>> # Initializing BLIP-2 vision and BLIP-2 Q-Former model configurations - >>> vision_config = Blip2VisionConfig() - >>> qformer_config = Blip2QFormerConfig() - - >>> config = Blip2ModelWithoutLMConfig.from_vision_qformer_configs(vision_config, qformer_config) - ```""" - - model_type = "blip-2-without-lm" - - def __init__( - self, - vision_config=None, - qformer_config=None, - num_query_tokens=32, - image_text_hidden_size=256, - **kwargs, - ): - super().__init__(**kwargs) - - if vision_config is None: - vision_config = {} - logger.info("vision_config is None. initializing the Blip2VisionConfig with default values.") - - if qformer_config is None: - qformer_config = {} - logger.info("qformer_config is None. Initializing the Blip2QFormerConfig with default values.") - - self.vision_config = Blip2VisionConfig(**vision_config) - self.qformer_config = Blip2QFormerConfig(**qformer_config) - - self.is_encoder_decoder = self.qformer_config.is_encoder_decoder - - self.num_query_tokens = num_query_tokens - self.image_text_hidden_size = image_text_hidden_size - self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size - # self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES - self.initializer_factor = 1.0 - self.initializer_range = 0.02 - @classmethod def from_vision_qformer_configs( cls, @@ -447,11 +374,10 @@ def from_vision_qformer_configs( **kwargs, ): r""" - Instantiate a [`Blip2ModelWithoutLMConfig`] (or a derived class) from a BLIP-2 vision and Q-Former model - configurations. + Instantiate a [`Blip2Config`] (or a derived class) from a BLIP-2 vision and Q-Former model configurations. Returns: - [`Blip2ModelWithoutLMConfig`]: An instance of a configuration object + [`Blip2Config`]: An instance of a configuration object """ return cls( diff --git a/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py b/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py index 54e6add8de50a8..3cea40f5371f97 100644 --- a/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py +++ b/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py @@ -35,7 +35,6 @@ Blip2Config, Blip2ForConditionalGeneration, Blip2ForImageTextRetrieval, - Blip2ModelWithoutLMConfig, Blip2Processor, Blip2QFormerConfig, Blip2VisionConfig, @@ -125,7 +124,7 @@ def get_blip2_config(model_name, eos_token_id): raise ValueError("Model name not supported") if "itm" in model_name: - config = Blip2ModelWithoutLMConfig( + config = Blip2Config( vision_config=vision_config, qformer_config=Blip2QFormerConfig(vocab_size=30523, qformer_text_input=True).to_dict(), ) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 5fd49f170e811a..21f41f7f9f27c5 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -41,7 +41,7 @@ replace_return_docstrings, ) from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM -from .configuration_blip_2 import Blip2Config, Blip2ModelWithoutLMConfig, Blip2QFormerConfig, Blip2VisionConfig +from .configuration_blip_2 import Blip2Config, Blip2QFormerConfig, Blip2VisionConfig logger = logging.get_logger(__name__) @@ -1163,7 +1163,6 @@ def custom_forward(*inputs): ) -# Adapted from https://github.com/salesforce/LAVIS/blob/main/lavis/models/blip2_models/Qformer.py#L51 class Blip2TextEmbeddings(nn.Module): """Construct the embeddings from word and position embeddings.""" @@ -1183,8 +1182,6 @@ def __init__(self, config): ) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") - self.config = config - def forward( self, input_ids: Optional[torch.FloatTensor] = None, @@ -1344,7 +1341,7 @@ def forward( query_length = query_embeds.shape[1] if query_embeds is not None else 0 - if hasattr(self, "embeddings"): + if self.config.qformer_text_input: embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, @@ -1838,11 +1835,11 @@ def forward( BLIP_2_START_DOCSTRING, ) class Blip2ModelWithoutLM(Blip2PreTrainedModel): - config_class = Blip2ModelWithoutLMConfig + config_class = Blip2Config main_input_name = "pixel_values" _keep_in_fp32_modules = [] - def __init__(self, config: Blip2ModelWithoutLMConfig): + def __init__(self, config: Blip2Config): super().__init__(config) self.vision_model = Blip2VisionModel(config.vision_config) @@ -2079,11 +2076,11 @@ def forward( class Blip2TextModelWithProjection(Blip2PreTrainedModel): - config_class = Blip2ModelWithoutLMConfig + config_class = Blip2Config supports_gradient_checkpointing = False _keep_in_fp32_modules = [] - def __init__(self, config: Blip2ModelWithoutLMConfig): + def __init__(self, config: Blip2Config): super().__init__(config) self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) @@ -2096,7 +2093,7 @@ def __init__(self, config: Blip2ModelWithoutLMConfig): self.post_init() @add_start_docstrings_to_model_forward(BLIP_2_TEXT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Blip2TextModelOutput, config_class=Blip2ModelWithoutLMConfig) + @replace_return_docstrings(output_type=Blip2TextModelOutput, config_class=Blip2Config) def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -2158,13 +2155,12 @@ def forward( ) -# Adapted from transformers.models.clip.modeling_clip.CLIPVisionModelWithProjection class Blip2VisionModelWithProjection(Blip2PreTrainedModel): - config_class = Blip2ModelWithoutLMConfig + config_class = Blip2Config main_input_name = "pixel_values" _keep_in_fp32_modules = [] - def __init__(self, config: Blip2ModelWithoutLMConfig): + def __init__(self, config: Blip2Config): super().__init__(config) self.vision_model = Blip2VisionModel(config.vision_config) @@ -2179,7 +2175,7 @@ def __init__(self, config: Blip2ModelWithoutLMConfig): self.post_init() @add_start_docstrings_to_model_forward(BLIP_2_VISION_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Blip2VisionModelOutput, config_class=Blip2ModelWithoutLMConfig) + @replace_return_docstrings(output_type=Blip2VisionModelOutput, config_class=Blip2Config) def forward( self, pixel_values: Optional[torch.FloatTensor] = None, @@ -2603,11 +2599,11 @@ def generate( ) class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): main_input_name = "pixel_values" - config_class = Blip2ModelWithoutLMConfig + config_class = Blip2Config _keep_in_fp32_modules = [] _tied_weights_keys = ["cls.predictions.decoder.bias"] - def __init__(self, config: Blip2ModelWithoutLMConfig): + def __init__(self, config: Blip2Config): super().__init__(config) self.vision_model = Blip2VisionModel(config.vision_config) @@ -2632,7 +2628,7 @@ def __init__(self, config: Blip2ModelWithoutLMConfig): self.post_init() @add_start_docstrings_to_model_forward(BLIP_2_VISION_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Blip2ImageTextMatchingModelOutput, config_class=Blip2ModelWithoutLMConfig) + @replace_return_docstrings(output_type=Blip2ImageTextMatchingModelOutput, config_class=Blip2Config) def forward( self, pixel_values: torch.FloatTensor, diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index 7d996bf514b611..caf8d83b9e389e 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -22,7 +22,7 @@ import numpy as np import requests -from transformers import CONFIG_MAPPING, Blip2Config, Blip2ModelWithoutLMConfig, Blip2QFormerConfig, Blip2VisionConfig +from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig from transformers.testing_utils import ( require_torch, require_torch_gpu, @@ -855,7 +855,7 @@ def prepare_config_and_inputs(self): return config, input_ids, attention_mask def get_config(self): - return Blip2ModelWithoutLMConfig.from_vision_qformer_configs( + return Blip2Config.from_vision_qformer_configs( vision_config=self.vision_model_tester.get_config(), qformer_config=self.qformer_model_tester.get_config(), ) @@ -1003,7 +1003,7 @@ def prepare_config_and_inputs(self): return config, pixel_values def get_config(self): - return Blip2ModelWithoutLMConfig.from_vision_qformer_configs( + return Blip2Config.from_vision_qformer_configs( vision_config=self.vision_model_tester.get_config(), qformer_config=self.qformer_model_tester.get_config(), ) @@ -1152,7 +1152,7 @@ def prepare_config_and_inputs(self): return config, input_ids, attention_mask, pixel_values def get_config(self): - return Blip2ModelWithoutLMConfig.from_vision_qformer_configs( + return Blip2Config.from_vision_qformer_configs( vision_config=self.vision_model_tester.get_config(), qformer_config=self.qformer_model_tester.get_config(), ) @@ -1546,7 +1546,7 @@ def test_inference_itm_features(self): -0.013867943547666073, ] ).to(torch_device) - self.assertTrue(torch.allclose(image_features[0][0][:6], expected_image_features, atol=1e-4)) + self.assertTrue(torch.allclose(image_features[0][0][:6], expected_image_features, atol=1e-3)) # text features text_inputs = processor(text="a woman sitting on the beach with a dog", padding=True, return_tensors="pt").to( @@ -1563,12 +1563,12 @@ def test_inference_itm_features(self): -0.046645939350128174, ] ).to(torch_device) - self.assertTrue(torch.allclose(text_features[0][0][:6], expected_text_features, atol=1e-4)) + self.assertTrue(torch.allclose(text_features[0][0][:6], expected_text_features, atol=1e-3)) # check similarity similarity = (image_features @ text_features[:, 0, :].t()).max() expected_similarity = torch.tensor(0.44385525584220886).to(torch_device) - self.assertTrue(torch.allclose(similarity, expected_similarity, atol=1e-4)) + self.assertTrue(torch.allclose(similarity, expected_similarity, atol=1e-3)) @require_torch_gpu def test_inference_itm_features_fp16(self): @@ -1589,7 +1589,7 @@ def test_inference_itm_features_fp16(self): 0.0352783203125, -0.01190185546875, ] - self.assertTrue(np.allclose(image_features[0][0][:6].tolist(), expected_image_features, atol=1e-4)) + self.assertTrue(np.allclose(image_features[0][0][:6].tolist(), expected_image_features, atol=1e-3)) # text features text_inputs = processor(text="a woman sitting on the beach with a dog", padding=True, return_tensors="pt").to( @@ -1604,9 +1604,9 @@ def test_inference_itm_features_fp16(self): 0.08648681640625, -0.04656982421875, ] - self.assertTrue(np.allclose(text_features[0][0][:6].tolist(), expected_text_features, atol=1e-4)) + self.assertTrue(np.allclose(text_features[0][0][:6].tolist(), expected_text_features, atol=1e-3)) # check similarity similarity = (image_features @ text_features[:, 0, :].t()).max() expected_similarity = 0.44384765625 - self.assertTrue(np.allclose(similarity.item(), expected_similarity, atol=1e-4)) + self.assertTrue(np.allclose(similarity.item(), expected_similarity, atol=1e-3)) From ce86d4c37031af228e8eaf3ec17e962b89b66176 Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Fri, 27 Oct 2023 18:00:24 +0200 Subject: [PATCH 18/25] remove Blip2ModelWithProjection --- docs/source/en/model_doc/blip-2.md | 4 - src/transformers/__init__.py | 2 - src/transformers/models/blip_2/__init__.py | 2 - .../convert_blip_2_original_to_pytorch.py | 11 +- .../models/blip_2/modeling_blip_2.py | 301 ------------------ src/transformers/utils/dummy_pt_objects.py | 7 - tests/models/blip_2/test_modeling_blip_2.py | 130 +------- 7 files changed, 7 insertions(+), 450 deletions(-) diff --git a/docs/source/en/model_doc/blip-2.md b/docs/source/en/model_doc/blip-2.md index 85f263b4b9e83c..bb6baa553be290 100644 --- a/docs/source/en/model_doc/blip-2.md +++ b/docs/source/en/model_doc/blip-2.md @@ -94,10 +94,6 @@ If you're interested in submitting a resource to be included here, please feel f [[autodoc]] Blip2ForImageTextRetrieval - forward -## Blip2ModelWithoutLM - -[[autodoc]] Blip2ModelWithoutLM - ## Blip2TextModelWithProjection [[autodoc]] Blip2TextModelWithProjection diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 5c5d55ea15ce4c..ab4a148fab0023 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1340,7 +1340,6 @@ "Blip2ForConditionalGeneration", "Blip2ForImageTextRetrieval", "Blip2Model", - "Blip2ModelWithoutLM", "Blip2PreTrainedModel", "Blip2QFormerModel", "Blip2TextModelWithProjection", @@ -5363,7 +5362,6 @@ Blip2ForConditionalGeneration, Blip2ForImageTextRetrieval, Blip2Model, - Blip2ModelWithoutLM, Blip2PreTrainedModel, Blip2QFormerModel, Blip2TextModelWithProjection, diff --git a/src/transformers/models/blip_2/__init__.py b/src/transformers/models/blip_2/__init__.py index 60105fb8539194..5ab04479d13596 100644 --- a/src/transformers/models/blip_2/__init__.py +++ b/src/transformers/models/blip_2/__init__.py @@ -35,7 +35,6 @@ _import_structure["modeling_blip_2"] = [ "BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST", "Blip2Model", - "Blip2ModelWithoutLM", "Blip2VisionModelWithProjection", "Blip2QFormerModel", "Blip2PreTrainedModel", @@ -65,7 +64,6 @@ Blip2ForConditionalGeneration, Blip2ForImageTextRetrieval, Blip2Model, - Blip2ModelWithoutLM, Blip2PreTrainedModel, Blip2QFormerModel, Blip2TextModelWithProjection, diff --git a/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py b/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py index 3cea40f5371f97..686cb56ab8432f 100644 --- a/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py +++ b/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py @@ -167,7 +167,6 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ "blip2-flan-t5-xl-coco": ("blip2_t5", "caption_coco_flant5xl"), "blip2-flan-t5-xxl": ("blip2_t5", "pretrain_flant5xxl"), "blip2-itm-vit-g": ("blip2_image_text_matching", "pretrain"), - # "blip2-itm-vit-large": ("blip2_image_text_matching", "pretrain_vitL"), "blip2-itm-vit-g-coco": ("blip2_image_text_matching", "coco"), } @@ -196,8 +195,6 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ # some keys can be renamed efficiently for key, val in state_dict.copy().items(): val = state_dict.pop(key) - if key.startswith("Qformer.cls"): - key = key.replace("Qformer.cls", "cls") if key.startswith("Qformer.bert"): key = key.replace("Qformer.bert", "qformer") if "attention.self" in key: @@ -217,7 +214,12 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False) assert len(missing_keys) == 0 - assert unexpected_keys == ["qformer.embeddings.position_ids"] + + if "itm" in model_name: + unexpected_keys = list(filter(lambda x: not x.startswith("Qformer.cls"), unexpected_keys)) + assert unexpected_keys == ["temp", "qformer.embeddings.position_ids"] + else: + assert unexpected_keys == ["qformer.embeddings.position_ids"] image = load_demo_image() original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device) @@ -349,7 +351,6 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ "blip2-flan-t5-xl-coco", "blip2-flan-t5-xxl", "blip2-itm-vit-g", - # "blip2-itm-vit-large", "blip2-itm-vit-g-coco", ] parser.add_argument( diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 21f41f7f9f27c5..78ebd4cee97763 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -1420,56 +1420,6 @@ def forward( ) -# Copied from transformers.models.blip.modeling_blip_text.BlipTextPredictionHeadTransform with Blip->Blip2 -class Blip2TextPredictionHeadTransform(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - if isinstance(config.hidden_act, str): - self.transform_act_fn = ACT2FN[config.hidden_act] - else: - self.transform_act_fn = config.hidden_act - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - -# Copied from transformers.models.blip.modeling_blip_text.BlipTextLMPredictionHead with Blip->Blip2 -class Blip2TextLMPredictionHead(nn.Module): - def __init__(self, config): - super().__init__() - self.transform = Blip2TextPredictionHeadTransform(config) - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def forward(self, hidden_states): - hidden_states = self.transform(hidden_states) - hidden_states = self.decoder(hidden_states) - return hidden_states - - -# Copied from transformers.models.blip.modeling_blip_text.BlipTextOnlyMLMHead with Blip->Blip2 -class Blip2TextOnlyMLMHead(nn.Module): - def __init__(self, config): - super().__init__() - self.predictions = Blip2TextLMPredictionHead(config) - - def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: - prediction_scores = self.predictions(sequence_output) - return prediction_scores - - @add_start_docstrings( """ BLIP-2 Model for generating text and image features. The model consists of a vision encoder, Querying Transformer @@ -1827,254 +1777,6 @@ def forward( ) -@add_start_docstrings( - """ - BLIP-2 Model for generating text and image features. The model consists of a vision encoder, Querying Transformer - (Q-Former) and a language model. - """, - BLIP_2_START_DOCSTRING, -) -class Blip2ModelWithoutLM(Blip2PreTrainedModel): - config_class = Blip2Config - main_input_name = "pixel_values" - _keep_in_fp32_modules = [] - - def __init__(self, config: Blip2Config): - super().__init__(config) - - self.vision_model = Blip2VisionModel(config.vision_config) - - self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) - self.qformer = Blip2QFormerModel(config.qformer_config) - - # vision projection layer - self.vision_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) - - # text projection layer - self.text_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) - - self.temp = nn.Parameter(0.07 * torch.ones([])) - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings_to_model_forward(BLIP_2_TEXT_INPUTS_DOCSTRING) - def get_text_features( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - r""" - Returns: - text_outputs (`CausalLMOutputWithPast`, or `tuple(torch.FloatTensor)` if `return_dict=False`): - The language model outputs. If `return_dict=True`, the output is a [`CausalLMOutputWithPast`] that - contains the language model logits, the past key values and the hidden states if - `output_hidden_states=True`. - Examples: - ```python - >>> import torch - >>> from transformers import AutoProcessor, Blip2ModelWithoutLM - - >>> device = "cuda" if torch.cuda.is_available() else "cpu" - - >>> model = Blip2ModelWithoutLM.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16) - - >>> model.to(device) # doctest: +IGNORE_RESULT - - >>> processor = AutoProcessor.from_pretrained("jpizarrom/blip2-itm-vit-g") - >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], return_tensors="pt").to(device) - >>> text_features = model.get_text_features(**inputs) - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - query_outputs = self.qformer( - input_ids=input_ids, - # query_embeds=query_tokens, - attention_mask=attention_mask, - return_dict=return_dict, - ) - embeds = query_outputs.last_hidden_state - text_features = self.text_proj(embeds) - text_features = normalize(text_features, dim=-1) - - return text_features - - @add_start_docstrings_to_model_forward(BLIP_2_VISION_INPUTS_DOCSTRING) - def get_image_features( - self, - pixel_values: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - r""" - Returns: - vision_outputs (`BaseModelOutputWithPooling` or tuple of `torch.FloatTensor`): - The vision model outputs. If `return_dict=True`, the output is a [`BaseModelOutputWithPooling`] that - contains the image features, the pooled image features and the hidden states if - `output_hidden_states=True`. - Examples: - ```python - >>> import torch - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Blip2ModelWithoutLM - - >>> device = "cuda" if torch.cuda.is_available() else "cpu" - - >>> model = Blip2ModelWithoutLM.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16) - - >>> model.to(device) # doctest: +IGNORE_RESULT - - >>> processor = AutoProcessor.from_pretrained("jpizarrom/blip2-itm-vit-g") - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - >>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) - >>> image_outputs = model.get_image_features(**inputs) - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - vision_outputs = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - image_embeds = vision_outputs[0] - - image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) - - query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) - query_outputs = self.qformer( - query_embeds=query_tokens, - encoder_hidden_states=image_embeds, - encoder_attention_mask=image_attention_mask, - return_dict=return_dict, - ) - - embeds = query_outputs.last_hidden_state - image_features = normalize(self.vision_proj(embeds), dim=-1) - - return image_features - - def forward( - self, - pixel_values: Optional[torch.FloatTensor] = None, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - return_loss: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, Blip2Output]: - r""" - Returns: - - Examples: - - ```python - >>> import torch - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Blip2ModelWithoutLM - - >>> device = "cuda" if torch.cuda.is_available() else "cpu" - - >>> model = Blip2ModelWithoutLM.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16) - >>> processor = AutoProcessor.from_pretrained("jpizarrom/blip2-itm-vit-g") - - >>> model.to(device) # doctest: +IGNORE_RESULT - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor( - ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True - ... ).to(device, torch.float16) - - >>> outputs = model(**inputs) - >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score - >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities - ```""" - # Use BLIP2 model's config for some fields (if specified) instead of those of vision & text components. - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - vision_outputs = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - image_embeds = vision_outputs[0] - image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) - - query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) - query_outputs = self.qformer( - query_embeds=query_tokens, - encoder_hidden_states=image_embeds, - encoder_attention_mask=image_attention_mask, - return_dict=return_dict, - ) - - image_embeds = query_outputs[0] if not return_dict else query_outputs.last_hidden_state - image_embeds = self.vision_proj(image_embeds) - - question_outputs = self.qformer( - input_ids=input_ids, - attention_mask=attention_mask, - return_dict=return_dict, - ) - question_embeds = question_outputs[0] if not return_dict else question_outputs.last_hidden_state - question_embeds = self.text_proj(question_embeds[:, 0, :]) - - image_feat = normalize(image_embeds, dim=-1) - text_feat = normalize(question_embeds, dim=-1) - - # text-query similarity - sim_t2q = torch.matmul(text_feat.unsqueeze(1).unsqueeze(1), image_feat.permute(0, 2, 1)).squeeze(2) - - # text-image similarity: aggregate across all query tokens - logits_per_text, _ = sim_t2q.max(-1) - logits_per_text = logits_per_text / self.temp - logits_per_image = logits_per_text.t() - - loss = None - if return_loss: - loss = blip2_loss(logits_per_text) - # TODO add Image-text Matching and Image Captioning loss computation - - if not return_dict: - output = (logits_per_image, logits_per_text, question_embeds, image_embeds, query_outputs, vision_outputs) - return ((loss,) + output) if loss is not None else output - - return Blip2Output( - loss=loss, - logits_per_image=logits_per_image, - logits_per_text=logits_per_text, - text_embeds=question_embeds, - image_embeds=image_embeds, - text_model_output=query_outputs, - vision_model_output=vision_outputs, - ) - - class Blip2TextModelWithProjection(Blip2PreTrainedModel): config_class = Blip2Config supports_gradient_checkpointing = False @@ -2611,7 +2313,6 @@ def __init__(self, config: Blip2Config): self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) self.qformer = Blip2QFormerModel(config.qformer_config) - self.cls = Blip2TextOnlyMLMHead(config.qformer_config) # vision projection layer self.vision_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) @@ -2622,8 +2323,6 @@ def __init__(self, config: Blip2Config): # image text matching head self.itm_head = nn.Linear(config.qformer_config.hidden_size, 2) - self.temp = nn.Parameter(0.07 * torch.ones([])) - # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index a605ddd62e1ae9..dd12c98f82e8d5 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1517,13 +1517,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class Blip2ModelWithoutLM(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - class Blip2PreTrainedModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index caf8d83b9e389e..688c25805de1bb 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -25,7 +25,6 @@ from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig from transformers.testing_utils import ( require_torch, - require_torch_gpu, require_torch_multi_gpu, require_vision, slow, @@ -52,7 +51,6 @@ Blip2ForConditionalGeneration, Blip2ForImageTextRetrieval, Blip2Model, - Blip2ModelWithoutLM, Blip2TextModelWithProjection, Blip2VisionModel, Blip2VisionModelWithProjection, @@ -1196,14 +1194,7 @@ def prepare_config_and_inputs_for_common(self): @require_torch class Blip2TextRetrievalModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = ( - ( - Blip2ForImageTextRetrieval, - Blip2ModelWithoutLM, - ) - if is_torch_available() - else () - ) + all_model_classes = (Blip2ForImageTextRetrieval,) if is_torch_available() else () fx_compatible = False test_head_masking = False test_pruning = False @@ -1278,39 +1269,6 @@ def test_model_from_pretrained(self): model = model_class.from_pretrained(model_name) self.assertIsNotNone(model) - def test_get_text_features(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - - inputs_dict = { - "input_ids": torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]).to(torch_device), - "attention_mask": torch.LongTensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]).to(torch_device), - } - - model = Blip2ModelWithoutLM(config).to(torch_device) - model.eval() - text_features = model.get_text_features(**inputs_dict) - self.assertEqual(text_features[0].shape, (10, config.image_text_hidden_size)) - - def test_get_image_features(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - keys_to_pop = ["input_ids", "attention_mask"] - - for key in keys_to_pop: - inputs_dict.pop(key) - - model = Blip2ModelWithoutLM(config).to(torch_device) - model.eval() - image_features = model.get_image_features(**inputs_dict) - self.assertEqual( - image_features.shape, # [12, 32, 256] - ( - self.model_tester.vision_model_tester.batch_size, - config.vision_config.hidden_size, - config.image_text_hidden_size, - ), - ) - def test_training(self): pass @@ -1524,89 +1482,3 @@ def test_inference_t5_multi_gpu(self): [0, 3, 7, 152, 67, 839, 1], ) self.assertEqual(generated_text, "san diego") - - @require_torch_gpu - def test_inference_itm_features(self): - processor = Blip2Processor.from_pretrained("jpizarrom/blip2-itm-vit-g") - model = Blip2ModelWithoutLM.from_pretrained( - "jpizarrom/blip2-itm-vit-g", - ).to(torch_device) - - # image features - image = prepare_img() - image_inputs = processor(images=image, return_tensors="pt").to(torch_device) - image_features = model.get_image_features(**image_inputs) - expected_image_features = torch.tensor( - [ - -0.0946953147649765, - -0.07541415840387344, - 0.03312666341662407, - 0.053536128252744675, - 0.03368198126554489, - -0.013867943547666073, - ] - ).to(torch_device) - self.assertTrue(torch.allclose(image_features[0][0][:6], expected_image_features, atol=1e-3)) - - # text features - text_inputs = processor(text="a woman sitting on the beach with a dog", padding=True, return_tensors="pt").to( - torch_device - ) - text_features = model.get_text_features(**text_inputs) - expected_text_features = torch.tensor( - [ - -0.10836730897426605, - 0.05315554141998291, - -0.028310950845479965, - 0.016979066655039787, - 0.0865054652094841, - -0.046645939350128174, - ] - ).to(torch_device) - self.assertTrue(torch.allclose(text_features[0][0][:6], expected_text_features, atol=1e-3)) - - # check similarity - similarity = (image_features @ text_features[:, 0, :].t()).max() - expected_similarity = torch.tensor(0.44385525584220886).to(torch_device) - self.assertTrue(torch.allclose(similarity, expected_similarity, atol=1e-3)) - - @require_torch_gpu - def test_inference_itm_features_fp16(self): - processor = Blip2Processor.from_pretrained("jpizarrom/blip2-itm-vit-g") - model = Blip2ModelWithoutLM.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16).to( - torch_device - ) - - # image features - image = prepare_img() - image_inputs = processor(images=image, return_tensors="pt").to(torch_device, dtype=torch.float16) - image_features = model.get_image_features(**image_inputs) - expected_image_features = [ - -0.093994140625, - -0.075927734375, - 0.031890869140625, - 0.053009033203125, - 0.0352783203125, - -0.01190185546875, - ] - self.assertTrue(np.allclose(image_features[0][0][:6].tolist(), expected_image_features, atol=1e-3)) - - # text features - text_inputs = processor(text="a woman sitting on the beach with a dog", padding=True, return_tensors="pt").to( - torch_device - ) - text_features = model.get_text_features(**text_inputs) - expected_text_features = [ - -0.1082763671875, - 0.053192138671875, - -0.02825927734375, - 0.0169830322265625, - 0.08648681640625, - -0.04656982421875, - ] - self.assertTrue(np.allclose(text_features[0][0][:6].tolist(), expected_text_features, atol=1e-3)) - - # check similarity - similarity = (image_features @ text_features[:, 0, :].t()).max() - expected_similarity = 0.44384765625 - self.assertTrue(np.allclose(similarity.item(), expected_similarity, atol=1e-3)) From 253e0676b0a80aece4216d2c65f9e0f4f4a60e78 Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Fri, 27 Oct 2023 18:54:00 +0200 Subject: [PATCH 19/25] remove _tied_weights_keys in Blip2ForImageTextRetrieval --- src/transformers/models/blip_2/modeling_blip_2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 78ebd4cee97763..2356a9c89a3a73 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -2303,7 +2303,6 @@ class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): main_input_name = "pixel_values" config_class = Blip2Config _keep_in_fp32_modules = [] - _tied_weights_keys = ["cls.predictions.decoder.bias"] def __init__(self, config: Blip2Config): super().__init__(config) From 81aea6809712728b0ec9da2eb2cb6c34edfc8960 Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Mon, 30 Oct 2023 12:08:40 +0100 Subject: [PATCH 20/25] remove unused code: blip2_loss --- src/transformers/models/blip_2/modeling_blip_2.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 2356a9c89a3a73..ba732c7f403a08 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -54,18 +54,6 @@ ] -# Copied from transformers.models.clip.modeling_clip.contrastive_loss -def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: - return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) - - -# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->blip2 -def blip2_loss(similarity: torch.Tensor) -> torch.Tensor: - caption_loss = contrastive_loss(similarity) - image_loss = contrastive_loss(similarity.t()) - return (caption_loss + image_loss) / 2.0 - - @dataclass class Blip2ForConditionalGenerationModelOutput(ModelOutput): """ From 04e26689f4fc5094f7b1edd3a0a8f47d2b5377ee Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Mon, 30 Oct 2023 12:11:40 +0100 Subject: [PATCH 21/25] remove unused Blip2Output --- .../models/blip_2/modeling_blip_2.py | 39 ------------------- 1 file changed, 39 deletions(-) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index ba732c7f403a08..f1e68c8de1bf6c 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -87,45 +87,6 @@ def to_tuple(self) -> Tuple[Any]: ) -@dataclass -# Copied from transformers.models.blip.modeling_blip.BlipOutput with Blip->Blip2 -class Blip2Output(ModelOutput): - """ - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): - Contrastive loss for image-text similarity. - logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): - The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text - similarity scores. - logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): - The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image - similarity scores. - text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): - The text embeddings obtained by applying the projection layer to the pooled output of [`Blip2TextModel`]. - image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): - The image embeddings obtained by applying the projection layer to the pooled output of - [`Blip2VisionModel`]. - text_model_output(`BaseModelOutputWithPooling`): - The output of the [`Blip2TextModel`]. - vision_model_output(`BaseModelOutputWithPooling`): - The output of the [`Blip2VisionModel`]. - """ - - loss: Optional[torch.FloatTensor] = None - logits_per_image: torch.FloatTensor = None - logits_per_text: torch.FloatTensor = None - text_embeds: torch.FloatTensor = None - image_embeds: torch.FloatTensor = None - text_model_output: BaseModelOutputWithPooling = None - vision_model_output: BaseModelOutputWithPooling = None - - def to_tuple(self) -> Tuple[Any]: - return tuple( - self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() - for k in self.keys() - ) - - @dataclass # Copied from transformers.models.blip.modeling_blip.BlipImageTextMatchingModelOutput with Blip->Blip2 class Blip2ImageTextMatchingModelOutput(ModelOutput): From 3d2dfbd6588b8951f37c09891474365c1eac1153 Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Mon, 30 Oct 2023 12:17:33 +0100 Subject: [PATCH 22/25] remove Blip2ModelWithoutLM from check_repo --- utils/check_repo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/utils/check_repo.py b/utils/check_repo.py index 6c82795e180561..0dd44b2010aa9d 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -147,7 +147,6 @@ "ClapAudioModelWithProjection", "Blip2ForConditionalGeneration", "Blip2ForImageTextRetrieval", - "Blip2ModelWithoutLM", "Blip2TextModelWithProjection", "Blip2VisionModelWithProjection", "Blip2QFormerModel", From b9343bad57e3fe81a31a6cd403fa54d76ef52993 Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Mon, 30 Oct 2023 15:27:08 +0100 Subject: [PATCH 23/25] add qformer_text_input line in the docstring --- src/transformers/models/blip_2/configuration_blip_2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/blip_2/configuration_blip_2.py b/src/transformers/models/blip_2/configuration_blip_2.py index 942dbaed91fc13..cd1974c2bde3d3 100644 --- a/src/transformers/models/blip_2/configuration_blip_2.py +++ b/src/transformers/models/blip_2/configuration_blip_2.py @@ -176,6 +176,8 @@ class Blip2QFormerConfig(PretrainedConfig): The frequency of adding cross-attention to the Transformer layers. encoder_hidden_size (`int`, *optional*, defaults to 1408): The hidden size of the hidden states for cross-attention. + qformer_text_input (`bool`, *optional*, defaults to `False`): + Whether to use BERT-style embeddings. Examples: From 6b65330920a010ef2f628ac8ee9f00e1dcb4e6ff Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Mon, 30 Oct 2023 18:44:58 +0100 Subject: [PATCH 24/25] add tests for Blip2ForImageTextRetrieval and Blip2VisionModelWithProjection --- .../models/blip_2/modeling_blip_2.py | 6 +- tests/models/blip_2/test_modeling_blip_2.py | 149 ++++++++++++++++-- 2 files changed, 139 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index f1e68c8de1bf6c..00d6429a7c63f8 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -1306,9 +1306,7 @@ def forward( device = embedding_output.device if attention_mask is None: - attention_mask = torch.ones( - ((batch_size, seq_length + past_key_values_length)), device=device, dtype=torch.long - ) + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. @@ -1326,7 +1324,7 @@ def forward( if type(encoder_attention_mask) == list: encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] elif encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device, dtype=torch.long) + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index 688c25805de1bb..8d0a1d6a5a5256 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -25,6 +25,7 @@ from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig from transformers.testing_utils import ( require_torch, + require_torch_gpu, require_torch_multi_gpu, require_vision, slow, @@ -974,11 +975,23 @@ def test_forward_signature(self): self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) @slow + @require_torch_gpu def test_model_from_pretrained(self): - for model_name in ["jpizarrom/blip2-itm-vit-g"]: - model = Blip2TextModelWithProjection.from_pretrained(model_name) - self.assertIsNotNone(model) - self.assertTrue(hasattr(model, "text_proj")) + model_name = "jpizarrom/blip2-itm-vit-g" + model = Blip2TextModelWithProjection.from_pretrained(model_name) + self.assertIsNotNone(model) + self.assertTrue(hasattr(model, "text_proj")) + + _, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs() + + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(input_ids=input_ids, attention_mask=attention_mask) + + self.assertEqual( + outputs.text_embeds.shape, (self.model_tester.qformer_model_tester.batch_size, input_ids.shape[1], 256) + ) class Blip2VisionModelWithProjectionTester: @@ -1122,11 +1135,21 @@ def test_forward_signature(self): self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) @slow + @require_torch_gpu def test_model_from_pretrained(self): - for model_name in ["jpizarrom/blip2-itm-vit-g"]: - model = Blip2VisionModelWithProjection.from_pretrained(model_name) - self.assertIsNotNone(model) - self.assertTrue(hasattr(model, "vision_proj")) + model_name = "jpizarrom/blip2-itm-vit-g" + model = Blip2VisionModelWithProjection.from_pretrained(model_name) + self.assertIsNotNone(model) + self.assertTrue(hasattr(model, "vision_proj")) + + _, pixel_values = self.model_tester.prepare_config_and_inputs() + + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(pixel_values=pixel_values) + + self.assertEqual(outputs.image_embeds.shape, (self.model_tester.vision_model_tester.batch_size, 32, 256)) class Blip2TextRetrievalModelTester: @@ -1263,11 +1286,26 @@ def test_load_vision_qformer_text_config(self): self.assertDictEqual(config.qformer_config.to_dict(), qformer_config.to_dict()) @slow + @require_torch_gpu def test_model_from_pretrained(self): - for model_name in ["jpizarrom/blip2-itm-vit-g"]: - for model_class in self.all_model_classes: - model = model_class.from_pretrained(model_name) - self.assertIsNotNone(model) + model_name = "jpizarrom/blip2-itm-vit-g" + model = Blip2ForImageTextRetrieval.from_pretrained(model_name) + self.assertIsNotNone(model) + + _, input_ids, attention_mask, pixel_values = self.model_tester.prepare_config_and_inputs() + + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask) + self.assertEqual(outputs.itm_score.shape, (self.model_tester.qformer_model_tester.batch_size, 2)) + + with torch.no_grad(): + outputs = model( + pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, use_itm_head=False + ) + self.assertEqual(outputs.itm_score.shape, (self.model_tester.qformer_model_tester.batch_size, 1)) def test_training(self): pass @@ -1482,3 +1520,90 @@ def test_inference_t5_multi_gpu(self): [0, 3, 7, 152, 67, 839, 1], ) self.assertEqual(generated_text, "san diego") + + @require_torch_gpu + def test_inference_itm(self): + model_name = "jpizarrom/blip2-itm-vit-g" + processor = Blip2Processor.from_pretrained(model_name) + model = Blip2ForImageTextRetrieval.from_pretrained(model_name).to(torch_device) + + image = prepare_img() + text = "A woman and her dog sitting in a beach" + inputs = processor(images=image, text=text, return_tensors="pt").to(torch_device) + + # forward pass + out_itm = model(**inputs) + out = model(**inputs, use_itm_head=False) + + # verify + expected_scores = torch.Tensor([[0.0238, 0.9762]]) + self.assertTrue(torch.allclose(torch.nn.Softmax()(out_itm[0].cpu()), expected_scores, rtol=1e-3, atol=1e-3)) + self.assertTrue(torch.allclose(out[0].cpu(), torch.Tensor([[0.4406]]), rtol=1e-3, atol=1e-3)) + + @require_torch_gpu + def test_inference_itm_fp16(self): + model_name = "jpizarrom/blip2-itm-vit-g" + processor = Blip2Processor.from_pretrained(model_name) + model = Blip2ForImageTextRetrieval.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device) + + image = prepare_img() + text = "A woman and her dog sitting in a beach" + inputs = processor(images=image, text=text, return_tensors="pt").to(torch_device, dtype=torch.float16) + + # forward pass + out_itm = model(**inputs) + out = model(**inputs, use_itm_head=False) + + # verify + expected_scores = torch.Tensor([[0.0239, 0.9761]]) + self.assertTrue( + torch.allclose(torch.nn.Softmax()(out_itm[0].cpu().float()), expected_scores, rtol=1e-3, atol=1e-3) + ) + self.assertTrue(torch.allclose(out[0].cpu().float(), torch.Tensor([[0.4406]]), rtol=1e-3, atol=1e-3)) + + @require_torch_gpu + def test_inference_vision_with_projection_fp16(self): + model_name = "jpizarrom/blip2-itm-vit-g" + processor = Blip2Processor.from_pretrained(model_name) + model = Blip2VisionModelWithProjection.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device) + + image = prepare_img() + inputs = processor(images=image, return_tensors="pt").to(torch_device, dtype=torch.float16) + + # forward pass + out = model(**inputs) + + # verify + expected_image_embeds = [ + -0.093994140625, + -0.075927734375, + 0.031890869140625, + 0.053009033203125, + 0.0352783203125, + -0.01190185546875, + ] + self.assertTrue(np.allclose(out.image_embeds[0][0][:6].tolist(), expected_image_embeds, atol=1e-3)) + + @require_torch_gpu + def test_inference_text_with_projection_fp16(self): + model_name = "jpizarrom/blip2-itm-vit-g" + processor = Blip2Processor.from_pretrained(model_name) + model = Blip2TextModelWithProjection.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device) + + inputs = processor(text="a woman sitting on the beach with a dog", padding=True, return_tensors="pt").to( + torch_device + ) + + # forward pass + out = model(**inputs) + + # verify + expected_text_embeds = [ + -0.1082763671875, + 0.053192138671875, + -0.02825927734375, + 0.0169830322265625, + 0.08648681640625, + -0.04656982421875, + ] + self.assertTrue(np.allclose(out.text_embeds[0][0][:6].tolist(), expected_text_embeds, atol=1e-3)) From 47acd937d009e694e22dd66551cf2e2e23dcc516 Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Tue, 31 Oct 2023 19:51:50 +0100 Subject: [PATCH 25/25] add skip on test_training_gradient_checkpointing_use_reentrant --- tests/models/blip_2/test_modeling_blip_2.py | 24 +++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index 0a19174eb76382..d9d34e6cc754ad 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -1103,6 +1103,18 @@ def test_training(self): def test_training_gradient_checkpointing(self): pass + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + @unittest.skip(reason="Hidden_states is tested in individual model tests") def test_hidden_states_output(self): pass @@ -1325,6 +1337,18 @@ def test_training(self): def test_training_gradient_checkpointing(self): pass + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()