1313from vllm .sequence import (ExecuteModelRequest , HiddenStates , SequenceData ,
1414 SequenceGroupMetadata )
1515from vllm .spec_decode .interfaces import SpeculativeProposals
16- from vllm .spec_decode .batch_expansion import BatchExpansionTop1Scorer
1716
1817if current_platform .is_cuda_alike ():
1918 from vllm .spec_decode .draft_model_runner import TP1DraftModelRunner
2019
21- from vllm .spec_decode .interfaces import (SpeculativeProposals ,
22- SpeculativeProposer )
20+ from vllm .spec_decode .interfaces import SpeculativeProposer
2321from vllm .spec_decode .proposer_worker_base import ProposerWorkerBase
2422from vllm .spec_decode .top1_proposer import Top1Proposer
2523from vllm .worker .worker_base import DelegateWorkerBase
@@ -89,7 +87,8 @@ def sampler_output(
8987 model_outputs : List [SamplerOutput ] = []
9088 if current_platform .is_cuda_alike () and isinstance (
9189 self .model_runner , TP1DraftModelRunner
92- ) and self .model_runner .supports_gpu_multi_step (expanded_request ) and not self .pard :
90+ ) and self .model_runner .supports_gpu_multi_step (
91+ expanded_request ) and not self .pard :
9392 # Here we run the draft_model_runner with multi-step prepare
9493 # on the GPU directly
9594 expanded_request .num_steps = sample_len
@@ -107,7 +106,8 @@ def sampler_output(
107106 self .worker .model_runner .return_hidden_states = True
108107
109108 if hasattr (self , "pard" ) and self .pard is True :
110- filtered_model_outputs = self .pard_infer (expanded_request , sample_len )
109+ filtered_model_outputs = self .pard_infer (
110+ expanded_request , sample_len )
111111 return filtered_model_outputs , True
112112
113113 for _ in range (sample_len ):
@@ -124,7 +124,6 @@ def sampler_output(
124124 indices_of_seq_with_bonus_tokens )
125125 model_outputs .append (model_output )
126126
127-
128127 # move indices to device to avoid stream sync
129128 indices_of_seq_with_bonus_tokens = torch .tensor (
130129 indices_of_seq_with_bonus_tokens , device = self .device )
@@ -133,7 +132,7 @@ def sampler_output(
133132 return filtered_model_outputs , True
134133
135134 def pard_infer (self , expanded_request : ExecuteModelRequest ,
136- sample_len : int ) -> List [SamplerOutput ]:
135+ sample_len : int ) -> List [SamplerOutput ]:
137136 # prepare recompute kv token
138137 # update seq_group_metadata_list
139138 mask_token_id = self .pard_token
@@ -147,69 +146,92 @@ def pard_infer(self, expanded_request: ExecuteModelRequest,
147146 for name , tmp_request in request_by_id .items ():
148147 seq_num_base = len (tmp_request )
149148 group_key = list (tmp_request [- 1 ].seq_data .keys ())[0 ]
150- output_token_ids = tmp_request [- 1 ].seq_data [group_key ].output_token_ids
151- rm_num = min (sample_len - 1 + seq_num_base - 1 , len (output_token_ids ) - 1 )
152- rm_token_ids = list (output_token_ids [len (output_token_ids ) - rm_num :])
149+ output_token_ids = tmp_request [- 1 ].seq_data [
150+ group_key ].output_token_ids
151+ rm_num = min (sample_len - 1 + seq_num_base - 1 ,
152+ len (output_token_ids ) - 1 )
153+ rm_token_ids = list (output_token_ids [len (output_token_ids ) -
154+ rm_num :])
153155 all_rm_token_ids .append (rm_token_ids )
154156 tmp_new_requests = tmp_request [- 1 ]
155- tmp_new_requests .seq_data [group_key ].output_token_ids = output_token_ids [:len (output_token_ids ) - rm_num ]
157+ tmp_new_requests .seq_data [
158+ group_key ].output_token_ids = output_token_ids [:len (
159+ output_token_ids ) - rm_num ]
156160 tmp_new_requests .seq_data [group_key ]._num_computed_tokens -= rm_num
157161 new_request_list .append (tmp_new_requests )
158162 expanded_request .seq_group_metadata_list = new_request_list
159163 max_rm_num = max ([len (i ) for i in all_rm_token_ids ])
160- min_rm_num = min ([len (i ) for i in all_rm_token_ids ])
161164
162165 # get proposal
163166 proposal = SpeculativeProposals (
164- proposal_token_ids = torch .tensor ([
165- rm_token_ids + [mask_token_id for i in range (sample_len - 1 + max_rm_num - len (rm_token_ids ))]
166- for rm_token_ids in all_rm_token_ids ], device = self .device ),
167- proposal_probs = torch .tensor ([sample_len - 1 + max_rm_num ] * len (new_request_list ), device = self .device ), #fake
168- proposal_lens = torch .tensor ([sample_len - 1 + max_rm_num ] * len (new_request_list ), device = self .device )
169- )
167+ proposal_token_ids = torch .tensor ([
168+ rm_token_ids + [
169+ mask_token_id for i in range (sample_len - 1 + max_rm_num -
170+ len (rm_token_ids ))
171+ ] for rm_token_ids in all_rm_token_ids
172+ ],
173+ device = self .device ),
174+ proposal_probs = torch .tensor ([sample_len - 1 + max_rm_num ] *
175+ len (new_request_list ),
176+ device = self .device ), #fake
177+ proposal_lens = torch .tensor ([sample_len - 1 + max_rm_num ] *
178+ len (new_request_list ),
179+ device = self .device ))
170180
171181 # pard forward
172182 keep_index = []
173183 rm_token_num = []
174184 rm_token_num_sum = []
175185 for i , rm_token_ids in enumerate (all_rm_token_ids ):
176- keep_index .extend ([i * (sample_len + max_rm_num ) + j for j in range (sample_len + len (rm_token_ids ))])
186+ keep_index .extend ([
187+ i * (sample_len + max_rm_num ) + j
188+ for j in range (sample_len + len (rm_token_ids ))
189+ ])
177190 rm_token_num .append (len (rm_token_ids ))
178191 rm_token_num_sum .append (sum (rm_token_num ))
179192
180- pard_draft_out = self .pard_scorer .score_proposals (expanded_request , proposal , return_output = True , keep_index = keep_index )
193+ pard_draft_out = self .pard_scorer .score_proposals (
194+ expanded_request ,
195+ proposal ,
196+ return_output = True ,
197+ keep_index = keep_index )
181198
182199 # align probs shape of target and draft model
183200 target_dim = self .pard_scorer ._vocab_size
184201 if pard_draft_out .sampled_token_probs .shape [1 ] > target_dim :
185- pard_draft_out .sampled_token_probs = pard_draft_out .sampled_token_probs [:, :target_dim ]
202+ tmp_draft_probs = pard_draft_out .sampled_token_probs [:, :
203+ target_dim ]
204+ pard_draft_out .sampled_token_probs = tmp_draft_probs
186205 elif pard_draft_out .sampled_token_probs .shape [1 ] < target_dim :
187206 pard_draft_out .sampled_token_probs = torch .nn .functional .pad (
188- pard_draft_out .sampled_token_probs , (0 , target_dim - pard_draft_out .sampled_token_probs .shape [1 ]), value = 0 )
207+ pard_draft_out .sampled_token_probs ,
208+ (0 , target_dim - pard_draft_out .sampled_token_probs .shape [1 ]),
209+ value = 0 )
189210
190211 # get output
191- output_indices = torch .tensor ([[i + tmp_rm + j * sample_len for j , tmp_rm in enumerate (rm_token_num_sum )]
192- for i in range (sample_len )], device = self .device )
212+ output_indices = torch .tensor ([[
213+ i + tmp_rm + j * sample_len
214+ for j , tmp_rm in enumerate (rm_token_num_sum )
215+ ] for i in range (sample_len )],
216+ device = self .device )
193217 filtered_model_outputs = [
194218 SamplerOutput (
195219 outputs = [
196220 pard_draft_out .outputs [i ] for i in output_indices_to_retain
197221 ] if len (pard_draft_out .outputs ) > 0 else [],
198222 sampled_token_probs = (
199- pard_draft_out .sampled_token_probs [output_indices_to_retain ]
200- if pard_draft_out .sampled_token_probs is not None
201- else None ),
202- logprobs = (
203- pard_draft_out .logprobs [output_indices_to_retain ]
204- if pard_draft_out .logprobs is not None else None ),
205- sampled_token_ids = (pard_draft_out .
206- sampled_token_ids [output_indices_to_retain ]
207- if pard_draft_out .sampled_token_ids
208- is not None else None ))
223+ pard_draft_out .
224+ sampled_token_probs [output_indices_to_retain ] if
225+ pard_draft_out .sampled_token_probs is not None else None ),
226+ logprobs = (pard_draft_out .logprobs [output_indices_to_retain ]
227+ if pard_draft_out .logprobs is not None else None ),
228+ sampled_token_ids = (
229+ pard_draft_out .sampled_token_ids [output_indices_to_retain ]
230+ if pard_draft_out .sampled_token_ids is not None else None ))
209231 for output_indices_to_retain in output_indices
210- ]
232+ ]
211233 return filtered_model_outputs
212-
234+
213235 @staticmethod
214236 def _maybe_update_previous_hidden_states (
215237 model_output : SamplerOutput ,
0 commit comments