@@ -146,17 +146,26 @@ def rejection_sample(
146146 is_greedy = sampling_metadata .temperature == GREEDY_TEMPERATURE
147147 if not sampling_metadata .all_random :
148148 # Rejection sampling for greedy sampling requests.
149- target_argmax = target_probs .argmax (dim = - 1 )
150- rejection_greedy_sample_pytorch (
151- output_token_ids ,
152- cu_num_draft_tokens ,
153- draft_token_ids ,
154- target_argmax ,
155- bonus_token_ids ,
156- is_greedy ,
157- max_spec_len ,
158- # num_warps=1,
159- )
149+ target_argmax = target_probs .argmax (dim = - 1 ).to (torch .int32 )
150+ if min (num_draft_tokens ) == 1 and max (
151+ num_draft_tokens ) == 1 and sampling_metadata .all_greedy :
152+ rejection_greedy_sample_spec_len_1_pytorch (
153+ output_token_ids ,
154+ draft_token_ids ,
155+ target_argmax ,
156+ bonus_token_ids ,
157+ )
158+ else :
159+ rejection_greedy_sample_pytorch (
160+ output_token_ids ,
161+ cu_num_draft_tokens ,
162+ draft_token_ids ,
163+ target_argmax ,
164+ bonus_token_ids ,
165+ num_draft_tokens ,
166+ max_spec_len ,
167+ is_greedy ,
168+ )
160169 if sampling_metadata .all_greedy :
161170 return output_token_ids
162171
@@ -284,47 +293,47 @@ def sample_recovered_tokens(
284293 return recovered_token_ids
285294
286295
296+ def rejection_greedy_sample_spec_len_1_pytorch (
297+ output_token_ids , # [batch_size, 2]
298+ draft_token_ids , # [num_tokens]
299+ target_argmax , # [num_tokens]
300+ bonus_token_ids , # [batch_size]
301+ ):
302+ batch_size = output_token_ids .size (0 )
303+ num_tokens = draft_token_ids .size (0 )
304+ assert batch_size == num_tokens
305+ accept_req_mask = draft_token_ids == target_argmax
306+ output_token_ids [:, 0 ] = target_argmax
307+ bonus_token_ids = bonus_token_ids .squeeze (1 )
308+ output_token_ids [accept_req_mask , 1 ] = bonus_token_ids [accept_req_mask ]
309+
310+
287311def rejection_greedy_sample_pytorch (
288312 output_token_ids , # [batch_size, max_spec_len + 1]
289313 cu_num_draft_tokens , # [batch_size]
290314 draft_token_ids , # [num_tokens]
291315 target_argmax , # [num_tokens]
292316 bonus_token_ids , # [batch_size]
317+ max_spec_len , # int
293318 is_greedy = None , # [batch_size] or None
294- max_spec_len = None ,
295319):
296320 batch_size = output_token_ids .shape [0 ]
297-
321+ device = output_token_ids . device
298322 if is_greedy is None :
299- is_greedy = torch .ones (batch_size ,
300- dtype = torch .bool ,
301- device = output_token_ids .device )
302-
303- for req_idx in range (batch_size ):
304- if not is_greedy [req_idx ]:
305- continue
306-
307- if req_idx == 0 :
308- start_idx = 0
309- else :
310- start_idx = cu_num_draft_tokens [req_idx - 1 ].item ()
311- end_idx = cu_num_draft_tokens [req_idx ].item ()
312- num_draft_tokens = end_idx - start_idx
313-
314- rejected = False
315- for pos in range (num_draft_tokens ):
316- if not rejected :
317- draft_token_id = draft_token_ids [start_idx + pos ].item ()
318- target_argmax_id = target_argmax [start_idx + pos ].item ()
319-
320- output_token_ids [req_idx , pos ] = target_argmax_id
321-
322- if draft_token_id != target_argmax_id :
323- rejected = True
324-
325- if not rejected :
326- bonus_token_id = bonus_token_ids [req_idx ].item ()
327- output_token_ids [req_idx , num_draft_tokens ] = bonus_token_id
323+ is_greedy = torch .ones (batch_size , dtype = torch .bool , device = device )
324+ draft_token_mask = draft_token_ids == target_argmax
325+ pos_ids = torch .arange (0 , max_spec_len + 1 , device = device ).view (1 , - 1 ).expand (batch_size , - 1 )
326+ pos_mask = pos_ids < cu_num_draft_tokens .view (- 1 , 1 )
327+ output_token_mask = torch .zeros ([batch_size , max_spec_len + 1 ], dtype = torch .bool , device = device )
328+ output_token_mask [pos_mask ] = draft_token_mask
329+ output_token_mask = torch .cumprod (output_token_mask , dim = 1 ) # [batch_size, max_spec_len + 1]
330+ extra_accept_id = torch .max (pos_ids * output_token_mask , dim = 1 , keepdim = True ) + 1
331+ output_token_mask [extra_accept_id ] = True
332+ output_token_mask *= is_greedy .view (- 1 , 1 )
333+ output_token_ids [pos_ids ] = draft_token_ids
334+ output_token_ids [:, - 1 ] = bonus_token_ids
335+ output_token_ids = output_token_ids * output_token_mask
336+ return output_token_ids
328337
329338
330339def rejection_random_sample_pytorch (
0 commit comments