Skip to content

Commit 529c425

Browse files
committed
fall back to full prefill when any of the KV cache receive fails.
1 parent 0e6c4aa commit 529c425

File tree

4 files changed

+228
-288
lines changed

4 files changed

+228
-288
lines changed

benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@ set -ex
1717

1818
kill_gpu_processes() {
1919
# kill all processes on GPU.
20-
pkill -f pt_main_thread
21-
pkill -f python3
22-
pgrep pt_main_thread | xargs kill -9
20+
pgrep pt_main_thread | xargs -r kill -9
21+
pgrep python3 | xargs -r kill -9
2322
for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done
2423
sleep 1
2524
}
@@ -64,7 +63,7 @@ launch_disagg_prefill() {
6463
# disagg prefill
6564
CUDA_VISIBLE_DEVICES=0 python3 \
6665
-m vllm.entrypoints.openai.api_server \
67-
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
66+
--model $model \
6867
--port 8100 \
6968
--max-model-len 10000 \
7069
--gpu-memory-utilization 0.6 \
@@ -75,7 +74,7 @@ launch_disagg_prefill() {
7574
--kv-buffer-size 5e9 &
7675
CUDA_VISIBLE_DEVICES=1 python3 \
7776
-m vllm.entrypoints.openai.api_server \
78-
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
77+
--model $model \
7978
--port 8200 \
8079
--max-model-len 10000 \
8180
--gpu-memory-utilization 0.6 \
@@ -93,7 +92,7 @@ launch_disagg_prefill() {
9392

9493
benchmark() {
9594
results_folder="./results"
96-
model="meta-llama/Meta-Llama-3.1-70B-Instruct"
95+
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
9796
dataset_name="sonnet"
9897
dataset_path="../sonnet_4x.txt"
9998
num_prompts=100

vllm/distributed/kv_transfer/kv_connector/base.py

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
"""
99

1010
from abc import ABC, abstractmethod
11-
from typing import TYPE_CHECKING, List, Optional
11+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
1212

1313
import torch
1414

15+
from vllm.sequence import IntermediateTensors
16+
1517
if TYPE_CHECKING:
1618
from vllm.config import KVTransferConfig
1719
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
@@ -119,18 +121,74 @@ def close(self) -> None:
119121
raise NotImplementedError
120122

121123
@abstractmethod
122-
def build_partial_prefill_input(
124+
def send_kv_caches_and_hidden_states(
123125
self,
126+
model_executable: torch.nn.Module,
124127
model_input: "ModelInputForGPUWithSamplingMetadata",
125-
input_tokens_list: List[torch.Tensor],
126-
num_computed_tokens_list: List[int],
127-
start_pos_list: List[int],
128-
slot_mapping_flat: torch.Tensor,
129-
device: torch.device,
130-
) -> "ModelInputForGPUWithSamplingMetadata":
131-
"""Rebuild the model input based on how many KV caches are received
128+
kv_caches: List[torch.Tensor],
129+
hidden_or_intermediate_states: Union[torch.Tensor,
130+
IntermediateTensors],
131+
) -> None:
132+
"""
133+
Send KV caches and hidden states to the connector.
134+
135+
This method processes the input tokens, KV caches, and
136+
hidden/intermediate states for a given model and sends the data to the
137+
decode instance.
138+
139+
Args:
140+
model_executable (torch.nn.Module): The model executable containing
141+
start and end layer information.
142+
model_input (ModelInputForGPUWithSamplingMetadata): The input
143+
metadata from vLLM.
144+
kv_caches (List[torch.Tensor]): List of KV caches (keys and values)
145+
for each layer.
146+
hidden_or_intermediate_states (Union[torch.Tensor,
147+
IntermediateTensors]):
148+
The hidden or intermediate states associated with the tokens.
149+
150+
Returns:
151+
None
132152
133-
Raises:
134-
NotImplementedError: This method must be implemented in subclasses.
135153
"""
154+
155+
raise NotImplementedError
156+
157+
@abstractmethod
158+
def recv_kv_caches_and_hidden_states(
159+
self, model_executable: torch.nn.Module,
160+
model_input: "ModelInputForGPUWithSamplingMetadata",
161+
kv_caches: List[torch.Tensor]
162+
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
163+
"ModelInputForGPUWithSamplingMetadata"]:
164+
"""
165+
Receive KV caches and hidden states from the connector.
166+
167+
This method attempts to retrieve KV caches and hidden states for input
168+
tokens. If all required KV caches and hidden states are received, it
169+
will bypass model input, else it will fall back to normal vLLM model
170+
forwarding.
171+
172+
Args:
173+
model_executable (torch.nn.Module):
174+
The model executable from vLLM modelrunner.
175+
model_input (ModelInputForGPUWithSamplingMetadata):
176+
The model input from vLLM modelrunner.
177+
kv_caches (List[torch.Tensor]):
178+
List of KV caches for each layer.
179+
180+
Returns:
181+
- hidden_or_intermediate_states (torch.Tensor or
182+
IntermediateTensors):
183+
Concatenated hidden states if all required data is retrieved,
184+
otherwise `None`.
185+
- bypass_model_exec (bool):
186+
Indicates whether the model execution can be skipped (True) or
187+
needs to be redone (False).
188+
- model_input (ModelInputForGPUWithSamplingMetadata):
189+
Optionally adjusted input metadata for re-execution when
190+
`bypass_model_exec=False`.
191+
192+
"""
193+
136194
raise NotImplementedError

0 commit comments

Comments
 (0)