Skip to content

Commit

Permalink
Improve readability of _speculative_sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
ofirzaf committed Jan 15, 2024
1 parent d5224cf commit 2725e16
Showing 1 changed file with 4 additions and 11 deletions.
15 changes: 4 additions & 11 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4794,12 +4794,13 @@ def _speculative_sampling(
NOTE: Unless otherwise stated, the variable names match those in the paper.
"""
new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
# Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
# selected by the assistant, respectively.
q = candidate_logits.softmax(dim=-1)
q_i = q[:, torch.arange(candidate_length), candidate_input_ids[:, -candidate_length:]].squeeze(0, 1)
q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
p = new_logits.softmax(dim=-1)
p_i = p[:, torch.arange(candidate_length), candidate_input_ids[:, -candidate_length:]].squeeze(0, 1)
p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
probability_ratio = p_i / q_i

# When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
Expand Down Expand Up @@ -4827,15 +4828,7 @@ def _speculative_sampling(

# The selected tokens include the matches (if any) plus the next sampled tokens
if n_matches > 0:
valid_tokens = torch.cat(
(
candidate_input_ids[
:, -candidate_length : candidate_input_ids.size(1) - candidate_length + n_matches :
],
t,
),
dim=-1,
)
valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1)
else:
valid_tokens = t

Expand Down

0 comments on commit 2725e16

Please sign in to comment.