@@ -72,14 +72,7 @@ def __init__(self):
7272 "best performance, please install FlashInfer." )
7373 self .forward = self .forward_native
7474 elif current_platform .is_tpu ():
75- if envs .VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION :
76- logger .warning (
77- "TPU-specific optimization for top-k & top-p sampling are "
78- "disabled, falling back to PyTorch-native implementation "
79- "which could be very slow." )
80- self .forward = self .forward_native
81- else :
82- self .forward = self .forward_tpu
75+ self .forward = self .forward_tpu
8376 else :
8477 self .forward = self .forward_native
8578
@@ -146,12 +139,22 @@ def apply_top_k_top_p_tpu(
146139 chance of being chosen during final sampling, so we can consider the tie
147140 being broken then.
148141 """
142+ probs = logits .softmax (dim = - 1 )
143+ probs_sort , _ = probs .sort (dim = - 1 , descending = False )
144+
149145 if k is not None :
150- logits = apply_top_k_only (logits , k )
146+ top_k_count = probs_sort .size (1 ) - k .to (torch .long ) # shape: (batch, )
147+ top_k_count = top_k_count .unsqueeze (dim = 1 )
148+ top_k_cutoff = probs_sort .gather (- 1 , top_k_count )
149+
150+ # Make sure the no top-k rows are no-op.
151+ no_top_k_mask = (k == logits .shape [1 ]).unsqueeze (dim = 1 )
152+ top_k_cutoff .masked_fill_ (no_top_k_mask , - float ("inf" ))
153+
154+ elements_to_discard = probs < top_k_cutoff
155+ logits .masked_fill_ (elements_to_discard , - float ("inf" ))
151156
152157 if p is not None :
153- probs = logits .softmax (dim = - 1 )
154- probs_sort , _ = probs .sort (dim = - 1 , descending = False )
155158 cumprob = torch .cumsum (probs_sort , dim = - 1 )
156159 top_p_mask = cumprob <= 1 - p .unsqueeze (dim = 1 )
157160 top_p_mask [:, - 1 ] = False # at least one
@@ -224,7 +227,7 @@ def apply_top_k_only(
224227 max_top_k = k .max ()
225228 # topk.values tensor has shape [batch_size, max_top_k].
226229 # Convert top k to 0-based index in range [0, max_top_k).
227- k_index = k .sub_ (1 ).unsqueeze (1 ). expand ( logits . shape [ 0 ], 1 )
230+ k_index = k .sub_ (1 ).unsqueeze (1 )
228231 top_k_mask = logits .topk (max_top_k , dim = 1 ).values .gather (1 , k_index .long ())
229232 # Handle non-topk rows.
230233 top_k_mask .masked_fill_ (no_top_k_mask .unsqueeze (1 ), - float ("inf" ))
0 commit comments