Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tests/v1/tpu/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import random

import pytest

from vllm import LLM, envs
Expand Down Expand Up @@ -39,3 +41,19 @@ def test_sampler_different(model_name: str):
# Unsupported `seed` param.
sampling_params = SamplingParams(temperature=0.3, seed=42)
output2 = llm.generate(prompts, sampling_params)

# Batch-case with TopK
for B in [4, 16]:
p = prompts * B
sampling_params = [
SamplingParams(
temperature=0.1,
min_p=0.8,
max_tokens=64,
# Vary number of ks
top_k=random.randint(4, 12)) for _ in range(B)
]
# Make sure first two reqs have the same K
sampling_params[0] = sampling_params[1]
output = llm.generate(p, sampling_params)
assert output[0].outputs[0].text == output[1].outputs[0].text
22 changes: 21 additions & 1 deletion tests/v1/tpu/test_topk_topp_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torch

from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
apply_top_k_top_p_tpu)

if not current_platform.is_tpu():
pytest.skip("This test needs a TPU.", allow_module_level=True)
Expand All @@ -16,6 +17,25 @@
TOLERANCE = 1e-6


def test_topk_equivalence_to_native_impl():
with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33)

logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))

# Random top-k values between 1 and 10.
k = torch.randint(1, 10, (BATCH_SIZE, ))

# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool),
VOCAB_SIZE)

result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None)

result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
assert torch.allclose(result_native, result_tpu)


def test_topp_result_sums_past_p():
with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33)
Expand Down
6 changes: 0 additions & 6 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@
VLLM_DP_MASTER_PORT: int = 0
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_USE_DEEP_GEMM: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0
Expand Down Expand Up @@ -684,11 +683,6 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"VLLM_V0_USE_OUTLINES_CACHE":
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",

# If set, disables TPU-specific optimization for top-k & top-p sampling
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,

# Gap between padding buckets for the forward pass. So we have
# 8, we will run forward pass with [16, 24, 32, ...].
"VLLM_TPU_BUCKET_PADDING_GAP":
Expand Down
27 changes: 15 additions & 12 deletions vllm/v1/sample/ops/topk_topp_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,7 @@ def __init__(self):
"best performance, please install FlashInfer.")
self.forward = self.forward_native
elif current_platform.is_tpu():
if envs.VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION:
logger.warning(
"TPU-specific optimization for top-k & top-p sampling are "
"disabled, falling back to PyTorch-native implementation "
"which could be very slow.")
self.forward = self.forward_native
else:
self.forward = self.forward_tpu
self.forward = self.forward_tpu
else:
self.forward = self.forward_native

Expand Down Expand Up @@ -146,12 +139,22 @@ def apply_top_k_top_p_tpu(
chance of being chosen during final sampling, so we can consider the tie
being broken then.
"""
probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False)

if k is not None:
logits = apply_top_k_only(logits, k)
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
top_k_count = top_k_count.unsqueeze(dim=1)
top_k_cutoff = probs_sort.gather(-1, top_k_count)

# Make sure the no top-k rows are no-op.
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))

elements_to_discard = probs < top_k_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))

if p is not None:
probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False)
cumprob = torch.cumsum(probs_sort, dim=-1)
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
top_p_mask[:, -1] = False # at least one
Expand Down Expand Up @@ -224,7 +227,7 @@ def apply_top_k_only(
max_top_k = k.max()
# topk.values tensor has shape [batch_size, max_top_k].
# Convert top k to 0-based index in range [0, max_top_k).
k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1)
k_index = k.sub_(1).unsqueeze(1)
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
# Handle non-topk rows.
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
Expand Down
13 changes: 8 additions & 5 deletions vllm/v1/sample/tpu/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
temperature=-1.0,
min_p=0.0,
# strictly disabled for now
# top_k=-1,
top_k=0,
# top_p=0.0,
# frequency_penalties=0.0,
# presence_penalties=0.0,
Expand Down Expand Up @@ -99,11 +99,13 @@ def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:

fill_slice(input_batch.temperature_cpu_tensor,
DEFAULT_SAMPLING_PARAMS["temperature"])
# TODO Temporarily disabled until sampling options are enabled
# fill_slice(input_batch.top_p_cpu_tensor)
# fill_slice(input_batch.top_k_cpu_tensor)
fill_slice(input_batch.min_p_cpu_tensor,
DEFAULT_SAMPLING_PARAMS["min_p"])
fill_slice(input_batch.top_k_cpu_tensor,
DEFAULT_SAMPLING_PARAMS["top_k"])
# TODO Temporarily disabled until sampling options are enabled
# fill_slice(input_batch.top_p_cpu_tensor,
# DEFAULT_SAMPLING_PARAMS["top_p"])

# Slice persistent device tensors to a fixed pre-compiled padded shape.
return cls(
Expand All @@ -112,6 +114,7 @@ def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
all_greedy=input_batch.all_greedy,
# TODO enable more and avoid returning None values
top_p=None, # input_batch.top_p[:padded_num_reqs],
top_k=None, # input_batch.top_k[:padded_num_reqs],
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
xla_device),
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
xla_device))
9 changes: 7 additions & 2 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,14 +883,19 @@ def _precompile_select_hidden_states(self) -> None:
device=self.device)
torch._dynamo.mark_dynamic(indices, 0)
self.select_hidden_states(dummy_hidden, indices)
logger.info(" -- num_tokens: %d", num_tokens)
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
num_reqs)
# Requests can't be more than tokens. But do compile for the
# next bigger value in case num_tokens uses bucketed padding.
if num_reqs >= min(num_tokens, self.max_num_reqs):
break
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("select_hidden_states")

def _precompile_sample_from_hidden(self) -> None:
logger.info("Compiling sampling with different input shapes.")
logger.info("Compiling sampling with different num_reqs.")
start = time.perf_counter()
hsize = self.model_config.get_hidden_size()
for num_reqs in self.num_reqs_paddings:
Expand Down