11# SPDX-License-Identifier: Apache-2.0
22import math
33
4+ import pytest
45import torch
56
67from vllm .platforms import current_platform
78from vllm .v1 .sample .ops .topk_topp_sampler import apply_top_k_top_p_tpu
89
9- if current_platform .is_tpu ():
10- import torch_xla .core .xla_model as xm
11-
12- DEVICE = xm .xla_device () if current_platform .is_tpu () else torch .device ("cuda" )
10+ if not current_platform .is_tpu ():
11+ pytest .skip ("This test needs a TPU." , allow_module_level = True )
12+ import torch_xla .core .xla_model as xm
1313
1414BATCH_SIZE = 1024
1515VOCAB_SIZE = 128 * 1024
16+ TOLERANCE = 1e-4
1617
1718
18- def test_topk_and_no_op_topp ():
19- with torch .device (DEVICE ):
20- if current_platform .is_tpu ():
21- xm .set_rng_state (seed = 33 )
22- else :
23- torch .manual_seed (33 )
19+ def test_topp_result_sums_past_p ():
20+ with torch .device (xm .xla_device ()):
21+ xm .set_rng_state (seed = 33 )
2422
2523 logits = torch .rand ((BATCH_SIZE , VOCAB_SIZE ))
24+ probs = logits .softmax (dim = - 1 )
2625
27- # Random top-k values between 1 and 9.
28- k = torch .randint (1 , 10 , (BATCH_SIZE , ))
29-
30- # Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
31- k .masked_fill_ (torch .randint (0 , 2 , (BATCH_SIZE , ), dtype = bool ),
32- VOCAB_SIZE )
26+ # Random top-p values between 0 and 1.
27+ p = torch .rand ((BATCH_SIZE , ))
3328
34- # Top-k only implementation
35- result1 = apply_top_k_top_p_tpu ( logits = logits . clone ( ), k = k , p = None )
29+ # Set p=1 for ~50% of requests in the batch (top-p disabled).
30+ p . masked_fill_ ( torch . randint ( 0 , 2 , ( BATCH_SIZE , ), dtype = bool ), 1 )
3631
37- # Top-p + top-k
38- no_op_top_p = torch .tensor ([1.0 ])
39- result2 = apply_top_k_top_p_tpu (logits = logits .clone (),
40- k = k ,
41- p = no_op_top_p )
32+ no_op_k = torch .tensor ([VOCAB_SIZE ])
33+ logits_masked = apply_top_k_top_p_tpu (logits = logits .clone (),
34+ k = no_op_k ,
35+ p = p )
4236
43- assert torch .allclose (result1 , result2 )
37+ # Verify that the masked logit's probability sums to at least p.
38+ probs .masked_fill_ (logits_masked .isinf (), 0 )
39+ masked_prob_sum = probs .sum (dim = - 1 )
40+ assert torch .all (torch .ge (masked_prob_sum + TOLERANCE , p ))
4441
4542
4643def test_topp_basic ():
47- with torch .device (DEVICE ):
44+ with torch .device (xm . xla_device () ):
4845 logits = torch .tensor ([[math .log (0.2 ),
4946 math .log (0.3 ),
5047 math .log (0.5 )],
@@ -64,7 +61,7 @@ def test_topp_basic():
6461
6562
6663def test_topp_select_all ():
67- with torch .device (DEVICE ):
64+ with torch .device (xm . xla_device () ):
6865 logits = torch .tensor ([[math .log (0.2 ),
6966 math .log (0.3 ),
7067 math .log (0.5 )],
@@ -80,7 +77,7 @@ def test_topp_select_all():
8077
8178
8279def test_topp_with_ties ():
83- with torch .device (DEVICE ):
80+ with torch .device (xm . xla_device () ):
8481 # Input has multiple math.log(0.3).
8582 logits = torch .tensor (
8683 [[math .log (0.3 ),
@@ -98,7 +95,7 @@ def test_topp_with_ties():
9895
9996
10097def test_both_topk_topp ():
101- with torch .device (DEVICE ):
98+ with torch .device (xm . xla_device () ):
10299 logits = torch .tensor ([[math .log (0.2 ),
103100 math .log (0.3 ),
104101 math .log (0.5 )],
0 commit comments