@@ -146,17 +146,25 @@ def rejection_sample(
146146 is_greedy = sampling_metadata .temperature == GREEDY_TEMPERATURE
147147 if not sampling_metadata .all_random :
148148 # Rejection sampling for greedy sampling requests.
149- target_argmax = target_probs .argmax (dim = - 1 )
150- rejection_greedy_sample_pytorch (
151- output_token_ids ,
152- cu_num_draft_tokens ,
153- draft_token_ids ,
154- target_argmax ,
155- bonus_token_ids ,
156- is_greedy ,
157- max_spec_len ,
158- # num_warps=1,
159- )
149+ target_argmax = target_probs .argmax (dim = - 1 ).to (torch .int32 )
150+ if min (num_draft_tokens ) == 1 and max (
151+ num_draft_tokens ) == 1 and sampling_metadata .all_greedy :
152+ rejection_greedy_sample_spec_len_1_pytorch (
153+ output_token_ids ,
154+ draft_token_ids ,
155+ target_argmax ,
156+ bonus_token_ids ,
157+ )
158+ else :
159+ rejection_greedy_sample_pytorch (
160+ output_token_ids ,
161+ cu_num_draft_tokens ,
162+ draft_token_ids ,
163+ target_argmax ,
164+ bonus_token_ids ,
165+ max_spec_len ,
166+ is_greedy ,
167+ )
160168 if sampling_metadata .all_greedy :
161169 return output_token_ids
162170
@@ -284,47 +292,52 @@ def sample_recovered_tokens(
284292 return recovered_token_ids
285293
286294
295+ def rejection_greedy_sample_spec_len_1_pytorch (
296+ output_token_ids , # [batch_size, 2]
297+ draft_token_ids , # [num_tokens]
298+ target_argmax , # [num_tokens]
299+ bonus_token_ids , # [batch_size]
300+ ):
301+ batch_size = output_token_ids .size (0 )
302+ num_tokens = draft_token_ids .size (0 )
303+ assert batch_size == num_tokens
304+ accept_req_mask = draft_token_ids == target_argmax
305+ output_token_ids [:, 0 ] = target_argmax
306+ bonus_token_ids = bonus_token_ids .squeeze (1 )
307+ output_token_ids [accept_req_mask , 1 ] = bonus_token_ids [accept_req_mask ]
308+
309+
287310def rejection_greedy_sample_pytorch (
288- output_token_ids , # [batch_size, max_spec_len + 1]
289- cu_num_draft_tokens , # [batch_size]
290- draft_token_ids , # [num_tokens]
291- target_argmax , # [num_tokens]
292- bonus_token_ids , # [batch_size]
293- is_greedy = None , # [batch_size] or None
294- max_spec_len = None ,
311+ output_token_ids , # [batch_size, max_spec_len + 1]
312+ cu_num_draft_tokens , # [batch_size]
313+ draft_token_ids , # [num_tokens]
314+ target_argmax , # [num_tokens]
315+ bonus_token_ids , # [batch_size]
316+ max_spec_len , # int
317+ is_greedy = None , # [batch_size] or None
295318):
296319 batch_size = output_token_ids .shape [0 ]
297-
320+ device = output_token_ids . device
298321 if is_greedy is None :
299- is_greedy = torch .ones (batch_size ,
300- dtype = torch .bool ,
301- device = output_token_ids .device )
302-
303- for req_idx in range (batch_size ):
304- if not is_greedy [req_idx ]:
305- continue
306-
307- if req_idx == 0 :
308- start_idx = 0
309- else :
310- start_idx = cu_num_draft_tokens [req_idx - 1 ].item ()
311- end_idx = cu_num_draft_tokens [req_idx ].item ()
312- num_draft_tokens = end_idx - start_idx
313-
314- rejected = False
315- for pos in range (num_draft_tokens ):
316- if not rejected :
317- draft_token_id = draft_token_ids [start_idx + pos ].item ()
318- target_argmax_id = target_argmax [start_idx + pos ].item ()
319-
320- output_token_ids [req_idx , pos ] = target_argmax_id
321-
322- if draft_token_id != target_argmax_id :
323- rejected = True
324-
325- if not rejected :
326- bonus_token_id = bonus_token_ids [req_idx ].item ()
327- output_token_ids [req_idx , num_draft_tokens ] = bonus_token_id
322+ is_greedy = torch .ones (batch_size , dtype = torch .bool , device = device )
323+ draft_token_mask = draft_token_ids == target_argmax
324+ pos_ids = torch .arange (0 , max_spec_len + 1 ,
325+ device = device ).view (1 , - 1 ).expand (batch_size , - 1 )
326+ pos_mask = pos_ids < cu_num_draft_tokens .view (- 1 , 1 )
327+ output_token_mask = torch .zeros ([batch_size , max_spec_len + 1 ],
328+ dtype = torch .bool ,
329+ device = device )
330+ output_token_mask [pos_mask ] = draft_token_mask
331+ output_token_mask = torch .cumprod (output_token_mask ,
332+ dim = 1 ) # [batch_size, max_spec_len + 1]
333+ extra_accept_id = torch .max (
334+ pos_ids * output_token_mask , dim = 1 , keepdim = True ) + 1
335+ output_token_mask [extra_accept_id ] = True
336+ output_token_mask *= is_greedy .view (- 1 , 1 )
337+ output_token_ids [pos_ids ] = draft_token_ids
338+ output_token_ids [:, - 1 ] = bonus_token_ids
339+ output_token_ids = output_token_ids * output_token_mask
340+ return output_token_ids
328341
329342
330343def rejection_random_sample_pytorch (
0 commit comments