diff --git a/docs/source/en/model_doc/blip-2.md b/docs/source/en/model_doc/blip-2.md index 0890e612561a..bb6baa553be2 100644 --- a/docs/source/en/model_doc/blip-2.md +++ b/docs/source/en/model_doc/blip-2.md @@ -87,4 +87,17 @@ If you're interested in submitting a resource to be included here, please feel f [[autodoc]] Blip2ForConditionalGeneration - forward - - generate \ No newline at end of file + - generate + +## Blip2ForImageTextRetrieval + +[[autodoc]] Blip2ForImageTextRetrieval + - forward + +## Blip2TextModelWithProjection + +[[autodoc]] Blip2TextModelWithProjection + +## Blip2VisionModelWithProjection + +[[autodoc]] Blip2VisionModelWithProjection diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 1fc1ff38d06d..735c51640819 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1363,10 +1363,13 @@ [ "BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST", "Blip2ForConditionalGeneration", + "Blip2ForImageTextRetrieval", "Blip2Model", "Blip2PreTrainedModel", "Blip2QFormerModel", + "Blip2TextModelWithProjection", "Blip2VisionModel", + "Blip2VisionModelWithProjection", ] ) _import_structure["models.bloom"].extend( @@ -5438,10 +5441,13 @@ from .models.blip_2 import ( BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST, Blip2ForConditionalGeneration, + Blip2ForImageTextRetrieval, Blip2Model, Blip2PreTrainedModel, Blip2QFormerModel, + Blip2TextModelWithProjection, Blip2VisionModel, + Blip2VisionModelWithProjection, ) from .models.bloom import ( BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/models/blip_2/__init__.py b/src/transformers/models/blip_2/__init__.py index 6fbfd53b3703..5ab04479d135 100644 --- a/src/transformers/models/blip_2/__init__.py +++ b/src/transformers/models/blip_2/__init__.py @@ -35,10 +35,13 @@ _import_structure["modeling_blip_2"] = [ "BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST", "Blip2Model", + "Blip2VisionModelWithProjection", "Blip2QFormerModel", "Blip2PreTrainedModel", "Blip2ForConditionalGeneration", + "Blip2ForImageTextRetrieval", "Blip2VisionModel", + "Blip2TextModelWithProjection", ] if TYPE_CHECKING: @@ -59,10 +62,13 @@ from .modeling_blip_2 import ( BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST, Blip2ForConditionalGeneration, + Blip2ForImageTextRetrieval, Blip2Model, Blip2PreTrainedModel, Blip2QFormerModel, + Blip2TextModelWithProjection, Blip2VisionModel, + Blip2VisionModelWithProjection, ) else: diff --git a/src/transformers/models/blip_2/configuration_blip_2.py b/src/transformers/models/blip_2/configuration_blip_2.py index 1b375e147f78..cd1974c2bde3 100644 --- a/src/transformers/models/blip_2/configuration_blip_2.py +++ b/src/transformers/models/blip_2/configuration_blip_2.py @@ -176,6 +176,8 @@ class Blip2QFormerConfig(PretrainedConfig): The frequency of adding cross-attention to the Transformer layers. encoder_hidden_size (`int`, *optional*, defaults to 1408): The hidden size of the hidden states for cross-attention. + qformer_text_input (`bool`, *optional*, defaults to `False`): + Whether to use BERT-style embeddings. Examples: @@ -209,6 +211,7 @@ def __init__( position_embedding_type="absolute", cross_attention_frequency=2, encoder_hidden_size=1408, + qformer_text_input=False, **kwargs, ): super().__init__(pad_token_id=pad_token_id, **kwargs) @@ -227,6 +230,7 @@ def __init__( self.position_embedding_type = position_embedding_type self.cross_attention_frequency = cross_attention_frequency self.encoder_hidden_size = encoder_hidden_size + self.qformer_text_input = qformer_text_input @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": @@ -266,7 +270,8 @@ class Blip2Config(PretrainedConfig): Dictionary of configuration options used to initialize any [`PretrainedConfig`]. num_query_tokens (`int`, *optional*, defaults to 32): The number of query tokens passed through the Transformer. - + image_text_hidden_size (`int`, *optional*, defaults to 256): + Dimentionality of the hidden state of the image-text fusion layer. kwargs (*optional*): Dictionary of keyword arguments. @@ -302,7 +307,15 @@ class Blip2Config(PretrainedConfig): model_type = "blip-2" - def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs): + def __init__( + self, + vision_config=None, + qformer_config=None, + text_config=None, + num_query_tokens=32, + image_text_hidden_size=256, + **kwargs, + ): super().__init__(**kwargs) if vision_config is None: @@ -326,6 +339,7 @@ def __init__(self, vision_config=None, qformer_config=None, text_config=None, nu self.is_encoder_decoder = self.text_config.is_encoder_decoder self.num_query_tokens = num_query_tokens + self.image_text_hidden_size = image_text_hidden_size self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES self.initializer_factor = 1.0 @@ -353,3 +367,23 @@ def from_vision_qformer_text_configs( text_config=text_config.to_dict(), **kwargs, ) + + @classmethod + def from_vision_qformer_configs( + cls, + vision_config: Blip2VisionConfig, + qformer_config: Blip2QFormerConfig, + **kwargs, + ): + r""" + Instantiate a [`Blip2Config`] (or a derived class) from a BLIP-2 vision and Q-Former model configurations. + + Returns: + [`Blip2Config`]: An instance of a configuration object + """ + + return cls( + vision_config=vision_config.to_dict(), + qformer_config=qformer_config.to_dict(), + **kwargs, + ) diff --git a/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py b/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py index c2e6eceae532..686cb56ab843 100644 --- a/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py +++ b/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py @@ -31,9 +31,12 @@ from transformers import ( AutoTokenizer, + BertTokenizer, Blip2Config, Blip2ForConditionalGeneration, + Blip2ForImageTextRetrieval, Blip2Processor, + Blip2QFormerConfig, Blip2VisionConfig, BlipImageProcessor, OPTConfig, @@ -51,7 +54,7 @@ def load_demo_image(): # here we list all keys to be renamed (original name on the left, our name on the right) -def create_rename_keys(config): +def create_rename_keys(config, model_name): rename_keys = [] # fmt: off @@ -77,8 +80,9 @@ def create_rename_keys(config): rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias")) # QFormer - rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight")) - rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias")) + if "itm" not in model_name: + rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight")) + rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias")) # fmt: on return rename_keys @@ -114,8 +118,18 @@ def get_blip2_config(model_name, eos_token_id): text_config = T5Config.from_pretrained("google/flan-t5-xl", dense_act_fn="gelu", bos_token_id=1).to_dict() elif "t5-xxl" in model_name: text_config = T5Config.from_pretrained("google/flan-t5-xxl", dense_act_fn="gelu", bos_token_id=1).to_dict() - - config = Blip2Config(vision_config=vision_config, text_config=text_config) + elif "itm" in model_name: + text_config = {} + else: + raise ValueError("Model name not supported") + + if "itm" in model_name: + config = Blip2Config( + vision_config=vision_config, + qformer_config=Blip2QFormerConfig(vocab_size=30523, qformer_text_input=True).to_dict(), + ) + else: + config = Blip2Config(vision_config=vision_config, text_config=text_config) return config, image_size @@ -125,15 +139,24 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ """ Copy/paste/tweak model's weights to Transformers design. """ - tokenizer = ( - AutoTokenizer.from_pretrained("facebook/opt-2.7b") - if "opt" in model_name - else AutoTokenizer.from_pretrained("google/flan-t5-xl") - ) - eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0] + if "opt" in model_name: + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b") + elif "itm" in model_name: + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right") + tokenizer.add_special_tokens({"bos_token": "[DEC]"}) + else: + tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl") + + if "itm" in model_name: + eos_token_id = None + else: + eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0] config, image_size = get_blip2_config(model_name, eos_token_id=eos_token_id) - hf_model = Blip2ForConditionalGeneration(config).eval() + if "itm" in model_name: + hf_model = Blip2ForImageTextRetrieval(config).eval() + else: + hf_model = Blip2ForConditionalGeneration(config).eval() model_name_to_original = { "blip2-opt-2.7b": ("blip2_opt", "pretrain_opt2.7b"), @@ -143,6 +166,8 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ "blip2-flan-t5-xl": ("blip2_t5", "pretrain_flant5xl"), "blip2-flan-t5-xl-coco": ("blip2_t5", "caption_coco_flant5xl"), "blip2-flan-t5-xxl": ("blip2_t5", "pretrain_flant5xxl"), + "blip2-itm-vit-g": ("blip2_image_text_matching", "pretrain"), + "blip2-itm-vit-g-coco": ("blip2_image_text_matching", "coco"), } name, type = model_name_to_original[model_name] @@ -163,7 +188,7 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ # update state dict keys state_dict = original_model.state_dict() - rename_keys = create_rename_keys(config) + rename_keys = create_rename_keys(config, model_name) for src, dest in rename_keys: rename_key(state_dict, src, dest) @@ -189,11 +214,15 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False) assert len(missing_keys) == 0 - assert unexpected_keys == ["qformer.embeddings.position_ids"] + + if "itm" in model_name: + unexpected_keys = list(filter(lambda x: not x.startswith("Qformer.cls"), unexpected_keys)) + assert unexpected_keys == ["temp", "qformer.embeddings.position_ids"] + else: + assert unexpected_keys == ["qformer.embeddings.position_ids"] image = load_demo_image() original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device) - input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device) # create processor image_processor = BlipImageProcessor( @@ -207,50 +236,100 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ original_model.to(lavis_device) hf_model.to(hf_model_device) - with torch.no_grad(): - if "opt" in model_name: - original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits - logits = hf_model(pixel_values, input_ids).logits - else: - original_logits = original_model( - {"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]} - ).logits - labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100) - logits = hf_model(pixel_values, input_ids, labels=labels).logits - assert original_logits.shape == logits.shape - print("First values of original logits:", original_logits[0, :3, :3]) - print("First values of HF logits:", logits[0, :3, :3]) + if "itm" in model_name: + caption = "a large fountain spewing water into the air" + input_ids = tokenizer([caption], return_tensors="pt").input_ids.to(hf_model_device) + attention_mask = processor(text=caption, return_tensors="pt").attention_mask.to(hf_model_device) - # assert values - assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4) - print("Looks ok!") + with torch.no_grad(): + original_logits = original_model( + {"image": original_pixel_values, "text_input": [caption]}, match_head="itm" + ) + logits = hf_model(pixel_values=original_pixel_values, input_ids=input_ids, attention_mask=attention_mask) - print("Generating a caption...") - prompt = "Question: what object is in this image? Answer:" - input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device) + assert original_logits.shape == logits.itm_score.shape + print("First values of original logits:", original_logits[0, :3]) + print("First values of HF logits:", logits.itm_score[0, :3]) - set_seed(42) + # assert values + # cast to same type + target_dtype = logits.itm_score.dtype + assert torch.allclose(original_logits.to(target_dtype), logits.itm_score, atol=1e-2) - original_outputs = original_model.generate( - {"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True - ) - outputs = hf_model.generate( - pixel_values, - input_ids, - do_sample=True, - num_beams=5, - max_length=30, - min_length=1, - top_p=0.9, - repetition_penalty=1.0, - length_penalty=1.0, - temperature=1, - ) - output_text = processor.batch_decode(outputs, skip_special_tokens=True) - output_text = [text.strip() for text in output_text] - print("Original generation:", original_outputs) - print("HF generation:", output_text) + original_itm_scores = torch.nn.functional.softmax(original_logits, dim=1) + itm_scores = torch.nn.functional.softmax(logits.itm_score, dim=1) + assert torch.allclose(original_itm_scores.to(target_dtype), itm_scores, atol=1e-2) + print("Looks ok!") + + with torch.no_grad(): + original_logits = original_model( + {"image": original_pixel_values, "text_input": [caption]}, match_head="itc" + ) + logits = hf_model( + pixel_values=original_pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + use_itm_head=False, + ) + + assert original_logits.shape == logits.itm_score.shape + print("First values of original logits:", original_logits[0, :3]) + print("First values of HF logits:", logits.itm_score[0, :3]) + + # assert values + # cast to same type + target_dtype = logits.itm_score.dtype + assert torch.allclose(original_logits.to(target_dtype), logits.itm_score, atol=1e-2) + print("Looks ok!") + + else: + input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device) + + with torch.no_grad(): + if "opt" in model_name: + original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits + logits = hf_model(pixel_values, input_ids).logits + else: + original_logits = original_model( + {"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]} + ).logits + labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100) + logits = hf_model(pixel_values, input_ids, labels=labels).logits + + assert original_logits.shape == logits.shape + print("First values of original logits:", original_logits[0, :3, :3]) + print("First values of HF logits:", logits[0, :3, :3]) + + # assert values + assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4) + print("Looks ok!") + + print("Generating a caption...") + prompt = "Question: what object is in this image? Answer:" + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device) + + set_seed(42) + + original_outputs = original_model.generate( + {"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True + ) + outputs = hf_model.generate( + pixel_values, + input_ids, + do_sample=True, + num_beams=5, + max_length=30, + min_length=1, + top_p=0.9, + repetition_penalty=1.0, + length_penalty=1.0, + temperature=1, + ) + output_text = processor.batch_decode(outputs, skip_special_tokens=True) + output_text = [text.strip() for text in output_text] + print("Original generation:", original_outputs) + print("HF generation:", output_text) if pytorch_dump_folder_path is not None: processor.save_pretrained(pytorch_dump_folder_path) @@ -271,6 +350,8 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ "blip2-flan-t5-xl", "blip2-flan-t5-xl-coco", "blip2-flan-t5-xxl", + "blip2-itm-vit-g", + "blip2-itm-vit-g-coco", ] parser.add_argument( "--model_name", diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 10a37c79b863..d8bfe065a155 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -22,6 +22,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss +from torch.nn.functional import normalize from ...activations import ACT2FN from ...modeling_outputs import ( @@ -86,6 +87,110 @@ def to_tuple(self) -> Tuple[Any]: ) +@dataclass +# Copied from transformers.models.blip.modeling_blip.BlipImageTextMatchingModelOutput with Blip->Blip2 +class Blip2ImageTextMatchingModelOutput(ModelOutput): + """ + Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the + last hidden states. This class also adds the loss term from the text decoder as well as the image-text similarity + scores. + + Args: + itm_score (`torch.FloatTensor`): + The image-text similarity scores. + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Languge modeling loss from the text decoder. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + vision_pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*): + Last layer hidden-state of the vision of the vision-only branch of the model. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + question_embeds (`torch.FloatTensor`): + The question embeddings obtained by the text projection layer. + """ + + itm_score: Optional[torch.FloatTensor] = None + loss: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + vision_pooler_output: Optional[torch.FloatTensor] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + question_embeds: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Blip2 +class Blip2TextModelOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Blip2 +class Blip2VisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + # Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2 class Blip2VisionEmbeddings(nn.Module): def __init__(self, config: Blip2VisionConfig): @@ -806,6 +911,10 @@ def __init__(self, config, layer_idx): else: self.has_cross_attention = False + if config.qformer_text_input: + self.intermediate = Blip2QFormerIntermediate(config) + self.output = Blip2QFormerOutput(config) + self.intermediate_query = Blip2QFormerIntermediate(config) self.output_query = Blip2QFormerOutput(config) @@ -986,6 +1095,56 @@ def forward( ) +class Blip2TextEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + def forward( + self, + input_ids: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if input_ids is not None: + input_ids = input_ids.to(self.word_embeddings.weight.device) + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + class Blip2QFormerModel(Blip2PreTrainedModel): """ Querying Transformer (Q-Former), used in BLIP-2. @@ -995,8 +1154,11 @@ def __init__(self, config: Blip2QFormerConfig): super().__init__(config) self.config = config - self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + if config.qformer_text_input: + self.embeddings = Blip2TextEmbeddings(config) + else: + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) self.encoder = Blip2QFormerEncoder(config) @@ -1063,7 +1225,7 @@ def get_extended_attention_mask( def forward( self, - query_embeds: torch.FloatTensor, + query_embeds: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, @@ -1073,6 +1235,8 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + input_ids: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`): @@ -1099,6 +1263,9 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if input_ids is None and query_embeds is None: + raise ValueError("You have to specify query_embeds when input_ids is None") + # past_key_values_length past_key_values_length = ( past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0 @@ -1106,8 +1273,16 @@ def forward( query_length = query_embeds.shape[1] if query_embeds is not None else 0 - embedding_output = self.layernorm(query_embeds) - embedding_output = self.dropout(embedding_output) + if self.config.qformer_text_input: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = self.layernorm(query_embeds) + embedding_output = self.dropout(embedding_output) input_shape = embedding_output.size()[:-1] batch_size, seq_length = input_shape @@ -1521,6 +1696,183 @@ def forward( ) +class Blip2TextModelWithProjection(Blip2PreTrainedModel): + config_class = Blip2Config + supports_gradient_checkpointing = False + _keep_in_fp32_modules = [] + + def __init__(self, config: Blip2Config): + super().__init__(config) + + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + self.qformer = Blip2QFormerModel(config.qformer_config) + + # text projection layer + self.text_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BLIP_2_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Blip2TextModelOutput, config_class=Blip2Config) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Blip2TextModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, Blip2TextModelWithProjection + + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> model = Blip2TextModelWithProjection.from_pretrained( + ... "jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16 + ... ) + + >>> model.to(device) # doctest: +IGNORE_RESULT + + >>> processor = AutoProcessor.from_pretrained("jpizarrom/blip2-itm-vit-g") + + >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], return_tensors="pt").to(device) + + >>> outputs = model(**inputs) + >>> text_embeds = outputs.text_embeds + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.qformer( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[0] if not return_dict else text_outputs.last_hidden_state + + text_embeds = self.text_proj(pooled_output) + text_embeds = normalize(text_embeds, dim=-1) + + if not return_dict: + outputs = (text_embeds, text_outputs[0]) + text_outputs[2:] + return tuple(output for output in outputs if output is not None) + + return Blip2TextModelOutput( + text_embeds=text_embeds, + last_hidden_state=text_outputs.last_hidden_state, + hidden_states=text_outputs.hidden_states, + attentions=text_outputs.attentions, + ) + + +class Blip2VisionModelWithProjection(Blip2PreTrainedModel): + config_class = Blip2Config + main_input_name = "pixel_values" + _keep_in_fp32_modules = [] + + def __init__(self, config: Blip2Config): + super().__init__(config) + + self.vision_model = Blip2VisionModel(config.vision_config) + + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + self.qformer = Blip2QFormerModel(config.qformer_config) + + # vision projection layer + self.vision_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BLIP_2_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Blip2VisionModelOutput, config_class=Blip2Config) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Blip2VisionModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Blip2VisionModelWithProjection + + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> processor = AutoProcessor.from_pretrained("jpizarrom/blip2-itm-vit-g") + >>> model = Blip2VisionModelWithProjection.from_pretrained( + ... "jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16 + ... ) + >>> model.to(device) # doctest: +IGNORE_RESULT + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) + + >>> outputs = model(**inputs) + >>> image_embeds = outputs.image_embeds + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[0] if not return_dict else vision_outputs.last_hidden_state + + image_attention_mask = torch.ones(pooled_output.size()[:-1], dtype=torch.long, device=pooled_output.device) + + query_tokens = self.query_tokens.expand(pooled_output.shape[0], -1, -1) + + query_outputs = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=pooled_output, + encoder_attention_mask=image_attention_mask, + return_dict=return_dict, + ) + + embeds = query_outputs[0] if not return_dict else query_outputs.last_hidden_state + image_embeds = self.vision_proj(embeds) + image_embeds = normalize(image_embeds, dim=-1) + + if not return_dict: + outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:] + return tuple(output for output in outputs if output is not None) + + return Blip2VisionModelOutput( + image_embeds=image_embeds, + last_hidden_state=vision_outputs.last_hidden_state, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + ) + + @add_start_docstrings( """ BLIP-2 Model for generating text given an image and an optional text prompt. The model consists of a vision @@ -1834,3 +2186,144 @@ def generate( ) return outputs + + +@add_start_docstrings( + """ + BLIP-2 Model with a vision and text projector, and a classification head on top. The model is used in the context + of image-text retrieval. Given an image and a text, the model returns the probability of the text being relevant to + the image. + """, + BLIP_2_START_DOCSTRING, +) +class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): + main_input_name = "pixel_values" + config_class = Blip2Config + _keep_in_fp32_modules = [] + + def __init__(self, config: Blip2Config): + super().__init__(config) + + self.vision_model = Blip2VisionModel(config.vision_config) + + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + + self.qformer = Blip2QFormerModel(config.qformer_config) + + # vision projection layer + self.vision_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) + + # text projection layer + self.text_proj = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size) + + # image text matching head + self.itm_head = nn.Linear(config.qformer_config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BLIP_2_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Blip2ImageTextMatchingModelOutput, config_class=Blip2Config) + def forward( + self, + pixel_values: torch.FloatTensor, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, + use_itm_head: Optional[bool] = True, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Blip2ImageTextMatchingModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Blip2ForImageTextRetrieval + + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> model = Blip2ForImageTextRetrieval.from_pretrained("jpizarrom/blip2-itm-vit-g", torch_dtype=torch.float16) + >>> processor = AutoProcessor.from_pretrained("jpizarrom/blip2-itm-vit-g") + + >>> model.to(device) # doctest: +IGNORE_RESULT + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text = "an image of a cat" + + >>> inputs = processor(images=image, text=text, return_tensors="pt").to(device, torch.float16) + >>> outputs = model(**inputs) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[0] + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + if use_itm_head: + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(query_tokens.device) + attention_mask = torch.cat([query_attention_mask, attention_mask], dim=1) + + question_embeds = self.qformer( + input_ids=input_ids, + query_embeds=query_tokens, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=return_dict, + ) + question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state + + output = self.itm_head(question_embeds[:, : query_tokens.size(1), :]) + output = output.mean(dim=1) + else: + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_outputs = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=return_dict, + ) + image_embeds = query_outputs[0] if not return_dict else query_outputs.last_hidden_state + + question_embeds = self.qformer( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=return_dict, + ) + question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state + + image_feat = normalize(self.vision_proj(image_embeds), dim=-1) + text_feat = normalize(self.text_proj(question_embeds[:, 0, :]), dim=-1) + + outputs = torch.bmm(image_feat, text_feat.unsqueeze(-1)) + output, _ = torch.max(outputs, dim=1) + + if not return_dict: + outputs = (output, vision_outputs[0]) + vision_outputs[2:] + (question_embeds,) + return tuple(output for output in outputs if output is not None) + + return Blip2ImageTextMatchingModelOutput( + itm_score=output, + last_hidden_state=vision_outputs.last_hidden_state, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + question_embeds=question_embeds, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 1310312519cc..9a40f098e370 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1503,6 +1503,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Blip2ForImageTextRetrieval(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Blip2Model(metaclass=DummyObject): _backends = ["torch"] @@ -1524,6 +1531,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Blip2TextModelWithProjection(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Blip2VisionModel(metaclass=DummyObject): _backends = ["torch"] @@ -1531,6 +1545,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Blip2VisionModelWithProjection(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index dd87961372d2..d9d34e6cc754 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -25,6 +25,7 @@ from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig from transformers.testing_utils import ( require_torch, + require_torch_gpu, require_torch_multi_accelerator, require_vision, slow, @@ -47,7 +48,14 @@ import torch from torch import nn - from transformers import Blip2ForConditionalGeneration, Blip2Model, Blip2VisionModel + from transformers import ( + Blip2ForConditionalGeneration, + Blip2ForImageTextRetrieval, + Blip2Model, + Blip2TextModelWithProjection, + Blip2VisionModel, + Blip2VisionModelWithProjection, + ) from transformers.models.blip_2.modeling_blip_2 import BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST @@ -242,6 +250,7 @@ def __init__( initializer_range=0.02, bos_token_id=0, scope=None, + qformer_text_input=False, ): self.parent = parent self.batch_size = batch_size @@ -261,6 +270,7 @@ def __init__( self.initializer_range = initializer_range self.scope = scope self.bos_token_id = bos_token_id + self.qformer_text_input = qformer_text_input def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -293,6 +303,7 @@ def get_config(self): max_position_embeddings=self.max_position_embeddings, initializer_range=self.initializer_range, bos_token_id=self.bos_token_id, + qformer_text_input=self.qformer_text_input, ) @@ -486,7 +497,7 @@ def test_forward_signature(self): self.assertListEqual(arg_names[:1], expected_arg_names) def test_load_vision_qformer_text_config(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config, _ = self.model_tester.prepare_config_and_inputs_for_common() # Save Blip2Config and check if we can load Blip2VisionConfig from it with tempfile.TemporaryDirectory() as tmp_dir_name: @@ -747,7 +758,7 @@ def test_forward_signature(self): self.assertListEqual(arg_names[:1], expected_arg_names) def test_load_vision_qformer_text_config(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config, _ = self.model_tester.prepare_config_and_inputs_for_common() # Save Blip2Config and check if we can load Blip2VisionConfig from it with tempfile.TemporaryDirectory() as tmp_dir_name: @@ -835,6 +846,540 @@ def test_initialization(self): ) +class Blip2TextModelWithProjectionTester: + def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training=True): + if vision_kwargs is None: + vision_kwargs = {} + if qformer_kwargs is None: + qformer_kwargs = {"qformer_text_input": True} + + self.parent = parent + self.vision_model_tester = Blip2VisionModelTester(parent, **vision_kwargs) + self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs) + self.is_training = is_training + + def prepare_config_and_inputs(self): + _, input_ids, attention_mask = self.qformer_model_tester.prepare_config_and_inputs() + + config = self.get_config() + + return config, input_ids, attention_mask + + def get_config(self): + return Blip2Config.from_vision_qformer_configs( + vision_config=self.vision_model_tester.get_config(), + qformer_config=self.qformer_model_tester.get_config(), + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, attention_mask = config_and_inputs + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + def create_and_check_model(self, config, input_ids, attention_mask): + model = Blip2TextModelWithProjection(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(input_ids, attention_mask=attention_mask, output_attentions=True, output_hidden_states=True) + + self.parent.assertEqual( + result.last_hidden_state.shape, + (self.vision_model_tester.batch_size, input_ids.shape[1], self.qformer_model_tester.hidden_size), + ) + self.parent.assertEqual( + result.text_embeds.shape, + ( + self.vision_model_tester.batch_size, + input_ids.shape[1], + config.image_text_hidden_size, + ), + ) + + with torch.no_grad(): + result2 = model( + input_ids, + attention_mask=attention_mask, + return_dict=not config.use_return_dict, + output_attentions=True, + output_hidden_states=True, + ) + + self.parent.assertTrue(torch.allclose(result.text_embeds, result2[0])) + self.parent.assertTrue(torch.allclose(result.last_hidden_state, result2[1])) + self.parent.assertTrue(torch.allclose(result.hidden_states[0], result2[2][0])) + self.parent.assertTrue(torch.allclose(result.hidden_states[1], result2[2][1])) + self.parent.assertTrue(torch.allclose(result.attentions[0], result2[3][0])) + self.parent.assertTrue(torch.allclose(result.attentions[1], result2[3][1])) + + +@require_torch +class Blip2TextModelWithProjectionTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (Blip2TextModelWithProjection,) if is_torch_available() else () + fx_compatible = False + test_pruning = False + test_head_masking = False + + test_resize_embeddings = False + test_attention_outputs = False + test_torchscript = False + + def setUp(self): + self.model_tester = Blip2TextModelWithProjectionTester(self) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_training(self): + pass + + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="Hidden_states is tested in individual model tests") + def test_hidden_states_output(self): + pass + + @unittest.skip(reason="Blip2TextModelWithProjection does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Retain_grad is tested in individual model tests") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="Blip2TextModelWithProjection does not have input/output embeddings") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="Blip2TextModelWithProjection has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="Blip2TextModelWithProjection has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_to_base(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + if model.config.is_encoder_decoder: + # TODO + raise NotImplementedError + else: + expected_arg_names = [ + "input_ids", + "attention_mask", + "position_ids", + ] + + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + + @slow + @require_torch_gpu + def test_model_from_pretrained(self): + model_name = "jpizarrom/blip2-itm-vit-g" + model = Blip2TextModelWithProjection.from_pretrained(model_name) + self.assertIsNotNone(model) + self.assertTrue(hasattr(model, "text_proj")) + + _, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs() + + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(input_ids=input_ids, attention_mask=attention_mask) + + self.assertEqual( + outputs.text_embeds.shape, (self.model_tester.qformer_model_tester.batch_size, input_ids.shape[1], 256) + ) + + +class Blip2VisionModelWithProjectionTester: + def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training=True): + if vision_kwargs is None: + vision_kwargs = {} + if qformer_kwargs is None: + qformer_kwargs = {"qformer_text_input": True} + + self.parent = parent + self.vision_model_tester = Blip2VisionModelTester(parent, **vision_kwargs) + self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs) + self.is_training = is_training + + def prepare_config_and_inputs(self): + _, pixel_values = self.vision_model_tester.prepare_config_and_inputs() + + config = self.get_config() + + return config, pixel_values + + def get_config(self): + return Blip2Config.from_vision_qformer_configs( + vision_config=self.vision_model_tester.get_config(), + qformer_config=self.qformer_model_tester.get_config(), + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = { + "pixel_values": pixel_values, + } + return config, inputs_dict + + def create_and_check_model(self, config, pixel_values): + model = Blip2VisionModelWithProjection(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(pixel_values, output_attentions=True, output_hidden_states=True) + + self.parent.assertEqual( + result.last_hidden_state.shape, + ( + self.vision_model_tester.batch_size, + self.vision_model_tester.seq_length, + self.qformer_model_tester.hidden_size, + ), + ) + self.parent.assertEqual( + result.image_embeds.shape, + ( + self.vision_model_tester.batch_size, + config.vision_config.hidden_size, + config.image_text_hidden_size, + ), + ) + + with torch.no_grad(): + result2 = model( + pixel_values, + return_dict=not config.use_return_dict, + output_attentions=True, + output_hidden_states=True, + ) + + self.parent.assertTrue(torch.allclose(result.image_embeds, result2[0])) + self.parent.assertTrue(torch.allclose(result.last_hidden_state, result2[1])) + self.parent.assertTrue(torch.allclose(result.hidden_states[0], result2[2][0])) + self.parent.assertTrue(torch.allclose(result.hidden_states[1], result2[2][1])) + self.parent.assertTrue(torch.allclose(result.attentions[0], result2[3][0])) + self.parent.assertTrue(torch.allclose(result.attentions[1], result2[3][1])) + + +@require_torch +class Blip2VisionModelWithProjectionTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (Blip2VisionModelWithProjection,) if is_torch_available() else () + fx_compatible = False + test_pruning = False + test_head_masking = False + + test_resize_embeddings = False + test_attention_outputs = False + test_torchscript = False + + def setUp(self): + self.model_tester = Blip2VisionModelWithProjectionTester(self) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_training(self): + pass + + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="Hidden_states is tested in individual model tests") + def test_hidden_states_output(self): + pass + + @unittest.skip(reason="Blip2TextModelWithProjection does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Retain_grad is tested in individual model tests") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="Blip2TextModelWithProjection does not have input/output embeddings") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="Blip2TextModelWithProjection has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="Blip2TextModelWithProjection has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_to_base(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + if model.config.is_encoder_decoder: + # TODO + raise NotImplementedError + else: + expected_arg_names = [ + "pixel_values", + ] + + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + + @slow + @require_torch_gpu + def test_model_from_pretrained(self): + model_name = "jpizarrom/blip2-itm-vit-g" + model = Blip2VisionModelWithProjection.from_pretrained(model_name) + self.assertIsNotNone(model) + self.assertTrue(hasattr(model, "vision_proj")) + + _, pixel_values = self.model_tester.prepare_config_and_inputs() + + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(pixel_values=pixel_values) + + self.assertEqual(outputs.image_embeds.shape, (self.model_tester.vision_model_tester.batch_size, 32, 256)) + + +class Blip2TextRetrievalModelTester: + def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training=True): + if vision_kwargs is None: + vision_kwargs = {} + if qformer_kwargs is None: + qformer_kwargs = {"qformer_text_input": True} + + self.parent = parent + self.vision_model_tester = Blip2VisionModelTester(parent, **vision_kwargs) + self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs) + self.is_training = is_training + + def prepare_config_and_inputs(self): + _, input_ids, attention_mask = self.qformer_model_tester.prepare_config_and_inputs() + _, pixel_values = self.vision_model_tester.prepare_config_and_inputs() + + config = self.get_config() + + return config, input_ids, attention_mask, pixel_values + + def get_config(self): + return Blip2Config.from_vision_qformer_configs( + vision_config=self.vision_model_tester.get_config(), + qformer_config=self.qformer_model_tester.get_config(), + ) + + def create_and_check_model(self, config, input_ids, attention_mask, pixel_values): + model = Blip2ForImageTextRetrieval(config).to(torch_device).eval() + with torch.no_grad(): + result = model(pixel_values, input_ids, attention_mask) + + image_size = (self.vision_model_tester.image_size, self.vision_model_tester.image_size) + patch_size = (self.vision_model_tester.patch_size, self.vision_model_tester.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + + self.parent.assertEqual( + result.itm_score.shape, + (self.vision_model_tester.batch_size, 2), + ) + self.parent.assertEqual( + result.last_hidden_state.shape, + (self.vision_model_tester.batch_size, num_patches + 1, self.qformer_model_tester.hidden_size), + ) + self.parent.assertEqual( + result.question_embeds.shape, + ( + self.vision_model_tester.batch_size, + self.vision_model_tester.hidden_size + input_ids.shape[1], + self.qformer_model_tester.hidden_size, + ), + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, attention_mask, pixel_values = config_and_inputs + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + } + return config, inputs_dict + + +@require_torch +class Blip2TextRetrievalModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (Blip2ForImageTextRetrieval,) if is_torch_available() else () + fx_compatible = False + test_head_masking = False + test_pruning = False + test_resize_embeddings = False + test_attention_outputs = False + test_torchscript = False + + def setUp(self): + self.model_tester = Blip2TextRetrievalModelTester(self) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="Hidden_states is tested in individual model tests") + def test_hidden_states_output(self): + pass + + @unittest.skip(reason="Inputs_embeds is tested in individual model tests") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Retain_grad is tested in individual model tests") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="Blip2Model does not have input/output embeddings") + def test_model_common_attributes(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + if model.config.is_encoder_decoder: + # TODO + raise NotImplementedError + else: + expected_arg_names = [ + "pixel_values", + "input_ids", + "attention_mask", + ] + + expected_arg_names.extend(["use_itm_head"] if "use_itm_head" in arg_names else []) + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + + def test_load_vision_qformer_text_config(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + # Save Blip2Config and check if we can load Blip2VisionConfig from it + with tempfile.TemporaryDirectory() as tmp_dir_name: + config.save_pretrained(tmp_dir_name) + vision_config = Blip2VisionConfig.from_pretrained(tmp_dir_name) + self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict()) + + # Save Blip2Config and check if we can load Blip2QFormerConfig from it + with tempfile.TemporaryDirectory() as tmp_dir_name: + config.save_pretrained(tmp_dir_name) + qformer_config = Blip2QFormerConfig.from_pretrained(tmp_dir_name) + self.assertDictEqual(config.qformer_config.to_dict(), qformer_config.to_dict()) + + @slow + @require_torch_gpu + def test_model_from_pretrained(self): + model_name = "jpizarrom/blip2-itm-vit-g" + model = Blip2ForImageTextRetrieval.from_pretrained(model_name) + self.assertIsNotNone(model) + + _, input_ids, attention_mask, pixel_values = self.model_tester.prepare_config_and_inputs() + + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask) + self.assertEqual(outputs.itm_score.shape, (self.model_tester.qformer_model_tester.batch_size, 2)) + + with torch.no_grad(): + outputs = model( + pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, use_itm_head=False + ) + self.assertEqual(outputs.itm_score.shape, (self.model_tester.qformer_model_tester.batch_size, 1)) + + def test_training(self): + pass + + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + # check if `logit_scale` is initilized as per the original implementation + if name == "logit_scale": + self.assertAlmostEqual( + param.data.item(), + np.log(1 / 0.07), + delta=1e-3, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + elif name == "temp": + self.assertAlmostEqual( + param.data.item(), + 0.07, + delta=1e-3, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + else: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + # We will verify our results on an image of cute cats def prepare_img(): url = "https://huggingface.co/hf-internal-testing/blip-test-image/resolve/main/demo.jpg" @@ -1011,3 +1556,90 @@ def test_inference_t5_multi_accelerator(self): [0, 3, 7, 152, 67, 839, 1], ) self.assertEqual(generated_text, "san diego") + + @require_torch_gpu + def test_inference_itm(self): + model_name = "jpizarrom/blip2-itm-vit-g" + processor = Blip2Processor.from_pretrained(model_name) + model = Blip2ForImageTextRetrieval.from_pretrained(model_name).to(torch_device) + + image = prepare_img() + text = "A woman and her dog sitting in a beach" + inputs = processor(images=image, text=text, return_tensors="pt").to(torch_device) + + # forward pass + out_itm = model(**inputs) + out = model(**inputs, use_itm_head=False) + + # verify + expected_scores = torch.Tensor([[0.0238, 0.9762]]) + self.assertTrue(torch.allclose(torch.nn.Softmax()(out_itm[0].cpu()), expected_scores, rtol=1e-3, atol=1e-3)) + self.assertTrue(torch.allclose(out[0].cpu(), torch.Tensor([[0.4406]]), rtol=1e-3, atol=1e-3)) + + @require_torch_gpu + def test_inference_itm_fp16(self): + model_name = "jpizarrom/blip2-itm-vit-g" + processor = Blip2Processor.from_pretrained(model_name) + model = Blip2ForImageTextRetrieval.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device) + + image = prepare_img() + text = "A woman and her dog sitting in a beach" + inputs = processor(images=image, text=text, return_tensors="pt").to(torch_device, dtype=torch.float16) + + # forward pass + out_itm = model(**inputs) + out = model(**inputs, use_itm_head=False) + + # verify + expected_scores = torch.Tensor([[0.0239, 0.9761]]) + self.assertTrue( + torch.allclose(torch.nn.Softmax()(out_itm[0].cpu().float()), expected_scores, rtol=1e-3, atol=1e-3) + ) + self.assertTrue(torch.allclose(out[0].cpu().float(), torch.Tensor([[0.4406]]), rtol=1e-3, atol=1e-3)) + + @require_torch_gpu + def test_inference_vision_with_projection_fp16(self): + model_name = "jpizarrom/blip2-itm-vit-g" + processor = Blip2Processor.from_pretrained(model_name) + model = Blip2VisionModelWithProjection.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device) + + image = prepare_img() + inputs = processor(images=image, return_tensors="pt").to(torch_device, dtype=torch.float16) + + # forward pass + out = model(**inputs) + + # verify + expected_image_embeds = [ + -0.093994140625, + -0.075927734375, + 0.031890869140625, + 0.053009033203125, + 0.0352783203125, + -0.01190185546875, + ] + self.assertTrue(np.allclose(out.image_embeds[0][0][:6].tolist(), expected_image_embeds, atol=1e-3)) + + @require_torch_gpu + def test_inference_text_with_projection_fp16(self): + model_name = "jpizarrom/blip2-itm-vit-g" + processor = Blip2Processor.from_pretrained(model_name) + model = Blip2TextModelWithProjection.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device) + + inputs = processor(text="a woman sitting on the beach with a dog", padding=True, return_tensors="pt").to( + torch_device + ) + + # forward pass + out = model(**inputs) + + # verify + expected_text_embeds = [ + -0.1082763671875, + 0.053192138671875, + -0.02825927734375, + 0.0169830322265625, + 0.08648681640625, + -0.04656982421875, + ] + self.assertTrue(np.allclose(out.text_embeds[0][0][:6].tolist(), expected_text_embeds, atol=1e-3)) diff --git a/utils/check_repo.py b/utils/check_repo.py index 181905dab9f7..d6f53097d785 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -153,6 +153,9 @@ "ClapAudioModel", "ClapAudioModelWithProjection", "Blip2ForConditionalGeneration", + "Blip2ForImageTextRetrieval", + "Blip2TextModelWithProjection", + "Blip2VisionModelWithProjection", "Blip2QFormerModel", "Blip2VisionModel", "ErnieMForInformationExtraction",