Skip to content

Commit 26dd972

Browse files
authored
[FEAT]Support reset prefix cache by specified device (#15003)
1 parent 61c7a1b commit 26dd972

File tree

15 files changed

+49
-34
lines changed

15 files changed

+49
-34
lines changed

vllm/core/block/cpu_gpu_block_allocator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,10 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float:
341341
assert device in self._allocators
342342
return self._allocators[device].get_prefix_cache_hit_rate()
343343

344-
def reset_prefix_cache(self) -> bool:
345-
"""Reset prefix cache for all devices."""
344+
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
345+
"""Reset prefix cache for specified or all devices."""
346+
if device:
347+
return self._allocators[device].reset_prefix_cache()
346348
success = True
347349
for allocator in self._allocators.values():
348350
success = success and allocator.reset_prefix_cache()

vllm/core/block/interfaces.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float:
305305
pass
306306

307307
@abstractmethod
308-
def reset_prefix_cache(self) -> bool:
308+
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
309309
"""Reset prefix cache."""
310310
pass
311311

vllm/core/block_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,8 +456,8 @@ def get_num_free_cpu_blocks(self) -> int:
456456
def get_prefix_cache_hit_rate(self, device: Device) -> float:
457457
return self.block_allocator.get_prefix_cache_hit_rate(device)
458458

459-
def reset_prefix_cache(self) -> bool:
460-
return self.block_allocator.reset_prefix_cache()
459+
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
460+
return self.block_allocator.reset_prefix_cache(device)
461461

462462
def _can_swap(self,
463463
seq_group: SequenceGroup,

vllm/core/interfaces.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import enum
44
from abc import ABC, abstractmethod
5-
from typing import List
5+
from typing import List, Optional
66
from typing import Sequence as GenericSequence
77
from typing import Tuple
88

@@ -125,8 +125,8 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float:
125125
pass
126126

127127
@abstractmethod
128-
def reset_prefix_cache(self) -> bool:
129-
"""Reset prefix cache for all devices."""
128+
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
129+
"""Reset prefix cache for specified or all devices."""
130130
pass
131131

132132
@abstractmethod

vllm/core/placeholder_block_space_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import List, Tuple
3+
from typing import List, Optional, Tuple
44

55
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
66
from vllm.sequence import Sequence, SequenceGroup
@@ -92,7 +92,7 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup,
9292
def get_prefix_cache_hit_rate(self, device: Device) -> float:
9393
return -1
9494

95-
def reset_prefix_cache(self) -> bool:
95+
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
9696
return True
9797

9898
def get_num_cached_tokens(self, seq: Sequence) -> int:

vllm/core/scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -634,8 +634,8 @@ def has_unfinished_seqs(self) -> bool:
634634
def get_prefix_cache_hit_rate(self, device: Device) -> float:
635635
return self.block_manager.get_prefix_cache_hit_rate(device)
636636

637-
def reset_prefix_cache(self) -> bool:
638-
return self.block_manager.reset_prefix_cache()
637+
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
638+
return self.block_manager.reset_prefix_cache(device)
639639

640640
def get_num_unfinished_seq_groups(self) -> int:
641641
return len(self.waiting) + len(self.running) + len(self.swapped)

vllm/engine/async_llm_engine.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from vllm.sequence import ExecuteModelRequest
3636
from vllm.transformers_utils.tokenizer import AnyTokenizer
3737
from vllm.usage.usage_lib import UsageContext
38-
from vllm.utils import deprecate_kwargs, weak_bind
38+
from vllm.utils import Device, deprecate_kwargs, weak_bind
3939

4040
logger = init_logger(__name__)
4141
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@@ -1216,8 +1216,9 @@ async def start_profile(self) -> None:
12161216
async def stop_profile(self) -> None:
12171217
self.engine.stop_profile()
12181218

1219-
async def reset_prefix_cache(self) -> None:
1220-
self.engine.reset_prefix_cache()
1219+
async def reset_prefix_cache(self,
1220+
device: Optional[Device] = None) -> None:
1221+
self.engine.reset_prefix_cache(device)
12211222

12221223
async def sleep(self, level: int = 1) -> None:
12231224
self.engine.sleep(level)

vllm/engine/llm_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -955,12 +955,12 @@ def has_unfinished_requests_for_virtual_engine(
955955
"""
956956
return self.scheduler[virtual_engine].has_unfinished_seqs()
957957

958-
def reset_prefix_cache(self) -> bool:
958+
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
959959
"""Reset prefix cache for all devices."""
960960

961961
success = True
962962
for scheduler in self.scheduler:
963-
success = success and scheduler.reset_prefix_cache()
963+
success = success and scheduler.reset_prefix_cache(device)
964964
return success
965965

966966
@staticmethod

vllm/engine/multiprocessing/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from vllm.outputs import RequestOutput
1414
from vllm.prompt_adapter.request import PromptAdapterRequest
1515
from vllm.sampling_params import SamplingParams
16-
from vllm.utils import deprecate_kwargs
16+
from vllm.utils import Device, deprecate_kwargs
1717

1818
VLLM_RPC_SUCCESS_STR = "SUCCESS"
1919

@@ -123,8 +123,9 @@ class RPCUProfileRequest(Enum):
123123
STOP_PROFILE = 2
124124

125125

126-
class RPCResetPrefixCacheRequest(Enum):
127-
RESET_PREFIX_CACHE = 1
126+
@dataclass
127+
class RPCResetPrefixCacheRequest:
128+
device: Device
128129

129130

130131
class RPCSleepRequest(Enum):

vllm/engine/multiprocessing/client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from vllm.prompt_adapter.request import PromptAdapterRequest
4848
from vllm.sampling_params import SamplingParams
4949
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
50-
from vllm.utils import deprecate_kwargs
50+
from vllm.utils import Device, deprecate_kwargs
5151

5252
logger = init_logger(__name__)
5353

@@ -684,11 +684,12 @@ async def stop_profile(self) -> None:
684684
await self._send_one_way_rpc_request(
685685
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
686686

687-
async def reset_prefix_cache(self) -> None:
687+
async def reset_prefix_cache(self,
688+
device: Optional[Device] = None) -> None:
688689
"""Reset the prefix cache"""
689690

690691
await self._send_one_way_rpc_request(
691-
request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE,
692+
request=RPCResetPrefixCacheRequest(device),
692693
socket=self.input_socket)
693694

694695
async def sleep(self, level: int = 1) -> None:

0 commit comments

Comments
 (0)