@@ -53,19 +53,15 @@ def build_baseline(
5353 outputs : List [Dict [str , torch .Tensor ]] = []
5454
5555 inp = gen_inputs (defn , workload , device = device , stensors = loaded_stensors )
56- if "probs" in inp :
57- inp ["probs" ] = torch .softmax (
58- inp ["probs" ], dim = - 1
59- ) # convert logits to probs for sampling
6056 inputs .append (inp )
6157
6258 thresholding_method = _detect_thresholding_method (defn )
6359 params = {k : inp [k ] for k in ["top_k" , "top_p" ] if k in inp }
6460 valid_mask = _compute_valid_sampling_mask (inp ["probs" ], thresholding_method , params )
65-
61+
6662 masked_probs = inp ["probs" ] * valid_mask .float ()
6763 expected_probs = masked_probs / masked_probs .sum (dim = - 1 , keepdim = True )
68-
64+
6965 outputs .append ({"expected_probs" : expected_probs })
7066
7167 latencies : List [float ] = []
@@ -151,7 +147,7 @@ def check_correctness(
151147 samples_flat = samples .unsqueeze (0 )
152148 else :
153149 samples_flat = samples .flatten ()
154-
150+
155151 batch_size = valid_mask .shape [0 ]
156152 for i in range (len (samples_flat )):
157153 batch_idx = i % batch_size
@@ -160,9 +156,7 @@ def check_correctness(
160156 correctness = Correctness (
161157 max_relative_error = float ("inf" ), max_absolute_error = float ("inf" )
162158 )
163- message = (
164- f"Sample { sample_idx } is outside valid { thresholding_method } mask for batch { batch_idx } "
165- )
159+ message = f"Sample { sample_idx } is outside valid { thresholding_method } mask for batch { batch_idx } "
166160 print (message , file = sys .stderr )
167161 return correctness , make_eval (
168162 status = EvaluationStatus .INCORRECT_NUMERICAL ,
@@ -186,11 +180,11 @@ def check_correctness(
186180 tvds = []
187181 max_abs_errors = []
188182 max_rel_errors = []
189-
183+
190184 for i in range (batch_size ):
191185 tvd_i = 0.5 * torch .sum (torch .abs (sol_freqs [i ] - expected_probs [i ])).item ()
192186 tvds .append (tvd_i )
193-
187+
194188 max_abs_i , max_rel_i , _ , _ = compute_error_stats (sol_freqs [i ], expected_probs [i ], cfg )
195189 max_abs_errors .append (max_abs_i )
196190 max_rel_errors .append (max_rel_i )
@@ -202,9 +196,9 @@ def check_correctness(
202196
203197 numerical_incorrect = max_tvd > cfg .sampling_tvd_threshold
204198 correctness = Correctness (
205- max_relative_error = max_rel ,
206- max_absolute_error = max_abs ,
207- extra = {"tvd" : max_tvd , "tvds_per_batch" : tvds }
199+ max_relative_error = max_rel ,
200+ max_absolute_error = max_abs ,
201+ extra = {"tvd" : max_tvd , "tvds_per_batch" : tvds },
208202 )
209203 if numerical_incorrect :
210204 return correctness , make_eval (
@@ -234,15 +228,15 @@ def _detect_thresholding_method(defn: Definition) -> str:
234228
235229
236230def _compute_valid_sampling_mask (
237- probs : torch .Tensor , method : str , params : Dict [str , Any ], eps : float = 1e-5
231+ probs : torch .Tensor , method : str , params : Dict [str , Any ], eps : float = 5e-2
238232) -> torch .Tensor :
239233 """
240234 For tie-breaking in top_k (allows any token with prob >= k-th largest)
241235 and numerical precision in top_p (allows tokens within eps of nucleus boundary).
242236 """
243237 if probs .dim () == 1 :
244238 probs = probs .unsqueeze (0 )
245-
239+
246240 batch_size , vocab_size = probs .shape
247241 device = probs .device
248242
@@ -254,48 +248,40 @@ def _compute_valid_sampling_mask(
254248 if method in ["top_k" , "top_k_top_p" ]:
255249 if "top_k" not in params :
256250 raise ValueError (f"top_k parameter required for { method } but not found" )
257-
251+
258252 top_k_param = params ["top_k" ]
259253 for i in range (batch_size ):
260- k = (
261- int (top_k_param [i ].item ())
262- if top_k_param .dim () > 0
263- else int (top_k_param .item ())
264- )
254+ k = int (top_k_param [i ].item ()) if top_k_param .dim () > 0 else int (top_k_param .item ())
265255
266256 if 0 < k < vocab_size :
267257 sorted_probs , _ = torch .sort (probs [i ], descending = True )
268258 # k-th largest value (0-indexed, so k-1)
269259 pivot = sorted_probs [k - 1 ]
270- mask [i ] = probs [i ] >= pivot # tie-breaking handling
260+ mask [i ] = probs [i ] >= pivot # tie-breaking handling
271261
272262 # Apply top_p mask with epsilon tolerance
273263 if method in ["top_p" , "top_k_top_p" ]:
274264 if "top_p" not in params :
275265 raise ValueError (f"top_p parameter required for { method } but not found" )
276-
266+
277267 top_p_param = params ["top_p" ]
278268 for i in range (batch_size ):
279- p = (
280- float (top_p_param [i ].item ())
281- if top_p_param .dim () > 0
282- else float (top_p_param .item ())
283- )
269+ p = float (top_p_param [i ].item ()) if top_p_param .dim () > 0 else float (top_p_param .item ())
284270
285271 if 0 < p < 1 :
286272 sorted_probs , sorted_indices = torch .sort (probs [i ], descending = True )
287273 cumsum = torch .cumsum (sorted_probs , dim = 0 )
288-
274+
289275 # Find tokens in nucleus (cumsum <= p + eps for numerical tolerance)
290276 nucleus_mask = cumsum <= (p + eps )
291-
277+
292278 if not nucleus_mask .any ():
293279 nucleus_mask [0 ] = True
294-
280+
295281 # Map back to original indices
296282 p_mask = torch .zeros (vocab_size , dtype = torch .bool , device = device )
297283 p_mask [sorted_indices [nucleus_mask ]] = True
298-
284+
299285 mask [i ] = mask [i ] & p_mask
300286
301287 return mask
@@ -310,12 +296,12 @@ def _sample_token_distributions(
310296) -> torch .Tensor :
311297 original_batch_size = inputs ["probs" ].shape [0 ] if inputs ["probs" ].dim () > 1 else 1
312298 vocab_size = inputs ["probs" ].shape [- 1 ]
313-
299+
314300 # Repeat entire input batch to fill up to target_batch_size for efficient sampling
315301 target_batch_size = 10000
316302 repeat_count = target_batch_size // original_batch_size
317303 actual_batch_size = repeat_count * original_batch_size
318-
304+
319305 padded_inputs = {}
320306 for key , value in inputs .items ():
321307 if isinstance (value , torch .Tensor ) and value .dim () > 0 :
@@ -341,12 +327,12 @@ def _sample_token_distributions(
341327 else :
342328 # For non-tensor inputs, keep as is
343329 padded_inputs [key ] = value
344-
330+
345331 counters = torch .zeros (
346332 (original_batch_size , vocab_size ), dtype = torch .int64 , device = torch .device (device )
347333 )
348334
349- trials_needed = (num_trials + actual_batch_size - 1 ) // actual_batch_size
335+ trials_needed = (num_trials + repeat_count - 1 ) // repeat_count
350336 total_samples_per_batch = 0
351337
352338 for _ in range (trials_needed ):
0 commit comments