@@ -59,10 +59,14 @@ def build_baseline(
5959 ) # convert logits to probs for sampling
6060 inputs .append (inp )
6161
62- freq_dist = _compute_frequency_distribution (
63- ref_runnable , inp , device , defn , num_trials = 50000
64- )
65- outputs .append ({"frequency_distribution" : freq_dist })
62+ thresholding_method = _detect_thresholding_method (defn )
63+ params = {k : inp [k ] for k in ["top_k" , "top_p" ] if k in inp }
64+ valid_mask = _compute_valid_sampling_mask (inp ["probs" ], thresholding_method , params )
65+
66+ masked_probs = inp ["probs" ] * valid_mask .float ()
67+ expected_probs = masked_probs / masked_probs .sum (dim = - 1 , keepdim = True )
68+
69+ outputs .append ({"expected_probs" : expected_probs })
6670
6771 latencies : List [float ] = []
6872 for inp in inputs :
@@ -94,15 +98,20 @@ def check_correctness(
9498 log_path : str ,
9599 device : str ,
96100 ) -> Tuple [Optional [Correctness ], Optional [Evaluation ]]:
97- ref_freq = ref_outputs [0 ]["frequency_distribution " ]
98- vocab_size = ref_freq .shape [0 ]
101+ expected_probs = ref_outputs [0 ]["expected_probs " ]
102+ vocab_size = expected_probs .shape [- 1 ]
99103
100104 inp = inputs [0 ]
101105 params = {k : inp [k ] for k in ["top_k" , "top_p" ] if k in inp }
102106
103107 output_names = list (defn .outputs .keys ())
104108 output_dtypes = {k : dtype_str_to_torch_dtype (v .dtype ) for k , v in defn .outputs .items ()}
105109
110+ # Compute valid sampling mask based on thresholding
111+ thresholding_method = _detect_thresholding_method (defn )
112+ probs = inp ["probs" ]
113+ valid_mask = _compute_valid_sampling_mask (probs , thresholding_method , params )
114+
106115 # Validate correct sampling token set
107116 for _ in range (cfg .sampling_validation_trials ):
108117 try :
@@ -137,27 +146,34 @@ def check_correctness(
137146 correctness = correctness ,
138147 )
139148
140- # Validate thresholding
141- thresholding_method = _detect_thresholding_method (defn )
142- probs = inp ["probs" ]
143- if not _check_thresholding (samples , probs , thresholding_method , params ):
144- correctness = Correctness (
145- max_relative_error = float ("inf" ), max_absolute_error = float ("inf" )
146- )
147- message = (
148- f"Samples { samples .tolist ()} does not meet { thresholding_method } thresholding"
149- )
150- print (message , file = sys .stderr )
151- return correctness , make_eval (
152- status = EvaluationStatus .INCORRECT_NUMERICAL ,
153- device = device ,
154- log_path = log_path ,
155- correctness = correctness ,
156- )
149+ # Validate thresholding - check samples are within valid mask
150+ if samples .dim () == 0 :
151+ samples_flat = samples .unsqueeze (0 )
152+ else :
153+ samples_flat = samples .flatten ()
154+
155+ batch_size = valid_mask .shape [0 ]
156+ for i in range (len (samples_flat )):
157+ batch_idx = i % batch_size
158+ sample_idx = samples_flat [i ].item ()
159+ if not valid_mask [batch_idx , sample_idx ]:
160+ correctness = Correctness (
161+ max_relative_error = float ("inf" ), max_absolute_error = float ("inf" )
162+ )
163+ message = (
164+ f"Sample { sample_idx } is outside valid { thresholding_method } mask for batch { batch_idx } "
165+ )
166+ print (message , file = sys .stderr )
167+ return correctness , make_eval (
168+ status = EvaluationStatus .INCORRECT_NUMERICAL ,
169+ device = device ,
170+ log_path = log_path ,
171+ correctness = correctness ,
172+ )
157173
158174 try :
159- sol_freq = _compute_frequency_distribution (
160- sol_runnable , inp , device , defn , num_trials = 50000
175+ sol_freqs = _sample_token_distributions (
176+ sol_runnable , inp , device , defn , num_trials = 500000
161177 )
162178 torch .cuda .synchronize (device )
163179 except Exception :
@@ -166,13 +182,29 @@ def check_correctness(
166182 status = EvaluationStatus .RUNTIME_ERROR , device = device , log_path = log_path
167183 )
168184
169- # total variation distance
170- tvd = 0.5 * torch .sum (torch .abs (sol_freq - ref_freq )).item ()
171- max_abs , max_rel , _ , _ = compute_error_stats (sol_freq , ref_freq , cfg )
172-
173- numerical_incorrect = tvd > cfg .sampling_tvd_threshold
185+ batch_size = expected_probs .shape [0 ]
186+ tvds = []
187+ max_abs_errors = []
188+ max_rel_errors = []
189+
190+ for i in range (batch_size ):
191+ tvd_i = 0.5 * torch .sum (torch .abs (sol_freqs [i ] - expected_probs [i ])).item ()
192+ tvds .append (tvd_i )
193+
194+ max_abs_i , max_rel_i , _ , _ = compute_error_stats (sol_freqs [i ], expected_probs [i ], cfg )
195+ max_abs_errors .append (max_abs_i )
196+ max_rel_errors .append (max_rel_i )
197+
198+ # Use the worst (max) TVD and errors across all batch elements
199+ max_tvd = max (tvds )
200+ max_abs = max (max_abs_errors )
201+ max_rel = max (max_rel_errors )
202+
203+ numerical_incorrect = max_tvd > cfg .sampling_tvd_threshold
174204 correctness = Correctness (
175- max_relative_error = max_rel , max_absolute_error = max_abs , extra = {"tvd" : tvd }
205+ max_relative_error = max_rel ,
206+ max_absolute_error = max_abs ,
207+ extra = {"tvd" : max_tvd , "tvds_per_batch" : tvds }
176208 )
177209 if numerical_incorrect :
178210 return correctness , make_eval (
@@ -201,23 +233,125 @@ def _detect_thresholding_method(defn: Definition) -> str:
201233 return "none" # no thresholding
202234
203235
204- def _compute_frequency_distribution (
236+ def _compute_valid_sampling_mask (
237+ probs : torch .Tensor , method : str , params : Dict [str , Any ], eps : float = 1e-5
238+ ) -> torch .Tensor :
239+ """
240+ For tie-breaking in top_k (allows any token with prob >= k-th largest)
241+ and numerical precision in top_p (allows tokens within eps of nucleus boundary).
242+ """
243+ if probs .dim () == 1 :
244+ probs = probs .unsqueeze (0 )
245+
246+ batch_size , vocab_size = probs .shape
247+ device = probs .device
248+
249+ if method == "none" :
250+ return torch .ones ((batch_size , vocab_size ), dtype = torch .bool , device = device )
251+
252+ mask = torch .ones ((batch_size , vocab_size ), dtype = torch .bool , device = device )
253+
254+ if method in ["top_k" , "top_k_top_p" ]:
255+ if "top_k" not in params :
256+ raise ValueError (f"top_k parameter required for { method } but not found" )
257+
258+ top_k_param = params ["top_k" ]
259+ for i in range (batch_size ):
260+ k = (
261+ int (top_k_param [i ].item ())
262+ if top_k_param .dim () > 0
263+ else int (top_k_param .item ())
264+ )
265+
266+ if 0 < k < vocab_size :
267+ sorted_probs , _ = torch .sort (probs [i ], descending = True )
268+ # k-th largest value (0-indexed, so k-1)
269+ pivot = sorted_probs [k - 1 ]
270+ mask [i ] = probs [i ] >= pivot # tie-breaking handling
271+
272+ # Apply top_p mask with epsilon tolerance
273+ if method in ["top_p" , "top_k_top_p" ]:
274+ if "top_p" not in params :
275+ raise ValueError (f"top_p parameter required for { method } but not found" )
276+
277+ top_p_param = params ["top_p" ]
278+ for i in range (batch_size ):
279+ p = (
280+ float (top_p_param [i ].item ())
281+ if top_p_param .dim () > 0
282+ else float (top_p_param .item ())
283+ )
284+
285+ if 0 < p < 1 :
286+ sorted_probs , sorted_indices = torch .sort (probs [i ], descending = True )
287+ cumsum = torch .cumsum (sorted_probs , dim = 0 )
288+
289+ # Find tokens in nucleus (cumsum <= p + eps for numerical tolerance)
290+ nucleus_mask = cumsum <= (p + eps )
291+
292+ if not nucleus_mask .any ():
293+ nucleus_mask [0 ] = True
294+
295+ # Map back to original indices
296+ p_mask = torch .zeros (vocab_size , dtype = torch .bool , device = device )
297+ p_mask [sorted_indices [nucleus_mask ]] = True
298+
299+ mask [i ] = mask [i ] & p_mask
300+
301+ return mask
302+
303+
304+ def _sample_token_distributions (
205305 runnable : Runnable ,
206306 inputs : Dict [str , Any ],
207307 device : str ,
208308 defn : Definition ,
209- num_trials : int = 10000 ,
309+ num_trials : int = 500000 ,
210310) -> torch .Tensor :
211- batch_size = inputs ["probs" ].shape [0 ] if inputs ["probs" ].dim () > 1 else 1
311+ original_batch_size = inputs ["probs" ].shape [0 ] if inputs ["probs" ].dim () > 1 else 1
212312 vocab_size = inputs ["probs" ].shape [- 1 ]
213- counter = torch .zeros (vocab_size , dtype = torch .int64 , device = torch .device (device ))
214-
215- trials_needed = (num_trials + batch_size - 1 ) // batch_size
216- total_samples_collected = 0
313+
314+ # Repeat entire input batch to fill up to target_batch_size for efficient sampling
315+ target_batch_size = 10000
316+ repeat_count = target_batch_size // original_batch_size
317+ actual_batch_size = repeat_count * original_batch_size
318+
319+ padded_inputs = {}
320+ for key , value in inputs .items ():
321+ if isinstance (value , torch .Tensor ) and value .dim () > 0 :
322+ if key == "probs" :
323+ # For probs, repeat the entire batch
324+ if value .dim () == 1 :
325+ value = value .unsqueeze (0 )
326+ # Repeat the entire batch repeat_count times
327+ padded_value = value .repeat (repeat_count , * ([1 ] * (value .dim () - 1 )))
328+ elif key in ["top_k" , "top_p" ]:
329+ # For sampling parameters, repeat the entire batch
330+ if value .dim () == 0 :
331+ padded_value = value .unsqueeze (0 ).repeat (actual_batch_size )
332+ else :
333+ padded_value = value .repeat (repeat_count )
334+ else :
335+ # For other tensors, repeat entire batch along batch dimension
336+ if value .dim () == 0 :
337+ padded_value = value .unsqueeze (0 ).repeat (actual_batch_size )
338+ else :
339+ padded_value = value .repeat (repeat_count , * ([1 ] * (value .dim () - 1 )))
340+ padded_inputs [key ] = padded_value
341+ else :
342+ # For non-tensor inputs, keep as is
343+ padded_inputs [key ] = value
344+
345+ counters = torch .zeros (
346+ (original_batch_size , vocab_size ), dtype = torch .int64 , device = torch .device (device )
347+ )
348+
349+ trials_needed = (num_trials + actual_batch_size - 1 ) // actual_batch_size
350+ total_samples_per_batch = 0
217351
218352 for _ in range (trials_needed ):
219353 with torch .no_grad ():
220- out = runnable (** inputs )
354+ out = runnable (** padded_inputs )
221355
222356 output_names = list (defn .outputs .keys ())
223357 output_dtypes = {k : dtype_str_to_torch_dtype (v .dtype ) for k , v in defn .outputs .items ()}
@@ -229,118 +363,19 @@ def _compute_frequency_distribution(
229363 samples = out_normalized ["samples" ]
230364
231365 if samples .dim () == 0 :
366+ # Single sample - assign to first batch element
232367 sample_idx = samples .item ()
233- counter [sample_idx ] += 1
234- total_samples_collected += 1
235- else : # Batch of samples
236- for i in range (samples .numel ()):
237- sample_idx = samples .flatten ()[i ].item ()
238- counter [sample_idx ] += 1
239- total_samples_collected += 1
240-
241- frequency = counter .float () / total_samples_collected
242- return frequency
243-
244-
245- def _check_thresholding (
246- samples : torch .Tensor , probs : torch .Tensor , method : str , params : Dict [str , Any ]
247- ) -> bool :
248- """Check if samples conform to the specified thresholding method.
249-
250- Parameters
251- ----------
252- samples : torch.Tensor
253- Sampled token indices.
254- probs : torch.Tensor
255- Probability distribution used for sampling.
256- method : str
257- Thresholding method: "top_k", "top_p", "top_k_top_p", or "none".
258- params : Dict[str, Any]
259- Sampling parameters (top_k, top_p values).
260-
261- Returns
262- -------
263- bool
264- True if samples are valid, False otherwise.
265- """
266- batch_size , vocab_size = probs .shape
267- device = probs .device
268-
269- for i in range (batch_size ):
270- prob_row = probs [i ]
271- sample = samples [i ].item ()
272-
273- if method == "top_k" :
274- if "top_k" not in params :
275- raise ValueError ("top_k parameter is required for top_k thresholding but not found" )
276- k = (
277- int (params ["top_k" ][i ].item ())
278- if params ["top_k" ].dim () > 0
279- else int (params ["top_k" ].item ())
280- )
281-
282- if 0 < k < vocab_size :
283- sorted_prob_desc , _ = torch .sort (prob_row , descending = True )
284- pivot = sorted_prob_desc [k - 1 ]
285- mask_top_k = (prob_row >= pivot ).int ()
286- if mask_top_k [sample ] != 1 :
287- return False
288-
289- elif method == "top_p" :
290- if "top_p" not in params :
291- raise ValueError ("top_p parameter is required for top_p thresholding but not found" )
292- p = (
293- float (params ["top_p" ][i ].item ())
294- if params ["top_p" ].dim () > 0
295- else float (params ["top_p" ].item ())
296- )
297-
298- if 0 < p < 1 :
299- eps = 1e-4 # numerical stability
300- sorted_probs , indices = torch .sort (prob_row , descending = False )
301- cdf = torch .cumsum (sorted_probs , dim = 0 )
302- valid_mask = cdf > (1 - p ) - eps
303- valid_indices = indices [valid_mask ]
304-
305- if sample not in valid_indices :
306- return False
307-
308- elif method == "top_k_top_p" :
309- if "top_k" not in params or "top_p" not in params :
310- raise ValueError (
311- "top_k and top_p parameters are both required for top_k_top_p thresholding but not found"
312- )
313- k = (
314- int (params ["top_k" ][i ].item ())
315- if params ["top_k" ].dim () > 0
316- else int (params ["top_k" ].item ())
317- )
318- p = (
319- float (params ["top_p" ][i ].item ())
320- if params ["top_p" ].dim () > 0
321- else float (params ["top_p" ].item ())
322- )
323-
324- if 0 < k < vocab_size :
325- sorted_prob_desc , _ = torch .sort (prob_row , descending = True )
326- pivot = sorted_prob_desc [k - 1 ]
327- mask_top_k = (prob_row >= pivot ).int ()
328- else :
329- mask_top_k = torch .ones (vocab_size , dtype = torch .int32 , device = device )
330-
331- if 0 < p < 1 :
332- eps = 1e-4
333- sorted_probs_asc , indices = torch .sort (prob_row , descending = False )
334- cdf = torch .cumsum (sorted_probs_asc , dim = 0 )
335- mask_top_p = torch .zeros (vocab_size , dtype = torch .int32 , device = device )
336- valid_p_mask = cdf > (1 - p ) - eps
337- mask_top_p [indices [valid_p_mask ]] = 1
338- else :
339- mask_top_p = torch .ones (vocab_size , dtype = torch .int32 , device = device )
340-
341- joint_mask = torch .minimum (mask_top_k , mask_top_p )
342-
343- if joint_mask [sample ] != 1 :
344- return False
345-
346- return True
368+ counters [0 , sample_idx ] += 1
369+ total_samples_per_batch += 1
370+ else :
371+ # slice and accumulate per original batch element
372+ samples_flat = samples .flatten ()
373+ for i in range (samples_flat .numel ()):
374+ batch_idx = i % original_batch_size
375+ sample_idx = samples_flat [i ].item ()
376+ counters [batch_idx , sample_idx ] += 1
377+ total_samples_per_batch += repeat_count
378+
379+ # [batch_size, vocab_size]
380+ frequencies = counters .float () / total_samples_per_batch
381+ return frequencies
0 commit comments