Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Automatic Prefix Caching #2762

Merged
merged 88 commits into from
Mar 2, 2024
Merged
Show file tree
Hide file tree
Changes from 62 commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
d1a91aa
init
SageMoore Feb 2, 2024
f27bfc8
Merge branch 'upstream-main' into prefix-caching
SageMoore Feb 5, 2024
ec21130
Move evictor and eviction policy to a separate class
ElizaWszola Feb 5, 2024
73ab52c
format, replace match with if-else
ElizaWszola Feb 5, 2024
76b5290
shore up some of the eviction logic
SageMoore Feb 5, 2024
fb9132b
autoformat
SageMoore Feb 5, 2024
c84bbda
Test block hashing
ElizaWszola Feb 6, 2024
be146c0
Format
ElizaWszola Feb 6, 2024
063d2fb
added block allocator tests
SageMoore Feb 6, 2024
15099d2
added timestamps to the PhysicalTokenBlock and updated the eviction l…
SageMoore Feb 7, 2024
9411e06
Delete the free hash table from the evictor class
SageMoore Feb 7, 2024
359b829
Remove the evictor class in favor of eviction free functions
SageMoore Feb 7, 2024
c9b0be6
debugging in progress
ElizaWszola Feb 8, 2024
cc80d1b
Merge branch 'prefix-caching' of https://github.com/neuralmagic/neura…
ElizaWszola Feb 8, 2024
6218d1a
partial block support
ElizaWszola Feb 8, 2024
b35819d
Move PhysicalTokenBlock.last_accessed updates to the block_manager/sc…
SageMoore Feb 8, 2024
38c1fc6
Remove overly aggressive assert
SageMoore Feb 8, 2024
b3e73f5
minor refactoring
SageMoore Feb 8, 2024
48624d9
Add prefix len to eviction strategy
ElizaWszola Feb 9, 2024
9780ccd
Merge branch 'prefix-caching' of https://github.com/neuralmagic/neura…
ElizaWszola Feb 9, 2024
bb471f2
fixed a few bugs in the partial block management code
SageMoore Feb 9, 2024
5d5db12
auto format
SageMoore Feb 9, 2024
ffbddd9
fix fork/cow mechanisms so that they work with partial blocks
SageMoore Feb 9, 2024
1f7fe42
replace the partial block table with a simpler promotion mechanism
SageMoore Feb 12, 2024
7ab75d7
clean up the BlockSpaceManager a bit
SageMoore Feb 12, 2024
ca3e288
fix minor typos
SageMoore Feb 12, 2024
ecf389d
minor name change
SageMoore Feb 12, 2024
427566a
update assert
SageMoore Feb 12, 2024
a3431bb
fix swap_in and swap_out
ElizaWszola Feb 12, 2024
dedc9c0
remove dead code in BlockSpaceManager
SageMoore Feb 12, 2024
86299a4
refactor swap_in/swap_out in BlockSpaceManager
SageMoore Feb 12, 2024
614a197
Update the partial block promotion logic to account for the full vers…
SageMoore Feb 12, 2024
0f85474
remove min from sequence hash
ElizaWszola Feb 13, 2024
9672b20
Remove prefix.py
SageMoore Feb 13, 2024
6044c2b
misc formatting
SageMoore Feb 13, 2024
9f7ae9f
bring back free table
ElizaWszola Feb 13, 2024
6655130
format
ElizaWszola Feb 13, 2024
1d6f0a0
update get_num_free_blocks to account for blocks in free table
SageMoore Feb 13, 2024
0ca5c43
add some more asserts to BlockAllocator
SageMoore Feb 13, 2024
7d6444d
contains_block() now looks at both table and free_table + a couple as…
ElizaWszola Feb 14, 2024
4775423
updated semantics of prefix length in block
SageMoore Feb 14, 2024
5cfee5f
bring back prefix block tables
ElizaWszola Feb 15, 2024
c64dd0d
Merge branch 'prefix-caching' of https://github.com/neuralmagic/neura…
ElizaWszola Feb 15, 2024
4925071
Nits (style)
ElizaWszola Feb 15, 2024
4fba5f9
delete comment
ElizaWszola Feb 15, 2024
46c62e4
Added computed_block_nums
SageMoore Feb 15, 2024
38b34d8
pythonize get_all_computed_block_ids
ElizaWszola Feb 15, 2024
ba97f80
account for prefix_len=0 in _prepare_prompt
SageMoore Feb 15, 2024
fe37722
attempt to fix build
SageMoore Feb 15, 2024
28f4ad2
attempt to fix build
SageMoore Feb 15, 2024
bff30a7
cap computed blocks to prefix length
SageMoore Feb 15, 2024
e829c34
misc fixes
SageMoore Feb 15, 2024
7f78ad4
typo
SageMoore Feb 15, 2024
18da5e6
account for none
SageMoore Feb 15, 2024
49357be
block manager refactoring
SageMoore Feb 15, 2024
ea4ec9d
clamp prefix length down to a multiple of block size
SageMoore Feb 15, 2024
f5fa2de
minor prefix length fix
SageMoore Feb 15, 2024
704aa47
replace 16 with block size
SageMoore Feb 16, 2024
8771b3f
First round of feedback changes
ElizaWszola Feb 21, 2024
b2c5992
Merge branch 'upstream-main' into prefix-caching
ElizaWszola Feb 21, 2024
2dba195
added a flag to disable automatic prefix caching
SageMoore Feb 21, 2024
ba01fa8
Update vllm/engine/llm_engine.py
mgoin Feb 21, 2024
16f9e80
Merge branch 'upstream-main' into prefix-caching
ElizaWszola Feb 23, 2024
2914b5a
remove explicit prefix pos
ElizaWszola Feb 23, 2024
bd235fd
remove assert for sliding window, check what will happen
ElizaWszola Feb 23, 2024
ba382d9
Try the other way around
ElizaWszola Feb 23, 2024
660007f
Delete redundant prefix caching test
ElizaWszola Feb 23, 2024
f74f67d
Don't add last block to
ElizaWszola Feb 23, 2024
093cb1c
Format
ElizaWszola Feb 23, 2024
d459d15
refactored the eviction logic into a separate class
SageMoore Feb 23, 2024
fea6789
minor fixes
SageMoore Feb 23, 2024
052c294
format evictor file
SageMoore Feb 23, 2024
e26cd8e
added documentation to the evictor class
SageMoore Feb 23, 2024
2335360
delete newline
SageMoore Feb 23, 2024
d66154c
format
SageMoore Feb 23, 2024
6a38439
Fix timestamp in eviction policy
ElizaWszola Feb 28, 2024
47e94ba
Merge branch 'upstream-main' into prefix-caching
ElizaWszola Feb 28, 2024
a449eb6
addressing review comments
SageMoore Feb 29, 2024
30708b8
minor evictor fix
SageMoore Feb 29, 2024
4e99660
format
SageMoore Feb 29, 2024
5b4413b
More protection against sliding window
ElizaWszola Feb 29, 2024
63a1985
Merge branch 'prefix-caching' of https://github.com/neuralmagic/neura…
ElizaWszola Feb 29, 2024
7d17304
Change automatic prefix caching arg to enable in arg utils
ElizaWszola Feb 29, 2024
6358bf0
fix minor BlockAllocator update_hash bug
SageMoore Feb 29, 2024
b9fbb66
fix test_prefix_caching test
SageMoore Feb 29, 2024
4ce8ceb
fix minor perf regression
SageMoore Mar 1, 2024
11126ab
Only mark last prefix block as computed, assume no computed blocks wi…
ElizaWszola Mar 1, 2024
e252bb6
Merge branch 'main' into prefix-caching
SageMoore Mar 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/models/engine_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ Below, you can find an explanation of every engine argument for vLLM:

