Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Prefix caching for multimodal language models #11187

Merged
merged 10 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 86 additions & 2 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@
import pytest

from vllm.inputs import token_inputs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens


def make_request(request_id, prompt_token_ids):
def make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None):
return Request(
request_id=request_id,
inputs=token_inputs(prompt_token_ids=prompt_token_ids),
inputs=token_inputs(prompt_token_ids=prompt_token_ids,
multi_modal_placeholders={"image": mm_positions}
if mm_positions else None,
multi_modal_hashes=mm_hashes),
sampling_params=SamplingParams(max_tokens=17),
eos_token_id=100,
arrival_time=0,
Expand All @@ -38,6 +45,7 @@ def test_prefill():
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
computed_blocks = manager.get_computed_blocks(req0)
assert len(req0.kv_block_hashes) == 3
assert not computed_blocks
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
Expand All @@ -61,6 +69,7 @@ def test_prefill():
unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks = manager.get_computed_blocks(req1)
assert len(req1.kv_block_hashes) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
Expand Down Expand Up @@ -90,6 +99,7 @@ def test_prefill():
unique_token_ids = [3] * 6
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_block = manager.get_computed_blocks(req2)
assert len(req2.kv_block_hashes) == 3
assert [b.block_id for b in computed_block] == [0, 1, 2]
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
Expand Down Expand Up @@ -416,3 +426,77 @@ def test_cache_blocks():
)
assert len(manager.cached_block_hash_to_block) == 3
assert blocks[0].block_hash is not None


def test_mm_prefix_caching():
"""
This tests that the multi-modal prefix caching is correct.
"""
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)

# Common prompt tokens (T is text tokens and P is image placeholder tokens)
# [T,...,T, P0,...,P0], [P0,...,P0,T,...,T,P1,...,P1], [P1,...,P1]
common_token_ids = list(range(10)) + [-1] * 6
common_token_ids += [-1] * 4 + list(range(10, 20)) + [-1] * 2
common_token_ids += [-1] * 16

common_mm_positions = [
PlaceholderRange(offset=11, length=10),
PlaceholderRange(offset=30, length=18),
]
common_mm_hashes = ["aaa", "bbb"]

# A unique image plus some text tokens.
unique_token_ids = [-1] * 7 + [100] * 4
all_token_ids = common_token_ids + unique_token_ids
mm_positions = common_mm_positions + [
PlaceholderRange(offset=48, length=7)
]
mm_hashes = common_mm_hashes + ["ccc"]
req0 = make_request("0",
all_token_ids,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks = manager.get_computed_blocks(req0)

# Completed block should have hashes with extra keys.
assert not computed_blocks
assert len(req0.kv_block_hashes) == 3
assert req0.kv_block_hashes[0].extra_keys == (("aaa", 0), )
assert req0.kv_block_hashes[1].extra_keys == (("aaa", 5), ("bbb", 0))
assert req0.kv_block_hashes[2].extra_keys == (("bbb", 2), )

blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
req0.num_computed_tokens = 59

# Append slots without allocating a new block.
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.append_slots(req0, 5)
assert new_blocks is not None and len(new_blocks) == 0

# The just completed block should have hashes with extra keys.
assert len(req0.kv_block_hashes) == 4
assert req0.kv_block_hashes[3].extra_keys == (("ccc", 0), )

# Cache hit.
unique_token_ids = [-1] * 7 + [200] * 5
all_token_ids = common_token_ids + unique_token_ids
mm_positions = common_mm_positions + [
PlaceholderRange(offset=48, length=7)
]
mm_hashes = common_mm_hashes + ["ccc"]
req1 = make_request("1",
all_token_ids,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks = manager.get_computed_blocks(req1)
assert len(computed_blocks) == 3
15 changes: 0 additions & 15 deletions tests/v1/engine/test_engine_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,6 @@ def test_prefix_caching_from_cli():
assert engine_args.enable_prefix_caching


def test_defaults():
engine_args = EngineArgs(model="facebook/opt-125m")

# Assert V1 defaults
assert (engine_args.enable_prefix_caching
), "V1 turns on prefix caching by default"


def test_defaults_with_usage_context():
engine_args = EngineArgs(model="facebook/opt-125m")
vllm_config: VllmConfig = engine_args.create_engine_config(
Expand All @@ -52,10 +44,3 @@ def test_defaults_with_usage_context():
UsageContext.OPENAI_API_SERVER)
assert vllm_config.scheduler_config.max_num_seqs == 1024
assert vllm_config.scheduler_config.max_num_batched_tokens == 2048


def test_prefix_cache_disabled_with_multimodel():
engine_args = EngineArgs(model="llava-hf/llava-1.5-7b-hf")

vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS)
assert not vllm_config.cache_config.enable_prefix_caching
14 changes: 6 additions & 8 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def __post_init__(self):
# by user.
if self.enable_prefix_caching is None:
self.enable_prefix_caching = bool(envs.VLLM_USE_V1)

# Override max_num_seqs if it's not set by user.
if self.max_num_seqs is None:
self.max_num_seqs = 256 if not envs.VLLM_USE_V1 else 1024
Expand Down Expand Up @@ -1027,11 +1028,11 @@ def create_engine_config(self,
device_config = DeviceConfig(device=self.device)
model_config = self.create_model_config()

if model_config.is_multimodal_model:
if self.enable_prefix_caching:
logger.warning(
"--enable-prefix-caching is currently not "
"supported for multimodal models and has been disabled.")
if (model_config.is_multimodal_model and not envs.VLLM_USE_V1
and self.enable_prefix_caching):
logger.warning("--enable-prefix-caching is currently not "
"supported for multimodal models in v0 and "
"has been disabled.")
self.enable_prefix_caching = False

cache_config = CacheConfig(
Expand Down Expand Up @@ -1266,9 +1267,6 @@ def _override_v1_engine_config(self, engine_config: VllmConfig) -> None:
Override the EngineConfig's configs based on the usage context for V1.
"""
assert envs.VLLM_USE_V1, "V1 is not enabled"
if engine_config.model_config.is_multimodal_model:
# TODO (ywang96): Enable APC by default when VLM supports it.
assert not engine_config.cache_config.enable_prefix_caching
comaniac marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
Expand Down
20 changes: 20 additions & 0 deletions vllm/inputs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ class TokenInputs(TypedDict):
Placeholder ranges for the multi-modal data.
"""

multi_modal_hashes: NotRequired[List[str]]
"""
The hashes of the multi-modal data.
"""

mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the
Expand All @@ -177,6 +182,7 @@ def token_inputs(
prompt: Optional[str] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
multi_modal_inputs: Optional["MultiModalKwargs"] = None,
multi_modal_hashes: Optional[List[str]] = None,
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> TokenInputs:
Expand All @@ -191,6 +197,8 @@ def token_inputs(
inputs["multi_modal_data"] = multi_modal_data
if multi_modal_inputs is not None:
inputs["multi_modal_inputs"] = multi_modal_inputs
if multi_modal_hashes is not None:
inputs["multi_modal_hashes"] = multi_modal_hashes
if multi_modal_placeholders is not None:
inputs["multi_modal_placeholders"] = multi_modal_placeholders
if mm_processor_kwargs is not None:
Expand Down Expand Up @@ -295,6 +303,18 @@ def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]:

assert_never(inputs)

@cached_property
def multi_modal_hashes(self) -> List[str]:
inputs = self.inputs

if inputs["type"] == "token":
return inputs.get("multi_modal_hashes", [])

if inputs["type"] == "multimodal":
return inputs.get("mm_hashes", [])
comaniac marked this conversation as resolved.
Show resolved Hide resolved

assert_never(inputs)

@cached_property
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
inputs = self.inputs
Expand Down
3 changes: 3 additions & 0 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ class MultiModalInputsV2(TypedDict):
mm_kwargs: MultiModalKwargs
"""Keyword arguments to be directly passed to the model after batching."""

mm_hashes: NotRequired[List[str]]
"""The hashes of the multi-modal data."""

mm_placeholders: MultiModalPlaceholderDict
"""
For each modality, information about the placeholder tokens in
Expand Down
74 changes: 50 additions & 24 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock, hash_block_tokens,
KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens,
hash_request_tokens)
from vllm.v1.request import Request

Expand Down Expand Up @@ -83,10 +85,12 @@ def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]:

computed_blocks = []

# TODO(rickyx): potentially we could cache this so we don't have to
# recompute it every time.
block_hashes = hash_request_tokens(self.block_size,
request.all_token_ids)
# The block hashes for the request may already be computed
# if the request was preempted and resumed.
if not request.kv_block_hashes:
comaniac marked this conversation as resolved.
Show resolved Hide resolved
request.set_kv_block_hashes(
hash_request_tokens(self.block_size, request))
block_hashes = request.kv_block_hashes

for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not
Expand Down Expand Up @@ -242,14 +246,16 @@ def allocate_slots(
num_computed_tokens = len(computed_blocks) * self.block_size
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size

self._cache_full_blocks(
request=request,
blk_start_idx=len(computed_blocks),
# The new full blocks are the full blocks that are not computed.
full_blocks=self.req_to_blocks[request.request_id]
[len(computed_blocks):num_full_blocks],
prev_block=computed_blocks[-1] if computed_blocks else None,
)
new_full_blocks = self.req_to_blocks[
request.request_id][len(computed_blocks):num_full_blocks]
if new_full_blocks:
self._cache_full_blocks(
request=request,
blk_start_idx=len(computed_blocks),
# The new full blocks are the full blocks that are not computed.
full_blocks=new_full_blocks,
prev_block=computed_blocks[-1] if computed_blocks else None,
)

return new_blocks

Expand Down Expand Up @@ -376,6 +382,8 @@ def _cache_full_blocks(
full_blocks: The list of blocks to update hash metadata.
prev_block: The previous block in the chain.
"""
num_cached_block_hashes = len(request.kv_block_hashes)

# Update the new blocks with the block hashes through the chain.
prev_block_hash_value = None
if prev_block is not None:
Expand All @@ -387,17 +395,35 @@ def _cache_full_blocks(
for i, blk in enumerate(full_blocks):
blk_idx = blk_start_idx + i

block_tokens = request.all_token_ids[blk_idx *
self.block_size:(blk_idx +
1) *
self.block_size]
assert len(block_tokens) == self.block_size, (
f"Expected {self.block_size} tokens, got {len(block_tokens)} "
f"at {blk_idx}th block for request "
f"{request.request_id}({request})")

# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value, block_tokens)
if blk_idx < num_cached_block_hashes:
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# generated tokens with preemption). In this case we simply
# reuse the block hash.
block_hash = request.kv_block_hashes[blk_idx]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
start_token_idx = blk_idx * self.block_size
end_token_idx = (blk_idx + 1) * self.block_size
block_tokens = request.all_token_ids[
start_token_idx:end_token_idx]
assert len(block_tokens) == self.block_size, (
f"Expected {self.block_size} tokens, got "
f"{len(block_tokens)} at {blk_idx}th block for request "
f"{request.request_id}({request})")
comaniac marked this conversation as resolved.
Show resolved Hide resolved

# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys, _ = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, -1)

# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value,
block_tokens, extra_keys)
request.append_kv_block_hashes(block_hash)

# Update and added the full block to the cache.
blk.block_hash = block_hash
Expand Down
Loading
Loading