Skip to content

Commit a4bc75e

Browse files
committed
Update tests.
* Added a more comprehensive correctness test for top-p. * Included tests/v1/tpu/test_topk_topp_sampler.py in run-tpu-v1-test.sh. Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
1 parent 99de8cd commit a4bc75e

File tree

2 files changed

+28
-29
lines changed

2 files changed

+28
-29
lines changed

.buildkite/run-tpu-v1-test.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ docker run --privileged --net host --shm-size=16G -it \
3636
&& echo TEST_6 \
3737
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \
3838
&& echo TEST_7 \
39-
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \
39+
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py \
40+
&& echo TEST_8 \
41+
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py" \
4042

4143

4244
# TODO: This test fails because it uses RANDOM_SEED sampling

tests/v1/tpu/test_topk_topp_sampler.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,47 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import math
33

4+
import pytest
45
import torch
56

67
from vllm.platforms import current_platform
78
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu
89

9-
if current_platform.is_tpu():
10-
import torch_xla.core.xla_model as xm
11-
12-
DEVICE = xm.xla_device() if current_platform.is_tpu() else torch.device("cuda")
10+
if not current_platform.is_tpu():
11+
pytest.skip("This test needs a TPU.", allow_module_level=True)
12+
import torch_xla.core.xla_model as xm
1313

1414
BATCH_SIZE = 1024
1515
VOCAB_SIZE = 128 * 1024
16+
TOLERANCE = 1e-4
1617

1718

18-
def test_topk_and_no_op_topp():
19-
with torch.device(DEVICE):
20-
if current_platform.is_tpu():
21-
xm.set_rng_state(seed=33)
22-
else:
23-
torch.manual_seed(33)
19+
def test_topp_result_sums_past_p():
20+
with torch.device(xm.xla_device()):
21+
xm.set_rng_state(seed=33)
2422

2523
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
24+
probs = logits.softmax(dim=-1)
2625

27-
# Random top-k values between 1 and 9.
28-
k = torch.randint(1, 10, (BATCH_SIZE, ))
29-
30-
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
31-
k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool),
32-
VOCAB_SIZE)
26+
# Random top-p values between 0 and 1.
27+
p = torch.rand((BATCH_SIZE, ))
3328

34-
# Top-k only implementation
35-
result1 = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None)
29+
# Set p=1 for ~50% of requests in the batch (top-p disabled).
30+
p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool), 1)
3631

37-
# Top-p + top-k
38-
no_op_top_p = torch.tensor([1.0])
39-
result2 = apply_top_k_top_p_tpu(logits=logits.clone(),
40-
k=k,
41-
p=no_op_top_p)
32+
no_op_k = torch.tensor([VOCAB_SIZE])
33+
logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(),
34+
k=no_op_k,
35+
p=p)
4236

43-
assert torch.allclose(result1, result2)
37+
# Verify that the masked logit's probability sums to at least p.
38+
probs.masked_fill_(logits_masked.isinf(), 0)
39+
masked_prob_sum = probs.sum(dim=-1)
40+
assert torch.all(torch.ge(masked_prob_sum + TOLERANCE, p))
4441

4542

4643
def test_topp_basic():
47-
with torch.device(DEVICE):
44+
with torch.device(xm.xla_device()):
4845
logits = torch.tensor([[math.log(0.2),
4946
math.log(0.3),
5047
math.log(0.5)],
@@ -64,7 +61,7 @@ def test_topp_basic():
6461

6562

6663
def test_topp_select_all():
67-
with torch.device(DEVICE):
64+
with torch.device(xm.xla_device()):
6865
logits = torch.tensor([[math.log(0.2),
6966
math.log(0.3),
7067
math.log(0.5)],
@@ -80,7 +77,7 @@ def test_topp_select_all():
8077

8178

8279
def test_topp_with_ties():
83-
with torch.device(DEVICE):
80+
with torch.device(xm.xla_device()):
8481
# Input has multiple math.log(0.3).
8582
logits = torch.tensor(
8683
[[math.log(0.3),
@@ -98,7 +95,7 @@ def test_topp_with_ties():
9895

9996

10097
def test_both_topk_topp():
101-
with torch.device(DEVICE):
98+
with torch.device(xm.xla_device()):
10299
logits = torch.tensor([[math.log(0.2),
103100
math.log(0.3),
104101
math.log(0.5)],

0 commit comments

Comments
 (0)