From 8dbb065bae782dd06fd65808fd501723e7217bd2 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 12 Dec 2023 15:52:50 +0000 Subject: [PATCH 1/8] speculative decoding --- src/transformers/generation/utils.py | 85 ++++++++++++++++++++++------ tests/generation/test_utils.py | 81 +++++++++++++++++++++++++- 2 files changed, 147 insertions(+), 19 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d23f7f9245d7e8..7c5ef5f2efadf1 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4624,40 +4624,89 @@ def assisted_decoding( for i in range(candidate_length + 1): new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) - # 3. Obtain the next tokens from the original model logits. - if do_sample: - probs = new_logits.softmax(dim=-1) - selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] + # 3. Select the accepted tokens. There are two possible cases: + # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) + # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). + # NOTE:Unless otherwise stated, the variable names match those in the paper. + if do_sample and candidate_logits is not None: + # Gets the probabilities from the logits. q_i and p_i denote the model and assistant (respectively) + # probabilities of the tokens selected by the assistant. + q = candidate_logits.softmax(dim=-1) + q_i = q[ + :, + torch.range(0, candidate_length - 1, dtype=torch.int), + candidate_input_ids[:, -candidate_length:], + ].squeeze(0, 1) + p = new_logits.softmax(dim=-1) + p_i = p[ + :, + torch.range(0, candidate_length - 1, dtype=torch.int), + candidate_input_ids[:, -candidate_length:], + ].squeeze(0, 1) + probability_ratio = p_i / q_i + + # When probability_ratio > 1 (i.e. q_i(x) < p_i(x)), keep the token. Otherwise reject with + # p = 1 - probability_ratio (= keep with p=probability_ratio). Keep all the tokens until the first + # rejection + r_i = torch.rand_like(probability_ratio) + is_rejected = r_i > probability_ratio # equivalent: is_accepted = r_i <= probability_ratio + n_matches = (is_rejected.cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1 + + # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct + # behavior) + if last_assistant_token_is_eos and n_matches == candidate_length: + n_matches -= 1 + n_matches = min(n_matches, max_len - cur_len - 1) + + # Next token selection: if there is a rejection, adjust the distribution from the main model before + # sampling. + gamma = candidate_logits.shape[1] + p_n_plus_1 = p[:, n_matches, :] + if n_matches < gamma: + q_n_plus_1 = q[:, n_matches, :] + p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0).softmax(dim=-1) + else: + p_prime = p_n_plus_1 + t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :] + + # The selected tokens include the matches plus the next sampled token + selected_tokens = torch.cat((candidate_input_ids[:, :n_matches], t), dim=-1) + + # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the + # original model logits with the candidate tokens. We can keep the candidate tokens until the first + # mismatch, or until the max length is reached. else: - selected_tokens = new_logits.argmax(dim=-1) + if do_sample: + probs = new_logits.softmax(dim=-1) + selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] + else: + selected_tokens = new_logits.argmax(dim=-1) + + candidate_new_tokens = candidate_input_ids[:, -candidate_length:] + n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() - # 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep - # the assistant forecasted tokens until the first mismatch, or until the max length is reached. - candidate_new_tokens = candidate_input_ids[:, -candidate_length:] - n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() + # Ensure we don't generate beyond max_len or an EOS token + if last_assistant_token_is_eos and n_matches == candidate_length: + n_matches -= 1 + n_matches = min(n_matches, max_len - cur_len - 1) - # 5. Update variables according to the number of matching assistant tokens. Remember: the token generated + # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated # by the model after the last candidate match is also valid, as it is generated from a correct sequence. # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there # is no match. - # 5.1. Ensure we don't generate beyond max_len or an EOS token - if last_assistant_token_is_eos and n_matches == candidate_length: - n_matches -= 1 - n_matches = min(n_matches, max_len - cur_len - 1) - - # 5.2. Get the valid continuation, after the matching tokens + # 4.1. Get the valid continuation, after the matching tokens valid_tokens = selected_tokens[:, : n_matches + 1] input_ids = torch.cat((input_ids, valid_tokens), dim=-1) if streamer is not None: streamer.put(valid_tokens.cpu()) new_cur_len = input_ids.shape[-1] - # 5.3. Discard past key values relative to unused assistant tokens + # 4.2. Discard past key values relative to unused assistant tokens new_cache_size = new_cur_len - 1 outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) - # 6. Update the candidate generation strategy if needed + # 5. Update the candidate generation strategy if needed candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) if synced_gpus and this_peer_finished: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 973f54f0039701..14353cafb55386 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3170,9 +3170,88 @@ def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist()) # Assistant model - assistant = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( + assistant = AutoModelForSeq2SeqLM.from_pretrained( + "hf-internal-testing/tiny-random-BartForConditionalGeneration" + ).to(torch_device) + + # If assisted generation passes model_kwargs correctly, should be same as previous + outputs_assisted = model.generate( + input_ids, + foo=True, + assistant_model=assistant, + ) + self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) + + # Check that passing encoder_outputs directly also works as expected + encoder_outputs = assistant.get_encoder()(input_ids) + + outputs_assisted = model.generate( + foo=True, + assistant_model=assistant, + encoder_outputs=encoder_outputs, + assistant_encoder_outputs=encoder_outputs, + ) + self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) + + def test_assisted_decoding_encoder_decoder_shared_encoder(self): + """ + Tests that the following scenario is compatible with assisted generation: + 1. encoder-decoder main model + 2. decoder-only assistant model + 3. both have a custom input + (e.g. DistilWhisper) + """ + + # PT-only test: TF doesn't support assisted decoding yet. + # Bart subclass with a kwarg called foo that distorts the output + class FakeBartSeq2Seq(BartForConditionalGeneration): + def forward(self, input_ids, foo=False, **kwargs): + outs = super().forward(input_ids, **kwargs) + if foo: + outs["logits"][:, :, :] = 0.0 + return outs + + def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): + kwargs["encoder_outputs"] = encoder_outputs + inputs = super().prepare_inputs_for_generation(*args, **kwargs) + inputs["foo"] = foo + return inputs + + class FakeBartCausalLM(BartForCausalLM): + def forward(self, input_ids, attention_mask, past_key_values, foo=False, **kwargs): + outs = super().forward(input_ids, attention_mask, past_key_values=past_key_values, **kwargs) + if foo: + outs["logits"][:, :, :] = 0.0 + return outs + + def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): + kwargs["encoder_outputs"] = encoder_outputs + inputs = super().prepare_inputs_for_generation(*args, **kwargs) + inputs["foo"] = foo + return inputs + + model = FakeBartSeq2Seq.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( torch_device ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration") + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + # Traditional way of generating text + outputs_normal = model.generate(input_ids) + self.assertEqual(outputs_normal.shape, (1, 20)) + + # Should be different with foo + outputs_foo = model.generate(input_ids, foo=True) + with self.assertRaises(AssertionError): + self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist()) + + # Assistant model + assistant = FakeBartCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-BartForConditionalGeneration" + ).to(torch_device) # If assisted generation passes model_kwargs correctly, should be same as previous outputs_assisted = model.generate( From a726936b068739bcf097f9a1e1b88b8ce684cba8 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 12 Dec 2023 16:42:02 +0000 Subject: [PATCH 2/8] fix test --- tests/generation/test_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 14353cafb55386..201ca716835970 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3170,9 +3170,9 @@ def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist()) # Assistant model - assistant = AutoModelForSeq2SeqLM.from_pretrained( - "hf-internal-testing/tiny-random-BartForConditionalGeneration" - ).to(torch_device) + assistant = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( + torch_device + ) # If assisted generation passes model_kwargs correctly, should be same as previous outputs_assisted = model.generate( From 7e4deaba22877babd2f9349b0030567dddaed37a Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 12 Dec 2023 17:09:47 +0000 Subject: [PATCH 3/8] space --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7c5ef5f2efadf1..5857d6f7518b37 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4627,7 +4627,7 @@ def assisted_decoding( # 3. Select the accepted tokens. There are two possible cases: # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). - # NOTE:Unless otherwise stated, the variable names match those in the paper. + # NOTE: Unless otherwise stated, the variable names match those in the paper. if do_sample and candidate_logits is not None: # Gets the probabilities from the logits. q_i and p_i denote the model and assistant (respectively) # probabilities of the tokens selected by the assistant. From e234e1ef80c211757c6aa272fd46de1355e2dc75 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 12 Dec 2023 17:16:32 +0000 Subject: [PATCH 4/8] better comments --- src/transformers/generation/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5857d6f7518b37..cbbf347bfbe1a3 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4629,8 +4629,8 @@ def assisted_decoding( # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). # NOTE: Unless otherwise stated, the variable names match those in the paper. if do_sample and candidate_logits is not None: - # Gets the probabilities from the logits. q_i and p_i denote the model and assistant (respectively) - # probabilities of the tokens selected by the assistant. + # Gets the probabilities from the logits. q_i and p_i denote the model and assistant probabilities of + # the tokens selected by the assistant, respectivelly. q = candidate_logits.softmax(dim=-1) q_i = q[ :, @@ -4646,7 +4646,7 @@ def assisted_decoding( probability_ratio = p_i / q_i # When probability_ratio > 1 (i.e. q_i(x) < p_i(x)), keep the token. Otherwise reject with - # p = 1 - probability_ratio (= keep with p=probability_ratio). Keep all the tokens until the first + # p = 1 - probability_ratio (= keep with p = probability_ratio). Keep all the tokens until the first # rejection r_i = torch.rand_like(probability_ratio) is_rejected = r_i > probability_ratio # equivalent: is_accepted = r_i <= probability_ratio From b4dab21d9926897459d2b2dc4b21c3be8e0e23c2 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 14 Dec 2023 14:04:38 +0000 Subject: [PATCH 5/8] remove redundant test --- tests/generation/test_utils.py | 78 ---------------------------------- 1 file changed, 78 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 201ca716835970..b8d46ff1a3bbf8 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3271,81 +3271,3 @@ def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, assistant_encoder_outputs=encoder_outputs, ) self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) - - def test_assisted_decoding_encoder_decoder_shared_encoder(self): - """ - Tests that the following scenario is compatible with assisted generation: - 1. encoder-decoder main model - 2. decoder-only assistant model - 3. both have a custom input - (e.g. DistilWhisper) - """ - - # PT-only test: TF doesn't support assisted decoding yet. - # Bart subclass with a kwarg called foo that distorts the output - class FakeBartSeq2Seq(BartForConditionalGeneration): - def forward(self, input_ids, foo=False, **kwargs): - outs = super().forward(input_ids, **kwargs) - if foo: - outs["logits"][:, :, :] = 0.0 - return outs - - def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): - kwargs["encoder_outputs"] = encoder_outputs - inputs = super().prepare_inputs_for_generation(*args, **kwargs) - inputs["foo"] = foo - return inputs - - class FakeBartCausalLM(BartForCausalLM): - def forward(self, input_ids, attention_mask, past_key_values, foo=False, **kwargs): - outs = super().forward(input_ids, attention_mask, past_key_values=past_key_values, **kwargs) - if foo: - outs["logits"][:, :, :] = 0.0 - return outs - - def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): - kwargs["encoder_outputs"] = encoder_outputs - inputs = super().prepare_inputs_for_generation(*args, **kwargs) - inputs["foo"] = foo - return inputs - - model = FakeBartSeq2Seq.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( - torch_device - ) - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration") - - text = "Hello world" - tokenized_inputs = tokenizer([text], return_tensors="pt") - input_ids = tokenized_inputs.input_ids.to(torch_device) - - # Traditional way of generating text - outputs_normal = model.generate(input_ids) - self.assertEqual(outputs_normal.shape, (1, 20)) - - # Should be different with foo - outputs_foo = model.generate(input_ids, foo=True) - with self.assertRaises(AssertionError): - self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist()) - - # Assistant model - assistant = FakeBartCausalLM.from_pretrained( - "hf-internal-testing/tiny-random-BartForConditionalGeneration" - ).to(torch_device) - - # If assisted generation passes model_kwargs correctly, should be same as previous - outputs_assisted = model.generate( - input_ids, - foo=True, - assistant_model=assistant, - ) - self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) - - # Check that passing encoder_outputs directly also works as expected - encoder_outputs = model.get_encoder()(input_ids) - - outputs_assisted = model.generate( - foo=True, - assistant_model=assistant, - encoder_outputs=encoder_outputs, - ) - self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) From f2f99f34ea05975f1d9215826c388b4c7fbe5d49 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 14 Dec 2023 14:05:48 +0000 Subject: [PATCH 6/8] test nit --- tests/generation/test_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index b8d46ff1a3bbf8..973f54f0039701 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3262,12 +3262,11 @@ def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) # Check that passing encoder_outputs directly also works as expected - encoder_outputs = assistant.get_encoder()(input_ids) + encoder_outputs = model.get_encoder()(input_ids) outputs_assisted = model.generate( foo=True, assistant_model=assistant, encoder_outputs=encoder_outputs, - assistant_encoder_outputs=encoder_outputs, ) self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) From 64c59a51a2a4366ccb1c8624dcc86ff5aae7beef Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 18 Dec 2023 10:46:17 +0000 Subject: [PATCH 7/8] Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index cbbf347bfbe1a3..e896229badd98f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4630,7 +4630,7 @@ def assisted_decoding( # NOTE: Unless otherwise stated, the variable names match those in the paper. if do_sample and candidate_logits is not None: # Gets the probabilities from the logits. q_i and p_i denote the model and assistant probabilities of - # the tokens selected by the assistant, respectivelly. + # the tokens selected by the assistant, respectively. q = candidate_logits.softmax(dim=-1) q_i = q[ :, From c7f1d12c35573705d58b27f9408387130f52d98c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 18 Dec 2023 11:53:44 +0000 Subject: [PATCH 8/8] PR comments --- src/transformers/generation/utils.py | 111 ++++++++++++++++----------- 1 file changed, 67 insertions(+), 44 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e896229badd98f..b3bc4cd8d875cc 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4627,50 +4627,18 @@ def assisted_decoding( # 3. Select the accepted tokens. There are two possible cases: # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). - # NOTE: Unless otherwise stated, the variable names match those in the paper. + max_matches = max_len - cur_len - 1 if do_sample and candidate_logits is not None: - # Gets the probabilities from the logits. q_i and p_i denote the model and assistant probabilities of - # the tokens selected by the assistant, respectively. - q = candidate_logits.softmax(dim=-1) - q_i = q[ - :, - torch.range(0, candidate_length - 1, dtype=torch.int), - candidate_input_ids[:, -candidate_length:], - ].squeeze(0, 1) - p = new_logits.softmax(dim=-1) - p_i = p[ - :, - torch.range(0, candidate_length - 1, dtype=torch.int), - candidate_input_ids[:, -candidate_length:], - ].squeeze(0, 1) - probability_ratio = p_i / q_i - - # When probability_ratio > 1 (i.e. q_i(x) < p_i(x)), keep the token. Otherwise reject with - # p = 1 - probability_ratio (= keep with p = probability_ratio). Keep all the tokens until the first - # rejection - r_i = torch.rand_like(probability_ratio) - is_rejected = r_i > probability_ratio # equivalent: is_accepted = r_i <= probability_ratio - n_matches = (is_rejected.cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1 - - # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct - # behavior) - if last_assistant_token_is_eos and n_matches == candidate_length: - n_matches -= 1 - n_matches = min(n_matches, max_len - cur_len - 1) - - # Next token selection: if there is a rejection, adjust the distribution from the main model before - # sampling. - gamma = candidate_logits.shape[1] - p_n_plus_1 = p[:, n_matches, :] - if n_matches < gamma: - q_n_plus_1 = q[:, n_matches, :] - p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0).softmax(dim=-1) - else: - p_prime = p_n_plus_1 - t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :] - - # The selected tokens include the matches plus the next sampled token - selected_tokens = torch.cat((candidate_input_ids[:, :n_matches], t), dim=-1) + next_sampled_tokens, n_matches = _speculative_sampling( + candidate_input_ids, + candidate_logits, + candidate_length, + new_logits, + last_assistant_token_is_eos, + max_matches, + ) + # The selected tokens include the matches plus the next sampled tokens + selected_tokens = torch.cat((candidate_input_ids[:, :n_matches], next_sampled_tokens), dim=-1) # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the # original model logits with the candidate tokens. We can keep the candidate tokens until the first @@ -4688,7 +4656,7 @@ def assisted_decoding( # Ensure we don't generate beyond max_len or an EOS token if last_assistant_token_is_eos and n_matches == candidate_length: n_matches -= 1 - n_matches = min(n_matches, max_len - cur_len - 1) + n_matches = min(n_matches, max_matches) # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated # by the model after the last candidate match is also valid, as it is generated from a correct sequence. @@ -4804,6 +4772,61 @@ def assisted_decoding( return input_ids +def _speculative_sampling( + candidate_input_ids, + candidate_logits, + candidate_length, + new_logits, + last_assistant_token_is_eos, + max_matches, +): + """ + Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns + the next selected token, as well as the number of candidate matches. + + NOTE: Unless otherwise stated, the variable names match those in the paper. + """ + # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens + # selected by the assistant, respectively. + q = candidate_logits.softmax(dim=-1) + q_i = q[ + :, + torch.range(0, candidate_length - 1, dtype=torch.int), + candidate_input_ids[:, -candidate_length:], + ].squeeze(0, 1) + p = new_logits.softmax(dim=-1) + p_i = p[ + :, + torch.range(0, candidate_length - 1, dtype=torch.int), + candidate_input_ids[:, -candidate_length:], + ].squeeze(0, 1) + probability_ratio = p_i / q_i + + # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller + # than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio + # (= keep with p = probability_ratio). Keep all the tokens until the first rejection + r_i = torch.rand_like(probability_ratio) + is_accepted = r_i <= probability_ratio + n_matches = (~is_accepted.cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1 + + # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior) + if last_assistant_token_is_eos and n_matches == candidate_length: + n_matches -= 1 + n_matches = min(n_matches, max_matches) + + # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling. + gamma = candidate_logits.shape[1] + p_n_plus_1 = p[:, n_matches, :] + if n_matches < gamma: + q_n_plus_1 = q[:, n_matches, :] + p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0).softmax(dim=-1) + else: + p_prime = p_n_plus_1 + t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :] + + return t, n_matches + + def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False): """ Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple