diff --git a/tests/v1/tpu/test_topk_topp_sampler.py b/tests/v1/tpu/test_topk_topp_sampler.py index dce0303e68d5..9eea315918b6 100644 --- a/tests/v1/tpu/test_topk_topp_sampler.py +++ b/tests/v1/tpu/test_topk_topp_sampler.py @@ -5,7 +5,8 @@ import torch from vllm.platforms import current_platform -from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu +from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, + apply_top_k_top_p_tpu) if not current_platform.is_tpu(): pytest.skip("This test needs a TPU.", allow_module_level=True) @@ -16,6 +17,29 @@ TOLERANCE = 1e-6 +def test_topk_equivalence_to_native_impl(): + with torch.device(xm.xla_device()): + xm.set_rng_state(seed=33) + + logits = torch.rand((BATCH_SIZE, VOCAB_SIZE)) + + # Random top-k values between 1 and 10. + k = torch.randint(1, 10, (BATCH_SIZE, )) + + # Set k=vocab_size for ~50% of requests in the batch (top-k disabled). + k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool), + VOCAB_SIZE) + + result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None) + + result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None) + + xm.mark_step() + + # Perform assertion on CPU. + assert torch.allclose(result_native.cpu(), result_tpu.cpu()) + + def test_topp_result_sums_past_p(): with torch.device(xm.xla_device()): xm.set_rng_state(seed=33) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index f69623edd632..90123a427f0f 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -138,25 +138,35 @@ def apply_top_k_top_p_tpu( This algorithm avoids using torch.scatter which is extremely slow on TPU. This is achieved by finding a "cut-off" element in the original logit, and after thresholding the logit using this cut-off, the remaining elements - shall constitute the top-p set. + shall constitute the top-k/p set. Note: in the case of tie (i.e. multipple cut-off elements present in the - logit), all tie elements are included in the top-p set. In other words, + logit), all tie elements are included in the top-k/p set. In other words, this function does not break ties. Instead, these tie tokens have equal chance of being chosen during final sampling, so we can consider the tie being broken then. """ + probs = logits.softmax(dim=-1) + probs_sort, _ = probs.sort(dim=-1, descending=False) + if k is not None: - logits = apply_top_k_only(logits, k) + top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, ) + top_k_count = top_k_count.unsqueeze(dim=1) + top_k_cutoff = probs_sort.gather(-1, top_k_count) + + # Make sure the no top-k rows are no-op. + no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) + top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) + + elements_to_discard = probs < top_k_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) if p is not None: - probs = logits.softmax(dim=-1) - probs_sort, _ = probs.sort(dim=-1, descending=False) cumprob = torch.cumsum(probs_sort, dim=-1) top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) top_p_mask[:, -1] = False # at least one - top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) + top_p_count = top_p_mask.sum(dim=-1).unsqueeze(dim=1) top_p_cutoff = probs_sort.gather(-1, top_p_count) elements_to_discard = probs < top_p_cutoff logits.masked_fill_(elements_to_discard, -float("inf"))