1919
2020
2121class TopKTopPSampler (nn .Module ):
22+ """
23+ Module that performs optional top-k and top-p filtering followed by
24+ weighted random sampling of logits.
25+
26+ Implementations may update the logits tensor in-place.
27+ """
2228
2329 def __init__ (self ):
2430 super ().__init__ ()
@@ -84,7 +90,11 @@ def forward_native(
8490 k : Optional [torch .Tensor ],
8591 p : Optional [torch .Tensor ],
8692 ) -> torch .Tensor :
87- """PyTorch-native implementation of top-k and top-p sampling."""
93+ """
94+ PyTorch-native implementation of top-k and top-p sampling.
95+
96+ The logits tensor may be updated in-place.
97+ """
8898 logits = apply_top_k_top_p (logits , k , p )
8999 probs = logits .softmax (dim = - 1 , dtype = torch .float32 )
90100 return random_sample (probs , generators )
@@ -136,10 +146,18 @@ def apply_top_k_top_p(
136146) -> torch .Tensor :
137147 """Apply top-k and top-p masks to the logits.
138148
139- This function sorts the logits tensor, which can be slow for large batches.
149+ If a top-p is used, this function will sort the logits tensor,
150+ which can be slow for large batches.
151+
152+ The logits tensor may be updated in-place.
140153 """
141- if k is None and p is None :
142- return logits
154+ if p is None :
155+ if k is None :
156+ return logits
157+
158+ # Avoid sorting vocab for top-k only case.
159+ return apply_top_k_only (logits , k )
160+
143161 logits_sort , logits_idx = logits .sort (dim = - 1 , descending = False )
144162
145163 if k is not None :
@@ -153,7 +171,7 @@ def apply_top_k_top_p(
153171 if p is not None :
154172 # Apply top-p.
155173 probs_sort = logits_sort .softmax (dim = - 1 )
156- probs_sum = probs_sort .cumsum (dim = - 1 )
174+ probs_sum = torch .cumsum (probs_sort , dim = - 1 , out = probs_sort )
157175 top_p_mask = probs_sum <= 1 - p .unsqueeze (dim = 1 )
158176 # at least one
159177 top_p_mask [:, - 1 ] = False
@@ -164,6 +182,31 @@ def apply_top_k_top_p(
164182 return logits
165183
166184
185+ def apply_top_k_only (
186+ logits : torch .Tensor ,
187+ k : torch .Tensor ,
188+ ) -> torch .Tensor :
189+ """
190+ Apply top-k mask to the logits.
191+
192+ This implementation doesn't involve sorting the entire vocab.
193+
194+ The logits tensor may be updated in-place.
195+ """
196+ no_top_k_mask = k == logits .shape [1 ]
197+ # Set non-top-k rows to 1 so that we can gather.
198+ k = k .masked_fill (no_top_k_mask , 1 )
199+ max_top_k = k .max ()
200+ # topk.values tensor has shape [batch_size, max_top_k].
201+ # Convert top k to 0-based index in range [0, max_top_k).
202+ k_index = k .sub_ (1 ).unsqueeze (1 )
203+ top_k_mask = logits .topk (max_top_k , dim = 1 ).values .gather (1 , k_index )
204+ # Handle non-topk rows.
205+ top_k_mask .masked_fill_ (no_top_k_mask .unsqueeze (1 ), - float ("inf" ))
206+ logits .masked_fill_ (logits < top_k_mask , - float ("inf" ))
207+ return logits
208+
209+
167210def random_sample (
168211 probs : torch .Tensor ,
169212 generators : dict [int , torch .Generator ],
0 commit comments