Skip to content

Commit fbcfee9

Browse files
authored
Clean unused KVCache after usage (vllm-project#10)
* Add underlying functions * tests done
1 parent a8561b8 commit fbcfee9

File tree

5 files changed

+39
-8
lines changed

5 files changed

+39
-8
lines changed

tests/under_models/send_mock_request.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ async def step_async(self) -> List[RequestOutput]:
4242
blocks_to_swap_in={},
4343
blocks_to_swap_out={},
4444
blocks_to_copy={},
45+
finished_seqs=[],
4546
)
4647
print(output)
4748

@@ -68,6 +69,7 @@ async def step_async_multiple(self) -> List[RequestOutput]:
6869
blocks_to_swap_in={},
6970
blocks_to_swap_out={},
7071
blocks_to_copy={},
72+
finished_seqs=[],
7173
)
7274

7375
# TODO: change this to real one

vllm/core/scheduler.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
blocks_to_swap_out: Dict[int, int],
3737
blocks_to_copy: Dict[int, List[int]],
3838
ignored_seq_groups: List[SequenceGroup],
39+
finished_seqs: List[int],
3940
) -> None:
4041
self.scheduled_seq_groups = scheduled_seq_groups
4142
self.prompt_run = prompt_run
@@ -46,11 +47,13 @@ def __init__(
4647
# Swap in and swap out should never happen at the same time.
4748
assert not (blocks_to_swap_in and blocks_to_swap_out)
4849
self.ignored_seq_groups = ignored_seq_groups
50+
self.finished_seqs = finished_seqs
4951

5052
def is_empty(self) -> bool:
5153
# NOTE: We do not consider the ignored sequence groups.
5254
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
53-
and not self.blocks_to_swap_out and not self.blocks_to_copy)
55+
and not self.blocks_to_swap_out and not self.blocks_to_copy
56+
and not self.finished_seqs)
5457

5558

5659
class Scheduler:
@@ -417,6 +420,7 @@ def __init__(
417420
self.waiting: List[SequenceGroup] = []
418421
# Sequence groups in the RUNNING state.
419422
self.running: List[SequenceGroup] = []
423+
self.cleaned: List[int] = []
420424

421425
def add_seq_group(self, seq_group: SequenceGroup) -> None:
422426
# Add sequence groups to the waiting queue.
@@ -456,6 +460,8 @@ def _schedule(self) -> SchedulerOutputs:
456460

457461
ignored_seq_groups: List[SequenceGroup] = []
458462
scheduled: List[SequenceGroup] = []
463+
finished_seqs: List[int] = self.cleaned.copy()
464+
self.cleaned=[]
459465
# The total number of sequences on the fly, including the
460466
# requests in the generation phase.
461467
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
@@ -518,6 +524,7 @@ def _schedule(self) -> SchedulerOutputs:
518524
blocks_to_swap_out={},
519525
blocks_to_copy={},
520526
ignored_seq_groups=ignored_seq_groups,
527+
finished_seqs=finished_seqs,
521528
)
522529
return scheduler_outputs
523530

@@ -539,6 +546,7 @@ def _schedule(self) -> SchedulerOutputs:
539546
blocks_to_swap_out={},
540547
blocks_to_copy={},
541548
ignored_seq_groups=[],
549+
finished_seqs=finished_seqs,
542550
)
543551
return scheduler_outputs
544552

@@ -576,7 +584,8 @@ def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
576584
self.block_manager.fork(parent_seq, child_seq)
577585

578586
def free_seq(self, seq: Sequence) -> None:
579-
self.block_manager.free(seq)
587+
#self.block_manager.free(seq)
588+
self.cleaned.append(seq.seq_id)
580589

581590
def free_finished_seq_groups(self) -> None:
582591
for seq_group in self.running:

vllm/engine/async_llm_engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,15 @@ async def step_async(self) -> List[RequestOutput]:
192192
return ignored
193193

194194
# Execute the model.
195+
# Co(gc): Now that we do not have page table support, we need to pass the
196+
# list of sequences that have been finished so that we can clean the KVCache.
195197
output = await self._run_workers_async(
196198
"execute_model",
197199
seq_group_metadata_list=seq_group_metadata_list,
198200
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
199201
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
200202
blocks_to_copy=scheduler_outputs.blocks_to_copy,
203+
finished_seqs=scheduler_outputs.finished_seqs,
201204
)
202205
print("We finished model_execution")
203206
return self._process_model_outputs(output, scheduler_outputs) + ignored

vllm/engine/llm_engine.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -384,8 +384,7 @@ def _process_sequence_group_samples(
384384
# not be used in the future iterations.
385385
parent.status = SequenceStatus.FINISHED_ABORTED
386386
seq_group.remove(parent.seq_id)
387-
# TODO(gc): Should we do anything special in this case?
388-
# self.scheduler.free_seq(parent)
387+
self.scheduler.free_seq(parent)
389388
continue
390389
# Fork the parent sequence if there are multiple child samples.
391390
# The outputs diverges, we need to fork the requests
@@ -425,7 +424,7 @@ def _process_sequence_group_samples(
425424
# old sequences.
426425
for seq, parent in child_seqs:
427426
if seq is parent and seq.is_finished():
428-
#self.scheduler.free_seq(seq)
427+
self.scheduler.free_seq(seq)
429428
pass
430429
return
431430

@@ -523,8 +522,7 @@ def _process_sequence_group_samples(
523522
# manager. Keep them in the sequence group as candidate output.
524523
for seq, parent in selected_child_seqs:
525524
if seq is parent and seq.is_finished():
526-
#self.scheduler.free_seq(seq)
527-
pass
525+
self.scheduler.free_seq(seq)
528526

529527
# Remove the unselected parent sequences from the sequence group and
530528
# free their memory in block manager.
@@ -533,7 +531,7 @@ def _process_sequence_group_samples(
533531
# Remove the parent sequence if it is not selected for next
534532
# iteration
535533
seq_group.remove(seq.seq_id)
536-
#self.scheduler.free_seq(seq)
534+
self.scheduler.free_seq(seq)
537535

538536
def _process_model_outputs(
539537
self, output: SamplerOutput,

vllm/worker/worker.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,21 @@ def __init__(
5050

5151
self.kv_cache = dict()
5252

53+
def clean_finished_seqs(
54+
self,
55+
finished_seqs: List[int]
56+
):
57+
"""
58+
This function cleans the finished sequences and their KVCache in self.kv_cache
59+
"""
60+
for seq_id in finished_seqs:
61+
if seq_id not in self.kv_cache.keys():
62+
raise ValueError(
63+
f"Duplicate key {seq_id} received during clean worker's KVCache"
64+
)
65+
del self.kv_cache[seq_id]
66+
67+
5368
def init_model(self):
5469
# This env var set by Ray causes exceptions with graph building.
5570
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
@@ -293,6 +308,7 @@ def execute_model(
293308
blocks_to_swap_in: Dict[int, int],
294309
blocks_to_swap_out: Dict[int, int],
295310
blocks_to_copy: Dict[int, List[int]],
311+
finished_seqs: List[int],
296312
) -> SamplerOutput:
297313
# Issue cache operations.
298314
# issued_cache_op = False
@@ -310,6 +326,9 @@ def execute_model(
310326
# cache_events = self.cache_events
311327
# else:
312328
# cache_events = None
329+
if finished_seqs:
330+
self.clean_finished_seqs(finished_seqs)
331+
313332
cache_events = None
314333
# If there is no input, we don't need to execute the model.
315334
if not seq_group_metadata_list:

0 commit comments

Comments
 (0)