1616FLASHINFER_ENABLED = current_platform .is_cuda () and is_flashinfer_available
1717
1818
19+ @pytest .fixture (autouse = True )
20+ def reset_default_device ():
21+ """
22+ Explicitly set the default device, which can affect subsequent tests.
23+ Adding this fixture helps avoid this problem.
24+ """
25+ original_device = torch .get_default_device ()
26+ yield
27+ torch .set_default_device (original_device )
28+
29+
1930def test_topk_impl_equivalance ():
2031
21- with torch .device (DEVICE ):
22- generator = Generator (device = DEVICE ).manual_seed (33 )
32+ torch .set_default_device (DEVICE )
33+ generator = Generator (device = DEVICE ).manual_seed (33 )
2334
24- logits = torch .rand ((BATCH_SIZE , VOCAB_SIZE ), generator = generator )
35+ logits = torch .rand ((BATCH_SIZE , VOCAB_SIZE ), generator = generator )
2536
26- # Random top-k values between 1 and 9.
27- k = torch .randint (1 , 10 , (BATCH_SIZE , ), generator = generator )
37+ # Random top-k values between 1 and 9.
38+ k = torch .randint (1 , 10 , (BATCH_SIZE , ), generator = generator )
2839
29- # Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
30- k .masked_fill_ (
31- torch .randint (0 ,
32- 2 , (BATCH_SIZE , ),
33- generator = generator ,
34- dtype = bool ), VOCAB_SIZE )
40+ # Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
41+ k .masked_fill_ (
42+ torch .randint (0 , 2 , (BATCH_SIZE , ), generator = generator , dtype = bool ),
43+ VOCAB_SIZE )
3544
36- # Top-k only implementation
37- result1 = apply_top_k_top_p (logits = logits .clone (), k = k , p = None )
45+ # Top-k only implementation
46+ result1 = apply_top_k_top_p (logits = logits .clone (), k = k , p = None )
3847
39- # Top-p + top-k
40- no_op_top_p = torch .tensor ([1.0 ])
41- result2 = apply_top_k_top_p (logits = logits .clone (), k = k , p = no_op_top_p )
48+ # Top-p + top-k
49+ no_op_top_p = torch .tensor ([1.0 ])
50+ result2 = apply_top_k_top_p (logits = logits .clone (), k = k , p = no_op_top_p )
4251
43- assert torch .allclose (result1 , result2 )
52+ assert torch .allclose (result1 , result2 )
4453
4554
4655def test_flashinfer_sampler ():
@@ -58,50 +67,49 @@ def test_flashinfer_sampler():
5867 pytest .skip (
5968 "FlashInfer not installed or not available on this platform." )
6069
61- with torch .device (DEVICE ):
62- generator = Generator (device = DEVICE ).manual_seed (42 )
63-
64- # Generate random logits
65- logits = torch .rand ((BATCH_SIZE , VOCAB_SIZE ), generator = generator )
66-
67- # Generate various top-k and top-p values
68- k_values = torch .randint (1 , 1000 , (BATCH_SIZE , ), generator = generator )
69- p_values = torch .rand (
70- (BATCH_SIZE , ),
71- generator = generator ) * 0.5 + 0.5 # range in [0.5, 1.0]
72-
73- # Sometimes disable top-k (k=vocab_size)
74- k_values .masked_fill_ (
75- torch .randint (0 ,
76- 2 , (BATCH_SIZE , ),
77- generator = generator ,
78- dtype = torch .bool ), VOCAB_SIZE )
79-
80- # Sometimes disable top-p (p=1.0)
81- p_values .masked_fill_ (
82- torch .randint (0 ,
83- 2 , (BATCH_SIZE , ),
84- generator = generator ,
85- dtype = torch .bool ), 1.0 )
86-
87- python_logits = apply_top_k_top_p (
88- logits = logits .clone (),
89- k = k_values ,
90- p = p_values ,
91- )
92- python_probs = torch .softmax (python_logits , dim = - 1 )
93-
94- # FlashInfer only exposed renorm interfaces for probs so convert first
95- flashinfer_probs = torch .softmax (logits .clone (), dim = - 1 )
96- flashinfer_probs = top_k_renorm_probs (
97- probs = flashinfer_probs ,
98- top_k = k_values ,
99- )
100- flashinfer_probs = top_p_renorm_probs (
101- probs = flashinfer_probs ,
102- top_p = p_values ,
103- )
104-
105- # Compare the results
106- assert torch .allclose (python_probs , flashinfer_probs , atol = 2e-2 ), \
107- "FlashInfer and Python sampling implementations do not match!"
70+ torch .set_default_device (DEVICE )
71+ generator = Generator (device = DEVICE ).manual_seed (42 )
72+
73+ # Generate random logits
74+ logits = torch .rand ((BATCH_SIZE , VOCAB_SIZE ), generator = generator )
75+
76+ # Generate various top-k and top-p values
77+ k_values = torch .randint (1 , 1000 , (BATCH_SIZE , ), generator = generator )
78+ p_values = torch .rand (
79+ (BATCH_SIZE , ), generator = generator ) * 0.5 + 0.5 # range in [0.5, 1.0]
80+
81+ # Sometimes disable top-k (k=vocab_size)
82+ k_values .masked_fill_ (
83+ torch .randint (0 ,
84+ 2 , (BATCH_SIZE , ),
85+ generator = generator ,
86+ dtype = torch .bool ), VOCAB_SIZE )
87+
88+ # Sometimes disable top-p (p=1.0)
89+ p_values .masked_fill_ (
90+ torch .randint (0 ,
91+ 2 , (BATCH_SIZE , ),
92+ generator = generator ,
93+ dtype = torch .bool ), 1.0 )
94+
95+ python_logits = apply_top_k_top_p (
96+ logits = logits .clone (),
97+ k = k_values ,
98+ p = p_values ,
99+ )
100+ python_probs = torch .softmax (python_logits , dim = - 1 )
101+
102+ # FlashInfer only exposed renorm interfaces for probs so convert first
103+ flashinfer_probs = torch .softmax (logits .clone (), dim = - 1 )
104+ flashinfer_probs = top_k_renorm_probs (
105+ probs = flashinfer_probs ,
106+ top_k = k_values ,
107+ )
108+ flashinfer_probs = top_p_renorm_probs (
109+ probs = flashinfer_probs ,
110+ top_p = p_values ,
111+ )
112+
113+ # Compare the results
114+ assert torch .allclose (python_probs , flashinfer_probs , atol = 2e-2 ), \
115+ "FlashInfer and Python sampling implementations do not match!"
0 commit comments