diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 359b5b263f68..d64142e77f37 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -341,8 +341,10 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: assert device in self._allocators return self._allocators[device].get_prefix_cache_hit_rate() - def reset_prefix_cache(self) -> bool: - """Reset prefix cache for all devices.""" + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + """Reset prefix cache for specified or all devices.""" + if device: + return self._allocators[device].reset_prefix_cache() success = True for allocator in self._allocators.values(): success = success and allocator.reset_prefix_cache() diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 0b0197deb8d4..301656996435 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -305,7 +305,7 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: pass @abstractmethod - def reset_prefix_cache(self) -> bool: + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: """Reset prefix cache.""" pass diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index c5b3b04f37ca..c6bf6d163132 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -456,8 +456,8 @@ def get_num_free_cpu_blocks(self) -> int: def get_prefix_cache_hit_rate(self, device: Device) -> float: return self.block_allocator.get_prefix_cache_hit_rate(device) - def reset_prefix_cache(self) -> bool: - return self.block_allocator.reset_prefix_cache() + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + return self.block_allocator.reset_prefix_cache(device) def _can_swap(self, seq_group: SequenceGroup, diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index b48ba87e95a0..4c1182debcec 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -2,7 +2,7 @@ import enum from abc import ABC, abstractmethod -from typing import List +from typing import List, Optional from typing import Sequence as GenericSequence from typing import Tuple @@ -125,8 +125,8 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: pass @abstractmethod - def reset_prefix_cache(self) -> bool: - """Reset prefix cache for all devices.""" + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + """Reset prefix cache for specified or all devices.""" pass @abstractmethod diff --git a/vllm/core/placeholder_block_space_manager.py b/vllm/core/placeholder_block_space_manager.py index 70c22afa8e15..0f5d8ca6dc7e 100644 --- a/vllm/core/placeholder_block_space_manager.py +++ b/vllm/core/placeholder_block_space_manager.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Tuple +from typing import List, Optional, Tuple from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup @@ -92,7 +92,7 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup, def get_prefix_cache_hit_rate(self, device: Device) -> float: return -1 - def reset_prefix_cache(self) -> bool: + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: return True def get_num_cached_tokens(self, seq: Sequence) -> int: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index e93143c83d9f..cf85a2135c81 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -634,8 +634,8 @@ def has_unfinished_seqs(self) -> bool: def get_prefix_cache_hit_rate(self, device: Device) -> float: return self.block_manager.get_prefix_cache_hit_rate(device) - def reset_prefix_cache(self) -> bool: - return self.block_manager.reset_prefix_cache() + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + return self.block_manager.reset_prefix_cache(device) def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 63787590bf47..c6fafbeea9c1 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -35,7 +35,7 @@ from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext -from vllm.utils import deprecate_kwargs, weak_bind +from vllm.utils import Device, deprecate_kwargs, weak_bind logger = init_logger(__name__) ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S @@ -1216,8 +1216,9 @@ async def start_profile(self) -> None: async def stop_profile(self) -> None: self.engine.stop_profile() - async def reset_prefix_cache(self) -> None: - self.engine.reset_prefix_cache() + async def reset_prefix_cache(self, + device: Optional[Device] = None) -> None: + self.engine.reset_prefix_cache(device) async def sleep(self, level: int = 1) -> None: self.engine.sleep(level) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ca50f08a3804..51a82c415a8a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -955,12 +955,12 @@ def has_unfinished_requests_for_virtual_engine( """ return self.scheduler[virtual_engine].has_unfinished_seqs() - def reset_prefix_cache(self) -> bool: + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: """Reset prefix cache for all devices.""" success = True for scheduler in self.scheduler: - success = success and scheduler.reset_prefix_cache() + success = success and scheduler.reset_prefix_cache(device) return success @staticmethod diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 144dd822a177..fdad53580ee7 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -13,7 +13,7 @@ from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.utils import deprecate_kwargs +from vllm.utils import Device, deprecate_kwargs VLLM_RPC_SUCCESS_STR = "SUCCESS" @@ -123,8 +123,9 @@ class RPCUProfileRequest(Enum): STOP_PROFILE = 2 -class RPCResetPrefixCacheRequest(Enum): - RESET_PREFIX_CACHE = 1 +@dataclass +class RPCResetPrefixCacheRequest: + device: Device class RPCSleepRequest(Enum): diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index e2ae9486e435..db91c5d3564a 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -47,7 +47,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.utils import deprecate_kwargs +from vllm.utils import Device, deprecate_kwargs logger = init_logger(__name__) @@ -684,11 +684,12 @@ async def stop_profile(self) -> None: await self._send_one_way_rpc_request( request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) - async def reset_prefix_cache(self) -> None: + async def reset_prefix_cache(self, + device: Optional[Device] = None) -> None: """Reset the prefix cache""" await self._send_one_way_rpc_request( - request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE, + request=RPCResetPrefixCacheRequest(device), socket=self.input_socket) async def sleep(self, level: int = 1) -> None: diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index f314075b166e..be9f3af0b545 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -18,7 +18,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import collect_from_async_generator, random_uuid +from vllm.utils import Device, collect_from_async_generator, random_uuid logger = init_logger(__name__) @@ -274,7 +274,8 @@ async def stop_profile(self) -> None: ... @abstractmethod - async def reset_prefix_cache(self) -> None: + async def reset_prefix_cache(self, + device: Optional[Device] = None) -> None: """Reset the prefix cache""" ... diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a0e2fa2918bd..84b0093ce4ca 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -42,7 +42,8 @@ get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.usage.usage_lib import UsageContext -from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of +from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs, + is_list_of) logger = init_logger(__name__) @@ -1187,8 +1188,8 @@ def start_profile(self) -> None: def stop_profile(self) -> None: self.llm_engine.stop_profile() - def reset_prefix_cache(self) -> bool: - return self.llm_engine.reset_prefix_cache() + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + return self.llm_engine.reset_prefix_cache(device) def sleep(self, level: int = 1): """ diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index ef557193ae2a..9eb17726b1f0 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -85,7 +85,7 @@ from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.usage.usage_lib import UsageContext -from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, +from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path, is_valid_ipv6_address, set_ulimit) from vllm.version import __version__ as VLLM_VERSION @@ -677,8 +677,12 @@ async def reset_prefix_cache(raw_request: Request): Reset the prefix cache. Note that we currently do not check if the prefix cache is successfully reset in the API server. """ - logger.info("Resetting prefix cache...") - await engine_client(raw_request).reset_prefix_cache() + device = None + device_str = raw_request.query_params.get("device") + if device_str is not None: + device = Device[device_str.upper()] + logger.info("Resetting prefix cache with specific %s...", str(device)) + await engine_client(raw_request).reset_prefix_cache(device) return Response(status_code=200) @router.post("/sleep") diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index d4ac9c066d50..171c1c7da28e 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -24,7 +24,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.utils import cdiv, kill_process_tree +from vllm.utils import Device, cdiv, kill_process_tree from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.parallel_sampling import ParentRequest @@ -398,7 +398,10 @@ async def start_profile(self) -> None: async def stop_profile(self) -> None: await self.engine_core.profile_async(False) - async def reset_prefix_cache(self) -> None: + async def reset_prefix_cache(self, + device: Optional[Device] = None) -> None: + if device == Device.CPU: + raise ValueError("Not supported on CPU.") await self.engine_core.reset_prefix_cache_async() async def sleep(self, level: int = 1) -> None: diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 63b0a8fca32b..14338e5cbe88 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -20,6 +20,7 @@ from vllm.transformers_utils.tokenizer_group import ( BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import UsageContext +from vllm.utils import Device from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.parallel_sampling import ParentRequest @@ -226,7 +227,7 @@ def start_profile(self): def stop_profile(self): self.engine_core.profile(False) - def reset_prefix_cache(self): + def reset_prefix_cache(self, device: Optional[Device] = None): self.engine_core.reset_prefix_cache() def sleep(self, level: int = 1):