Skip to content

Commit 8bddd03

Browse files
committed
[Model] Integrate PARD into vLLM pre-commit
Signed-off-by: root <anzihao_hh@126.com> Signed-off-by: <anzihao_hh@126.com>
1 parent b1c6d0b commit 8bddd03

File tree

3 files changed

+77
-52
lines changed

3 files changed

+77
-52
lines changed

vllm/spec_decode/batch_expansion.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,17 @@ def score_proposals(
7373
if VLLM_INVALID_TOKEN_ID not in proposals
7474
]
7575

76-
(spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens) = self._expand_batch(
76+
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
77+
num_scoring_tokens) = self._expand_batch(
7778
seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
7879
proposal_token_ids_list=proposal_token_ids_list_without_skips,
7980
proposal_lens_list=proposal_lens_list,
8081
)
8182

8283
if keep_index is not None:
83-
target_seq_group_metadata_list = [target_seq_group_metadata_list[i] for i in keep_index]
84+
target_seq_group_metadata_list = [
85+
target_seq_group_metadata_list[i] for i in keep_index
86+
]
8487
target_sampler_output = self._scorer_worker.execute_model(
8588
execute_model_req=execute_model_req.clone(
8689
seq_group_metadata_list=target_seq_group_metadata_list))

vllm/spec_decode/multi_step_worker.py

Lines changed: 58 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@
1313
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
1414
SequenceGroupMetadata)
1515
from vllm.spec_decode.interfaces import SpeculativeProposals
16-
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
1716

1817
if current_platform.is_cuda_alike():
1918
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
2019

