diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index e28615bb3c2621..f9bd4db76325a7 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -173,10 +173,12 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length # (which implicitly contains the number of accepted candidates from the previous round) + new_cur_len = input_ids.shape[-1] + max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1) + if max_new_tokens == 0: + return input_ids, None has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None if has_past_key_values: - new_cur_len = input_ids.shape[-1] - new_cache_size = new_cur_len - 1 self.assistant_kwargs["past_key_values"] = _crop_past_key_values( self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 @@ -190,7 +192,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, # 2. Forecast next N tokens using the assistant model. assistant_generation_kwargs = { self.input_ids_key: input_ids, - "max_new_tokens": int(self.num_assistant_tokens), + "max_new_tokens": max_new_tokens, "generation_config": self.generation_config, "logits_processor": self.logits_processor, } diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7e269a7b178019..f1856eba0327a4 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4658,9 +4658,11 @@ def assisted_decoding( else: selected_tokens = new_logits.argmax(dim=-1) - candidate_new_tokens = candidate_input_ids[:, -candidate_length:] - n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() - + if candidate_length > 0: + candidate_new_tokens = candidate_input_ids[:, -candidate_length:] + n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() + else: + n_matches = 0 # Ensure we don't generate beyond max_len or an EOS token if last_assistant_token_is_eos and n_matches == candidate_length: n_matches -= 1