|
33 | 33 | from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput |
34 | 34 | from vllm.v1.sample.metadata import SamplingMetadata |
35 | 35 | from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID |
| 36 | +from vllm.v1.spec_decode.ngram_proposer import NgramProposer |
36 | 37 | from vllm.v1.utils import bind_kv_cache |
37 | 38 | from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch |
38 | 39 | from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin |
@@ -117,6 +118,15 @@ def __init__( |
117 | 118 | # req_id -> (input_id -> encoder_output) |
118 | 119 | self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} |
119 | 120 |
|
| 121 | + # Set up speculative decoding. |
| 122 | + self.use_spec_decode = False |
| 123 | + if self.speculative_config: |
| 124 | + # TODO: find a better way to check if we are using ngram. |
| 125 | + assert self.speculative_config.ngram_prompt_lookup_min, \ |
| 126 | + "Currently, only ngram spec decode is supported in V1." |
| 127 | + self.drafter = NgramProposer() |
| 128 | + self.use_spec_decode = True |
| 129 | + |
120 | 130 | # Request states. |
121 | 131 | self.requests: Dict[str, CachedRequestState] = {} |
122 | 132 | # Persistent batch. |
@@ -367,6 +377,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: |
367 | 377 | self.input_batch.token_ids_cpu[ |
368 | 378 | req_index, |
369 | 379 | start_token_index:end_token_index] = req_data.new_token_ids |
| 380 | + self.input_batch.num_tokens_no_spec[req_index] = end_token_index |
370 | 381 | # Add spec_token_ids to token_ids_cpu. |
371 | 382 | spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( |
372 | 383 | req_id, []) |
@@ -1009,15 +1020,51 @@ def execute_model( |
1009 | 1020 | for seq in sampled_token_ids[valid_mask].split(gen_lens) |
1010 | 1021 | ] |
1011 | 1022 |
|
| 1023 | + if not self.use_spec_decode: |
| 1024 | + spec_token_ids = None |
| 1025 | + else: |
| 1026 | + spec_token_ids = self.generate_draft_token_ids( |
| 1027 | + valid_sampled_token_ids) |
| 1028 | + |
1012 | 1029 | model_runner_output = ModelRunnerOutput( |
1013 | 1030 | req_ids=req_ids, |
1014 | 1031 | req_id_to_index=self.input_batch.req_id_to_index, |
1015 | 1032 | sampled_token_ids=valid_sampled_token_ids, |
| 1033 | + spec_token_ids=spec_token_ids, |
1016 | 1034 | logprobs=logprobs_lists, |
1017 | 1035 | prompt_logprobs_dict=prompt_logprobs_dict, |
1018 | 1036 | ) |
1019 | 1037 | return model_runner_output |
1020 | 1038 |
|
| 1039 | + def generate_draft_token_ids( |
| 1040 | + self, |
| 1041 | + sampled_token_ids: List[List[int]], |
| 1042 | + ) -> List[List[int]]: |
| 1043 | + # TODO(woosuk): Optimize. |
| 1044 | + num_reqs = len(sampled_token_ids) |
| 1045 | + draft_token_ids: List[List[int]] = [] |
| 1046 | + for i in range(num_reqs): |
| 1047 | + if len(sampled_token_ids[i]) == 0: |
| 1048 | + # Skip speculative decoding. |
| 1049 | + draft_token_ids.append([]) |
| 1050 | + continue |
| 1051 | + |
| 1052 | + # Add sampled_token_ids to token_ids_cpu. |
| 1053 | + start_idx = self.input_batch.num_tokens_no_spec[i] |
| 1054 | + end_idx = start_idx + len(sampled_token_ids[i]) |
| 1055 | + self.input_batch.token_ids_cpu[ |
| 1056 | + i, start_idx:end_idx] = sampled_token_ids[i] |
| 1057 | + drafter_output = self.drafter.propose( |
| 1058 | + self.input_batch.token_ids_cpu[i, :end_idx], |
| 1059 | + self.speculative_config.ngram_prompt_lookup_min, |
| 1060 | + self.speculative_config.num_speculative_tokens, |
| 1061 | + ) |
| 1062 | + if drafter_output is None or len(drafter_output) == 0: |
| 1063 | + draft_token_ids.append([]) |
| 1064 | + else: |
| 1065 | + draft_token_ids.append(drafter_output.tolist()) |
| 1066 | + return draft_token_ids |
| 1067 | + |
1021 | 1068 | def load_model(self) -> None: |
1022 | 1069 | logger.info("Starting to load model %s...", self.model_config.model) |
1023 | 1070 | with DeviceMemoryProfiler() as m: # noqa: SIM117 |
|
0 commit comments