Token block size for contiguous chunks of tokens.

.. option:: --disable-prefix-caching
SageMoore marked this conversation as resolved.
Show resolved Hide resolved

Disables automatic prefix caching

.. option:: --seed <seed>

Random seed for operations.
Expand Down
73 changes: 72 additions & 1 deletion tests/prefix_caching/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import pytest

from vllm import LLM, SamplingParams
from vllm.core.block_manager import BlockAllocator
from vllm.utils import Device

prefix = (
"You are an expert school principal, skilled in effectively managing "
Expand All @@ -18,6 +20,14 @@
"the following paragraph: ")


def allocate_all_blocks(block_allocator, num_blocks):
blocks = []
for i in range(num_blocks):
# use i as the block_hash
blocks.append(block_allocator.allocate(i, 0))
return blocks


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("max_tokens", [16])
def test_prefix_caching(
Expand All @@ -38,4 +48,65 @@ def test_prefix_caching(
outputs_without_prefix, outputs_with_prefix):
assert (output_without_prefix.outputs[0].token_ids ==
output_with_prefix.outputs[0].token_ids)
assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 1


@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [16])
def test_block_allocator(
block_size: int,
num_blocks: int,
):
block_hash = 1
block_allocator = BlockAllocator(Device.CPU, block_size, num_blocks)

# Allocate two PysicalTokenBlocks with the same hash and check that they are the same PhysicalTokenBlock
first_block = block_allocator.allocate(block_hash, 0)
second_block = block_allocator.allocate(block_hash, 0)
assert (first_block == second_block)
assert (second_block.ref_count == 2)

# Free the first_block and confirm that the ref_count is correctly decremented on the second block
block_allocator.free(first_block)
assert (second_block.ref_count == 1)

# Free the second block
block_allocator.free(second_block)

# Reallocate the first block and confirm that, even after the block had its ref_count go to 0, we still get the same block back
first_block = block_allocator.allocate(block_hash, 0)
assert (first_block == second_block)
assert (first_block.block_hash == block_hash)


