|
8 | 8 | """ |
9 | 9 |
|
10 | 10 | from abc import ABC, abstractmethod |
11 | | -from typing import TYPE_CHECKING, List, Optional |
| 11 | +from typing import TYPE_CHECKING, List, Optional, Tuple, Union |
12 | 12 |
|
13 | 13 | import torch |
14 | 14 |
|
| 15 | +from vllm.sequence import IntermediateTensors |
| 16 | + |
15 | 17 | if TYPE_CHECKING: |
16 | 18 | from vllm.config import KVTransferConfig |
17 | 19 | from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata |
@@ -119,18 +121,74 @@ def close(self) -> None: |
119 | 121 | raise NotImplementedError |
120 | 122 |
|
121 | 123 | @abstractmethod |
122 | | - def build_partial_prefill_input( |
| 124 | + def send_kv_caches_and_hidden_states( |
123 | 125 | self, |
| 126 | + model_executable: torch.nn.Module, |
124 | 127 | model_input: "ModelInputForGPUWithSamplingMetadata", |
125 | | - input_tokens_list: List[torch.Tensor], |
126 | | - num_computed_tokens_list: List[int], |
127 | | - start_pos_list: List[int], |
128 | | - slot_mapping_flat: torch.Tensor, |
129 | | - device: torch.device, |
130 | | - ) -> "ModelInputForGPUWithSamplingMetadata": |
131 | | - """Rebuild the model input based on how many KV caches are received |
| 128 | + kv_caches: List[torch.Tensor], |
| 129 | + hidden_or_intermediate_states: Union[torch.Tensor, |
| 130 | + IntermediateTensors], |
| 131 | + ) -> None: |
| 132 | + """ |
| 133 | + Send KV caches and hidden states to the connector. |
| 134 | +
|
| 135 | + This method processes the input tokens, KV caches, and |
| 136 | + hidden/intermediate states for a given model and sends the data to the |
| 137 | + decode instance. |
| 138 | +
|
| 139 | + Args: |
| 140 | + model_executable (torch.nn.Module): The model executable containing |
| 141 | + start and end layer information. |
| 142 | + model_input (ModelInputForGPUWithSamplingMetadata): The input |
| 143 | + metadata from vLLM. |
| 144 | + kv_caches (List[torch.Tensor]): List of KV caches (keys and values) |
| 145 | + for each layer. |
| 146 | + hidden_or_intermediate_states (Union[torch.Tensor, |
| 147 | + IntermediateTensors]): |
| 148 | + The hidden or intermediate states associated with the tokens. |
| 149 | +
|
| 150 | + Returns: |
| 151 | + None |
132 | 152 |
|
133 | | - Raises: |
134 | | - NotImplementedError: This method must be implemented in subclasses. |
135 | 153 | """ |
| 154 | + |
| 155 | + raise NotImplementedError |
| 156 | + |
| 157 | + @abstractmethod |
| 158 | + def recv_kv_caches_and_hidden_states( |
| 159 | + self, model_executable: torch.nn.Module, |
| 160 | + model_input: "ModelInputForGPUWithSamplingMetadata", |
| 161 | + kv_caches: List[torch.Tensor] |
| 162 | + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, |
| 163 | + "ModelInputForGPUWithSamplingMetadata"]: |
| 164 | + """ |
| 165 | + Receive KV caches and hidden states from the connector. |
| 166 | +
|
| 167 | + This method attempts to retrieve KV caches and hidden states for input |
| 168 | + tokens. If all required KV caches and hidden states are received, it |
| 169 | + will bypass model input, else it will fall back to normal vLLM model |
| 170 | + forwarding. |
| 171 | +
|
| 172 | + Args: |
| 173 | + model_executable (torch.nn.Module): |
| 174 | + The model executable from vLLM modelrunner. |
| 175 | + model_input (ModelInputForGPUWithSamplingMetadata): |
| 176 | + The model input from vLLM modelrunner. |
| 177 | + kv_caches (List[torch.Tensor]): |
| 178 | + List of KV caches for each layer. |
| 179 | +
|
| 180 | + Returns: |
| 181 | + - hidden_or_intermediate_states (torch.Tensor or |
| 182 | + IntermediateTensors): |
| 183 | + Concatenated hidden states if all required data is retrieved, |
| 184 | + otherwise `None`. |
| 185 | + - bypass_model_exec (bool): |
| 186 | + Indicates whether the model execution can be skipped (True) or |
| 187 | + needs to be redone (False). |
| 188 | + - model_input (ModelInputForGPUWithSamplingMetadata): |
| 189 | + Optionally adjusted input metadata for re-execution when |
| 190 | + `bypass_model_exec=False`. |
| 191 | +
|
| 192 | + """ |
| 193 | + |
136 | 194 | raise NotImplementedError |
0 commit comments