@@ -156,14 +156,13 @@ def rejection_sample(
156156 bonus_token_ids ,
157157 )
158158 else :
159- num_draft_tokens_tensor = torch .tensor (num_draft_tokens ,
160- device = device )
161159 rejection_greedy_sample_pytorch (
162160 output_token_ids ,
163- num_draft_tokens_tensor ,
161+ cu_num_draft_tokens ,
164162 draft_token_ids ,
165163 target_argmax ,
166164 bonus_token_ids ,
165+ num_draft_tokens ,
167166 max_spec_len ,
168167 is_greedy ,
169168 )
@@ -311,40 +310,72 @@ def rejection_greedy_sample_spec_len_1_pytorch(
311310
312311def rejection_greedy_sample_pytorch (
313312 output_token_ids , # [batch_size, max_spec_len + 1]
314- num_draft_tokens , # [batch_size]
313+ cu_num_draft_tokens , # [batch_size]
315314 draft_token_ids , # [num_tokens]
316315 target_argmax , # [num_tokens]
317- bonus_token_ids , # [batch_size, 1]
318- max_spec_len , # int
316+ bonus_token_ids , # [batch_size]
317+ draft_tokens_per_req , # [batch_size], list
318+ max_spec_len ,
319319 is_greedy = None , # [batch_size] or None
320320):
321- batch_size = output_token_ids .shape [0 ]
321+ batch_size = output_token_ids .size (0 )
322+ num_tokens = draft_token_ids .size (0 )
322323 device = output_token_ids .device
324+ draft_tokens_per_req = torch .tensor (draft_tokens_per_req ).to (
325+ device , non_blocking = True )
323326 if is_greedy is None :
324327 is_greedy = torch .ones (batch_size , dtype = torch .bool , device = device )
325- draft_token_mask = draft_token_ids == target_argmax
326- pos_ids = torch .arange (0 , max_spec_len + 1 ,
327- device = device ).view (1 , - 1 ).expand (batch_size , - 1 )
328- pos_mask = pos_ids < num_draft_tokens .view (- 1 , 1 )
329- output_token_mask = torch .zeros ([batch_size , max_spec_len + 1 ],
330- dtype = torch .int ,
331- device = device )
332- output_token_mask [pos_mask ] = draft_token_mask .to (torch .int )
333- output_token_mask = torch .cumprod (output_token_mask ,
334- dim = 1 ) # [batch_size, max_spec_len + 1]
335- extra_accept_pos = torch .max (
336- pos_ids * output_token_mask , dim = 1 , keepdim = True )[1 ] + 1
337- output_token_mask [:, extra_accept_pos ] = True
338- output_token_mask *= is_greedy .view (- 1 , 1 )
339- output_token_ids [pos_mask ] = target_argmax .to (output_token_ids .dtype )
340- extra_accept_ids = output_token_ids [:, extra_accept_pos ]
341- output_token_ids [pos_mask ] = draft_token_ids .to (output_token_ids .dtype )
342- output_token_ids [:, extra_accept_pos ] = extra_accept_ids .to (
343- output_token_ids .dtype )
344- output_token_ids [:, - 1 ] = bonus_token_ids .squeeze (1 ).to (
345- output_token_ids .dtype )
346- output_token_ids [~ output_token_mask .bool ()] = - 1
347- return output_token_ids
328+
329+ start_indices = cu_num_draft_tokens - draft_tokens_per_req
330+ req_ids = torch .arange (batch_size , device = device )
331+ token_req_ids = torch .repeat_interleave (req_ids , draft_tokens_per_req )
332+ token_positions = torch .arange (
333+ num_tokens , device = device ) - start_indices [token_req_ids ]
334+
335+ # Find the first mismatch position of each request.
336+ mismatch_global = (draft_token_ids != target_argmax )
337+ if max_spec_len == 0 :
338+ first_mismatch_pos_per_req = torch .zeros (batch_size ,
339+ dtype = torch .long ,
340+ device = device )
341+ else :
342+ # [bs, max_spec_len]
343+ pos_matrix = torch .full ((batch_size , max_spec_len ),
344+ - 1 ,
345+ dtype = torch .long ,
346+ device = device )
347+ pos_matrix [token_req_ids , token_positions ] = token_positions
348+ mismatch_matrix = torch .full ((batch_size , max_spec_len ),
349+ False ,
350+ dtype = torch .bool ,
351+ device = device )
352+ mismatch_matrix [token_req_ids , token_positions ] = mismatch_global
353+ mismatch_positions = torch .where (mismatch_matrix , pos_matrix ,
354+ max_spec_len * 2 )
355+ first_mismatch_pos_per_req , _ = torch .min (mismatch_positions , dim = 1 )
356+ no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2 )
357+ first_mismatch_pos_per_req [no_mismatch_mask ] = draft_tokens_per_req [
358+ no_mismatch_mask ]
359+
360+ # Copy matched target tokens into output.
361+ copy_len = torch .minimum (first_mismatch_pos_per_req + 1 ,
362+ draft_tokens_per_req )
363+ copy_indices = torch .arange (max_spec_len + 1 ,
364+ device = device ).expand (batch_size , - 1 )
365+ copy_mask = copy_indices < copy_len .unsqueeze (1 )
366+ greedy_mask = is_greedy .unsqueeze (1 )
367+ final_copy_mask = copy_mask & greedy_mask
368+ global_idx = start_indices .unsqueeze (1 ) + copy_indices
369+ output_token_ids [final_copy_mask ] = target_argmax [
370+ global_idx [final_copy_mask ]].to (output_token_ids .dtype )
371+ # Fill bonus token.
372+ needs_bonus = is_greedy & (first_mismatch_pos_per_req
373+ >= draft_tokens_per_req )
374+ if torch .any (needs_bonus ):
375+ bonus_rows = torch .where (needs_bonus )[0 ]
376+ bonus_cols = draft_tokens_per_req [bonus_rows ]
377+ bonus_token_ids = bonus_token_ids .squeeze (1 )
378+ output_token_ids [bonus_rows , bonus_cols ] = bonus_token_ids [bonus_rows ]
348379
349380
350381def rejection_random_sample_pytorch (
0 commit comments