11# SPDX-License-Identifier: Apache-2.0
2+ from typing import List
3+
24import torch
35import torch .nn as nn
46from torch .nn .utils .rnn import pad_sequence
@@ -52,62 +54,62 @@ def __init__(self):
5254 else :
5355 self .forward_method = self .forward_native
5456
55- def forward (self , logits : torch .Tensor ,
57+ def forward (self , draft_token_ids : List [List [int ]],
58+ target_probs : torch .Tensor ,
5659 sampling_metadata : SamplingMetadata ) -> SamplerOutput :
5760 if not sampling_metadata .all_greedy :
5861 raise NotImplementedError (
5962 "Currently, only greedy sampling is supported by "
6063 "rejection sampler." )
61- return self .forward_method (logits , sampling_metadata )
64+ return self .forward_method (draft_token_ids , target_probs ,
65+ sampling_metadata )
6266
6367 def flashinfer_sample (
6468 self ,
65- logits : torch .Tensor ,
69+ draft_token_ids : List [List [int ]],
70+ target_probs : torch .Tensor ,
6671 sampling_metadata : SamplingMetadata ,
6772 ) -> SamplerOutput :
6873 # NOTE: The following input preparationg can be moved
6974 # to the model runner with a persistent manner for better
7075 # performance.
71- assert sampling_metadata .spec_token_ids is not None
72- spec_token_ids = sampling_metadata .spec_token_ids
73- max_spec_len = max (len (s ) for s in spec_token_ids )
74- batch_size = len (spec_token_ids )
75- draft_token_ids = torch .full ((batch_size , max_spec_len ),
76- INVALID_TOKEN_ID ,
77- device = "cpu" ,
78- dtype = torch .long )
79-
80- target_token_ids = torch .full ((batch_size , max_spec_len + 1 ),
81- fill_value = INVALID_TOKEN_ID ,
82- device = logits .device ,
83- dtype = torch .long )
84-
85- # TODO: Vectorize the following loop for better performance.
86- start_loc = 0
87- for i in range (batch_size ):
88- num_spec_tokens = len (spec_token_ids [i ])
89- draft_token_ids [i , :num_spec_tokens ] = torch .tensor (
90- spec_token_ids [i ], device = "cpu" , dtype = torch .long )
91- end_loc = start_loc + num_spec_tokens + 1
92- # Assume greedy sampling.
93- target_token_ids [i , :num_spec_tokens + 1 ] = torch .argmax (
94- logits [start_loc :end_loc ], dim = - 1 )
95- start_loc = end_loc
96-
97- vocab_size = logits .size (- 1 )
98- # NOTE: CPU <-> GPU synchronization happens here.
99- draft_token_ids = draft_token_ids .to (logits .device )
100- draft_probs = _create_greedy_token_probs (draft_token_ids , vocab_size ,
101- logits .device )
102- target_probs = _create_greedy_token_probs (target_token_ids , vocab_size ,
103- logits .device )
104- uniform_samples = torch .zeros (batch_size ,
105- max_spec_len + 1 ,
106- device = logits .device )
76+ sample_lens = [len (x ) + 1 for x in draft_token_ids ]
77+ # Convert draft token IDs to a tensor, split by sample_lens, then pad.
78+ draft_token_ids = [
79+ torch .tensor (x , dtype = int , device = 'cpu' ) for x in draft_token_ids
80+ ]
81+ draft_token_ids_tensor = pad_sequence (draft_token_ids ,
82+ batch_first = True ,
83+ padding_value = INVALID_TOKEN_ID )
84+
85+ if sampling_metadata .all_greedy :
86+ target_token_ids = target_probs .argmax (dim = - 1 ).view (- 1 )
87+ target_token_ids = target_token_ids .split (sample_lens )
88+ target_token_ids = pad_sequence (target_token_ids ,
89+ batch_first = True ,
90+ padding_value = INVALID_TOKEN_ID )
91+
92+ vocab_size = target_probs .size (- 1 )
93+ # NOTE: CPU <-> GPU synchronization happens here.
94+ draft_token_ids_tensor = draft_token_ids_tensor .to (
95+ target_probs .device )
96+ draft_probs = _create_greedy_token_probs (draft_token_ids_tensor ,
97+ vocab_size ,
98+ target_probs .device )
99+ target_probs = _create_greedy_token_probs (target_token_ids ,
100+ vocab_size ,
101+ target_probs .device )
102+ uniform_samples = torch .zeros (draft_token_ids_tensor .size (0 ),
103+ draft_token_ids_tensor .size (1 ) + 1 ,
104+ device = target_probs .device )
105+ else :
106+ raise NotImplementedError (
107+ "Currently, only greedy sampling is supported by "
108+ "rejection sampler." )
107109
108110 sampled_token_ids , _ , _ = fs .chain_speculative_sampling (
109111 draft_probs ,
110- draft_token_ids ,
112+ draft_token_ids_tensor ,
111113 uniform_samples ,
112114 target_probs ,
113115 )
@@ -117,35 +119,35 @@ def flashinfer_sample(
117119 # TODO: The following method can be optimized for better performance.
118120 def forward_native (
119121 self ,
120- logits : torch .Tensor ,
122+ draft_token_ids : List [List [int ]],
123+ target_probs : torch .Tensor ,
121124 sampling_metadata : SamplingMetadata ,
122125 ) -> SamplerOutput :
123- assert sampling_metadata .spec_token_ids is not None
124- spec_lens = [len (x ) for x in sampling_metadata .spec_token_ids ]
125- # Add 1 to include the 'bonus' token.
126- sample_lens = [x + 1 for x in spec_lens ]
127-
128- output_token_ids = logits .argmax (dim = - 1 ).view (- 1 )
129- output_token_ids = output_token_ids .split (sample_lens )
130- output_token_ids = pad_sequence (output_token_ids ,
131- batch_first = True ,
132- padding_value = INVALID_TOKEN_ID )
133-
134- # Convert spec token IDs to a tensor, split by sample_lens, then pad.
135- spec_token_ids = [
136- torch .tensor (x ,
137- dtype = output_token_ids .dtype ,
138- device = output_token_ids .device )
139- for x in sampling_metadata .spec_token_ids
126+ sample_lens = [len (x ) + 1 for x in draft_token_ids ]
127+ # Convert draft token IDs to a tensor, split by sample_lens, then pad.
128+ draft_token_ids = [
129+ torch .tensor (x , dtype = int , device = 'cpu' ) for x in draft_token_ids
140130 ]
141- spec_token_ids = pad_sequence (spec_token_ids ,
142- batch_first = True ,
143- padding_value = INVALID_TOKEN_ID )
144-
145- # Produce a mask that remains 1 (True) until the first
146- # mismatch (cumprod turns 0 after a mismatch).
147- accept_mask = (output_token_ids [:, :- 1 ] == spec_token_ids ).cumprod (
148- dim = 1 )
131+ draft_token_ids_tensor = pad_sequence (draft_token_ids ,
132+ batch_first = True ,
133+ padding_value = INVALID_TOKEN_ID )
134+ draft_token_ids_tensor = draft_token_ids_tensor .to (target_probs .device )
135+ # Add 1 to include the 'bonus' token.
136+ if sampling_metadata .all_greedy :
137+ output_token_ids = target_probs .argmax (dim = - 1 ).view (- 1 )
138+ output_token_ids = output_token_ids .split (sample_lens )
139+ output_token_ids = pad_sequence (output_token_ids ,
140+ batch_first = True ,
141+ padding_value = INVALID_TOKEN_ID )
142+ # Produce a mask that remains 1 (True) until the first
143+ # mismatch (cumprod turns 0 after a mismatch).
144+ accept_mask = (
145+ output_token_ids [:, :- 1 ] == draft_token_ids_tensor ).cumprod (
146+ dim = 1 )
147+ else :
148+ raise NotImplementedError (
149+ "Currently, only greedy sampling is supported by "
150+ "rejection sampler." )
149151 # Identify valid positions (non-padding).
150152 valid_mask = output_token_ids != INVALID_TOKEN_ID
151153 # Generate mask with bonus token.
0 commit comments