Skip to content

Commit

Permalink
Idefics: generate fix (#29320)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored Feb 28, 2024
1 parent 2ce56d3 commit 7628b3a
Showing 1 changed file with 21 additions and 33 deletions.
54 changes: 21 additions & 33 deletions src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# limitations under the License.
""" PyTorch Idefics model."""
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -187,35 +187,6 @@ def expand_inputs_for_generation(
return input_ids, model_kwargs


def update_model_kwargs_for_generation(outputs, model_kwargs):
# must have this key set to at least None
if "past_key_values" in outputs:
model_kwargs["past_key_values"] = outputs.past_key_values
else:
model_kwargs["past_key_values"] = None

# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)

# update attention masks
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
if "image_attention_mask" in model_kwargs:
image_attention_mask = model_kwargs["image_attention_mask"]
last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
model_kwargs["image_attention_mask"] = last_mask

# Get the precomputed image_hidden_states
model_kwargs["image_hidden_states"] = outputs.image_hidden_states

return model_kwargs


def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
Expand Down Expand Up @@ -1580,9 +1551,26 @@ def _expand_inputs_for_generation(
):
return expand_inputs_for_generation(*args, **model_kwargs)

@staticmethod
def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder):
return update_model_kwargs_for_generation(outputs, model_kwargs)
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
model_inputs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
model_kwargs = super()._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder, standardize_cache_format, model_inputs
)

if "image_attention_mask" in model_kwargs:
image_attention_mask = model_kwargs["image_attention_mask"]
last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
model_kwargs["image_attention_mask"] = last_mask

# Get the precomputed image_hidden_states
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
return model_kwargs

@staticmethod
def _reorder_cache(past, beam_idx):
Expand Down

0 comments on commit 7628b3a

Please sign in to comment.