-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: speculative decoding #27979
Changes from all commits
8dbb065
a726936
7e4deab
e234e1e
b4dab21
f2f99f3
64c59a5
c7f1d12
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -4624,40 +4624,57 @@ def assisted_decoding( | |||||||||||||||||||||
for i in range(candidate_length + 1): | ||||||||||||||||||||||
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) | ||||||||||||||||||||||
|
||||||||||||||||||||||
# 3. Obtain the next tokens from the original model logits. | ||||||||||||||||||||||
if do_sample: | ||||||||||||||||||||||
probs = new_logits.softmax(dim=-1) | ||||||||||||||||||||||
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] | ||||||||||||||||||||||
# 3. Select the accepted tokens. There are two possible cases: | ||||||||||||||||||||||
# Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) | ||||||||||||||||||||||
# 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). | ||||||||||||||||||||||
max_matches = max_len - cur_len - 1 | ||||||||||||||||||||||
if do_sample and candidate_logits is not None: | ||||||||||||||||||||||
next_sampled_tokens, n_matches = _speculative_sampling( | ||||||||||||||||||||||
candidate_input_ids, | ||||||||||||||||||||||
candidate_logits, | ||||||||||||||||||||||
candidate_length, | ||||||||||||||||||||||
new_logits, | ||||||||||||||||||||||
last_assistant_token_is_eos, | ||||||||||||||||||||||
max_matches, | ||||||||||||||||||||||
) | ||||||||||||||||||||||
# The selected tokens include the matches plus the next sampled tokens | ||||||||||||||||||||||
selected_tokens = torch.cat((candidate_input_ids[:, :n_matches], next_sampled_tokens), dim=-1) | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the | ||||||||||||||||||||||
# original model logits with the candidate tokens. We can keep the candidate tokens until the first | ||||||||||||||||||||||
# mismatch, or until the max length is reached. | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
selected_tokens = new_logits.argmax(dim=-1) | ||||||||||||||||||||||
if do_sample: | ||||||||||||||||||||||
probs = new_logits.softmax(dim=-1) | ||||||||||||||||||||||
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
selected_tokens = new_logits.argmax(dim=-1) | ||||||||||||||||||||||
Comment on lines
+4647
to
+4651
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
It's probably time to soon factor this out into something like: selected_tokens = Categorical(new_logits / temperature).sample() everywhere in generate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes! Then equivalent sampling/non-sampling methods (e.g. greedy decoding/samplinh) could be merged into a single function, facilitating maintenance. I'm going to leave it to a follow-up PR, though, to keep this PR exclusively about speculative decoding. |
||||||||||||||||||||||
|
||||||||||||||||||||||
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep | ||||||||||||||||||||||
# the assistant forecasted tokens until the first mismatch, or until the max length is reached. | ||||||||||||||||||||||
candidate_new_tokens = candidate_input_ids[:, -candidate_length:] | ||||||||||||||||||||||
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() | ||||||||||||||||||||||
candidate_new_tokens = candidate_input_ids[:, -candidate_length:] | ||||||||||||||||||||||
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() | ||||||||||||||||||||||
gante marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||
|
||||||||||||||||||||||
# 5. Update variables according to the number of matching assistant tokens. Remember: the token generated | ||||||||||||||||||||||
# Ensure we don't generate beyond max_len or an EOS token | ||||||||||||||||||||||
if last_assistant_token_is_eos and n_matches == candidate_length: | ||||||||||||||||||||||
n_matches -= 1 | ||||||||||||||||||||||
n_matches = min(n_matches, max_matches) | ||||||||||||||||||||||
|
||||||||||||||||||||||
# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated | ||||||||||||||||||||||
# by the model after the last candidate match is also valid, as it is generated from a correct sequence. | ||||||||||||||||||||||
# Because of this last token, assisted generation search reduces to a normal greedy search/sample if there | ||||||||||||||||||||||
# is no match. | ||||||||||||||||||||||
|
||||||||||||||||||||||
# 5.1. Ensure we don't generate beyond max_len or an EOS token | ||||||||||||||||||||||
if last_assistant_token_is_eos and n_matches == candidate_length: | ||||||||||||||||||||||
n_matches -= 1 | ||||||||||||||||||||||
n_matches = min(n_matches, max_len - cur_len - 1) | ||||||||||||||||||||||
|
||||||||||||||||||||||
# 5.2. Get the valid continuation, after the matching tokens | ||||||||||||||||||||||
# 4.1. Get the valid continuation, after the matching tokens | ||||||||||||||||||||||
valid_tokens = selected_tokens[:, : n_matches + 1] | ||||||||||||||||||||||
input_ids = torch.cat((input_ids, valid_tokens), dim=-1) | ||||||||||||||||||||||
if streamer is not None: | ||||||||||||||||||||||
streamer.put(valid_tokens.cpu()) | ||||||||||||||||||||||
new_cur_len = input_ids.shape[-1] | ||||||||||||||||||||||
|
||||||||||||||||||||||
# 5.3. Discard past key values relative to unused assistant tokens | ||||||||||||||||||||||
# 4.2. Discard past key values relative to unused assistant tokens | ||||||||||||||||||||||
new_cache_size = new_cur_len - 1 | ||||||||||||||||||||||
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) | ||||||||||||||||||||||
|
||||||||||||||||||||||
# 6. Update the candidate generation strategy if needed | ||||||||||||||||||||||
# 5. Update the candidate generation strategy if needed | ||||||||||||||||||||||
candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) | ||||||||||||||||||||||
|
||||||||||||||||||||||
if synced_gpus and this_peer_finished: | ||||||||||||||||||||||
|
@@ -4755,6 +4772,61 @@ def assisted_decoding( | |||||||||||||||||||||
return input_ids | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
def _speculative_sampling( | ||||||||||||||||||||||
candidate_input_ids, | ||||||||||||||||||||||
candidate_logits, | ||||||||||||||||||||||
candidate_length, | ||||||||||||||||||||||
new_logits, | ||||||||||||||||||||||
last_assistant_token_is_eos, | ||||||||||||||||||||||
max_matches, | ||||||||||||||||||||||
): | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns | ||||||||||||||||||||||
the next selected token, as well as the number of candidate matches. | ||||||||||||||||||||||
|
||||||||||||||||||||||
NOTE: Unless otherwise stated, the variable names match those in the paper. | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
# Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens | ||||||||||||||||||||||
# selected by the assistant, respectively. | ||||||||||||||||||||||
q = candidate_logits.softmax(dim=-1) | ||||||||||||||||||||||
q_i = q[ | ||||||||||||||||||||||
:, | ||||||||||||||||||||||
torch.range(0, candidate_length - 1, dtype=torch.int), | ||||||||||||||||||||||
candidate_input_ids[:, -candidate_length:], | ||||||||||||||||||||||
].squeeze(0, 1) | ||||||||||||||||||||||
p = new_logits.softmax(dim=-1) | ||||||||||||||||||||||
p_i = p[ | ||||||||||||||||||||||
:, | ||||||||||||||||||||||
torch.range(0, candidate_length - 1, dtype=torch.int), | ||||||||||||||||||||||
candidate_input_ids[:, -candidate_length:], | ||||||||||||||||||||||
].squeeze(0, 1) | ||||||||||||||||||||||
probability_ratio = p_i / q_i | ||||||||||||||||||||||
|
||||||||||||||||||||||
# When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller | ||||||||||||||||||||||
# than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio | ||||||||||||||||||||||
# (= keep with p = probability_ratio). Keep all the tokens until the first rejection | ||||||||||||||||||||||
r_i = torch.rand_like(probability_ratio) | ||||||||||||||||||||||
is_accepted = r_i <= probability_ratio | ||||||||||||||||||||||
n_matches = (~is_accepted.cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1 | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior) | ||||||||||||||||||||||
if last_assistant_token_is_eos and n_matches == candidate_length: | ||||||||||||||||||||||
n_matches -= 1 | ||||||||||||||||||||||
n_matches = min(n_matches, max_matches) | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling. | ||||||||||||||||||||||
gamma = candidate_logits.shape[1] | ||||||||||||||||||||||
p_n_plus_1 = p[:, n_matches, :] | ||||||||||||||||||||||
if n_matches < gamma: | ||||||||||||||||||||||
q_n_plus_1 = q[:, n_matches, :] | ||||||||||||||||||||||
p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0).softmax(dim=-1) | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
p_prime = p_n_plus_1 | ||||||||||||||||||||||
t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :] | ||||||||||||||||||||||
|
||||||||||||||||||||||
return t, n_matches | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False): | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple | ||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this case still relevant? Not sure it's a good idea to have two "assisted decoding" do_sample=True cases in our generate. Should we maybe just deprecate this case?