diff --git a/tests/entrypoints/llm/test_mm_cache_stats.py b/tests/entrypoints/llm/test_mm_cache_stats.py new file mode 100644 index 000000000000..62bfefddbe1d --- /dev/null +++ b/tests/entrypoints/llm/test_mm_cache_stats.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm import LLM +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.v1.metrics.reader import Counter, Metric + +from ..openai.test_vision import TEST_IMAGE_ASSETS + + +def _make_messages(image_url: str) -> list[ChatCompletionMessageParam]: + return [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": image_url}, + }, + ], + } + ] + + +def _get_counter_value(metrics: list[Metric], name: str): + metric = next(m for m in metrics if m.name == name) + assert isinstance(metric, Counter) + return metric.value + + +def _get_mm_cache_stats(metrics: list[Metric]): + mm_cache_queries = _get_counter_value(metrics, "vllm:mm_cache_queries") + mm_cache_hits = _get_counter_value(metrics, "vllm:mm_cache_hits") + + return mm_cache_queries, mm_cache_hits + + +@pytest.mark.parametrize("image_urls", [TEST_IMAGE_ASSETS[:2]], indirect=True) +@pytest.mark.parametrize("mm_processor_cache_type", ["lru", "shm"]) +def test_mm_cache_stats( + num_gpus_available, + image_urls, + mm_processor_cache_type, +): + llm = LLM( + model="llava-hf/llava-1.5-7b-hf", + max_model_len=4096, + max_num_seqs=5, + enforce_eager=True, + mm_processor_cache_type=mm_processor_cache_type, + disable_log_stats=False, + limit_mm_per_prompt={"image": 2}, + ) + + llm.chat(_make_messages(image_urls[0])) + assert _get_mm_cache_stats(llm.get_metrics()) == (1, 0) + + llm.chat(_make_messages(image_urls[1])) + assert _get_mm_cache_stats(llm.get_metrics()) == (2, 0) + + llm.chat(_make_messages(image_urls[0])) + assert _get_mm_cache_stats(llm.get_metrics()) == (3, 1) + + # NOTE: This only resets hit rate stats in CachingMetrics + # The raw queries and hits counts remain unaffected + llm.reset_mm_cache() + + llm.chat(_make_messages(image_urls[0])) + assert _get_mm_cache_stats(llm.get_metrics()) == (4, 1) + + llm.chat(_make_messages(image_urls[1])) + assert _get_mm_cache_stats(llm.get_metrics()) == (5, 1) diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index 6b00dde494d1..dbcec9d31fc9 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -18,10 +18,18 @@ from ...utils import RemoteOpenAIServer -MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +MODELS = { + "text": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "multimodal": "HuggingFaceTB/SmolVLM-256M-Instruct", +} PREV_MINOR_VERSION = version._prev_minor_version() +@pytest.fixture(scope="module", params=list(MODELS.keys())) +def model_key(request): + yield request.param + + @pytest.fixture(scope="module") def default_server_args(): return [ @@ -45,11 +53,12 @@ def default_server_args(): f"--show-hidden-metrics-for-version={PREV_MINOR_VERSION}", ], ) -def server(default_server_args, request): +def server(model_key, default_server_args, request): if request.param: default_server_args.append(request.param) - with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + model_name = MODELS[model_key] + with RemoteOpenAIServer(model_name, default_server_args) as remote_server: yield remote_server @@ -60,64 +69,70 @@ async def client(server): _PROMPT = "Hello my name is Robert and I love magic" -tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) -_TOKENIZED_PROMPT = tokenizer(_PROMPT)["input_ids"] - -_NUM_REQUESTS = 10 -_NUM_PROMPT_TOKENS_PER_REQUEST = len(_TOKENIZED_PROMPT) -_NUM_GENERATION_TOKENS_PER_REQUEST = 10 - -# {metric_family: [(suffix, expected_value)]} -EXPECTED_VALUES = { - "vllm:time_to_first_token_seconds": [("_count", _NUM_REQUESTS)], - "vllm:time_per_output_token_seconds": [ - ("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1)) - ], - "vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_queue_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_inference_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_prefill_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_decode_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_prompt_tokens": [ - ("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS), - ], - "vllm:request_generation_tokens": [ - ("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS), - ], - "vllm:request_params_n": [("_count", _NUM_REQUESTS)], - "vllm:request_params_max_tokens": [ - ("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS), - ], - "vllm:iteration_tokens_total": [ - ( - "_sum", - _NUM_REQUESTS - * (_NUM_PROMPT_TOKENS_PER_REQUEST + _NUM_GENERATION_TOKENS_PER_REQUEST), - ), - ("_count", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), - ], - "vllm:prompt_tokens": [("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)], - "vllm:generation_tokens": [ - ("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST) - ], - "vllm:request_success": [("_total", _NUM_REQUESTS)], -} +_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + + +def _get_expected_values(num_requests: int, prompt_ids: list[int], max_tokens: int): + num_prompt_tokens = len(prompt_ids) + + # {metric_family: [(suffix, expected_value)]} + return { + "vllm:time_to_first_token_seconds": [("_count", num_requests)], + "vllm:time_per_output_token_seconds": [ + ("_count", num_requests * (max_tokens - 1)) + ], + "vllm:e2e_request_latency_seconds": [("_count", num_requests)], + "vllm:request_queue_time_seconds": [("_count", num_requests)], + "vllm:request_inference_time_seconds": [("_count", num_requests)], + "vllm:request_prefill_time_seconds": [("_count", num_requests)], + "vllm:request_decode_time_seconds": [("_count", num_requests)], + "vllm:request_prompt_tokens": [ + ("_sum", num_requests * num_prompt_tokens), + ("_count", num_requests), + ], + "vllm:request_generation_tokens": [ + ("_sum", num_requests * max_tokens), + ("_count", num_requests), + ], + "vllm:request_params_n": [("_count", num_requests)], + "vllm:request_params_max_tokens": [ + ("_sum", num_requests * max_tokens), + ("_count", num_requests), + ], + "vllm:iteration_tokens_total": [ + ( + "_sum", + num_requests * (num_prompt_tokens + max_tokens), + ), + ("_count", num_requests * max_tokens), + ], + "vllm:prompt_tokens": [("_total", num_requests * num_prompt_tokens)], + "vllm:generation_tokens": [("_total", num_requests * max_tokens)], + "vllm:request_success": [("_total", num_requests)], + } @pytest.mark.asyncio async def test_metrics_counts( server: RemoteOpenAIServer, client: openai.AsyncClient, + model_key: str, ): - for _ in range(_NUM_REQUESTS): + if model_key == "multimodal": + pytest.skip("Unnecessary test") + + model_name = MODELS[model_key] + tokenizer = AutoTokenizer.from_pretrained(model_name) + prompt_ids = tokenizer.encode(_PROMPT) + num_requests = 10 + max_tokens = 10 + + for _ in range(num_requests): # sending a request triggers the metrics to be logged. await client.completions.create( - model=MODEL_NAME, - prompt=_TOKENIZED_PROMPT, - max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST, + model=model_name, + prompt=prompt_ids, + max_tokens=max_tokens, ) response = requests.get(server.url_for("metrics")) @@ -125,8 +140,9 @@ async def test_metrics_counts( assert response.status_code == HTTPStatus.OK # Loop over all expected metric_families - for metric_family, suffix_values_list in EXPECTED_VALUES.items(): - if (metric_family not in EXPECTED_METRICS_V1) or ( + expected_values = _get_expected_values(num_requests, prompt_ids, max_tokens) + for metric_family, suffix_values_list in expected_values.items(): + if metric_family not in EXPECTED_METRICS_V1 or ( not server.show_hidden_metrics and metric_family in HIDDEN_DEPRECATED_METRICS ): @@ -217,6 +233,11 @@ async def test_metrics_counts( "vllm:request_decode_time_seconds_count", ] +EXPECTED_METRICS_MM = [ + "vllm:mm_cache_queries", + "vllm:mm_cache_hits", +] + HIDDEN_DEPRECATED_METRICS: list[str] = [ "vllm:gpu_cache_usage_perc", "vllm:gpu_prefix_cache_queries", @@ -231,19 +252,43 @@ async def test_metrics_counts( async def test_metrics_exist( server: RemoteOpenAIServer, client: openai.AsyncClient, + model_key: str, ): + model_name = MODELS[model_key] + # sending a request triggers the metrics to be logged. - await client.completions.create( - model=MODEL_NAME, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0, - ) + if model_key == "text": + await client.completions.create( + model=model_name, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0, + ) + else: + await client.chat.completions.create( + model=model_name, + messages=[ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": _IMAGE_URL}}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ], + max_tokens=5, + temperature=0.0, + ) response = requests.get(server.url_for("metrics")) assert response.status_code == HTTPStatus.OK - for metric in EXPECTED_METRICS_V1: + expected_metrics = EXPECTED_METRICS_V1 + if model_key == "multimodal": + # NOTE: Don't use in-place assignment + expected_metrics = expected_metrics + EXPECTED_METRICS_MM + + for metric in expected_metrics: if metric in HIDDEN_DEPRECATED_METRICS and not server.show_hidden_metrics: continue assert metric in response.text @@ -253,9 +298,14 @@ async def test_metrics_exist( async def test_abort_metrics_reset( server: RemoteOpenAIServer, client: openai.AsyncClient, + model_key: str, ): + model_name = MODELS[model_key] + tokenizer = AutoTokenizer.from_pretrained(model_name) + prompt_ids = tokenizer.encode(_PROMPT) + running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api( - server + server, ) # Expect no running requests or kvcache usage @@ -268,8 +318,8 @@ async def test_abort_metrics_reset( for _ in range(3): task = asyncio.create_task( client.completions.create( - model=MODEL_NAME, - prompt=_TOKENIZED_PROMPT, + model=model_name, + prompt=prompt_ids, max_tokens=100, # Long generation to give time to abort temperature=0.0, ) @@ -281,7 +331,7 @@ async def test_abort_metrics_reset( # Check that we have running requests running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api( - server + server, ) # Expect running requests and kvcache usage diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index a31817ec72b6..714a540e86b5 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -20,7 +20,6 @@ BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, - PrefixCachingMetrics, estimate_max_model_len, generate_block_hash_extra_keys, generate_scheduler_kv_cache_config, @@ -42,7 +41,7 @@ SlidingWindowSpec, UniformTypeKVCacheSpecs, ) -from vllm.v1.metrics.stats import PrefixCacheStats +from vllm.v1.metrics.stats import CachingMetrics, PrefixCacheStats from vllm.v1.request import Request pytestmark = pytest.mark.cpu_test @@ -536,7 +535,7 @@ def test_metrics(): """ Test the prefix caching metrics. """ - metrics = PrefixCachingMetrics(max_recent_requests=5) + metrics = CachingMetrics(max_recent_requests=5) assert metrics.hit_rate == 0.0 metrics.observe(_stats(1, 20, 9)) @@ -568,7 +567,7 @@ def test_metrics_empty_stats(): """ Test the prefix caching metrics with empty stats. """ - metrics = PrefixCachingMetrics(max_recent_requests=5) + metrics = CachingMetrics(max_recent_requests=5) metrics.observe(_stats(0, 0, 0)) metrics.observe(_stats(1, 20, 9)) metrics.observe(_stats(0, 0, 0)) diff --git a/tests/v1/distributed/test_async_llm_dp.py b/tests/v1/distributed/test_async_llm_dp.py index 75314dc37303..28bb91f34c39 100644 --- a/tests/v1/distributed/test_async_llm_dp.py +++ b/tests/v1/distributed/test_async_llm_dp.py @@ -17,7 +17,7 @@ from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.core_client import DPAsyncMPClient from vllm.v1.metrics.loggers import StatLoggerBase -from vllm.v1.metrics.stats import IterationStats, SchedulerStats +from vllm.v1.metrics.stats import IterationStats, MultiModalCacheStats, SchedulerStats DP_SIZE = int(os.getenv("DP_SIZE", 2)) @@ -93,6 +93,7 @@ def record( self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats], + mm_cache_stats: Optional[MultiModalCacheStats] = None, engine_idx: int = 0, ): if iteration_stats: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9c8c1bd11ab4..8f47c20f27e0 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -354,6 +354,10 @@ def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: else: self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer) + def reset_mm_cache(self) -> None: + self.processor.clear_mm_cache() + self.llm_engine.reset_mm_cache() + def get_default_sampling_params(self) -> SamplingParams: if self.default_sampling_params is None: self.default_sampling_params = self.model_config.get_diff_sampling_param() diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index df71fb790974..edb8ecc94382 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -274,6 +274,10 @@ def __init__( self.model_config = self.models.model_config self.max_model_len = self.model_config.max_model_len + async def reset_mm_cache(self) -> None: + self.processor.clear_mm_cache() + await self.engine_client.reset_mm_cache() + async def beam_search( self, prompt: PromptType, diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 3a7347b8e465..7bdef5cbe748 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -169,6 +169,10 @@ def list_loras(self) -> set[int]: assert s == sets[0], "All workers should have the same LORAs." return sets[0] + def reset_mm_cache(self) -> None: + """Reset the multi-modal cache in each worker.""" + self.collective_rpc("reset_mm_cache") + def start_profile(self) -> None: self.collective_rpc("start_profile") diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 8206f23d1878..612fd73c12b1 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -12,11 +12,8 @@ import vllm.envs as envs from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.cache import worker_receiver_cache_from_config from vllm.utils import get_distributed_init_method, get_ip, get_open_port, run_method from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType -from vllm.v1.executor.utils import get_and_update_mm_cache from vllm.v1.outputs import AsyncModelRunnerOutput from vllm.v1.worker.worker_base import WorkerWrapperBase @@ -30,16 +27,13 @@ def _init_executor(self) -> None: """Initialize the worker and load the model.""" self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0) distributed_init_method, rank, local_rank = self._distributed_args() - is_driver_worker = True kwargs = dict( vllm_config=self.vllm_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, - is_driver_worker=is_driver_worker, - ) - self.mm_receiver_cache = worker_receiver_cache_from_config( - self.vllm_config, MULTIMODAL_REGISTRY, Lock() + is_driver_worker=True, + shared_worker_lock=Lock(), ) self.async_output_thread: Optional[ThreadPoolExecutor] = None @@ -74,8 +68,6 @@ def collective_rpc( ) -> list[Any]: if kwargs is None: kwargs = {} - if self.mm_receiver_cache is not None and method == "execute_model": - get_and_update_mm_cache(self.mm_receiver_cache, args) if not non_block: return [run_method(self.driver_worker, method, args, kwargs)] diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 19dd61b2e369..809f6c8d83f0 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -19,6 +19,7 @@ from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils.jsontree import json_iter_leaves +from vllm.v1.metrics.stats import MultiModalCacheStats from .data import ( DecoderOnlyInputs, @@ -56,6 +57,8 @@ def __init__( self.mm_registry = mm_registry self.mm_processor_cache = mm_processor_cache + self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None + def get_tokenizer(self) -> AnyTokenizer: if self.tokenizer is None: raise ValueError( @@ -664,14 +667,13 @@ def _process_decoder_only_prompt( return self._build_decoder_only_llm_inputs(prompt_comps) - def preprocess( + def _preprocess( self, prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, *, mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> ProcessorInputs: - """Preprocess the input prompt.""" if self.model_config.is_encoder_decoder: # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder. @@ -694,6 +696,40 @@ def preprocess( mm_uuids=mm_uuids, ) - def clear_cache(self) -> None: + def preprocess( + self, + prompt: PromptType, + tokenization_kwargs: Optional[dict[str, Any]] = None, + *, + mm_uuids: Optional[MultiModalUUIDDict] = None, + ) -> ProcessorInputs: + """Preprocess the input prompt.""" + res = self._preprocess( + prompt, + tokenization_kwargs, + mm_uuids=mm_uuids, + ) + + if self.mm_processor_cache and self.mm_cache_stats is not None: + delta = self.mm_processor_cache.make_stats(delta=True) + self.mm_cache_stats.requests += 1 + self.mm_cache_stats.queries += delta.total + self.mm_cache_stats.hits += delta.hits + + return res + + def stat_mm_cache(self) -> Optional[MultiModalCacheStats]: + mm_cache_stats = self.mm_cache_stats + if mm_cache_stats is None: + return None + + self.mm_cache_stats = MultiModalCacheStats() + + return mm_cache_stats + + def clear_mm_cache(self) -> None: if self.mm_processor_cache is not None: self.mm_processor_cache.clear_cache() + + if self.mm_cache_stats is not None: + self.mm_cache_stats.reset = True diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index 7febc393157f..8b72bbe56eaf 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -18,7 +18,7 @@ from vllm.envs import VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME from vllm.logger import init_logger from vllm.utils import GiB_bytes, MiB_bytes -from vllm.utils.cache import LRUCache +from vllm.utils.cache import CacheInfo, LRUCache from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves from .inputs import ( @@ -302,6 +302,16 @@ def is_cached(self, mm_hashes: list[str]) -> list[bool]: """ return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes] + @abstractmethod + def make_stats(self, *, delta: bool = False) -> CacheInfo: + """ + Get (and reset) the multi-modal cache stats. + + Returns: + The current multi-modal caching stats. + """ + raise NotImplementedError + class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache): """ @@ -347,6 +357,10 @@ def get_and_update_item( def clear_cache(self) -> None: self._cache.clear() + @override + def make_stats(self, *, delta: bool = False) -> CacheInfo: + return self._cache.stat(delta=delta) + class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache): """ @@ -397,6 +411,10 @@ def get_and_update_item( def clear_cache(self) -> None: self._cache.clear() + @override + def make_stats(self, *, delta: bool = False) -> CacheInfo: + return self._cache.stat(delta=delta) + class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache): """ @@ -430,6 +448,20 @@ def __init__(self, vllm_config: "VllmConfig") -> None: # cache (prompt_updates, modality) for P0 only self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], str]] = {} + self._hits = 0 + self._total = 0 + self._last_info = CacheInfo(hits=0, total=0) + + def _stat(self, *, delta: bool = False) -> CacheInfo: + info = CacheInfo(hits=self._hits, total=self._total) + + if delta: + info_delta = info - self._last_info + self._last_info = info + info = info_delta + + return info + @override def is_cached_item(self, mm_hash: str) -> bool: return self._shm_cache.is_cached(mm_hash) @@ -441,12 +473,17 @@ def get_and_update_item( mm_hash: str, ) -> MultiModalProcessorCacheOutItem: if self._shm_cache.is_cached(mm_hash): + self._hits += 1 + self._total += 1 + address, monotonic_id = self._shm_cache.get_cached(mm_hash) prompt_updates, modality = self._p0_cache[mm_hash] return self.address_as_item(address, monotonic_id, modality), prompt_updates assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + self._total += 1 + try: address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0]) # Try to remove dangling items if p0 cache is too large. @@ -469,6 +506,14 @@ def clear_cache(self) -> None: self._shm_cache.clear() self._p0_cache.clear() + self._hits = 0 + self._total = 0 + self._last_info = CacheInfo(hits=0, total=0) + + @override + def make_stats(self, *, delta: bool = False) -> CacheInfo: + return self._stat(delta=delta) + def remove_dangling_items(self) -> None: """Remove items that are no longer in the shared memory cache.""" cached_hashes = self._shm_cache.key_index.keys() diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index bb0b7e259b41..7a602b993685 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -4,7 +4,7 @@ import copy import os -from collections import defaultdict, deque +from collections import defaultdict from collections.abc import Iterable, Sequence from dataclasses import dataclass from typing import Any, Callable, NewType, Optional, Union @@ -23,7 +23,6 @@ SlidingWindowSpec, UniformTypeKVCacheSpecs, ) -from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request # BlockHash represents the hash of a single KV-cache block used for @@ -101,78 +100,6 @@ def init_none_hash(hash_fn: Callable[[Any], bytes]): NONE_HASH = BlockHash(hash_fn(hash_seed)) -class PrefixCachingMetrics: - """Metrics for prefix caching with a hit rate of the max recent N requests. - - Args: - max_recent_requests: The number of the max recent requests to aggregate. - Defaults to 1000. - """ - - def __init__(self, max_recent_requests: int = 1000): - self.max_recent_requests = max_recent_requests - # The current aggregated values. - self.aggregated_requests = 0 - self.aggregated_query_total = 0 - self.aggregated_query_hit = 0 - # A deque of (requests, queries, hits) for the most recent requests. - self.query_queue: deque[tuple[int, int, int]] = deque() - - def observe(self, stats: PrefixCacheStats): - """Observe the prefix caching for a set of requests. - - This function is called with information gathered when new requests - are being scheduled and are looking for computed blocks. - - When there are more than `max_recent_requests` requests, the oldest set - of requests are removed from the metrics. - - Args: - stats: The prefix cache stats. - """ - # reset_prefix_cache was invoked before the current update. - # Reset the metrics before aggregating the current stats. - if stats.reset: - self.reset() - - # DO NOT appending empty stats to avoid helpful info get kicked out - # due to sliding window. - if stats.requests == 0: - return - - # Update the metrics. - self.query_queue.append((stats.requests, stats.queries, stats.hits)) - self.aggregated_requests += stats.requests - self.aggregated_query_total += stats.queries - self.aggregated_query_hit += stats.hits - - # Remove the oldest stats until number of requests does not exceed - # the limit. - # NOTE: We preserve the latest added stats regardless. - while ( - len(self.query_queue) > 1 - and self.aggregated_requests > self.max_recent_requests - ): - old_requests, old_queries, old_hits = self.query_queue.popleft() - self.aggregated_requests -= old_requests - self.aggregated_query_total -= old_queries - self.aggregated_query_hit -= old_hits - - def reset(self): - """Reset the metrics.""" - self.aggregated_requests = 0 - self.aggregated_query_total = 0 - self.aggregated_query_hit = 0 - self.query_queue.clear() - - @property - def hit_rate(self) -> float: - """Calculate the hit rate for the past N requests.""" - if self.aggregated_query_total == 0: - return 0.0 - return self.aggregated_query_hit / self.aggregated_query_total - - @dataclass class KVCacheBlock: """KV-cache block metadata.""" diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 527528b806ea..112ec92b3af8 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -463,6 +463,7 @@ def _run_output_handler(self): output_processor = self.output_processor log_stats = self.log_stats logger_manager = self.logger_manager + processor = self.processor async def output_handler(): try: @@ -511,6 +512,7 @@ async def output_handler(): engine_idx=outputs.engine_index, scheduler_stats=outputs.scheduler_stats, iteration_stats=iteration_stats, + mm_cache_stats=processor.stat_mm_cache(), ) except Exception as e: logger.exception("AsyncLLM output_handler failed.") @@ -660,7 +662,7 @@ async def stop_profile(self) -> None: await asyncio.gather(*coros) async def reset_mm_cache(self) -> None: - self.processor.clear_cache() + self.processor.clear_mm_cache() await self.engine_core.reset_mm_cache_async() async def reset_prefix_cache(self, device: Optional[Device] = None) -> None: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 93f7fd5725bd..8e9ca7e0b8cb 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -319,7 +319,7 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: ) engine_core_outputs = self.scheduler.update_from_output( scheduler_output, model_output - ) # type: ignore + ) return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0) @@ -400,16 +400,19 @@ def profile(self, is_start: bool = True): def reset_mm_cache(self): # NOTE: Since this is mainly for debugging, we don't attempt to - # re-sync the internal caches (P0 processor, P0 mirror, P1 mirror) + # re-sync the internal caches (P0 sender, P1 receiver) if self.scheduler.has_unfinished_requests(): logger.warning( "Resetting the multi-modal cache when requests are " "in progress may lead to desynced internal caches." ) + # The cache either exists in EngineCore or WorkerWrapperBase if self.mm_receiver_cache is not None: self.mm_receiver_cache.clear_cache() + self.model_executor.reset_mm_cache() + def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 2f8bd6c76ff9..b2261855d125 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -306,9 +306,11 @@ def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]: # 4) Record stats if self.logger_manager is not None: assert outputs.scheduler_stats is not None + self.logger_manager.record( scheduler_stats=outputs.scheduler_stats, iteration_stats=iteration_stats, + mm_cache_stats=self.processor.stat_mm_cache(), ) self.do_log_stats_with_interval() @@ -321,7 +323,7 @@ def stop_profile(self): self.engine_core.profile(False) def reset_mm_cache(self): - self.processor.clear_cache() + self.processor.clear_mm_cache() self.engine_core.reset_mm_cache() def reset_prefix_cache(self, device: Optional[Device] = None): diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 941b580e1f83..d106783d6dc1 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -21,6 +21,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreRequest +from vllm.v1.metrics.stats import MultiModalCacheStats from vllm.v1.structured_output.backend_guidance import validate_guidance_grammar from vllm.v1.structured_output.backend_lm_format_enforcer import ( validate_structured_output_request_lm_format_enforcer, @@ -573,5 +574,8 @@ def _validate_model_input( # check that chunked prefill does not truncate them # max_batch_len = self.scheduler_config.max_num_batched_tokens - def clear_cache(self) -> None: - self.input_preprocessor.clear_cache() + def stat_mm_cache(self) -> Optional[MultiModalCacheStats]: + return self.input_preprocessor.stat_mm_cache() + + def clear_mm_cache(self) -> None: + self.input_preprocessor.clear_mm_cache() diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 062b6042693b..d92c8f38571e 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -33,8 +33,6 @@ get_tp_group, ) from vllm.logger import init_logger -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.cache import worker_receiver_cache_from_config from vllm.utils import ( _maybe_force_spawn, decorate_logs, @@ -46,7 +44,6 @@ ) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.executor.abstract import Executor, FailureCallback -from vllm.v1.executor.utils import get_and_update_mm_cache from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput from vllm.v1.worker.worker_base import WorkerWrapperBase @@ -422,6 +419,7 @@ def __init__( "rank": rank, "distributed_init_method": distributed_init_method, "is_driver_worker": is_driver_worker, + "shared_worker_lock": shared_worker_lock, } wrapper.init_worker(all_kwargs) self.worker = wrapper @@ -445,11 +443,6 @@ def __init__( ) self.async_output_copy_thread.start() - # Initialize multimodal receiver cache if needed - self.mm_receiver_cache = worker_receiver_cache_from_config( - vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock - ) - # Initialize device self.worker.init_device() @@ -692,12 +685,7 @@ def worker_busy_loop(self, cancel: Optional[threading.Event] = None): func = getattr(self.worker, method) elif isinstance(method, bytes): func = partial(cloudpickle.loads(method), self.worker) - # retrieve from shm cache if available - if ( - self.mm_receiver_cache is not None - and func.__name__ == "execute_model" - ): - get_and_update_mm_cache(self.mm_receiver_cache, args) + output = func(*args, **kwargs) except Exception as e: # Notes have been introduced in python 3.11 diff --git a/vllm/v1/executor/utils.py b/vllm/v1/executor/utils.py deleted file mode 100644 index 884068a43882..000000000000 --- a/vllm/v1/executor/utils.py +++ /dev/null @@ -1,24 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.multimodal.cache import ShmObjectStoreReceiverCache -from vllm.v1.core.sched.output import SchedulerOutput - - -def get_and_update_mm_cache( - receiver_cache: ShmObjectStoreReceiverCache, - args: tuple[SchedulerOutput], -) -> None: - """ - For each MultiModalKwargsItem in SchedulerOutput, fetch from shared memory - cache as needed. - - Args: - receiver_cache: The receiver cache to update. - args: According to the collective_rpc call of execute_model method in - executor, args is a tuple of only one SchedulerOutput element. - """ - scheduler_output = args[0] - for request_data in scheduler_output.scheduled_new_reqs: - request_data.mm_features = receiver_cache.get_and_update_features( - request_data.mm_features - ) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 541af7af1725..3db9a428e93a 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -11,10 +11,14 @@ from vllm.config import SupportsMetricsInfo, VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging from vllm.logger import init_logger -from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.engine import FinishReason from vllm.v1.metrics.prometheus import unregister_vllm_metrics -from vllm.v1.metrics.stats import IterationStats, SchedulerStats +from vllm.v1.metrics.stats import ( + CachingMetrics, + IterationStats, + MultiModalCacheStats, + SchedulerStats, +) from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm logger = init_logger(__name__) @@ -38,6 +42,7 @@ def record( self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats], + mm_cache_stats: Optional[MultiModalCacheStats] = None, engine_idx: int = 0, ): ... @@ -53,10 +58,15 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.engine_index = engine_index self.vllm_config = vllm_config self._reset(time.monotonic()) + self.last_scheduler_stats = SchedulerStats() - # Prefix cache metrics. This cannot be reset. + self.last_mm_cache_stats: Optional[MultiModalCacheStats] = None + + # Caching metrics. This cannot be reset. # TODO: Make the interval configurable. - self.prefix_caching_metrics = PrefixCachingMetrics() + self.prefix_caching_metrics = CachingMetrics() + self.mm_caching_metrics = CachingMetrics() + self.spec_decoding_logging = SpecDecodingLogging() kv_tranfer_config = self.vllm_config.kv_transfer_config self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config) @@ -86,6 +96,7 @@ def record( self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats], + mm_cache_stats: Optional[MultiModalCacheStats] = None, engine_idx: int = 0, ): """Log Stats to standard output.""" @@ -101,6 +112,11 @@ def record( self.kv_connector_logging.observe(kv_connector_stats) self.last_scheduler_stats = scheduler_stats + if mm_cache_stats: + self.mm_caching_metrics.observe(mm_cache_stats) + + self.last_mm_cache_stats = mm_cache_stats + def log(self): now = time.monotonic() prompt_throughput = self._get_throughput(self.num_prompt_tokens, now) @@ -125,21 +141,32 @@ def log(self): self.last_prompt_throughput = prompt_throughput # Format and print output. - log_fn( - "Engine %03d: " - "Avg prompt throughput: %.1f tokens/s, " - "Avg generation throughput: %.1f tokens/s, " - "Running: %d reqs, Waiting: %d reqs, " - "GPU KV cache usage: %.1f%%, " + log_parts = [ + "Avg prompt throughput: %.1f tokens/s", + "Avg generation throughput: %.1f tokens/s", + "Running: %d reqs", + "Waiting: %d reqs", + "GPU KV cache usage: %.1f%%", "Prefix cache hit rate: %.1f%%", - self.engine_index, + ] + log_args = [ prompt_throughput, generation_throughput, scheduler_stats.num_running_reqs, scheduler_stats.num_waiting_reqs, scheduler_stats.kv_cache_usage * 100, self.prefix_caching_metrics.hit_rate * 100, + ] + if self.last_mm_cache_stats: + log_parts.append("MM cache hit rate: %.1f%%") + log_args.append(self.mm_caching_metrics.hit_rate * 100) + + log_fn( + "Engine %03d: " + ", ".join(log_parts), + self.engine_index, + *log_args, ) + self.spec_decoding_logging.log(log_fn=log_fn) self.kv_connector_logging.log(log_fn=log_fn) @@ -288,6 +315,32 @@ def __init__( counter_prefix_cache_hits, engine_indexes, model_name ) + # + # Multi-modal cache + # + + counter_mm_cache_queries = self._counter_cls( + name="vllm:mm_cache_queries", + documentation=( + "Multi-modal cache queries, in terms of number of queried items." + ), + labelnames=labelnames, + ) + self.counter_mm_cache_queries = make_per_engine( + counter_mm_cache_queries, engine_indexes, model_name + ) + + counter_mm_cache_hits = self._counter_cls( + name="vllm:mm_cache_hits", + documentation=( + "Multi-modal cache hits, in terms of number of cached items." + ), + labelnames=labelnames, + ) + self.counter_mm_cache_hits = make_per_engine( + counter_mm_cache_hits, engine_indexes, model_name + ) + # # Counters # @@ -657,6 +710,7 @@ def record( self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats], + mm_cache_stats: Optional[MultiModalCacheStats] = None, engine_idx: int = 0, ): """Log to prometheus.""" @@ -694,6 +748,10 @@ def record( scheduler_stats.spec_decoding_stats, engine_idx ) + if mm_cache_stats is not None: + self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries) + self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits) + if iteration_stats is None: return @@ -871,6 +929,7 @@ def record( self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats], + mm_cache_stats: Optional[MultiModalCacheStats] = None, engine_idx: Optional[int] = None, ): if engine_idx is None: @@ -878,9 +937,19 @@ def record( per_engine_loggers = self.per_engine_logger_dict[engine_idx] for logger in per_engine_loggers: - logger.record(scheduler_stats, iteration_stats, engine_idx) + logger.record( + scheduler_stats, + iteration_stats, + mm_cache_stats=mm_cache_stats, + engine_idx=engine_idx, + ) - self.prometheus_logger.record(scheduler_stats, iteration_stats, engine_idx) + self.prometheus_logger.record( + scheduler_stats, + iteration_stats, + mm_cache_stats=mm_cache_stats, + engine_idx=engine_idx, + ) def log(self): for per_engine_loggers in self.per_engine_logger_dict.values(): diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 5564718d5165..f0922288db32 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time +from collections import deque from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Optional @@ -13,24 +14,122 @@ @dataclass -class PrefixCacheStats: - """Stores prefix cache hit statistics.""" +class BaseCacheStats: + """Stores cache hit statistics.""" - # Whether reset_prefix_cache was invoked. reset: bool = False - # The number of new requests in this update. + """Whether the cache was reset.""" + requests: int = 0 - # The number of queries in these requests. Note that "queries" here - # means the number of tokens that were queried from the cache. + """The number of requests in this update.""" + queries: int = 0 - # The number of hits in these requests. + """The number of queries in these requests.""" + hits: int = 0 - # The number of previously preempted requests in this update. + """The number of hits in these requests.""" + + +class CachingMetrics: + """Metrics for caching with a hit rate of the most recent N requests. + Args: + interval: The number of the most recent requests to aggregate. + Defaults to 1000. + """ + + def __init__(self, max_recent_requests: int = 1000) -> None: + super().__init__() + + self.max_recent_requests = max_recent_requests + # The current aggregated values. + self.aggregated_requests = 0 + self.aggregated_query_total = 0 + self.aggregated_query_hit = 0 + + # A deque of (requests, queries, hits) for the most recent requests. + self.query_queue = deque[tuple[int, int, int]]() + + def observe(self, stats: BaseCacheStats): + """Observe the prefix caching for a set of requests. + + This function is called with information gathered when new requests + are being scheduled and are looking for computed blocks. + + When there are more than `max_recent_requests` requests, the oldest set + of requests are removed from the metrics. + + Args: + stats: The prefix cache stats. + """ + # reset_prefix_cache was invoked before the current update. + # Reset the metrics before aggregating the current stats. + if stats.reset: + self.reset() + + # DO NOT appending empty stats to avoid helpful info get kicked out + # due to sliding window. + if stats.requests == 0: + return + + # Update the metrics. + self.query_queue.append((stats.requests, stats.queries, stats.hits)) + self.aggregated_requests += stats.requests + self.aggregated_query_total += stats.queries + self.aggregated_query_hit += stats.hits + + # Remove the oldest stats until number of requests does not exceed + # the limit. + # NOTE: We preserve the latest added stats regardless. + while ( + len(self.query_queue) > 1 + and self.aggregated_requests > self.max_recent_requests + ): + old_requests, old_queries, old_hits = self.query_queue.popleft() + self.aggregated_requests -= old_requests + self.aggregated_query_total -= old_queries + self.aggregated_query_hit -= old_hits + + def reset(self): + """Reset the metrics.""" + self.aggregated_requests = 0 + self.aggregated_query_total = 0 + self.aggregated_query_hit = 0 + self.query_queue.clear() + + @property + def hit_rate(self) -> float: + """Calculate the hit rate for the past N requests.""" + if self.aggregated_query_total == 0: + return 0.0 + return self.aggregated_query_hit / self.aggregated_query_total + + +@dataclass +class PrefixCacheStats(BaseCacheStats): + """ + Stores prefix cache hit statistics. + - `reset`: Whether `reset_prefix_cache` was invoked. + - `queries`: Refers to the number of tokens that were queried. + """ + preempted_requests: int = 0 - # The `queries` number for preempted requests. + """The number of previously preempted requests in this update.""" + preempted_queries: int = 0 - # The `hits` number for preempted requests. + """The `queries` number for preempted requests.""" + preempted_hits: int = 0 + """The `hits` number for preempted requests.""" + + +@dataclass +class MultiModalCacheStats(BaseCacheStats): + """ + Stores multi-modal cache hit statistics. + - `reset`: Whether `reset_mm_cache` was invoked. + - `queries`: Refers to the number of multi-modal data items + that were queried. + """ @dataclass diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8ebfb1f2b857..0ea1d9077f5d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -504,6 +504,10 @@ def __init__( pin_memory=self.pin_memory, ) + def reset_mm_cache(self) -> None: + if self.mm_budget: + self.mm_budget.reset_cache() + def _get_positions(self, num_tokens: Any): if isinstance(num_tokens, int): if self.uses_mrope: diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 70f7c1d45b5f..4f4da73fba6e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -442,6 +442,9 @@ def compile_or_warm_up_model(self) -> None: # the model initialization and profiling. set_random_seed(self.model_config.seed) + def reset_mm_cache(self) -> None: + self.model_runner.reset_mm_cache() + def get_model(self) -> nn.Module: return self.model_runner.get_model() diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f8c1ec850b1b..f9e1fcedc890 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -371,6 +371,10 @@ def __init__( else: self.sample_from_logits_func = self.sample_from_logits + def reset_mm_cache(self) -> None: + if self.mm_budget: + self.mm_budget.reset_cache() + def _update_num_xla_graphs(self, case_str): check_comp = self.check_recompilation and not self.enforce_eager if not check_comp: diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 861d7ae737ee..b64cec318f6c 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -293,6 +293,9 @@ def compile_or_warm_up_model(self) -> None: # the model initialization and profiling. set_random_seed(self.model_config.seed) + def reset_mm_cache(self) -> None: + self.model_runner.reset_mm_cache() + def get_model(self) -> nn.Module: return self.model_runner.get_model() diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index c3d16827f10e..6657a2a8db82 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -126,6 +126,10 @@ def get_max_items( return max_items_per_prompt, max_items_per_batch + def reset_cache(self) -> None: + if self.cache is not None: + self.cache.clear_cache() + @dataclass class AttentionGroup: diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index dc9bb3910fbc..8ee3b240904c 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -4,7 +4,7 @@ from __future__ import annotations import os -from typing import Any, Callable, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union import torch import torch.nn as nn @@ -12,7 +12,8 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import worker_receiver_cache_from_config from vllm.utils import ( enable_trace_function_call_for_thread, resolve_obj_by_qualname, @@ -21,7 +22,10 @@ warn_for_unimplemented_methods, ) from vllm.v1.kv_cache_interface import KVCacheSpec -from vllm.v1.outputs import SamplerOutput + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.outputs import ModelRunnerOutput logger = init_logger(__name__) @@ -103,6 +107,11 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: """Initialize the KV cache with the given size in blocks.""" raise NotImplementedError + def reset_mm_cache(self) -> None: + reset_fn = getattr(self.model_runner, "reset_mm_cache", None) + if callable(reset_fn): + reset_fn() + def get_model(self) -> nn.Module: raise NotImplementedError @@ -114,9 +123,7 @@ def load_model(self) -> None: """Load model onto target device.""" raise NotImplementedError - def execute_model( - self, execute_model_req: ExecuteModelRequest | None = None - ) -> list[SamplerOutput] | None: + def execute_model(self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput: raise NotImplementedError def start_worker_execution_loop(self) -> None: @@ -125,11 +132,7 @@ def start_worker_execution_loop(self) -> None: You can stop the loop by executing a driver worker with an empty output. See `stop_remote_worker_execution_loop` for more details. """ - with self.current_platform.inference_mode(): - while True: - output = self.execute_model(execute_model_req=None) - if output is None: - return None + raise NotImplementedError("Dead V0 code") def determine_num_available_blocks(self) -> tuple[int, int]: """Determine the number of available blocks for the GPU KV cache and @@ -289,6 +292,28 @@ def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None: worker_class, extended_calls, ) + + shared_worker_lock = kwargs.pop("shared_worker_lock", None) + if shared_worker_lock is None: + msg = ( + "Missing `shared_worker_lock` argument from executor. " + "This argument is needed for mm_processor_cache_type='shm'." + ) + + mm_config = self.vllm_config.model_config.multimodal_config + if mm_config and mm_config.mm_processor_cache_type == "shm": + raise ValueError(msg) + else: + logger.warning_once(msg) + + self.mm_receiver_cache = None + else: + self.mm_receiver_cache = worker_receiver_cache_from_config( + self.vllm_config, + MULTIMODAL_REGISTRY, + shared_worker_lock, + ) + with set_current_vllm_config(self.vllm_config): # To make vLLM config available during worker initialization self.worker = worker_class(**kwargs) @@ -323,5 +348,34 @@ def execute_method(self, method: Union[str, bytes], *args, **kwargs): logger.exception(msg) raise e - def __getattr__(self, attr): + def __getattr__(self, attr: str): return getattr(self.worker, attr) + + def _apply_mm_cache(self, scheduler_output: SchedulerOutput) -> None: + mm_cache = self.mm_receiver_cache + if mm_cache is None: + return + + for req_data in scheduler_output.scheduled_new_reqs: + req_data.mm_features = mm_cache.get_and_update_features( + req_data.mm_features + ) + + def execute_model( + self, + scheduler_output: SchedulerOutput, + *args, + **kwargs, + ) -> ModelRunnerOutput: + self._apply_mm_cache(scheduler_output) + + assert self.worker is not None + return self.worker.execute_model(scheduler_output, *args, **kwargs) + + def reset_mm_cache(self) -> None: + mm_receiver_cache = self.mm_receiver_cache + if mm_receiver_cache is not None: + mm_receiver_cache.clear_cache() + + assert self.worker is not None + self.worker.reset_mm_cache()