21-
from vllm.spec_decode.interfaces import (SpeculativeProposals,
22-
SpeculativeProposer)
20+
from vllm.spec_decode.interfaces import SpeculativeProposer
2321
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
2422
from vllm.spec_decode.top1_proposer import Top1Proposer
2523
from vllm.worker.worker_base import DelegateWorkerBase
@@ -89,7 +87,8 @@ def sampler_output(
8987
model_outputs: List[SamplerOutput] = []
9088
if current_platform.is_cuda_alike() and isinstance(
9189
self.model_runner, TP1DraftModelRunner
92-
) and self.model_runner.supports_gpu_multi_step(expanded_request) and not self.pard:
90+
) and self.model_runner.supports_gpu_multi_step(
91+
expanded_request) and not self.pard:
9392
# Here we run the draft_model_runner with multi-step prepare
9493
# on the GPU directly
9594
expanded_request.num_steps = sample_len
@@ -107,7 +106,8 @@ def sampler_output(
107106
self.worker.model_runner.return_hidden_states = True
108107

109108
if hasattr(self, "pard") and self.pard is True:
110-
filtered_model_outputs = self.pard_infer(expanded_request, sample_len)
109+
filtered_model_outputs = self.pard_infer(
110+
expanded_request, sample_len)
111111
return filtered_model_outputs, True
112112

113113
for _ in range(sample_len):
@@ -124,7 +124,6 @@ def sampler_output(
124124
indices_of_seq_with_bonus_tokens)
125125
model_outputs.append(model_output)
126126

127-
128127
# move indices to device to avoid stream sync
129128
indices_of_seq_with_bonus_tokens = torch.tensor(
130129
indices_of_seq_with_bonus_tokens, device=self.device)
@@ -133,7 +132,7 @@ def sampler_output(
133132
return filtered_model_outputs, True
134133

135134
def pard_infer(self, expanded_request: ExecuteModelRequest,
136-
sample_len: int) -> List[SamplerOutput]:
135+
sample_len: int) -> List[SamplerOutput]:
137136
# prepare recompute kv token
138137
# update seq_group_metadata_list
139138
mask_token_id = self.pard_token
@@ -147,69 +146,92 @@ def pard_infer(self, expanded_request: ExecuteModelRequest,
147146
for name, tmp_request in request_by_id.items():
148147
seq_num_base = len(tmp_request)
149148
group_key = list(tmp_request[-1].seq_data.keys())[0]
150-
output_token_ids = tmp_request[-1].seq_data[group_key].output_token_ids
151-
rm_num = min(sample_len - 1 + seq_num_base - 1, len(output_token_ids) - 1)
152-
rm_token_ids = list(output_token_ids[len(output_token_ids) - rm_num:])
149+
output_token_ids = tmp_request[-1].seq_data[
150+
group_key].output_token_ids
151+
rm_num = min(sample_len - 1 + seq_num_base - 1,
152+
len(output_token_ids) - 1)
153+
rm_token_ids = list(output_token_ids[len(output_token_ids) -
154+
rm_num:])
153155
all_rm_token_ids.append(rm_token_ids)
154156
tmp_new_requests = tmp_request[-1]
155-
tmp_new_requests.seq_data[group_key].output_token_ids = output_token_ids[:len(output_token_ids) - rm_num]
157+
tmp_new_requests.seq_data[
158+
group_key].output_token_ids = output_token_ids[:len(
159+
output_token_ids) - rm_num]
156160
tmp_new_requests.seq_data[group_key]._num_computed_tokens -= rm_num
157161
new_request_list.append(tmp_new_requests)
158162
expanded_request.seq_group_metadata_list = new_request_list
159163
max_rm_num = max([len(i) for i in all_rm_token_ids])
160-
min_rm_num = min([len(i) for i in all_rm_token_ids])
161164

162165
# get proposal
163166
proposal = SpeculativeProposals(
164-
proposal_token_ids = torch.tensor([
165-
rm_token_ids + [mask_token_id for i in range(sample_len -1 + max_rm_num - len(rm_token_ids))]
166-
for rm_token_ids in all_rm_token_ids], device=self.device),
167-
proposal_probs=torch.tensor([sample_len - 1 + max_rm_num] * len(new_request_list), device=self.device), #fake
168-
proposal_lens=torch.tensor([sample_len - 1 + max_rm_num] * len(new_request_list), device=self.device)
169-
)
167+
proposal_token_ids=torch.tensor([
168+
rm_token_ids + [
169+
mask_token_id for i in range(sample_len - 1 + max_rm_num -
170+
len(rm_token_ids))
171+
] for rm_token_ids in all_rm_token_ids
172+
],
173+
device=self.device),
174+
proposal_probs=torch.tensor([sample_len - 1 + max_rm_num] *
175+
len(new_request_list),
176+
device=self.device), #fake
177+
proposal_lens=torch.tensor([sample_len - 1 + max_rm_num] *
178+
len(new_request_list),
179+
device=self.device))
170180

171181
# pard forward
172182
keep_index = []
173183
rm_token_num = []
174184
rm_token_num_sum = []
175185
for i, rm_token_ids in enumerate(all_rm_token_ids):
176-
keep_index.extend([i * (sample_len + max_rm_num) + j for j in range(sample_len + len(rm_token_ids))])
186+
keep_index.extend([
187+
i * (sample_len + max_rm_num) + j
188+
for j in range(sample_len + len(rm_token_ids))
189+
])
177190
rm_token_num.append(len(rm_token_ids))
178191
rm_token_num_sum.append(sum(rm_token_num))
179192

180-
pard_draft_out = self.pard_scorer.score_proposals(expanded_request, proposal, return_output=True, keep_index=keep_index)
193+
pard_draft_out = self.pard_scorer.score_proposals(
194+
expanded_request,
195+
proposal,
196+
return_output=True,
197+
keep_index=keep_index)
181198

182199
# align probs shape of target and draft model
183200
target_dim = self.pard_scorer._vocab_size
184201
if pard_draft_out.sampled_token_probs.shape[1] > target_dim:
185-
pard_draft_out.sampled_token_probs = pard_draft_out.sampled_token_probs[:, :target_dim]
202+
tmp_draft_probs = pard_draft_out.sampled_token_probs[:, :
203+
target_dim]
204+
pard_draft_out.sampled_token_probs = tmp_draft_probs
186205
elif pard_draft_out.sampled_token_probs.shape[1] < target_dim:
187206
pard_draft_out.sampled_token_probs = torch.nn.functional.pad(
188-
pard_draft_out.sampled_token_probs, (0, target_dim - pard_draft_out.sampled_token_probs.shape[1]), value=0)
207+
pard_draft_out.sampled_token_probs,
208+
(0, target_dim - pard_draft_out.sampled_token_probs.shape[1]),
209+
value=0)
189210

