Skip to content

Commit f94cb15

Browse files
committed
[KVConnector][Feature] Support KV connector cache reset via /reset_prefix_cache
Signed-off-by: tovam <tovam@pliops.com>
1 parent ded8ada commit f94cb15

File tree

11 files changed

+76
-25
lines changed

11 files changed

+76
-25
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,3 +491,12 @@ def build_prom_metrics(
491491
expose connector transfer stats via Prometheus.
492492
"""
493493
return None
494+
495+
def reset_cache(self) -> bool:
496+
"""
497+
Reset the connector's internal cache.
498+
499+
Returns:
500+
bool: True if the cache was successfully reset, False otherwise.
501+
"""
502+
return False

vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,3 +438,6 @@ def build_prom_metrics(
438438
per_engine_labelvalues,
439439
prom_metrics,
440440
)
441+
442+
def reset_cache(self) -> bool:
443+
return any(connector.reset_cache() for connector in self._connectors)

vllm/engine/protocol.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,10 @@ async def reset_mm_cache(self) -> None:
124124
...
125125

126126
@abstractmethod
127-
async def reset_prefix_cache(self, device: Device | None = None) -> None:
128-
"""Reset the prefix cache"""
127+
async def reset_prefix_cache(
128+
self, device: Device | None = None, reset_connector: bool = False
129+
) -> None:
130+
"""Reset the prefix cache and optionally any configured connector cache"""
129131
...
130132

131133
@abstractmethod

vllm/entrypoints/llm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1491,8 +1491,12 @@ def start_profile(self) -> None:
14911491
def stop_profile(self) -> None:
14921492
self.llm_engine.stop_profile()
14931493

1494-
def reset_prefix_cache(self, device: Device | None = None) -> None:
1495-
self.llm_engine.reset_prefix_cache(device)
1494+
def reset_prefix_cache(
1495+
self,
1496+
device: Device | None = None,
1497+
reset_connector: bool = False,
1498+
) -> None:
1499+
self.llm_engine.reset_prefix_cache(device, reset_connector)
14961500

14971501
def sleep(self, level: int = 1):
14981502
"""

vllm/entrypoints/openai/api_server.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -991,15 +991,33 @@ async def show_server_info(
991991
@router.post("/reset_prefix_cache")
992992
async def reset_prefix_cache(raw_request: Request):
993993
"""
994-
Reset the prefix cache. Note that we currently do not check if the
995-
prefix cache is successfully reset in the API server.
994+
Reset the local prefix cache.
995+
996+
Optionally, if the query parameter `reset_external=true`
997+
also resets the external (connector-managed) prefix cache.
998+
999+
Note that we currently do not check if the prefix cache
1000+
is successfully reset in the API server.
1001+
1002+
Example:
1003+
POST /reset_prefix_cache?device=gpu&reset_external=true
9961004
"""
9971005
device = None
9981006
device_str = raw_request.query_params.get("device")
9991007
if device_str is not None:
10001008
device = Device[device_str.upper()]
1001-
logger.info("Resetting prefix cache with specific %s...", str(device))
1002-
await engine_client(raw_request).reset_prefix_cache(device)
1009+
1010+
reset_connector = (
1011+
raw_request.query_params.get("reset_external", "false").lower() == "true"
1012+
)
1013+
1014+
logger.info(
1015+
"Resetting prefix cache (device=%s, reset_external_cache=%s)",
1016+
str(device),
1017+
reset_connector,
1018+
)
1019+
1020+
await engine_client(raw_request).reset_prefix_cache(device, reset_connector)
10031021
return Response(status_code=200)
10041022

10051023
@router.post("/reset_mm_cache")

vllm/v1/core/sched/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def has_requests(self) -> bool:
128128
return self.has_unfinished_requests() or self.has_finished_requests()
129129

130130
@abstractmethod
131-
def reset_prefix_cache(self) -> bool:
131+
def reset_prefix_cache(self, reset_connector: bool = False) -> bool:
132132
"""Reset the prefix cache for KV cache.
133133
134134
This is particularly required when the model weights are live-updated.

vllm/v1/core/sched/scheduler.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,8 +1240,17 @@ def get_num_unfinished_requests(self) -> int:
12401240
def has_finished_requests(self) -> bool:
12411241
return len(self.finished_req_ids) > 0
12421242

1243-
def reset_prefix_cache(self) -> bool:
1244-
return self.kv_cache_manager.reset_prefix_cache()
1243+
def reset_prefix_cache(self, reset_connector: bool = False) -> bool:
1244+
reset_success = self.kv_cache_manager.reset_prefix_cache()
1245+
if reset_connector:
1246+
reset_success = reset_success and self.reset_connector_cache()
1247+
return reset_success
1248+
1249+
def reset_connector_cache(self) -> bool:
1250+
if self.connector is None:
1251+
logger.warning("reset_connector called but no KV connector configured.")
1252+
return False
1253+
return self.connector.reset_cache()
12451254

12461255
def make_stats(
12471256
self,

vllm/v1/engine/async_llm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -680,10 +680,14 @@ async def reset_mm_cache(self) -> None:
680680
self.processor.clear_mm_cache()
681681
await self.engine_core.reset_mm_cache_async()
682682

683-
async def reset_prefix_cache(self, device: Device | None = None) -> None:
683+
async def reset_prefix_cache(
684+
self,
685+
device: Device | None = None,
686+
reset_connector: bool = False,
687+
) -> None:
684688
if device == Device.CPU:
685689
raise ValueError("Not supported on CPU.")
686-
await self.engine_core.reset_prefix_cache_async()
690+
await self.engine_core.reset_prefix_cache_async(reset_connector)
687691

688692
async def sleep(self, level: int = 1) -> None:
689693
await self.reset_prefix_cache()

vllm/v1/engine/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,8 +411,8 @@ def reset_mm_cache(self):
411411

412412
self.model_executor.reset_mm_cache()
413413

414-
def reset_prefix_cache(self):
415-
self.scheduler.reset_prefix_cache()
414+
def reset_prefix_cache(self, reset_connector: bool = False):
415+
self.scheduler.reset_prefix_cache(reset_connector)
416416

417417
def sleep(self, level: int = 1):
418418
self.model_executor.sleep(level)

vllm/v1/engine/core_client.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def profile(self, is_start: bool = True) -> None:
138138
def reset_mm_cache(self) -> None:
139139
raise NotImplementedError
140140

141-
def reset_prefix_cache(self) -> None:
141+
def reset_prefix_cache(self, reset_connector: bool = False) -> None:
142142
raise NotImplementedError
143143

144144
def sleep(self, level: int = 1) -> None:
@@ -208,7 +208,7 @@ async def profile_async(self, is_start: bool = True) -> None:
208208
async def reset_mm_cache_async(self) -> None:
209209
raise NotImplementedError
210210

211-
async def reset_prefix_cache_async(self) -> None:
211+
async def reset_prefix_cache_async(self, reset_connector: bool = False) -> None:
212212
raise NotImplementedError
213213

214214
async def sleep_async(self, level: int = 1) -> None:
@@ -287,8 +287,8 @@ def profile(self, is_start: bool = True) -> None:
287287
def reset_mm_cache(self) -> None:
288288
self.engine_core.reset_mm_cache()
289289

290-
def reset_prefix_cache(self) -> None:
291-
self.engine_core.reset_prefix_cache()
290+
def reset_prefix_cache(self, reset_connector: bool = False) -> None:
291+
self.engine_core.reset_prefix_cache(reset_connector)
292292

293293
def sleep(self, level: int = 1) -> None:
294294
self.engine_core.sleep(level)
@@ -750,8 +750,8 @@ def profile(self, is_start: bool = True) -> None:
750750
def reset_mm_cache(self) -> None:
751751
self.call_utility("reset_mm_cache")
752752

753-
def reset_prefix_cache(self) -> None:
754-
self.call_utility("reset_prefix_cache")
753+
def reset_prefix_cache(self, reset_connector: bool = False) -> None:
754+
self.call_utility("reset_prefix_cache", reset_connector)
755755

756756
def add_lora(self, lora_request: LoRARequest) -> bool:
757757
return self.call_utility("add_lora", lora_request)
@@ -954,8 +954,8 @@ async def profile_async(self, is_start: bool = True) -> None:
954954
async def reset_mm_cache_async(self) -> None:
955955
await self.call_utility_async("reset_mm_cache")
956956

957-
async def reset_prefix_cache_async(self) -> None:
958-
await self.call_utility_async("reset_prefix_cache")
957+
async def reset_prefix_cache_async(self, reset_connector: bool = False) -> None:
958+
await self.call_utility_async("reset_prefix_cache", reset_connector)
959959

960960
async def sleep_async(self, level: int = 1) -> None:
961961
await self.call_utility_async("sleep", level)

0 commit comments

Comments
 (0)