@@ -456,6 +456,21 @@ def _init_weights(self, module):
456456            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 
457457""" 
458458
459+ BLIP_2_QFORMER_START_DOCSTRING  =  r""" 
460+     This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 
461+     library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 
462+     etc.) 
463+ 
464+     This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 
465+     Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 
466+     and behavior. 
467+ 
468+     Parameters: 
469+         config ([`Blip2QFormerConfig`]): Model configuration class with all the parameters of the model. 
470+             Initializing with a config file does not load the weights associated with the model, only the 
471+             configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 
472+ """ 
473+ 
459474BLIP_2_VISION_INPUTS_DOCSTRING  =  r""" 
460475    Args: 
461476        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 
@@ -621,6 +636,60 @@ def _init_weights(self, module):
621636""" 
622637
623638
639+ BLIP2_QFORMER_INPUTS_DOCSTRING  =  r""" 
640+     Args: 
641+         query_embeds (`torch.FloatTensor`  of shape `(batch_size, sequence_length, hidden_size)`): 
642+             Hidden states to be used in the attention computation. If cross-attention, 
643+             will be used for the query (i.e., key and value will use the encoder_hidden_states). 
644+ 
645+         query_length (`int`, *optional*): 
646+             Length of the query, usually based on the number of query tokens. 
647+             If no value is provided, query_length will be inferred by the query_embeds. 
648+ 
649+         attention_mask (`torch.FloatTensor`, *optional*): 
650+             Attention mask of size `(batch, sequence_length)` where padding elements 
651+             are indicated by 0. 
652+ 
653+         head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 
654+             Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: 
655+ 
656+             - 1 indicates the head is **not masked**, 
657+             - 0 indicates the head is **masked**. 
658+ 
659+         encoder_hidden_states  (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`): 
660+             Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 
661+             the model is configured as a decoder. 
662+ 
663+         encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`): 
664+             Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 
665+             the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: 
666+             - 1 for tokens that are **not masked**, 
667+             - 0 for tokens that are **masked**. 
668+ 
669+         past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of: 
670+             shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and 
671+             value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are 
672+             used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key 
673+             value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape 
674+             `(batch_size, sequence_length)`. 
675+ 
676+         use_cache (`bool`, `optional`): 
677+             If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 
678+             `past_key_values`). 
679+ 
680+         output_attentions (`bool`, *optional*): 
681+             Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 
682+             tensors for more detail. 
683+ 
684+         output_hidden_states (`bool`, *optional*): 
685+             Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 
686+             more detail. 
687+ 
688+         return_dict (`bool`, *optional*): 
689+             Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 
690+ """ 
691+ 
692+ 
624693# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Blip2 
625694class  Blip2Encoder (nn .Module ):
626695    """ 
@@ -1248,11 +1317,13 @@ def forward(
12481317        return  embeddings 
12491318
12501319
1251- class  Blip2QFormerModel (Blip2PreTrainedModel ):
1252-     """ 
1253-     Querying Transformer (Q-Former), used in BLIP-2. 
1320+ @add_start_docstrings ( 
12541321    """  
1255- 
1322+     BLIP-2 Querying Transformer (Q-Former). 
1323+     """ ,
1324+     BLIP_2_QFORMER_START_DOCSTRING , 
1325+ ) 
1326+ class  Blip2QFormerModel (Blip2PreTrainedModel ):
12561327    def  __init__ (self , config : Blip2QFormerConfig ):
12571328        super ().__init__ (config )
12581329        self .config  =  config 
@@ -1323,6 +1394,10 @@ def get_extended_attention_mask(
13231394        extended_attention_mask  =  (1.0  -  extended_attention_mask ) *  - 10000.0 
13241395        return  extended_attention_mask 
13251396
1397+     @add_start_docstrings_to_model_forward (BLIP2_QFORMER_INPUTS_DOCSTRING ) 
1398+     @replace_return_docstrings ( 
1399+         output_type = BaseModelOutputWithPoolingAndCrossAttentions , config_class = Blip2QFormerConfig  
1400+     ) 
13261401    def  forward (
13271402        self ,
13281403        query_embeds : torch .FloatTensor ,
@@ -1338,23 +1413,7 @@ def forward(
13381413        return_dict : Optional [bool ] =  None ,
13391414    ) ->  Union [Tuple [torch .Tensor ], BaseModelOutputWithPoolingAndCrossAttentions ]:
13401415        r""" 
1341-         encoder_hidden_states  (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`): 
1342-             Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 
1343-             the model is configured as a decoder. 
1344-         encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`): 
1345-             Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 
1346-             the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: 
1347-             - 1 for tokens that are **not masked**, 
1348-             - 0 for tokens that are **masked**. 
1349-         past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of: 
1350-             shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and 
1351-             value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are 
1352-             used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key 
1353-             value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape 
1354-             `(batch_size, sequence_length)`. 
1355-         use_cache (`bool`, `optional`): 
1356-             If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 
1357-             `past_key_values`). 
1416+         Returns: 
13581417        """ 
13591418        output_attentions  =  output_attentions  if  output_attentions  is  not   None  else  self .config .output_attentions 
13601419        output_hidden_states  =  (
0 commit comments