-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[Core] Store only the keys for multi-modal data in P0 #22198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
df8dcc5
b215612
c6535ea
1855af4
4e8fad9
b76dbf1
4f97974
317c19f
7cd0023
e45a503
5f6d902
b3fea68
8c46e93
6d7253f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,9 @@ | |
|
|
||
| This guide covers optimization strategies and performance tuning for vLLM V1. | ||
|
|
||
| !!! tip | ||
| Running out of memory? Consult [this guide](./conserving_memory.md) on how to conserve memory. | ||
|
|
||
| ## Preemption | ||
|
|
||
| Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests. | ||
|
|
@@ -126,62 +129,44 @@ Data parallelism replicates the entire model across multiple GPU sets and proces | |
| Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`. | ||
| Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size. | ||
|
|
||
| ## Reducing Memory Usage | ||
|
|
||
| If you encounter out-of-memory issues, consider these strategies: | ||
| ## Input Processing | ||
|
|
||
| ### Context Length and Batch Size | ||
| ### Parallel Processing | ||
|
|
||
| You can reduce memory usage by limiting the context length and batch size: | ||
| You can run input processing in parallel via [API server scale-out](../serving/data_parallel_deployment.md#internal-load-balancing). | ||
| This is useful when input processing (which is run inside the API server) | ||
| becomes a bottleneck compared to model execution (which is run inside engine core) | ||
| and you have excess CPU capacity. | ||
|
|
||
| ```python | ||
| from vllm import LLM | ||
| ```console | ||
| # Run 4 API processes and 1 engine core process | ||
| vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 | ||
|
|
||
| llm = LLM( | ||
| model="meta-llama/Llama-3.1-8B-Instruct", | ||
| max_model_len=2048, # Limit context window | ||
| max_num_seqs=4 # Limit batch size | ||
| ) | ||
| # Run 4 API processes and 2 engine core processes | ||
| vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2 | ||
| ``` | ||
|
|
||
| ### Adjust CUDA Graph Compilation | ||
| !!! note | ||
| API server scale-out is only available for online inference. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm. What's preventing offline inference from using this?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's just how API server scale-out is set up now. Perhaps @njhill can help answer this |
||
|
|
||
| CUDA graph compilation in V1 uses more memory than in V0. You can reduce memory usage by adjusting the compilation level: | ||
|
|
||
| ```python | ||
| from vllm import LLM | ||
| from vllm.config import CompilationConfig, CompilationLevel | ||
|
|
||
| llm = LLM( | ||
| model="meta-llama/Llama-3.1-8B-Instruct", | ||
| compilation_config=CompilationConfig( | ||
| level=CompilationLevel.PIECEWISE, | ||
| cudagraph_capture_sizes=[1, 2, 4, 8] # Capture fewer batch sizes | ||
| ) | ||
| ) | ||
| ``` | ||
| !!! note | ||
| [Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Link the data_parallel_external_lb doc here as well?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's not needed because the link is already at the beginning of this section |
||
| because it requires a one-to-one correspondance between API and engine core processes. | ||
|
|
||
| Or, if you are not concerned about latency or overall performance, disable CUDA graph compilation entirely with `enforce_eager=True`: | ||
| ## Multi-Modal Caching | ||
|
|
||
| ```python | ||
| from vllm import LLM | ||
| ### Processor Cache | ||
|
|
||
| llm = LLM( | ||
| model="meta-llama/Llama-3.1-8B-Instruct", | ||
| enforce_eager=True # Disable CUDA graph compilation | ||
| ) | ||
| ``` | ||
| By default, the multi-modal processor cache is enabled to avoid repeatedly processing | ||
| the same multi-modal inputs via Hugging Face `AutoProcessor`, | ||
| which commonly occurs in multi-turn conversations. | ||
|
|
||
| ### Multimodal Models | ||
| You can adjust the size of the cache via `VLLM_MM_INPUT_CACHE_GIB` environment variable | ||
| (default 4 GiB per API process + 4 GiB per engine core process). | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto |
||
|
|
||
| For multi-modal models, you can reduce memory usage by limiting the number of images/videos per request: | ||
| If you do not benefit much from the cache, you can disable it completely via `disable_mm_preprocessor_cache`: | ||
|
|
||
| ```python | ||
| from vllm import LLM | ||
|
|
||
| # Accept up to 2 images per prompt | ||
| llm = LLM( | ||
| model="Qwen/Qwen2.5-VL-3B-Instruct", | ||
| limit_mm_per_prompt={"image": 2} | ||
| ) | ||
| llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", | ||
| disable_mm_preprocessor_cache=True) | ||
| ``` | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata | ||
| from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs, | ||
| MultiModalKwargsItem, | ||
| MultiModalSharedField) | ||
|
|
||
|
|
||
| def _dummy_elem(modality: str, key: str, size: int): | ||
| return MultiModalFieldElem( | ||
| modality=modality, | ||
| key=key, | ||
| data=torch.empty((size, ), dtype=torch.int8), | ||
| field=MultiModalSharedField(1), | ||
| ) | ||
|
|
||
|
|
||
| def _dummy_item(modality: str, size_by_key: dict[str, int]): | ||
| return MultiModalKwargsItem.from_elems([ | ||
| _dummy_elem(modality, key, size) for key, size in size_by_key.items() | ||
| ]) | ||
|
|
||
|
|
||
| def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]): | ||
| return MultiModalKwargs.from_items([ | ||
| _dummy_item(modality, size_by_key) | ||
| for modality, size_by_key in size_by_key_modality.items() | ||
| ]) | ||
|
|
||
|
|
||
| # yapf: disable | ||
| @pytest.mark.parametrize( | ||
| ("item", "expected_size"), | ||
| [ | ||
| (_dummy_item("a", {"a1": 100}), 100), | ||
| (_dummy_item("a", {"a1": 100, "a2": 110}), 210), | ||
| (_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501 | ||
| ], | ||
| ) | ||
| # yapf: enable | ||
| def test_cache_item_size(item, expected_size): | ||
| cache = MultiModalCache.get_lru_cache(2048, type(item)) | ||
|
|
||
| cache[""] = item | ||
| assert cache.currsize == expected_size | ||
|
|
||
| cache[""] = MultiModalCacheItemMetadata.wraps(item) | ||
| assert cache.currsize == expected_size |
Uh oh!
There was an error while loading. Please reload this page.