Skip to content

Commit

Permalink
fix assisted gen (consistent return api)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Oct 11, 2024
1 parent 53d8e10 commit 262c971
Showing 1 changed file with 11 additions and 24 deletions.
35 changes: 11 additions & 24 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4167,17 +4167,8 @@ def _assisted_decoding(
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

# This is needed if return_dict_in_generate is True
start_from_empty_dynamic_cache = False
past_key_values = model_kwargs.get("past_key_values", None)
if isinstance(past_key_values, DynamicCache) or (
isinstance(past_key_values, EncoderDecoderCache)
and isinstance(past_key_values.self_attention_cache, DynamicCache)
):
if past_key_values.get_seq_length() == 0:
start_from_empty_dynamic_cache = True

this_peer_finished = False
is_first_iteration = True # to preserve the same API in the output as other generation methods
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
cur_len = input_ids.shape[-1]

Expand Down Expand Up @@ -4289,50 +4280,46 @@ def _assisted_decoding(
# Store scores, attentions and hidden_states when required
# Assistant: modified to append one tuple element per token, as in the other generation methods.
if return_dict_in_generate:
newly_added_length = n_matches + 1
if output_scores:
scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1))
scores += tuple(new_logits[:, i, :] for i in range(newly_added_length))
if output_logits:
raw_logits += (next_token_logits,)

if "past_key_values" not in model_kwargs or start_from_empty_dynamic_cache:
added_len = new_cur_len
# set it to false for other iterations
start_from_empty_dynamic_cache = False
else:
added_len = n_matches + 1
raw_logits += tuple(next_token_logits[:, i, :] for i in range(newly_added_length))

newly_added_length = new_cur_len if is_first_iteration else newly_added_length
if output_attentions:
if self.config.is_encoder_decoder:
cross_attentions = _split_model_outputs(
cross_attentions, outputs.cross_attentions, cur_len, added_len
cross_attentions, outputs.cross_attentions, cur_len, newly_added_length
)
decoder_attentions = _split_model_outputs(
decoder_attentions,
outputs.decoder_attentions,
cur_len,
added_len,
newly_added_length,
is_decoder_attention=True,
)
else:
decoder_attentions = _split_model_outputs(
decoder_attentions,
outputs.attentions,
cur_len,
added_len,
newly_added_length,
is_decoder_attention=True,
)
if output_hidden_states:
if self.config.is_encoder_decoder:
decoder_hidden_states = _split_model_outputs(
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, newly_added_length
)
else:
decoder_hidden_states = _split_model_outputs(
decoder_hidden_states, outputs.hidden_states, cur_len, added_len
decoder_hidden_states, outputs.hidden_states, cur_len, newly_added_length
)

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
is_first_iteration = False

if streamer is not None:
streamer.end()
Expand Down

0 comments on commit 262c971

Please sign in to comment.