Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion tests/v1/tpu/test_topk_topp_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torch

from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
apply_top_k_top_p_tpu)

if not current_platform.is_tpu():
pytest.skip("This test needs a TPU.", allow_module_level=True)
Expand All @@ -16,6 +17,29 @@
TOLERANCE = 1e-6


def test_topk_equivalence_to_native_impl():
with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33)

logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))

# Random top-k values between 1 and 10.
k = torch.randint(1, 10, (BATCH_SIZE, ))

# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool),
VOCAB_SIZE)

result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None)

result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)

xm.mark_step()

# Perform assertion on CPU.
assert torch.allclose(result_native.cpu(), result_tpu.cpu())


def test_topp_result_sums_past_p():
with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33)
Expand Down
22 changes: 16 additions & 6 deletions vllm/v1/sample/ops/topk_topp_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,25 +138,35 @@ def apply_top_k_top_p_tpu(
This algorithm avoids using torch.scatter which is extremely slow on TPU.
This is achieved by finding a "cut-off" element in the original logit, and
after thresholding the logit using this cut-off, the remaining elements
shall constitute the top-p set.
shall constitute the top-k/p set.

Note: in the case of tie (i.e. multipple cut-off elements present in the
logit), all tie elements are included in the top-p set. In other words,
logit), all tie elements are included in the top-k/p set. In other words,
this function does not break ties. Instead, these tie tokens have equal
chance of being chosen during final sampling, so we can consider the tie
being broken then.
"""
probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False)

if k is not None:
logits = apply_top_k_only(logits, k)
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
top_k_count = top_k_count.unsqueeze(dim=1)
top_k_cutoff = probs_sort.gather(-1, top_k_count)

# Make sure the no top-k rows are no-op.
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))

elements_to_discard = probs < top_k_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))

if p is not None:
probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False)
cumprob = torch.cumsum(probs_sort, dim=-1)
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
top_p_mask[:, -1] = False # at least one

top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(dim=1)
top_p_cutoff = probs_sort.gather(-1, top_p_count)
elements_to_discard = probs < top_p_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))
Expand Down