Skip to content

Commit

Permalink
tmp commit
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Mar 5, 2024
1 parent 7628b3a commit fa1b49f
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 45 deletions.
53 changes: 34 additions & 19 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,6 @@ def _update_model_kwargs_for_generation(
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
model_inputs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
Expand Down Expand Up @@ -678,7 +677,8 @@ def _update_model_kwargs_for_generation(
dim=-1,
)

model_kwargs["cache_position"] = model_inputs.get("cache_position", None)
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1

return model_kwargs

Expand Down Expand Up @@ -1946,10 +1946,11 @@ def contrastive_search(
)

# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
batch_size, cur_len = input_ids.shape
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

this_peer_finished = False # used by synced_gpus only
batch_size = input_ids.shape[0]

while True:
if synced_gpus:
Expand Down Expand Up @@ -1990,7 +1991,6 @@ def contrastive_search(
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
standardize_cache_format=True,
model_inputs=model_inputs,
)
if not sequential:
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
Expand Down Expand Up @@ -2185,7 +2185,7 @@ def contrastive_search(
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder,
)

# if eos_token was found in one sentence, set sentence to finished
Expand Down Expand Up @@ -2397,7 +2397,9 @@ def greedy_search(
)

# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
batch_size, cur_len = input_ids.shape
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

this_peer_finished = False # used by synced_gpus only
while True:
Expand Down Expand Up @@ -2464,10 +2466,7 @@ def greedy_search(
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
model_inputs=model_inputs,
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder,
)

# if eos_token was found in one sentence, set sentence to finished
Expand Down Expand Up @@ -2689,7 +2688,9 @@ def sample(
)

# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
batch_size, cur_len = input_ids.shape
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

this_peer_finished = False # used by synced_gpus only
# auto-regressive generation
Expand Down Expand Up @@ -2759,7 +2760,7 @@ def sample(
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder,
)

# if eos_token was found in one sentence, set sentence to finished
Expand Down Expand Up @@ -2997,6 +2998,7 @@ def beam_search(
num_beams = beam_scorer.num_beams

batch_beam_size, cur_len = input_ids.shape
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

if num_beams * batch_size != batch_beam_size:
raise ValueError(
Expand Down Expand Up @@ -3150,7 +3152,7 @@ def beam_search(
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder,
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
Expand Down Expand Up @@ -3384,6 +3386,7 @@ def beam_sample(
num_beams = beam_scorer.num_beams

batch_beam_size, cur_len = input_ids.shape
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
Expand Down Expand Up @@ -3497,7 +3500,7 @@ def beam_sample(
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder,
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
Expand Down Expand Up @@ -3727,6 +3730,7 @@ def group_beam_search(
device = input_ids.device

batch_beam_size, cur_len = input_ids.shape
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

if return_dict_in_generate and output_scores:
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
Expand Down Expand Up @@ -3896,7 +3900,7 @@ def group_beam_search(
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder,
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
Expand Down Expand Up @@ -4128,6 +4132,7 @@ def constrained_beam_search(
num_beams = constrained_beam_scorer.num_beams

batch_beam_size, cur_len = input_ids.shape
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

if num_beams * batch_size != batch_beam_size:
raise ValueError(
Expand Down Expand Up @@ -4248,7 +4253,7 @@ def constrained_beam_search(

input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder,
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
Expand Down Expand Up @@ -4477,7 +4482,9 @@ def assisted_decoding(
)

# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
batch_size, cur_len = input_ids.shape
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

# other auxiliary variables
max_len = stopping_criteria[0].max_length
Expand Down Expand Up @@ -4521,6 +4528,14 @@ def assisted_decoding(
candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
)
candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
if "cache_position" in candidate_kwargs:
candidate_kwargs["cache_position"] = torch.cat(
(
candidate_kwargs["cache_position"],
torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long)
),
dim=0,
)

model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)

Expand Down Expand Up @@ -4639,7 +4654,7 @@ def assisted_decoding(
)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder,
)

# if eos_token was found in one sentence, set sentence to finished
Expand Down
10 changes: 7 additions & 3 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def forward(
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)

if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

Expand Down Expand Up @@ -334,7 +334,7 @@ def forward(
past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

Expand Down Expand Up @@ -533,7 +533,7 @@ def forward(
past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

Expand Down Expand Up @@ -782,6 +782,10 @@ def _reset_cache(self):
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""


Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1557,10 +1557,9 @@ def _update_model_kwargs_for_generation(
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
model_inputs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
model_kwargs = super()._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder, standardize_cache_format, model_inputs
outputs, model_kwargs, is_encoder_decoder, standardize_cache_format,
)

if "image_attention_mask" in model_kwargs:
Expand Down
49 changes: 28 additions & 21 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def forward(
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

Expand Down Expand Up @@ -444,7 +444,7 @@ def forward(
past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

Expand Down Expand Up @@ -643,7 +643,7 @@ def forward(
past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

Expand Down Expand Up @@ -896,6 +896,10 @@ def _reset_cache(self):
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""


Expand Down Expand Up @@ -1224,14 +1228,22 @@ def forward(
)

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
):
# With static cache, the `past_key_values` is None
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache = False
if past_key_values is None:
past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
has_static_cache = past_key_values is not None

past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
Expand Down Expand Up @@ -1264,20 +1276,6 @@ def prepare_inputs_for_generation(
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None:
# generation with static cache
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
past_length = 0
else:
past_length = cache_position[-1] + 1
input_ids = input_ids[:, past_length:]
position_ids = position_ids[:, past_length:]

# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
Expand All @@ -1287,6 +1285,15 @@ def prepare_inputs_for_generation(
# TODO: use `next_tokens` directly instead.
model_inputs = {"input_ids": input_ids.contiguous()}

if cache_position is None:
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
else:
cache_position = cache_position[-input_ids.shape[1] :]

if has_static_cache:
past_key_values = None

model_inputs.update(
{
"position_ids": position_ids.contiguous(),
Expand Down

0 comments on commit fa1b49f

Please sign in to comment.