From 190a63a0b9741cf7c2596557f11eb4da6203c04f Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Wed, 17 Jan 2024 04:02:31 +0200 Subject: [PATCH] @ofirzaf Fix candidate generation more than max_length - 1 --- .../generation/candidate_generator.py | 6 ++-- src/transformers/generation/utils.py | 34 ++++++++++--------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index e28615bb3c2621..0c2be0bcf4b3ee 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -173,10 +173,10 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length # (which implicitly contains the number of accepted candidates from the previous round) + new_cur_len = input_ids.shape[-1] + max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1) has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None if has_past_key_values: - new_cur_len = input_ids.shape[-1] - new_cache_size = new_cur_len - 1 self.assistant_kwargs["past_key_values"] = _crop_past_key_values( self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 @@ -190,7 +190,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, # 2. Forecast next N tokens using the assistant model. assistant_generation_kwargs = { self.input_ids_key: input_ids, - "max_new_tokens": int(self.num_assistant_tokens), + "max_new_tokens": max_new_tokens, "generation_config": self.generation_config, "logits_processor": self.logits_processor, } diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 87cdc5b7166e94..7231de2003e02e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4813,24 +4813,26 @@ def _speculative_sampling( # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior) if last_assistant_token_is_eos and n_matches == candidate_length: n_matches -= 1 - n_matches = min(n_matches, max_matches) - - # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling. - gamma = candidate_logits.shape[1] - p_n_plus_1 = p[:, n_matches, :] - if n_matches < gamma: - q_n_plus_1 = q[:, n_matches, :] - p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0) - p_prime.div_(p_prime.sum()) + valid_tokens = new_candidate_input_ids[:, : n_matches + 1] else: - p_prime = p_n_plus_1 - t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :] + n_matches = min(n_matches, max_matches) + + # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling. + gamma = min(candidate_logits.shape[1], max_matches) + p_n_plus_1 = p[:, n_matches, :] + if n_matches < gamma: + q_n_plus_1 = q[:, n_matches, :] + p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0) + p_prime.div_(p_prime.sum()) + else: + p_prime = p_n_plus_1 + t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :] - # The selected tokens include the matches (if any) plus the next sampled tokens - if n_matches > 0: - valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1) - else: - valid_tokens = t + # The selected tokens include the matches (if any) plus the next sampled tokens + if n_matches > 0: + valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1) + else: + valid_tokens = t return valid_tokens, n_matches