Skip to content

Commit

Permalink
change order of params in Blip2QFormerModel.forward
Browse files Browse the repository at this point in the history
  • Loading branch information
jpizarrom committed May 1, 2024
1 parent 715f6fa commit da0cc83
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,20 +1265,27 @@ def get_extended_attention_mask(

def forward(
self,
query_embeds: Optional[torch.FloatTensor] = None,
input_ids: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
query_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
input_ids: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,


) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
r"""
input_ids (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of input sequence tokens in the vocabulary of the language model.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
Expand All @@ -1296,11 +1303,6 @@ def forward(
use_cache (`bool`, `optional`):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
input_ids (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of input sequence tokens in the vocabulary of the language model.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down

0 comments on commit da0cc83

Please sign in to comment.