@@ -268,22 +268,27 @@ def _sample_draft_tokens(
268268 keep_sorted = torch .zeros_like (sp , dtype = torch .bool )
269269 keep_sorted [..., 0 ] = True
270270 keep_sorted [..., 1 :] = csum [..., :- 1 ] < top_p # STRICTLY below (correct rule)
271+ keep_sorted [..., :2 ] = True # Force min_keep=2 (prevents survivors=1)
271272 keep = torch .zeros_like (p , dtype = torch .bool ).scatter (- 1 , si , keep_sorted )
272273 x = torch .where (keep , x , torch .full_like (x , float ("-inf" )))
273274
274- # Optional smoothing with untempered baseline
275- probs_full = torch .softmax (x , dim = - 1 )
275+ # Optional smoothing over kept set (uniform mix)
276276 lam = float (getattr (self .opt_config , "draft_mix_lambda_max" , 0.0 ) or 0.0 )
277277 print (f"[SMOOTH_DEBUG] lambda_max from config: { lam } , will run smoothing: { lam > 0.0 } " ,
278278 file = sys .stderr , flush = True )
279+ logp_full = torch .log_softmax (x , dim = - 1 )
279280 if lam > 0.0 :
280- base = torch .softmax (logits_f32 , dim = - 1 ) # untempered baseline
281- probs_full = (1.0 - lam ) * probs_full + lam * base
282- probs_full = probs_full / probs_full .sum (dim = - 1 , keepdim = True )
283- logp_full = torch .log (probs_full .clamp_min (1e-20 ))
281+ kept = torch .isfinite (logp_full )
282+ p = torch .exp (logp_full )
283+ # Uniform over survivors only
284+ u = kept .float () / kept .float ().sum (dim = - 1 , keepdim = True ).clamp_min (1.0 )
285+ p = (1.0 - lam ) * p + lam * u
286+ p = p * kept # Ensure dropped stay at 0
287+ logp_full = torch .log (p .clamp_min (1e-45 ))
284288
285289 # Sample token and gather its logp
286- tok = torch .distributions .Categorical (probs = probs_full ).sample ()
290+ cat = torch .distributions .Categorical (logits = logp_full )
291+ tok = cat .sample ()
287292 tok_logp = logp_full .gather (- 1 , tok .unsqueeze (- 1 )).squeeze (- 1 )
288293
289294 # Debug logging
@@ -318,9 +323,6 @@ def propose(
318323 sampling_metadata : SamplingMetadata ,
319324 mm_embeds : Optional [list [torch .Tensor ]] = None ,
320325 ) -> torch .Tensor :
321- # Store sampling_metadata so _sample_draft_tokens() can access target temperature
322- self ._current_sampling_metadata = sampling_metadata
323-
324326 num_tokens = target_token_ids .shape [0 ]
325327 batch_size = next_token_ids .shape [0 ]
326328
0 commit comments