Skip to content

Commit 9a7928f

Browse files
alex-jw-brookszucchini-nlp
authored andcommitted
Expose blip2qformer (huggingface#37254)
* Expose blip2qformer * Add missing args to blip2 config
1 parent 79f14be commit 9a7928f

File tree

6 files changed

+86
-23
lines changed

6 files changed

+86
-23
lines changed

src/transformers/models/auto/configuration_auto.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
("blenderbot-small", "BlenderbotSmallConfig"),
5555
("blip", "BlipConfig"),
5656
("blip-2", "Blip2Config"),
57+
("blip_2_qformer", "Blip2QFormerConfig"),
5758
("bloom", "BloomConfig"),
5859
("bridgetower", "BridgeTowerConfig"),
5960
("bros", "BrosConfig"),
@@ -391,6 +392,7 @@
391392
("blenderbot-small", "BlenderbotSmall"),
392393
("blip", "BLIP"),
393394
("blip-2", "BLIP-2"),
395+
("blip_2_qformer", "BLIP-2 QFormer"),
394396
("bloom", "BLOOM"),
395397
("bort", "BORT"),
396398
("bridgetower", "BridgeTower"),
@@ -781,6 +783,7 @@
781783
("granitevision", "llava_next"),
782784
("sam_vision_model", "sam"),
783785
("llama4_text", "llama4"),
786+
("blip_2_qformer", "blip_2"),
784787
]
785788
)
786789

src/transformers/models/auto/modeling_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
("blenderbot-small", "BlenderbotSmallModel"),
5454
("blip", "BlipModel"),
5555
("blip-2", "Blip2Model"),
56+
("blip_2_qformer", "Blip2QFormerModel"),
5657
("bloom", "BloomModel"),
5758
("bridgetower", "BridgeTowerModel"),
5859
("bros", "BrosModel"),

src/transformers/models/blip_2/configuration_blip_2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ class Blip2QFormerConfig(PretrainedConfig):
144144
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
145145
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
146146
The epsilon used by the layer normalization layers.
147+
pad_token_id (`int`, *optional*, defaults to 0):
148+
Index to be used for padding token.
147149
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
148150
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
149151
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to

src/transformers/models/blip_2/modeling_blip_2.py

Lines changed: 80 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,21 @@ def _init_weights(self, module):
456456
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
457457
"""
458458

459+
BLIP_2_QFORMER_START_DOCSTRING = r"""
460+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
461+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
462+
etc.)
463+
464+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
465+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
466+
and behavior.
467+
468+
Parameters:
469+
config ([`Blip2QFormerConfig`]): Model configuration class with all the parameters of the model.
470+
Initializing with a config file does not load the weights associated with the model, only the
471+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
472+
"""
473+
459474
BLIP_2_VISION_INPUTS_DOCSTRING = r"""
460475
Args:
461476
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
@@ -621,6 +636,60 @@ def _init_weights(self, module):
621636
"""
622637

623638

639+
BLIP2_QFORMER_INPUTS_DOCSTRING = r"""
640+
Args:
641+
query_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
642+
Hidden states to be used in the attention computation. If cross-attention,
643+
will be used for the query (i.e., key and value will use the encoder_hidden_states).
644+
645+
query_length (`int`, *optional*):
646+
Length of the query, usually based on the number of query tokens.
647+
If no value is provided, query_length will be inferred by the query_embeds.
648+
649+
attention_mask (`torch.FloatTensor`, *optional*):
650+
Attention mask of size `(batch, sequence_length)` where padding elements
651+
are indicated by 0.
652+
653+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
654+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
655+
656+
- 1 indicates the head is **not masked**,
657+
- 0 indicates the head is **masked**.
658+
659+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
660+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
661+
the model is configured as a decoder.
662+
663+
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
664+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
665+
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
666+
- 1 for tokens that are **not masked**,
667+
- 0 for tokens that are **masked**.
668+
669+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
670+
shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
671+
value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
672+
used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
673+
value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
674+
`(batch_size, sequence_length)`.
675+
676+
use_cache (`bool`, `optional`):
677+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
678+
`past_key_values`).
679+
680+
output_attentions (`bool`, *optional*):
681+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
682+
tensors for more detail.
683+
684+
output_hidden_states (`bool`, *optional*):
685+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
686+
more detail.
687+
688+
return_dict (`bool`, *optional*):
689+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
690+
"""
691+
692+
624693
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Blip2
625694
class Blip2Encoder(nn.Module):
626695
"""
@@ -1248,11 +1317,13 @@ def forward(
12481317
return embeddings
12491318

12501319

1251-
class Blip2QFormerModel(Blip2PreTrainedModel):
1252-
"""
1253-
Querying Transformer (Q-Former), used in BLIP-2.
1320+
@add_start_docstrings(
12541321
"""
1255-
1322+
BLIP-2 Querying Transformer (Q-Former).
1323+
""",
1324+
BLIP_2_QFORMER_START_DOCSTRING,
1325+
)
1326+
class Blip2QFormerModel(Blip2PreTrainedModel):
12561327
def __init__(self, config: Blip2QFormerConfig):
12571328
super().__init__(config)
12581329
self.config = config
@@ -1323,6 +1394,10 @@ def get_extended_attention_mask(
13231394
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
13241395
return extended_attention_mask
13251396

1397+
@add_start_docstrings_to_model_forward(BLIP2_QFORMER_INPUTS_DOCSTRING)
1398+
@replace_return_docstrings(
1399+
output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=Blip2QFormerConfig
1400+
)
13261401
def forward(
13271402
self,
13281403
query_embeds: torch.FloatTensor,
@@ -1338,23 +1413,7 @@ def forward(
13381413
return_dict: Optional[bool] = None,
13391414
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
13401415
r"""
1341-
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
1342-
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1343-
the model is configured as a decoder.
1344-
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
1345-
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1346-
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1347-
- 1 for tokens that are **not masked**,
1348-
- 0 for tokens that are **masked**.
1349-
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
1350-
shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
1351-
value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
1352-
used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
1353-
value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
1354-
`(batch_size, sequence_length)`.
1355-
use_cache (`bool`, `optional`):
1356-
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1357-
`past_key_values`).
1416+
Returns:
13581417
"""
13591418
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
13601419
output_hidden_states = (

utils/check_docstrings.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@
106106
"BlenderbotSmallConfig",
107107
"BlenderbotSmallTokenizerFast",
108108
"BlenderbotTokenizerFast",
109-
"Blip2QFormerConfig",
110109
"Blip2VisionConfig",
111110
"BlipTextConfig",
112111
"BlipVisionConfig",

utils/check_repo.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@
187187
"ClapAudioModelWithProjection",
188188
"Blip2TextModelWithProjection",
189189
"Blip2VisionModelWithProjection",
190-
"Blip2QFormerModel",
191190
"Blip2VisionModel",
192191
"ErnieMForInformationExtraction",
193192
"FastSpeech2ConformerHifiGan",

0 commit comments

Comments
 (0)