Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions vllm/v1/sample/ops/topk_topp_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,18 @@ def forward_cuda(
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""More optimized implementation for top-k and top-p sampling."""
probs = logits.softmax(dim=-1, dtype=torch.float32)
if k is None and p is None:
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
if generators:
logger.warning("FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p)
return flashinfer_sample(probs, k, p, generators)
return flashinfer_sample(logits, k, p, generators)

def forward_tpu(
self,
Expand Down Expand Up @@ -254,17 +254,17 @@ def random_sample(


def flashinfer_sample(
probs: torch.Tensor,
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Sample from the probabilities using FlashInfer.
"""Sample from the logits using FlashInfer.

Statistically, this function is equivalent to the `random_sample` function.
However, this function is faster because it avoids sorting the logits tensor
via rejection sampling.

NOTE: The outputs of this function do not necessarily match the outputs of
the `random_sample` function. It only guarantees that the outputs are
statistically equivalent.
Expand All @@ -274,18 +274,19 @@ def flashinfer_sample(
the synchronization overhead.
"""
assert not (k is None and p is None)

if k is None:
# Top-p only.
probs = logits.softmax(dim=-1, dtype=torch.float32)
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
probs, p, deterministic=True)
elif p is None:
# Top-k only.
probs = logits.softmax(dim=-1, dtype=torch.float32)
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
probs, k, deterministic=True)
else:
# Both top-k and top-p.
next_token_ids = (flashinfer.sampling.top_k_top_p_sampling_from_probs(
probs, k, p, deterministic=True))
next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
logits, k, p, deterministic=True)

return next_token_ids.view(-1)