190211
# get output
191-
output_indices = torch.tensor([[i + tmp_rm + j * sample_len for j, tmp_rm in enumerate(rm_token_num_sum)]
192-
for i in range(sample_len)], device=self.device)
212+
output_indices = torch.tensor([[
213+
i + tmp_rm + j * sample_len
214+
for j, tmp_rm in enumerate(rm_token_num_sum)
215+
] for i in range(sample_len)],
216+
device=self.device)
193217
filtered_model_outputs = [
194218
SamplerOutput(
195219
outputs=[
196220
pard_draft_out.outputs[i] for i in output_indices_to_retain
197221
] if len(pard_draft_out.outputs) > 0 else [],
198222
sampled_token_probs=(
199-
pard_draft_out.sampled_token_probs[output_indices_to_retain]
200-
if pard_draft_out.sampled_token_probs is not None
201-
else None),
202-
logprobs=(
203-
pard_draft_out.logprobs[output_indices_to_retain]
204-
if pard_draft_out.logprobs is not None else None),
205-
sampled_token_ids=(pard_draft_out.
206-
sampled_token_ids[output_indices_to_retain]
207-
if pard_draft_out.sampled_token_ids
208-
is not None else None))
223+
pard_draft_out.
224+
sampled_token_probs[output_indices_to_retain] if
225+
pard_draft_out.sampled_token_probs is not None else None),
226+
logprobs=(pard_draft_out.logprobs[output_indices_to_retain]
227+
if pard_draft_out.logprobs is not None else None),
228+
sampled_token_ids=(
229+
pard_draft_out.sampled_token_ids[output_indices_to_retain]
230+
if pard_draft_out.sampled_token_ids is not None else None))
209231
for output_indices_to_retain in output_indices
210-
]
232+
]
211233
return filtered_model_outputs
212-
234+
213235
@staticmethod
214236
def _maybe_update_previous_hidden_states(
215237
model_output: SamplerOutput,

vllm/spec_decode/spec_decode_worker.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
101101
scorer_worker=target_worker,
102102
draft_worker_kwargs=draft_worker_kwargs,
103103
disable_mqa_scorer=speculative_config.disable_mqa_scorer,
104-
disable_by_batch_size=speculative_config.
105-
disable_by_batch_size,
106-
draft_token_acceptance_method=speculative_config.
107-
acceptance_method,
104+
disable_by_batch_size=speculative_config.disable_by_batch_size,
105+
draft_token_acceptance_method=speculative_config.acceptance_method,
108106
typical_acceptance_sampler_posterior_threshold=speculative_config.
109107
posterior_threshold,
110108
typical_acceptance_sampler_posterior_alpha=speculative_config.
@@ -201,7 +199,6 @@ def create_worker(
201199
if draft_model_config.hf_config.model_type == "eagle":
202200
enable_lm_head_weight_load = True
203201

204-
205202
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
206203

207204
if draft_model_config.hf_config.model_type == "deepseek_mtp":
@@ -210,10 +207,13 @@ def create_worker(
210207

211208
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
212209
proposer_worker, draft_tp, target_tp)
213-
pard = draft_model_config.hf_config.__dict__.get('spd_type', None) == 'pard'
210+
pard = draft_model_config.hf_config.__dict__.get('spd_type',
211+
None) == 'pard'
214212
proposer_worker.pard = pard
215213
if pard:
216-
proposer_worker.pard_token = draft_model_config.hf_config.__dict__['pard_token']
214+
pard_token = draft_model_config.hf_config.__dict__[
215+
'pard_token']
216+
proposer_worker.pard_token = pard_token
217217

218218
logger.info("Configuring SpecDecodeWorker with proposer=%s",
219219
type(proposer_worker))
@@ -350,7 +350,6 @@ def __init__(
350350
self._disable_log_stats = disable_log_stats
351351
self._num_spec_prefill_steps = num_spec_prefill_steps
352352

353-
354353
def init_device(self) -> None:
355354
"""Initialize both scorer and proposer models.
356355
"""
@@ -396,11 +395,11 @@ def init_device(self) -> None:
396395
device=self.device,
397396
vocab_size=self._vocab_size)
398397

399-
400398
if self.proposer_worker.pard:
401-
self.proposer_worker.pard_scorer = scorer_cls(scorer_worker=self.proposer_worker,
402-
device=self.device,
403-
vocab_size=self._vocab_size)
399+
self.proposer_worker.pard_scorer = scorer_cls(
400+
scorer_worker=self.proposer_worker,
401+
device=self.device,
402+
vocab_size=self._vocab_size)
404403

405404
self._configure_model_sampler_for_spec_decode()
406405

@@ -796,7 +795,7 @@ def _run_speculative_decoding_step(
796795
# Pass last hidden states from target model to proposer
797796
execute_model_req.previous_hidden_states = self.previous_hidden_states
798797
self.previous_hidden_states = None
799-
798+
800799
with Timer() as proposal_timer:
801800
# Generate proposals using draft worker.
802801
proposals = self.proposer_worker.get_spec_proposals(
@@ -1275,7 +1274,8 @@ def _vocab_size(self) -> int:
12751274
for worker in [self.proposer_worker, self.scorer_worker]
12761275
]
12771276
if not self.proposer_worker.pard:
1278-
assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
1277+
assert all(vocab_sizes[0] == vocab_size
1278+
for vocab_size in vocab_sizes)
12791279
return vocab_sizes[0]
12801280

12811281
@property

0 commit comments

Comments
 (0)