Skip to content

Commit

Permalink
[V1] Simplify prefix caching logic by removing `num_evictable_compute…
Browse files Browse the repository at this point in the history
…d_blocks` (vllm-project#11310)
  • Loading branch information
heheda12345 authored and Ubuntu committed Jan 19, 2025
1 parent a7d6962 commit ef5bb50
Showing 1 changed file with 2 additions and 11 deletions.
13 changes: 2 additions & 11 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,32 +201,23 @@ def allocate_slots(
f"num_tokens must be greater than 0, got {num_tokens}")

# Touch the computed blocks to make sure they won't be evicted.
num_evictable_computed_blocks = 0
if self.enable_caching:
self._touch(computed_blocks)

# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
num_evictable_computed_blocks = len(
[blk for blk in computed_blocks if blk.ref_cnt == 0])
else:
assert not computed_blocks, (
"Computed blocks should be empty when "
"prefix caching is disabled")

num_required_blocks = cdiv(num_tokens, self.block_size)
if (num_required_blocks > self.free_block_queue.num_free_blocks -
num_evictable_computed_blocks):
if (num_required_blocks > self.free_block_queue.num_free_blocks):
# Cannot allocate new blocks.
return None

# Determine the number of new blocks to allocate considering
# preallocated blocks.
num_new_blocks = min(
num_required_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks -
num_evictable_computed_blocks,
self.free_block_queue.num_free_blocks,
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
# [..., max_num_blocks_per_req].
Expand Down

0 comments on commit ef5bb50

Please sign in to comment.