|
| 1 | +# |
| 2 | +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. |
| 3 | +# This file is a part of the vllm-ascend project. |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# |
| 17 | + |
| 18 | +from typing import List, Set, Tuple |
| 19 | + |
| 20 | +import torch |
| 21 | +from vllm.model_executor.layers.sampler import SamplerOutput |
| 22 | +from vllm.sequence import ExecuteModelRequest |
| 23 | +from vllm.spec_decode.multi_step_worker import MultiStepWorker |
| 24 | + |
| 25 | +from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner |
| 26 | + |
| 27 | + |
| 28 | +def sampler_output( |
| 29 | + self, |
| 30 | + execute_model_req: ExecuteModelRequest, |
| 31 | + sample_len: int, |
| 32 | + seq_ids_with_bonus_token_in_last_step: Set[int], |
| 33 | +) -> Tuple[List[SamplerOutput], bool]: |
| 34 | + """Run the model forward pass sample_len times. Returns the list of |
| 35 | + sampler output, one per model forward pass, along with indicator of |
| 36 | + whether torch tensor in sampler output need to be transposed in latter |
| 37 | + sampler_output_to_torch logic. |
| 38 | +
|
| 39 | + For multi step worker, this indicator shall be True. |
| 40 | + """ |
| 41 | + self._raise_if_unsupported(execute_model_req) |
| 42 | + # Expand the batch for sequences with a bonus token. |
| 43 | + # Perform a forward pass on the expanded batch and filter the |
| 44 | + # response to retain only the original sequences' responses. |
| 45 | + expanded_request, indices_of_seq_with_bonus_tokens =\ |
| 46 | + self._expand_execute_model_request( |
| 47 | + execute_model_req, seq_ids_with_bonus_token_in_last_step) |
| 48 | + |
| 49 | + # Run model sample_len times. |
| 50 | + model_outputs: List[SamplerOutput] = [] |
| 51 | + |
| 52 | + if isinstance(self.model_runner, TP1DraftModelRunner) and \ |
| 53 | + self.model_runner.supports_gpu_multi_step(expanded_request): |
| 54 | + # Here we run the draft_model_runner with multi-step prepare |
| 55 | + # on the GPU directly |
| 56 | + expanded_request.num_steps = sample_len |
| 57 | + self.model_runner.set_indices_of_seq_with_bonus_tokens( |
| 58 | + indices_of_seq_with_bonus_tokens) |
| 59 | + model_outputs = self.execute_model(execute_model_req=expanded_request) |
| 60 | + else: |
| 61 | + # Here we run multi-step directly, with every step prepared |
| 62 | + # on the CPU. |
| 63 | + # TODO: Remove this branch once DraftModelRunner supports TP>1 |
| 64 | + # and other restrictions that are part of DraftModelRunner's |
| 65 | + # supports_gpu_multi_step(..) |
| 66 | + for _ in range(sample_len): |
| 67 | + model_output: List[SamplerOutput] = self.worker.execute_model( |
| 68 | + execute_model_req=expanded_request) |
| 69 | + assert (len(model_output) == 1 |
| 70 | + ), "composing multistep workers not supported" |
| 71 | + model_output = model_output[0] |
| 72 | + |
| 73 | + self._append_new_tokens(model_output, |
| 74 | + expanded_request.seq_group_metadata_list, |
| 75 | + indices_of_seq_with_bonus_tokens) |
| 76 | + model_outputs.append(model_output) |
| 77 | + |
| 78 | + # move indices to device to avoid stream sync |
| 79 | + indices_of_seq_with_bonus_tokens = torch.tensor( |
| 80 | + indices_of_seq_with_bonus_tokens, device=self.device) |
| 81 | + filtered_model_outputs = self._filter_model_output( |
| 82 | + model_outputs, indices_of_seq_with_bonus_tokens) |
| 83 | + return filtered_model_outputs, True |
| 84 | + |
| 85 | + |
| 86 | +MultiStepWorker.sampler_output = torch.inference_mode()(sampler_output) |
0 commit comments