@@ -456,6 +456,21 @@ def _init_weights(self, module):
456456 configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
457457"""
458458
459+ BLIP_2_QFORMER_START_DOCSTRING = r"""
460+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
461+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
462+ etc.)
463+
464+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
465+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
466+ and behavior.
467+
468+ Parameters:
469+ config ([`Blip2QFormerConfig`]): Model configuration class with all the parameters of the model.
470+ Initializing with a config file does not load the weights associated with the model, only the
471+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
472+ """
473+
459474BLIP_2_VISION_INPUTS_DOCSTRING = r"""
460475 Args:
461476 pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
@@ -621,6 +636,60 @@ def _init_weights(self, module):
621636"""
622637
623638
639+ BLIP2_QFORMER_INPUTS_DOCSTRING = r"""
640+ Args:
641+ query_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
642+ Hidden states to be used in the attention computation. If cross-attention,
643+ will be used for the query (i.e., key and value will use the encoder_hidden_states).
644+
645+ query_length (`int`, *optional*):
646+ Length of the query, usually based on the number of query tokens.
647+ If no value is provided, query_length will be inferred by the query_embeds.
648+
649+ attention_mask (`torch.FloatTensor`, *optional*):
650+ Attention mask of size `(batch, sequence_length)` where padding elements
651+ are indicated by 0.
652+
653+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
654+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
655+
656+ - 1 indicates the head is **not masked**,
657+ - 0 indicates the head is **masked**.
658+
659+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
660+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
661+ the model is configured as a decoder.
662+
663+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
664+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
665+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
666+ - 1 for tokens that are **not masked**,
667+ - 0 for tokens that are **masked**.
668+
669+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
670+ shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
671+ value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
672+ used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
673+ value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
674+ `(batch_size, sequence_length)`.
675+
676+ use_cache (`bool`, `optional`):
677+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
678+ `past_key_values`).
679+
680+ output_attentions (`bool`, *optional*):
681+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
682+ tensors for more detail.
683+
684+ output_hidden_states (`bool`, *optional*):
685+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
686+ more detail.
687+
688+ return_dict (`bool`, *optional*):
689+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
690+ """
691+
692+
624693# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Blip2
625694class Blip2Encoder (nn .Module ):
626695 """
@@ -1248,11 +1317,13 @@ def forward(
12481317 return embeddings
12491318
12501319
1251- class Blip2QFormerModel (Blip2PreTrainedModel ):
1252- """
1253- Querying Transformer (Q-Former), used in BLIP-2.
1320+ @add_start_docstrings (
12541321 """
1255-
1322+ BLIP-2 Querying Transformer (Q-Former).
1323+ """ ,
1324+ BLIP_2_QFORMER_START_DOCSTRING ,
1325+ )
1326+ class Blip2QFormerModel (Blip2PreTrainedModel ):
12561327 def __init__ (self , config : Blip2QFormerConfig ):
12571328 super ().__init__ (config )
12581329 self .config = config
@@ -1323,6 +1394,10 @@ def get_extended_attention_mask(
13231394 extended_attention_mask = (1.0 - extended_attention_mask ) * - 10000.0
13241395 return extended_attention_mask
13251396
1397+ @add_start_docstrings_to_model_forward (BLIP2_QFORMER_INPUTS_DOCSTRING )
1398+ @replace_return_docstrings (
1399+ output_type = BaseModelOutputWithPoolingAndCrossAttentions , config_class = Blip2QFormerConfig
1400+ )
13261401 def forward (
13271402 self ,
13281403 query_embeds : torch .FloatTensor ,
@@ -1338,23 +1413,7 @@ def forward(
13381413 return_dict : Optional [bool ] = None ,
13391414 ) -> Union [Tuple [torch .Tensor ], BaseModelOutputWithPoolingAndCrossAttentions ]:
13401415 r"""
1341- encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
1342- Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1343- the model is configured as a decoder.
1344- encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
1345- Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1346- the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1347- - 1 for tokens that are **not masked**,
1348- - 0 for tokens that are **masked**.
1349- past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
1350- shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
1351- value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
1352- used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
1353- value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
1354- `(batch_size, sequence_length)`.
1355- use_cache (`bool`, `optional`):
1356- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1357- `past_key_values`).
1416+ Returns:
13581417 """
13591418 output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
13601419 output_hidden_states = (
0 commit comments