@pytest.mark.parametrize("num_blocks", [16])
def test_eviction(num_blocks: int, ):
block_size = 16
block_allocator = BlockAllocator(Device.CPU, block_size, num_blocks)
blocks = []

for i in range(num_blocks):
# use i as the block_hash
blocks.append(block_allocator.allocate(i, 0))

#Free all blocks
for block in blocks:
block_allocator.free(block)

# Allocate a new block and confirm that it's the first block freed. I.E The Least Recently Used block
new_block_hash = block_size
new_block = block_allocator.allocate(new_block_hash, 0)
assert (new_block == blocks[0])
assert (new_block.block_hash == new_block_hash)

# Reallocate the second in blocks to remove it from the free list
realloc_block_hash = 1
realloc_block = block_allocator.allocate(realloc_block_hash, 0)
assert (realloc_block == blocks[realloc_block_hash])
assert (realloc_block.block_hash == realloc_block_hash)

# Allocate a new block and confirm that it's not the realloc_block, since the realloc_block shouldn't be in the free list
new_block_hash = block_size + 1
new_block = block_allocator.allocate(new_block_hash, 0)
assert (realloc_block != new_block)
assert (new_block.block_hash == new_block_hash)
assert (new_block.block_number == 2)
76 changes: 76 additions & 0 deletions tests/test_cache_block_hashing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Test hashing of cache blocks.

Run `pytest tests/test_cache_block_hashing.py`.
"""
import pytest

from vllm.transformers_utils.tokenizer import TokenizerGroup
from vllm.sequence import Sequence

# Make two prefixes with different first blocks.
prefix_start = [("You are an expert"), ("You are a")]
prefix_common = (
" school principal, skilled in effectively managing "
"faculty and staff. Draft 10-15 questions for a potential first grade "
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
"community, joyful discovery, and life-long learning. The candidate is "
"coming in for a first-round panel interview for a 8th grade Math "
"teaching role. They have 5 years of previous teaching experience "
"as an assistant teacher at a co-ed, public school with experience "
"in middle school math teaching. Based on this, fulfill "
"the following: ")
prefixes = [start + prefix_common for start in prefix_start]

# Sample prompts.
sample_prompts = [
"Hello, my name is", "The president of the United States is",
"The capital of France is", "The future of AI is"
]


# Helper function.
def flatten_2d(li):
return [lss for ls in li for lss in ls]


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("max_num_seqs", [256])
def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int):

tokenizer = TokenizerGroup(
tokenizer_id="facebook/opt-125m",
enable_lora=False,
max_num_seqs=max_num_seqs,
max_input_length=None,
)

hashes = []

for prefix in prefixes:
hashes.append([])
prompts = [prefix + prompt for prompt in sample_prompts]
seq_id = 0
for prompt in prompts:
hashes[-1].append([])
prompt_token_ids = tokenizer.encode(prompt)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)

num_blocks = len(prompt_token_ids) // block_size
for idx in range(num_blocks):
hashes[-1][-1].append(seq.hash(idx))

seq_id += 1

# Check that hashes made with two prefixes with different first blocks are
# different everywhere.
for hash0, hash1 in zip(flatten_2d(hashes[0]), flatten_2d(hashes[1])):
assert (hash0 != hash1)

# Check that hashes of different prompts made with the same prefix are the
# same until the hashes that contain the prompt.
for hash_pref in hashes:
same_hashes = [tuple(h[:-1]) for h in hash_pref]
different_hashes = [h[-1] for h in hash_pref]
assert (len(set(same_hashes)) == 1)
assert (len(set(different_hashes)) == len(different_hashes))
14 changes: 13 additions & 1 deletion vllm/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List

from vllm.utils import Device
from time import monotonic
SageMoore marked this conversation as resolved.
Show resolved Hide resolved

_BLANK_TOKEN_ID = -1

Expand Down Expand Up @@ -55,17 +56,28 @@ def __init__(
device: Device,
block_number: int,
block_size: int,
block_hash: int,
prefix_len: int,
) -> None:
self.device = device
self.block_number = block_number
self.block_size = block_size
self.block_hash = block_hash
self.prefix_len = prefix_len

self.ref_count = 0
self.last_accessed = monotonic()
SageMoore marked this conversation as resolved.
Show resolved Hide resolved

self.computed = False

# TODO: update this
def __repr__(self) -> str:
return (f'PhysicalTokenBlock(device={self.device}, '
f'block_number={self.block_number}, '
f'ref_count={self.ref_count})')
f'prefix_len={self.prefix_len}, '
f'ref_count={self.ref_count}, '
f'last_accessed={self.last_accessed}, '
f'computed={self.computed})')


# Mapping: logical block number -> physical block.
Expand Down
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,14 @@ def __init__(
swap_space: int,
cache_dtype: str,
sliding_window: Optional[int] = None,
disable_prefix_caching: bool = False,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GB
self.cache_dtype = cache_dtype
self.sliding_window = sliding_window
self.disable_prefix_caching = disable_prefix_caching
self._verify_args()
self._verify_cache_dtype()

Expand Down
Loading