-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: handle cache_position
update in generate
#29467
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also set the dtype of the cache positions to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Our integers inputs ( |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -633,7 +633,6 @@ def _update_model_kwargs_for_generation( | |
model_kwargs: Dict[str, Any], | ||
is_encoder_decoder: bool = False, | ||
standardize_cache_format: bool = False, | ||
model_inputs: Optional[Dict[str, Any]] = None, | ||
) -> Dict[str, Any]: | ||
# update past_key_values | ||
model_kwargs["past_key_values"] = self._extract_past_from_model_output( | ||
|
@@ -663,7 +662,8 @@ def _update_model_kwargs_for_generation( | |
dim=-1, | ||
) | ||
|
||
model_kwargs["cache_position"] = model_inputs.get("cache_position", None) | ||
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None: | ||
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. my single worry here is potential stride, adding a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've double-checked, it's always Its shape will indeed be different, at least between prefill and subsequent generation |
||
|
||
return model_kwargs | ||
|
||
|
@@ -1931,10 +1931,15 @@ def _contrastive_search( | |
) | ||
|
||
# keep track of which sequences are already finished | ||
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | ||
batch_size, cur_len = ( | ||
model_kwargs["attention_mask"].shape | ||
if model_kwargs.get("attention_mask", None) is not None | ||
else input_ids.shape | ||
) | ||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) | ||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) | ||
|
||
this_peer_finished = False # used by synced_gpus only | ||
batch_size = input_ids.shape[0] | ||
|
||
while True: | ||
if synced_gpus: | ||
|
@@ -1975,7 +1980,6 @@ def _contrastive_search( | |
model_kwargs, | ||
is_encoder_decoder=self.config.is_encoder_decoder, | ||
standardize_cache_format=True, | ||
model_inputs=model_inputs, | ||
) | ||
if not sequential: | ||
# Expands model inputs top_k times, for batched forward passes (akin to beam search). | ||
|
@@ -2170,7 +2174,9 @@ def _contrastive_search( | |
if streamer is not None: | ||
streamer.put(next_tokens.cpu()) | ||
model_kwargs = self._update_model_kwargs_for_generation( | ||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs | ||
outputs, | ||
model_kwargs, | ||
is_encoder_decoder=self.config.is_encoder_decoder, | ||
) | ||
|
||
# if eos_token was found in one sentence, set sentence to finished | ||
|
@@ -2389,7 +2395,13 @@ def _greedy_search( | |
) | ||
|
||
# keep track of which sequences are already finished | ||
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | ||
batch_size, cur_len = ( | ||
model_kwargs["attention_mask"].shape | ||
if model_kwargs.get("attention_mask", None) is not None | ||
else input_ids.shape | ||
) | ||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) | ||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) | ||
|
||
this_peer_finished = False # used by synced_gpus only | ||
while True: | ||
|
@@ -2459,7 +2471,6 @@ def _greedy_search( | |
outputs, | ||
model_kwargs, | ||
is_encoder_decoder=self.config.is_encoder_decoder, | ||
model_inputs=model_inputs, | ||
) | ||
|
||
# if eos_token was found in one sentence, set sentence to finished | ||
|
@@ -2688,7 +2699,13 @@ def _sample( | |
) | ||
|
||
# keep track of which sequences are already finished | ||
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | ||
batch_size, cur_len = ( | ||
model_kwargs["attention_mask"].shape | ||
if model_kwargs.get("attention_mask", None) is not None | ||
else input_ids.shape | ||
) | ||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) | ||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) | ||
|
||
this_peer_finished = False # used by synced_gpus only | ||
# auto-regressive generation | ||
|
@@ -2758,7 +2775,9 @@ def _sample( | |
if streamer is not None: | ||
streamer.put(next_tokens.cpu()) | ||
model_kwargs = self._update_model_kwargs_for_generation( | ||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs | ||
outputs, | ||
model_kwargs, | ||
is_encoder_decoder=self.config.is_encoder_decoder, | ||
) | ||
|
||
# if eos_token was found in one sentence, set sentence to finished | ||
|
@@ -3003,6 +3022,7 @@ def _beam_search( | |
num_beams = beam_scorer.num_beams | ||
|
||
batch_beam_size, cur_len = input_ids.shape | ||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) | ||
|
||
if num_beams * batch_size != batch_beam_size: | ||
raise ValueError( | ||
|
@@ -3156,7 +3176,9 @@ def _beam_search( | |
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) | ||
|
||
model_kwargs = self._update_model_kwargs_for_generation( | ||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs | ||
outputs, | ||
model_kwargs, | ||
is_encoder_decoder=self.config.is_encoder_decoder, | ||
) | ||
if model_kwargs.get("past_key_values", None) is not None: | ||
model_kwargs["past_key_values"] = self._temporary_reorder_cache( | ||
|
@@ -3397,6 +3419,7 @@ def _beam_sample( | |
num_beams = beam_scorer.num_beams | ||
|
||
batch_beam_size, cur_len = input_ids.shape | ||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) | ||
|
||
# init attention / hidden states / scores tuples | ||
scores = () if (return_dict_in_generate and output_scores) else None | ||
|
@@ -3510,7 +3533,9 @@ def _beam_sample( | |
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) | ||
|
||
model_kwargs = self._update_model_kwargs_for_generation( | ||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs | ||
outputs, | ||
model_kwargs, | ||
is_encoder_decoder=self.config.is_encoder_decoder, | ||
) | ||
if model_kwargs.get("past_key_values", None) is not None: | ||
model_kwargs["past_key_values"] = self._temporary_reorder_cache( | ||
|
@@ -3747,6 +3772,7 @@ def _group_beam_search( | |
device = input_ids.device | ||
|
||
batch_beam_size, cur_len = input_ids.shape | ||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) | ||
|
||
if return_dict_in_generate and output_scores: | ||
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)] | ||
|
@@ -3916,7 +3942,9 @@ def _group_beam_search( | |
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) | ||
|
||
model_kwargs = self._update_model_kwargs_for_generation( | ||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs | ||
outputs, | ||
model_kwargs, | ||
is_encoder_decoder=self.config.is_encoder_decoder, | ||
) | ||
if model_kwargs.get("past_key_values", None) is not None: | ||
model_kwargs["past_key_values"] = self._temporary_reorder_cache( | ||
|
@@ -4155,6 +4183,7 @@ def _constrained_beam_search( | |
num_beams = constrained_beam_scorer.num_beams | ||
|
||
batch_beam_size, cur_len = input_ids.shape | ||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) | ||
|
||
if num_beams * batch_size != batch_beam_size: | ||
raise ValueError( | ||
|
@@ -4275,7 +4304,9 @@ def _constrained_beam_search( | |
|
||
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) | ||
model_kwargs = self._update_model_kwargs_for_generation( | ||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs | ||
outputs, | ||
model_kwargs, | ||
is_encoder_decoder=self.config.is_encoder_decoder, | ||
) | ||
if model_kwargs.get("past_key_values", None) is not None: | ||
model_kwargs["past_key_values"] = self._temporary_reorder_cache( | ||
|
@@ -4511,7 +4542,13 @@ def _assisted_decoding( | |
) | ||
|
||
# keep track of which sequences are already finished | ||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) | ||
batch_size, cur_len = batch_size, cur_len = ( | ||
model_kwargs["attention_mask"].shape | ||
if model_kwargs.get("attention_mask", None) is not None | ||
else input_ids.shape | ||
) | ||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) | ||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) | ||
|
||
# other auxiliary variables | ||
max_len = stopping_criteria[0].max_length | ||
|
@@ -4555,6 +4592,14 @@ def _assisted_decoding( | |
candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder | ||
) | ||
candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) | ||
if "cache_position" in candidate_kwargs: | ||
candidate_kwargs["cache_position"] = torch.cat( | ||
( | ||
candidate_kwargs["cache_position"], | ||
torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long), | ||
), | ||
dim=0, | ||
) | ||
|
||
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs) | ||
|
||
|
@@ -4673,7 +4718,9 @@ def _assisted_decoding( | |
) | ||
|
||
model_kwargs = self._update_model_kwargs_for_generation( | ||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs | ||
outputs, | ||
model_kwargs, | ||
is_encoder_decoder=self.config.is_encoder_decoder, | ||
) | ||
|
||
# if eos_token was found in one sentence, set sentence to finished | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alright, we are deprecating this anyways