diff --git a/tests/test_logger.py b/tests/test_logger.py index 01672358902f..f5f99f3b9392 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -444,6 +444,7 @@ def test_request_logger_log_outputs_integration(): prompt_embeds=None, params=None, lora_request=None, + cache_hit_threshold=None, ) request_logger.log_outputs( diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 92e3831b9c7a..4e988ad1a79d 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1287,6 +1287,176 @@ def test_kv_connector_handles_preemption(): assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1 +def _iterate_until_done(scheduler: Scheduler): + while True: + scheduler_output = scheduler.schedule() + if len(scheduler.running) == 0: + break + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + + +@pytest.mark.parametrize( + "global_threshold," + "request_num_tokens," + "request_local_hit_blocks," + "request_external_hit_blocks," + "request_thresholds," + "request_expected_scehduled", + [ + ( + 0.0, + [57, 34, 28], + [1, 1, 0], + [0, 1, 0], + # expected hit ratio: [0.281, 0.941, 0.0] + # calculated as (local + external) * BLOCK_SIZE / tokens + [None, 0.4, 0.1], + [True, True, False], + ), + ( + 0.3, + [157, 134, 128, 20, 150], + [4, 1, 0, 0, 1], + [2, 4, 0, 1, 0], + # expected hit ratio: [0.611, 0.597, 0.0, 0.8, 0.106] + [0.8, 0.4, 0.1, None, None], + [False, True, False, True, False], + ), + ], +) +def test_cache_hit_threshold( + # we validate global_threshold is used when request threshold is None + global_threshold: float, + # number of tokens in each request + request_num_tokens: list[int], + # number of blocks hit in local cache per request + request_local_hit_blocks: list[int], + # number of blocks hit in external cache per request + request_external_hit_blocks: list[int], + # optional cache_hit_threshold for each request + request_thresholds: list[float | None], + # bool per request indicating if it is expected to be scheduled + request_expected_scehduled: list[bool], +): + assert ( + len(request_num_tokens) + == len(request_thresholds) + == len(request_local_hit_blocks) + == len(request_external_hit_blocks) + == len(request_expected_scehduled) + ) + + scheduler = create_scheduler( + enable_prefix_caching=True, + global_cache_hit_threshold=global_threshold, + use_kv_connector=True, + ) + + _insert_to_local_cache(request_local_hit_blocks, scheduler) + _mock_external_cache_hit(request_external_hit_blocks, scheduler) + + requests, scheduler_output = _create_and_schedule_requests( + request_num_tokens, request_thresholds, scheduler + ) + + # assert all requests expected to be scheduled are indeed scheduled + assert [ + r.request_id + for r, expected in zip(requests, request_expected_scehduled) + if expected + ] == [s.req_id for s in scheduler_output.scheduled_new_reqs] + + # assert other requests are "finished" due to cache threshold + requests_expected_not_scheduled = [ + r for r, expected in zip(requests, request_expected_scehduled) if not expected + ] + assert all( + r.status == RequestStatus.FINISHED_CACHE_HIT_BELOW_THRESHOLD + for r in requests_expected_not_scheduled + ) + + _iterate_until_done(scheduler) + assert_scheduler_empty(scheduler) + + +def _create_and_schedule_requests( + request_num_tokens: list[int], + request_thresholds: list[float | None], + scheduler: Scheduler, +): + num_requests = len(request_num_tokens) + requests = create_requests( + num_requests=num_requests, + num_tokens=request_num_tokens, + block_size=scheduler.cache_config.block_size, + cache_hit_thresholds=request_thresholds, + ) + + for request in requests: + scheduler.add_request(request) + + scheduler_output = scheduler.schedule() + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + return requests, scheduler_output + + +def _mock_external_cache_hit(request_external_hit_blocks, scheduler: Scheduler): + BLOCK_SIZE = scheduler.cache_config.block_size + scheduler.connector.get_num_new_matched_tokens = Mock(name="method") + scheduler.connector.get_num_new_matched_tokens.side_effect = [ + (i * BLOCK_SIZE, False) for i in request_external_hit_blocks + ] + + +def _insert_to_local_cache(request_local_hit_blocks, scheduler: Scheduler): + """Schedule requests to fill in the local cache""" + BLOCK_SIZE = scheduler.cache_config.block_size + num_total_requests = len(request_local_hit_blocks) + + requests_to_schedule = [ + i for i, hit_blocks in enumerate(request_local_hit_blocks) if hit_blocks > 0 + ] + + num_requests_to_schedule = len(requests_to_schedule) + if num_requests_to_schedule == 0: + # nothing to do + return + + # Mock no external Cache Hit for this cache-warmup phase + scheduler.connector.get_num_new_matched_tokens = Mock(name="method") + scheduler.connector.get_num_new_matched_tokens.return_value = (0, False) + + # set threshold to 0.0 to ensure all are scheduled + zero_thresholds: list[float | None] = [0.0] * num_total_requests + + # Only requests with local hits should run and populate the cache + # We create all requests to make sure the correct tokens are cached + # (since the tokens are generated according to request id) + requests = create_requests( + num_requests=num_total_requests, + num_tokens=[x * BLOCK_SIZE for x in request_local_hit_blocks], + block_size=BLOCK_SIZE, + cache_hit_thresholds=zero_thresholds, + ) + + # Only schedule the request we want to run and populate the cache + for i in requests_to_schedule: + scheduler.add_request(requests[i]) + + scheduler_output = scheduler.schedule() + + # verify all were indeed scheduled + assert len(scheduler_output.scheduled_new_reqs) == num_requests_to_schedule + + # iterate until all scheduled requests are done + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + _iterate_until_done(scheduler) + assert_scheduler_empty(scheduler) + + def make_output(scheduler: Scheduler): return ModelRunnerOutput( req_ids=[req.request_id for req in scheduler.running], @@ -1359,13 +1529,7 @@ def test_memory_leak(): model_runner_output = make_output(scheduler) scheduler.update_from_output(scheduler_output, model_runner_output) - # Iterate until done. - while True: - scheduler_output = scheduler.schedule() - if len(scheduler.running) == 0: - break - model_runner_output = make_output(scheduler) - scheduler.update_from_output(scheduler_output, model_runner_output) + _iterate_until_done(scheduler) # Confirm no memory leak. assert_scheduler_empty(scheduler) diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 3f5e1b9eeaf7..43f2ac6dd19c 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -47,6 +47,7 @@ def create_scheduler( skip_tokenizer_init: bool = False, async_scheduling: bool = False, disable_hybrid_kv_cache_manager: bool = False, + global_cache_hit_threshold: float = 0.0, ) -> Scheduler | AsyncScheduler: """Create scheduler under test. @@ -72,6 +73,7 @@ def create_scheduler( enable_chunked_prefill=True, async_scheduling=async_scheduling, disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager, + global_cache_hit_threshold=global_cache_hit_threshold, ) model_config = ModelConfig( model=model, @@ -141,19 +143,21 @@ def create_scheduler( def create_requests( num_requests: int, - num_tokens: int = 10, + num_tokens: int | list[int] = 10, mm_positions: list[list[PlaceholderRange]] | None = None, max_tokens: int = 16, stop_token_ids: list[int] | None = None, prompt_logprobs: int | None = None, same_prompt: bool = False, block_size: int = 16, + cache_hit_thresholds: list[float | None] | None = None, ) -> list[Request]: global _none_hash_initialized if not _none_hash_initialized: init_none_hash(sha256) _none_hash_initialized = True - + if cache_hit_thresholds is not None: + assert len(cache_hit_thresholds) == num_requests block_hasher = get_request_block_hasher(block_size, sha256) sampling_params = SamplingParams( ignore_eos=False, @@ -177,8 +181,16 @@ def create_requests( modality="image", ) mm_features.append(mm_feature) - - prompt_token_ids = [0] * num_tokens if same_prompt else [i] * num_tokens + request_num_tokens: int = ( + num_tokens[i] if isinstance(num_tokens, list) else num_tokens + ) + prompt_token_ids = ( + [0] * request_num_tokens if same_prompt else [i] * request_num_tokens + ) + if cache_hit_thresholds is not None: + cache_hit_threshold = cache_hit_thresholds[i] + else: + cache_hit_threshold = None request = Request( request_id=f"{i}", prompt_token_ids=prompt_token_ids, @@ -187,6 +199,7 @@ def create_requests( mm_features=mm_features if mm_features else None, eos_token_id=EOS_TOKEN_ID, block_hasher=block_hasher, + cache_hit_threshold=cache_hit_threshold, ) requests.append(request) return requests diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index b837b830e774..1c615168b26c 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -138,6 +138,14 @@ class SchedulerConfig: structured outputs, speculative decoding, and pipeline parallelism. """ + global_cache_hit_threshold: float = 0.0 + """The threshold for cache hit ratio to handle all requests, + except for requests which override it using the "cache_hit_threshold" field. + This feature enables Decode-first optimization in P/D disaggregation: + Decode nodes can avoide remote Prefill in case of high cache hit ratio. + If set to 0.0, the optimization is disabled. + """ + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -295,4 +303,13 @@ def _verify_args(self) -> Self: f"{self.max_num_partial_prefills=}." ) + if (self.global_cache_hit_threshold < 0.0) or ( + self.global_cache_hit_threshold > 1.0 + ): + raise ValueError( + "global_cache_hit_threshold " + f"({self.global_cache_hit_threshold}) " + "must be between 0.0 and 1.0, inclusive." + ) + return self diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index ee91cb0ef5c3..b42837cd1d5d 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -964,7 +964,8 @@ def __str__(self): f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa f"pooler_config={self.model_config.pooler_config!r}, " - f"compilation_config={self.compilation_config!r}" + f"compilation_config={self.compilation_config!r}", + f"global_cache_hit_threshold={self.scheduler_config.global_cache_hit_threshold}", ) @model_validator(mode="after") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 14fd4e70ad6c..f568e4ff6665 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -563,6 +563,7 @@ class EngineArgs: kv_offloading_backend: KVOffloadingBackend | None = ( CacheConfig.kv_offloading_backend ) + global_cache_hit_threshold: float = SchedulerConfig.global_cache_hit_threshold def __post_init__(self): # support `EngineArgs(compilation_config={...})` @@ -1054,6 +1055,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: scheduler_group.add_argument( "--scheduler-cls", **scheduler_kwargs["scheduler_cls"] ) + scheduler_group.add_argument( + "--global-cache-hit-threshold", + **scheduler_kwargs["global_cache_hit_threshold"], + ) scheduler_group.add_argument( "--disable-hybrid-kv-cache-manager", **scheduler_kwargs["disable_hybrid_kv_cache_manager"], @@ -1601,6 +1606,7 @@ def create_engine_config( long_prefill_token_threshold=self.long_prefill_token_threshold, disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager, async_scheduling=self.async_scheduling, + global_cache_hit_threshold=self.global_cache_hit_threshold, ) if not model_config.is_multimodal_model and self.default_mm_loras: diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 24fcd9fe1cab..6b1acce4f930 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -64,6 +64,7 @@ def generate( trace_headers: Mapping[str, str] | None = None, priority: int = 0, data_parallel_rank: int | None = None, + cache_hit_threshold: float | None = None, ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request.""" ... @@ -79,6 +80,7 @@ def encode( priority: int = 0, truncate_prompt_tokens: int | None = None, tokenization_kwargs: dict[str, Any] | None = None, + cache_hit_threshold: float | None = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: """Generate outputs for a request from a pooling model.""" ... diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 154cdeb42a3e..ab0aa55d9497 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -60,11 +60,14 @@ async def generate(request: Request) -> Response: async def _generate(request_dict: dict, raw_request: Request) -> Response: prompt = request_dict.pop("prompt") stream = request_dict.pop("stream", False) + cache_hit_threshold = request_dict.pop("cache_hit_threshold", None) sampling_params = SamplingParams(**request_dict) request_id = random_uuid() assert engine is not None - results_generator = engine.generate(prompt, sampling_params, request_id) + results_generator = engine.generate( + prompt, sampling_params, request_id, cache_hit_threshold=cache_hit_threshold + ) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index 678a7b3a60b5..e0accd847498 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -25,6 +25,7 @@ def log_inputs( prompt_embeds: torch.Tensor | None, params: SamplingParams | PoolingParams | BeamSearchParams | None, lora_request: LoRARequest | None, + cache_hit_threshold: float | None = None, ) -> None: max_log_len = self.max_log_len if max_log_len is not None: @@ -37,7 +38,7 @@ def log_inputs( logger.debug( "Request %s details: prompt: %r, " "prompt_token_ids: %s, " - "prompt_embeds shape: %s.", + "prompt_embeds shape: %s, ", request_id, prompt, prompt_token_ids, @@ -45,10 +46,12 @@ def log_inputs( ) logger.info( - "Received request %s: params: %s, lora_request: %s.", + "Received request %s: params: %s, lora_request: %s ", + "cache_hit_threshold: %s.", request_id, params, lora_request, + cache_hit_threshold, ) def log_outputs( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index d0061f9d5b40..da994d6a265d 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -530,6 +530,17 @@ def function_call_parsing(cls, data): return data +def _validate_cache_hit_threshold(cls, data): + cache_hit_threshold = data.get("cache_hit_threshold") + if cache_hit_threshold is not None and ( + cache_hit_threshold < 0.0 or cache_hit_threshold > 1.0 + ): + raise ValueError( + "Parameter `cache_hit_threshold` must be between 0.0 and 1.0 if provided." + ) + return data + + class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create @@ -777,6 +788,11 @@ class ChatCompletionRequest(OpenAIBaseModel): description="KVTransfer parameters used for disaggregated serving.", ) + cache_hit_threshold: float | None = Field( + default=None, + description="Minimum required KV-cache hit ratio to process the request.", + ) + vllm_xargs: dict[str, str | int | float] | None = Field( default=None, description=( @@ -892,6 +908,7 @@ def to_sampling_params( if self.kv_transfer_params: # Pass in kv_transfer_params via extra_args extra_args["kv_transfer_params"] = self.kv_transfer_params + return SamplingParams.from_optional( n=self.n, best_of=self.best_of, @@ -1080,6 +1097,10 @@ def check_generation_prompt(cls, data): ) return data + _validate_cache_hit_threshold = model_validator(mode="before")( + _validate_cache_hit_threshold + ) + @model_validator(mode="before") @classmethod def check_cache_salt_support(cls, data): @@ -1274,6 +1295,11 @@ class CompletionRequest(OpenAIBaseModel): description="KVTransfer parameters used for disaggregated serving.", ) + cache_hit_threshold: float | None = Field( + default=None, + description="Minimum required KV-cache hit ratio to process the request.", + ) + vllm_xargs: dict[str, str | int | float] | None = Field( default=None, description=( @@ -1386,6 +1412,7 @@ def to_sampling_params( if self.kv_transfer_params: # Pass in kv_transfer_params via extra_args extra_args["kv_transfer_params"] = self.kv_transfer_params + return SamplingParams.from_optional( n=self.n, best_of=self.best_of, @@ -1499,6 +1526,10 @@ def check_cache_salt_support(cls, data): ) return data + _validate_cache_hit_threshold = model_validator(mode="before")( + _validate_cache_hit_threshold + ) + class EmbeddingCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 25979d5502b0..ee3e2dbcccfa 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -295,11 +295,14 @@ async def create_chat_completion( self.default_sampling_params, ) + cache_hit_threshold = request.cache_hit_threshold + self._log_inputs( request_id, request_prompts[i], params=sampling_params, lora_request=lora_request, + cache_hit_threshold=cache_hit_threshold, ) trace_headers = ( @@ -323,6 +326,7 @@ async def create_chat_completion( lora_request=lora_request, trace_headers=trace_headers, priority=request.priority, + cache_hit_threshold=request.cache_hit_threshold, ) generator = self.engine_client.generate( @@ -335,6 +339,7 @@ async def create_chat_completion( prompt_text=prompt_text, tokenization_kwargs=tokenization_kwargs, data_parallel_rank=data_parallel_rank, + cache_hit_threshold=request.cache_hit_threshold, ) generators.append(generator) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 14dbdd4cb4c7..47802e231edc 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -183,12 +183,14 @@ async def create_completion( ) request_id_item = f"{request_id}-{i}" + cache_hit_threshold = request.cache_hit_threshold self._log_inputs( request_id_item, engine_prompt, params=sampling_params, lora_request=lora_request, + cache_hit_threshold=cache_hit_threshold, ) trace_headers = ( @@ -216,6 +218,7 @@ async def create_completion( lora_request=lora_request, trace_headers=trace_headers, priority=request.priority, + cache_hit_threshold=cache_hit_threshold, ) generator = self.engine_client.generate( @@ -228,6 +231,7 @@ async def create_completion( prompt_text=prompt_text, tokenization_kwargs=tokenization_kwargs, data_parallel_rank=data_parallel_rank, + cache_hit_threshold=request.cache_hit_threshold, ) generators.append(generator) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 51f6106acec3..1b6ff54b9170 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -228,12 +228,14 @@ async def _process_chunked_request( prompt=chunk_text, prompt_token_ids=chunk_tokens ) + cache_hit_threshold = getattr(ctx.request, "cache_hit_threshold", None) # Log the chunk self._log_inputs( chunk_request_id, chunk_request_prompt, params=pooling_params, lora_request=ctx.lora_request, + cache_hit_threshold=cache_hit_threshold, ) # Create generator for this chunk and wrap it to return indices @@ -244,6 +246,7 @@ async def _process_chunked_request( lora_request=ctx.lora_request, trace_headers=trace_headers, priority=getattr(ctx.request, "priority", 0), + cache_hit_threshold=cache_hit_threshold, ) generators.append(original_generator) @@ -343,12 +346,14 @@ async def _create_single_prompt_generator( ) -> AsyncGenerator[RequestOutput | PoolingRequestOutput, None]: """Create a generator for a single prompt using standard processing.""" request_id_item = f"{ctx.request_id}-{prompt_index}" + cache_hit_threshold = getattr(ctx.request, "cache_hit_threshold", None) self._log_inputs( request_id_item, engine_prompt, params=pooling_params, lora_request=ctx.lora_request, + cache_hit_threshold=cache_hit_threshold, ) # Return the original generator without wrapping @@ -359,6 +364,7 @@ async def _create_single_prompt_generator( lora_request=ctx.lora_request, trace_headers=trace_headers, priority=getattr(ctx.request, "priority", 0), + cache_hit_threshold=cache_hit_threshold, ) @override diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 46e79edbde61..5227563b9aee 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -637,12 +637,13 @@ async def _prepare_generators( for i, engine_prompt in enumerate(ctx.engine_prompts): request_id_item = f"{ctx.request_id}-{i}" - + cache_hit_threshold = getattr(ctx.request, "cache_hit_threshold", None) self._log_inputs( request_id_item, engine_prompt, params=pooling_params, lora_request=ctx.lora_request, + cache_hit_threshold=cache_hit_threshold, ) generator = self.engine_client.encode( @@ -652,6 +653,7 @@ async def _prepare_generators( lora_request=ctx.lora_request, trace_headers=trace_headers, priority=getattr(ctx.request, "priority", 0), + cache_hit_threshold=cache_hit_threshold, ) generators.append(generator) @@ -1150,6 +1152,7 @@ async def _process_inputs( lora_request: LoRARequest | None, trace_headers: Mapping[str, str] | None, priority: int, + cache_hit_threshold: float | None = None, ) -> tuple[EngineCoreRequest, dict[str, Any]]: """Use the Processor to process inputs for AsyncLLM.""" tokenization_kwargs: dict[str, Any] = {} @@ -1165,6 +1168,7 @@ async def _process_inputs( tokenization_kwargs=tokenization_kwargs, trace_headers=trace_headers, priority=priority, + cache_hit_threshold=cache_hit_threshold, ) return engine_request, tokenization_kwargs @@ -1182,11 +1186,13 @@ async def _generate_with_builtin_tools( prompt_text, _, _ = self._get_prompt_components(request_prompt) orig_priority = priority while True: + cache_hit_threshold = kwargs.get("cache_hit_threshold") self._log_inputs( request_id, request_prompt, params=sampling_params, lora_request=lora_request, + cache_hit_threshold=cache_hit_threshold, ) trace_headers = kwargs.get("trace_headers") engine_request, tokenization_kwargs = await self._process_inputs( @@ -1196,6 +1202,7 @@ async def _generate_with_builtin_tools( lora_request=lora_request, trace_headers=trace_headers, priority=priority, + cache_hit_threshold=cache_hit_threshold, ) generator = self.engine_client.generate( @@ -1250,6 +1257,7 @@ def _log_inputs( inputs: RequestPrompt | PromptType, params: SamplingParams | PoolingParams | BeamSearchParams | None, lora_request: LoRARequest | None, + cache_hit_threshold: float | None = None, ) -> None: if self.request_logger is None: return @@ -1263,6 +1271,7 @@ def _log_inputs( prompt_embeds, params=params, lora_request=lora_request, + cache_hit_threshold=cache_hit_threshold, ) async def _get_trace_headers( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f51744eb2640..fd43bdbc94c9 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -110,6 +110,11 @@ def __init__( self.parallel_config.data_parallel_rank, ) + # List to collect requests that are below the cache hit ratio + # threshold. These requests will be finished and the list cleared + # in update_from_output(). + self.cache_hit_below_threshold_request_ids: list[str] = [] + num_gpu_blocks = self.cache_config.num_gpu_blocks assert num_gpu_blocks is not None and num_gpu_blocks > 0 @@ -432,6 +437,40 @@ def schedule(self) -> SchedulerOutput: num_computed_tokens = ( num_new_local_computed_tokens + num_external_computed_tokens ) + # Cache hit threshold in request overrides global setting + scheduler_config = self.vllm_config.scheduler_config + cache_hit_threshold = ( + request.cache_hit_threshold + if request.cache_hit_threshold is not None + else scheduler_config.global_cache_hit_threshold + ) + + # Check if cache hit is above threshold + cache_hit_percent = ( + num_computed_tokens / request.num_prompt_tokens + if request.num_prompt_tokens > 0 + else 0.0 + ) + if cache_hit_percent < cache_hit_threshold: + threshold_source = ( + "request" + if request.cache_hit_threshold is not None + else "global" + ) + logger.debug( + "Request %s rejected: cache hit rate %.2f" + " < threshold %.2f (%s)", + request.request_id, + cache_hit_percent, + cache_hit_threshold, + threshold_source, + ) + self.waiting.pop_request() + self.cache_hit_below_threshold_request_ids.append( + request.request_id + ) + continue + # KVTransfer: WAITING reqs have num_computed_tokens > 0 # after async KV recvs are completed. else: @@ -1059,6 +1098,26 @@ def update_from_output( batch = KVEventBatch(ts=time.time(), events=events) self.kv_event_publisher.publish(batch) + # Handle requests that were rejected due to low cache hit rate. + if self.cache_hit_below_threshold_request_ids: + for req_id in self.cache_hit_below_threshold_request_ids: + req = self.requests.get(req_id) + if req is None: + # The request is already finished, e.g. aborted. + continue + # Add EngineCoreOutput for this Request. + req.status = RequestStatus.FINISHED_CACHE_HIT_BELOW_THRESHOLD + outputs[req.client_index].append( + EngineCoreOutput( + request_id=req_id, + new_token_ids=[], + finish_reason=req.get_finished_reason(), + ) + ) + self._free_request(req) + # Clear the list after finishing all such requests. + self.cache_hit_below_threshold_request_ids.clear() + # Create EngineCoreOutputs for all clients that have requests with # outputs in this step. engine_core_outputs = { diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index e2c1ed7b561c..3480d6582a86 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -18,7 +18,7 @@ # These are possible values of RequestOutput.finish_reason, # so form part of the external API. -FINISH_REASON_STRINGS = ("stop", "length", "abort") +FINISH_REASON_STRINGS = ("stop", "length", "abort", "cache_threshold") class FinishReason(enum.IntEnum): @@ -30,12 +30,14 @@ class FinishReason(enum.IntEnum): stop - a stop string was emitted length - max_tokens was consumed, or max_model_len was reached abort - aborted for another reason + cache_threshold - not handled due to cache hit below threshold """ STOP = 0 LENGTH = 1 ABORT = 2 + CACHE_THRESHOLD = 3 def __str__(self): return FINISH_REASON_STRINGS[self.value] @@ -58,6 +60,7 @@ class EngineCoreRequest( cache_salt: str | None data_parallel_rank: int | None prompt_embeds: torch.Tensor | None = None + cache_hit_threshold: float | None = None # Index of the client, used to ensure outputs are sent back to the same # client for this request when scaling out the front-end. diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index f0d5b77e8e18..240eabef09ea 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -268,6 +268,7 @@ async def add_request( priority: int = 0, data_parallel_rank: int | None = None, prompt_text: str | None = None, + cache_hit_threshold: float | None = None, ) -> RequestOutputCollector: """Add new request to the AsyncLLM.""" @@ -298,6 +299,7 @@ async def add_request( trace_headers, priority, data_parallel_rank, + cache_hit_threshold=cache_hit_threshold, ) if isinstance(prompt, str): prompt_text = prompt @@ -359,6 +361,7 @@ async def generate( trace_headers: Mapping[str, str] | None = None, priority: int = 0, data_parallel_rank: int | None = None, + cache_hit_threshold: float | None = None, ) -> AsyncGenerator[RequestOutput, None]: """ Main function called by the API server to kick off a request @@ -411,6 +414,7 @@ async def generate( priority=priority, data_parallel_rank=data_parallel_rank, prompt_text=prompt_text, + cache_hit_threshold=cache_hit_threshold, ) # The output_handler task pushes items into the queue. @@ -546,6 +550,7 @@ async def encode( priority: int = 0, truncate_prompt_tokens: int | None = None, tokenization_kwargs: dict[str, Any] | None = None, + cache_hit_threshold: float | None = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: """ Main function called by the API server to kick off a request @@ -583,6 +588,7 @@ async def encode( tokenization_kwargs=tokenization_kwargs, trace_headers=trace_headers, priority=priority, + cache_hit_threshold=cache_hit_threshold, ) # The output_handler task pushes items into the queue. diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index f44b6b2070d9..2f01a9eaee36 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -225,6 +225,7 @@ def add_request( trace_headers: Mapping[str, str] | None = None, priority: int = 0, prompt_text: str | None = None, + cache_hit_threshold: float | None = None, ) -> None: # Validate the request_id type. if not isinstance(request_id, str): @@ -248,6 +249,7 @@ def add_request( tokenization_kwargs, trace_headers, priority, + cache_hit_threshold=cache_hit_threshold, ) if isinstance(prompt, str): prompt_text = prompt diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index c49fd1bde8b9..e20aa1a705df 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -349,6 +349,7 @@ def process_inputs( trace_headers: Mapping[str, str] | None = None, priority: int = 0, data_parallel_rank: int | None = None, + cache_hit_threshold: float | None = None, ) -> EngineCoreRequest: self._validate_lora(lora_request) self._validate_params(params) @@ -477,6 +478,7 @@ def process_inputs( priority=priority, data_parallel_rank=data_parallel_rank, trace_headers=trace_headers, + cache_hit_threshold=cache_hit_threshold, ) def _validate_model_inputs( diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 864b0eb7fa41..a317556013a1 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -43,6 +43,7 @@ def __init__( cache_salt: str | None = None, priority: int = 0, trace_headers: Mapping[str, str] | None = None, + cache_hit_threshold: float | None = None, block_hasher: Callable[["Request"], list["BlockHash"]] | None = None, ) -> None: self.request_id = request_id @@ -65,6 +66,8 @@ def __init__( # P/D: Connector-specific KV transfer parameters. self.kv_transfer_params: dict[str, Any] | None = None + self.cache_hit_threshold: float | None = cache_hit_threshold + if pooling_params is not None: # Pooling models. self.max_tokens = 1 @@ -148,6 +151,7 @@ def from_engine_core_request( priority=request.priority, trace_headers=request.trace_headers, block_hasher=block_hasher, + cache_hit_threshold=request.cache_hit_threshold, ) def append_output_token_ids( @@ -223,6 +227,7 @@ class RequestStatus(enum.IntEnum): FINISHED_LENGTH_CAPPED = enum.auto() FINISHED_ABORTED = enum.auto() FINISHED_IGNORED = enum.auto() + FINISHED_CACHE_HIT_BELOW_THRESHOLD = enum.auto() def __str__(self): return self.name @@ -245,4 +250,5 @@ def get_finished_reason(status: "RequestStatus") -> FinishReason | None: RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH, RequestStatus.FINISHED_ABORTED: FinishReason.ABORT, RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH, + RequestStatus.FINISHED_CACHE_HIT_BELOW_THRESHOLD: FinishReason.CACHE_THRESHOLD, }