11from typing import Dict , Optional
22
33import torch
4- import torch .nn as nn
5-
64from vllm .v1 .sample .ops .topk_topp_sampler import TopKTopPSampler , random_sample
7- from vllm .logger import init_logger
8-
9-
10- logger = init_logger (__name__ )
115
126
137class AscendTopKTopPSampler (TopKTopPSampler ):
148
15- def __init__ (self ):
16- super ().__init__ ()
17- # TODO(linfeng): eliminate warning for FlashInfer here
18- self .forward = self .forward_npu
19-
20- def forward_npu (
9+ def forward_native (
2110 self ,
2211 logits : torch .Tensor ,
2312 generators : Dict [int , torch .Generator ],
@@ -28,37 +17,48 @@ def forward_npu(
2817 logits = apply_top_k_top_p_npu (logits , k , p )
2918 probs = logits .softmax (dim = - 1 , dtype = torch .float32 )
3019 return random_sample (probs , generators )
31-
20+
3221
3322def apply_top_k_top_p_npu (
3423 logits : torch .Tensor ,
3524 k : Optional [torch .Tensor ],
3625 p : Optional [torch .Tensor ],
3726) -> torch .Tensor :
38- """Apply top-k and top-p optimized for NPU.
39-
40- This algorithm avoids using torch.scatter which is time-consuming on NPU.
41- """
42- # TODO(linfeng): consider the case taht either p or k is applied
27+ """Apply top-k and/or top-p optimized for NPU."""
4328 if k is None and p is None :
4429 return logits
30+
4531 batch_size , vocab_size = logits .shape
32+ device = logits .device
4633 logits_sort , logits_idx = logits .sort (dim = - 1 , descending = False )
34+ if k is not None :
35+ safe_k = torch .clamp (k , min = 1 , max = vocab_size )
36+ boundary_idx = (vocab_size - safe_k ).unsqueeze (1 )
37+ boundary = logits_sort .gather (1 , boundary_idx )
38+ top_k_mask = logits_sort < boundary
39+ logits_sort = logits_sort .masked_fill (top_k_mask , - float ("inf" ))
40+ else :
41+ top_k_mask = torch .zeros_like (logits_sort , dtype = torch .bool )
4742
48- boundary = logits_sort .gather (1 , (vocab_size - k ).unsqueeze (dim = 1 ))
49- top_k_mask = logits_sort < boundary
50- logits_sort .masked_fill_ (top_k_mask , - float ("inf" ))
51- cutoff = top_k_mask .sum (dim = - 1 ).min ()
52- probs_sort = logits_sort .softmax (dim = - 1 )[:, cutoff :]
53- probs_sum = probs_sort .cumsum (dim = - 1 )
54- top_p_mask = probs_sum > 1 - p .unsqueeze (dim = 1 )
55- top_p_mask [:, - 1 ] = True
56- strides = torch .arange (0 , batch_size * vocab_size , vocab_size , device = logits .device )
57- flatten_idx = logits_idx [:, cutoff :] + strides .unsqueeze (dim = 1 )
58- valid_idx = torch .masked_select (flatten_idx , top_p_mask )
43+ cutoffs = top_k_mask .sum (dim = - 1 )
44+ strides = torch .arange (0 ,
45+ batch_size * vocab_size ,
46+ vocab_size ,
47+ device = device ).unsqueeze (1 )
48+ if p is not None :
49+ global_cutoff = cutoffs .min ()
50+ active_part = logits_idx [:, global_cutoff :]
51+ probs_sort = logits_sort [:, global_cutoff :].softmax (dim = - 1 )
52+ cumprob = probs_sort .cumsum (dim = - 1 )
53+ top_p_mask = (cumprob <= (1 - p .unsqueeze (1 ))) | (torch .arange (
54+ probs_sort .size (1 ), device = device ) == probs_sort .size (1 ) - 1 )
55+ else :
56+ active_part = logits_idx
57+ top_p_mask = torch .arange (vocab_size , device = device ).expand (
58+ batch_size , - 1 ) >= cutoffs .unsqueeze (1 )
5959
60+ valid_idx = (active_part + strides ).masked_select (top_p_mask )
6061 logits_flatten = logits .flatten ()
61- valid_logits = torch .index_select (logits_flatten , 0 , valid_idx )
62- logits = torch .empty_like (logits_flatten ).fill_ (- float ("inf" ))
63- logits [valid_idx ] = valid_logits
64- return logits .reshape (batch_size , vocab_size )
62+ output = torch .full_like (logits_flatten , - float ('inf' ))
63+ output [valid_idx ] = logits_flatten [valid_idx ]
64+ return output .reshape (batch_size , vocab_size )
0 commit comments