Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Fix candidate generation more than max_length - 1
  • Loading branch information
ofirzaf committed Jan 17, 2024
1 parent 2725e16 commit 190a63a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 19 deletions.
6 changes: 3 additions & 3 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,

# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
new_cur_len = input_ids.shape[-1]
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cur_len = input_ids.shape[-1]

new_cache_size = new_cur_len - 1
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
Expand All @@ -190,7 +190,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs = {
self.input_ids_key: input_ids,
"max_new_tokens": int(self.num_assistant_tokens),
"max_new_tokens": max_new_tokens,
"generation_config": self.generation_config,
"logits_processor": self.logits_processor,
}
Expand Down
34 changes: 18 additions & 16 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4813,24 +4813,26 @@ def _speculative_sampling(
# Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
if last_assistant_token_is_eos and n_matches == candidate_length:
n_matches -= 1
n_matches = min(n_matches, max_matches)

# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
gamma = candidate_logits.shape[1]
p_n_plus_1 = p[:, n_matches, :]
if n_matches < gamma:
q_n_plus_1 = q[:, n_matches, :]
p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0)
p_prime.div_(p_prime.sum())
valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
else:
p_prime = p_n_plus_1
t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]
n_matches = min(n_matches, max_matches)

# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
gamma = min(candidate_logits.shape[1], max_matches)
p_n_plus_1 = p[:, n_matches, :]
if n_matches < gamma:
q_n_plus_1 = q[:, n_matches, :]
p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0)
p_prime.div_(p_prime.sum())
else:
p_prime = p_n_plus_1
t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]

# The selected tokens include the matches (if any) plus the next sampled tokens
if n_matches > 0:
valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1)
else:
valid_tokens = t
# The selected tokens include the matches (if any) plus the next sampled tokens
if n_matches > 0:
valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1)
else:
valid_tokens = t

return valid_tokens, n_matches

Expand Down

0 comments on commit 190a63a

Please sign in to comment.