diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 9505bd7ce43d..3efafa8f0b1f 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -264,7 +264,9 @@ def compute_probs_and_sample_next_token( # TODO(woosuk): Consider seeds. q = torch.empty_like(probs) q.exponential_() - next_token_ids = probs.div_(q).argmax(dim=-1).view(-1) + # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs + # will be used later for rejection sampling. + next_token_ids = probs.div(q).argmax(dim=-1).view(-1) if not sampling_metadata.all_random: greedy_token_ids = probs.argmax(dim=-1) next_token_ids = torch.where(