diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index e28615bb3c2621..3bdd88300469b9 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -171,12 +171,16 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, """ input_ids = input_ids.to(self.assistant_model.device) + # Don't generate more than `max_length - 1` candidates since the target model generates one extra token. + new_cur_len = input_ids.shape[-1] + max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1) + if max_new_tokens == 0: + return input_ids, None + # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length # (which implicitly contains the number of accepted candidates from the previous round) has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None if has_past_key_values: - new_cur_len = input_ids.shape[-1] - new_cache_size = new_cur_len - 1 self.assistant_kwargs["past_key_values"] = _crop_past_key_values( self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 @@ -190,7 +194,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, # 2. Forecast next N tokens using the assistant model. assistant_generation_kwargs = { self.input_ids_key: input_ids, - "max_new_tokens": int(self.num_assistant_tokens), + "max_new_tokens": max_new_tokens, "generation_config": self.generation_config, "logits_processor": self.logits_processor, } diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0a05a81ec9a8e3..a0d2f61e5f6bde 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4404,7 +4404,7 @@ def assisted_decoding( else: selected_tokens = new_logits.argmax(dim=-1) - candidate_new_tokens = candidate_input_ids[:, -candidate_length:] + candidate_new_tokens = candidate_input_ids[:, cur_len:] n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() # Ensure we don't generate beyond max_len or an EOS token @@ -4540,12 +4540,13 @@ def _speculative_sampling( NOTE: Unless otherwise stated, the variable names match those in the paper. """ + new_candidate_input_ids = candidate_input_ids[:, -candidate_length:] # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens # selected by the assistant, respectively. q = candidate_logits.softmax(dim=-1) - q_i = q[:, torch.arange(candidate_length), candidate_input_ids[:, -candidate_length:]].squeeze(0, 1) + q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1) p = new_logits.softmax(dim=-1) - p_i = p[:, torch.arange(candidate_length), candidate_input_ids[:, -candidate_length:]].squeeze(0, 1) + p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1) probability_ratio = p_i / q_i # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller @@ -4553,28 +4554,33 @@ def _speculative_sampling( # (= keep with p = probability_ratio). Keep all the tokens until the first rejection r_i = torch.rand_like(probability_ratio) is_accepted = r_i <= probability_ratio - n_matches = (~is_accepted.cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1 + n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1 # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior) if last_assistant_token_is_eos and n_matches == candidate_length: + # Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model + # due to acceptance on EOS we fix `n_matches` n_matches -= 1 - n_matches = min(n_matches, max_matches) - - # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling. - gamma = candidate_logits.shape[1] - p_n_plus_1 = p[:, n_matches, :] - if n_matches < gamma: - q_n_plus_1 = q[:, n_matches, :] - p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0).softmax(dim=-1) + valid_tokens = new_candidate_input_ids[:, : n_matches + 1] else: - p_prime = p_n_plus_1 - t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :] + n_matches = min(n_matches, max_matches) + + # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling. + gamma = min(candidate_logits.shape[1], max_matches) + p_n_plus_1 = p[:, n_matches, :] + if n_matches < gamma: + q_n_plus_1 = q[:, n_matches, :] + p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0) + p_prime.div_(p_prime.sum()) + else: + p_prime = p_n_plus_1 + t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :] - # The selected tokens include the matches (if any) plus the next sampled tokens - if n_matches > 0: - valid_tokens = torch.cat((candidate_input_ids[:, -n_matches:], t), dim=-1) - else: - valid_tokens = t + # The selected tokens include the matches (if any) plus the next sampled tokens + if n_matches > 0: + valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1) + else: + valid_tokens = t return valid_tokens, n_matches diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 05f0981dba3714..2c16f41ae171dc 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -88,6 +88,7 @@ TopKLogitsWarper, TopPLogitsWarper, ) + from transformers.generation.utils import _speculative_sampling class GenerationTesterMixin: @@ -2424,6 +2425,43 @@ def test_top_k_top_p_filtering_with_filter_value(self): self.assertTrue(torch.allclose(expected_output, output, atol=1e-12)) + def test_speculative_sampling(self): + # assume vocab size 10, input length 5 + 3 generated candidates + candidate_input_ids = torch.tensor([[8, 0, 3, 9, 8, 1, 4, 5]]) # input tokens + candidate_logits = torch.tensor( + [ + [ + [-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 1 + [-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 4 + [-10.0, -10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0], # generated 5 + ] + ] + ) + candidate_length = 3 + inf = float("inf") + new_logits = torch.tensor( + [ + [ + [-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # accepts 1 + [-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # accepts 4 + [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 10.0, -inf], # rejects 5, accepts 8 + [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # N/A + ] + ] + ) + last_assistant_token_is_eos = False + max_matches = 5 + validated_tokens, n_matches = _speculative_sampling( + candidate_input_ids, + candidate_logits, + candidate_length, + new_logits, + last_assistant_token_is_eos, + max_matches, + ) + self.assertTrue(n_matches.item() == 2) + self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8]) + @require_torch class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin):