Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ def test_request_logger_log_outputs_integration():
prompt_embeds=None,
params=None,
lora_request=None,
cache_hit_threshold=None,
)

request_logger.log_outputs(
Expand Down
178 changes: 171 additions & 7 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,176 @@ def test_kv_connector_handles_preemption():
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1


def _iterate_until_done(scheduler: Scheduler):
while True:
scheduler_output = scheduler.schedule()
if len(scheduler.running) == 0:
break
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)


@pytest.mark.parametrize(
"global_threshold,"
"request_num_tokens,"
"request_local_hit_blocks,"
"request_external_hit_blocks,"
"request_thresholds,"
"request_expected_scehduled",
[
(
0.0,
[57, 34, 28],
[1, 1, 0],
[0, 1, 0],
# expected hit ratio: [0.281, 0.941, 0.0]
# calculated as (local + external) * BLOCK_SIZE / tokens
[None, 0.4, 0.1],
[True, True, False],
),
(
0.3,
[157, 134, 128, 20, 150],
[4, 1, 0, 0, 1],
[2, 4, 0, 1, 0],
# expected hit ratio: [0.611, 0.597, 0.0, 0.8, 0.106]
[0.8, 0.4, 0.1, None, None],
[False, True, False, True, False],
),
],
)
def test_cache_hit_threshold(
# we validate global_threshold is used when request threshold is None
global_threshold: float,
# number of tokens in each request
request_num_tokens: list[int],
# number of blocks hit in local cache per request
request_local_hit_blocks: list[int],
# number of blocks hit in external cache per request
request_external_hit_blocks: list[int],
# optional cache_hit_threshold for each request
request_thresholds: list[float | None],
# bool per request indicating if it is expected to be scheduled
request_expected_scehduled: list[bool],
):
assert (
len(request_num_tokens)
== len(request_thresholds)
== len(request_local_hit_blocks)
== len(request_external_hit_blocks)
== len(request_expected_scehduled)
)

scheduler = create_scheduler(
enable_prefix_caching=True,
global_cache_hit_threshold=global_threshold,
use_kv_connector=True,
)

_insert_to_local_cache(request_local_hit_blocks, scheduler)
_mock_external_cache_hit(request_external_hit_blocks, scheduler)

requests, scheduler_output = _create_and_schedule_requests(
request_num_tokens, request_thresholds, scheduler
)

# assert all requests expected to be scheduled are indeed scheduled
assert [
r.request_id
for r, expected in zip(requests, request_expected_scehduled)
if expected
] == [s.req_id for s in scheduler_output.scheduled_new_reqs]

# assert other requests are "finished" due to cache threshold
requests_expected_not_scheduled = [
r for r, expected in zip(requests, request_expected_scehduled) if not expected
]
assert all(
r.status == RequestStatus.FINISHED_CACHE_HIT_BELOW_THRESHOLD
for r in requests_expected_not_scheduled
)

_iterate_until_done(scheduler)
assert_scheduler_empty(scheduler)


def _create_and_schedule_requests(
request_num_tokens: list[int],
request_thresholds: list[float | None],
scheduler: Scheduler,
):
num_requests = len(request_num_tokens)
requests = create_requests(
num_requests=num_requests,
num_tokens=request_num_tokens,
block_size=scheduler.cache_config.block_size,
cache_hit_thresholds=request_thresholds,
)

for request in requests:
scheduler.add_request(request)

scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
return requests, scheduler_output


def _mock_external_cache_hit(request_external_hit_blocks, scheduler: Scheduler):
BLOCK_SIZE = scheduler.cache_config.block_size
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.side_effect = [
(i * BLOCK_SIZE, False) for i in request_external_hit_blocks
]


def _insert_to_local_cache(request_local_hit_blocks, scheduler: Scheduler):
"""Schedule requests to fill in the local cache"""
BLOCK_SIZE = scheduler.cache_config.block_size
num_total_requests = len(request_local_hit_blocks)

requests_to_schedule = [
i for i, hit_blocks in enumerate(request_local_hit_blocks) if hit_blocks > 0
]

num_requests_to_schedule = len(requests_to_schedule)
if num_requests_to_schedule == 0:
# nothing to do
return

# Mock no external Cache Hit for this cache-warmup phase
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (0, False)

# set threshold to 0.0 to ensure all are scheduled
zero_thresholds: list[float | None] = [0.0] * num_total_requests

# Only requests with local hits should run and populate the cache
# We create all requests to make sure the correct tokens are cached
# (since the tokens are generated according to request id)
requests = create_requests(
num_requests=num_total_requests,
num_tokens=[x * BLOCK_SIZE for x in request_local_hit_blocks],
block_size=BLOCK_SIZE,
cache_hit_thresholds=zero_thresholds,
)

# Only schedule the request we want to run and populate the cache
for i in requests_to_schedule:
scheduler.add_request(requests[i])

scheduler_output = scheduler.schedule()

# verify all were indeed scheduled
assert len(scheduler_output.scheduled_new_reqs) == num_requests_to_schedule

# iterate until all scheduled requests are done
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
_iterate_until_done(scheduler)
assert_scheduler_empty(scheduler)


def make_output(scheduler: Scheduler):
return ModelRunnerOutput(
req_ids=[req.request_id for req in scheduler.running],
Expand Down Expand Up @@ -1359,13 +1529,7 @@ def test_memory_leak():
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)

# Iterate until done.
while True:
scheduler_output = scheduler.schedule()
if len(scheduler.running) == 0:
break
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
_iterate_until_done(scheduler)

# Confirm no memory leak.
assert_scheduler_empty(scheduler)
Expand Down
21 changes: 17 additions & 4 deletions tests/v1/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def create_scheduler(
skip_tokenizer_init: bool = False,
async_scheduling: bool = False,
disable_hybrid_kv_cache_manager: bool = False,
global_cache_hit_threshold: float = 0.0,
) -> Scheduler | AsyncScheduler:
"""Create scheduler under test.

Expand All @@ -72,6 +73,7 @@ def create_scheduler(
enable_chunked_prefill=True,
async_scheduling=async_scheduling,
disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager,
global_cache_hit_threshold=global_cache_hit_threshold,
)
model_config = ModelConfig(
model=model,
Expand Down Expand Up @@ -141,19 +143,21 @@ def create_scheduler(

def create_requests(
num_requests: int,
num_tokens: int = 10,
num_tokens: int | list[int] = 10,
mm_positions: list[list[PlaceholderRange]] | None = None,
max_tokens: int = 16,
stop_token_ids: list[int] | None = None,
prompt_logprobs: int | None = None,
same_prompt: bool = False,
block_size: int = 16,
cache_hit_thresholds: list[float | None] | None = None,
) -> list[Request]:
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(sha256)
_none_hash_initialized = True

if cache_hit_thresholds is not None:
assert len(cache_hit_thresholds) == num_requests
block_hasher = get_request_block_hasher(block_size, sha256)
sampling_params = SamplingParams(
ignore_eos=False,
Expand All @@ -177,8 +181,16 @@ def create_requests(
modality="image",
)
mm_features.append(mm_feature)

prompt_token_ids = [0] * num_tokens if same_prompt else [i] * num_tokens
request_num_tokens: int = (
num_tokens[i] if isinstance(num_tokens, list) else num_tokens
)
prompt_token_ids = (
[0] * request_num_tokens if same_prompt else [i] * request_num_tokens
)
if cache_hit_thresholds is not None:
cache_hit_threshold = cache_hit_thresholds[i]
else:
cache_hit_threshold = None
request = Request(
request_id=f"{i}",
prompt_token_ids=prompt_token_ids,
Expand All @@ -187,6 +199,7 @@ def create_requests(
mm_features=mm_features if mm_features else None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=block_hasher,
cache_hit_threshold=cache_hit_threshold,
)
requests.append(request)
return requests
17 changes: 17 additions & 0 deletions vllm/config/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ class SchedulerConfig:
structured outputs, speculative decoding, and pipeline parallelism.
"""

global_cache_hit_threshold: float = 0.0
"""The threshold for cache hit ratio to handle all requests,
except for requests which override it using the "cache_hit_threshold" field.
This feature enables Decode-first optimization in P/D disaggregation:
Decode nodes can avoide remote Prefill in case of high cache hit ratio.
If set to 0.0, the optimization is disabled.
"""

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down Expand Up @@ -295,4 +303,13 @@ def _verify_args(self) -> Self:
f"{self.max_num_partial_prefills=}."
)

if (self.global_cache_hit_threshold < 0.0) or (
self.global_cache_hit_threshold > 1.0
):
raise ValueError(
"global_cache_hit_threshold "
f"({self.global_cache_hit_threshold}) "
"must be between 0.0 and 1.0, inclusive."
)

return self
3 changes: 2 additions & 1 deletion vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,8 @@ def __str__(self):
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
f"pooler_config={self.model_config.pooler_config!r}, "
f"compilation_config={self.compilation_config!r}"
f"compilation_config={self.compilation_config!r}",
f"global_cache_hit_threshold={self.scheduler_config.global_cache_hit_threshold}",
)

@model_validator(mode="after")
Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ class EngineArgs:
kv_offloading_backend: KVOffloadingBackend | None = (
CacheConfig.kv_offloading_backend
)
global_cache_hit_threshold: float = SchedulerConfig.global_cache_hit_threshold

def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
Expand Down Expand Up @@ -1054,6 +1055,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
scheduler_group.add_argument(
"--scheduler-cls", **scheduler_kwargs["scheduler_cls"]
)
scheduler_group.add_argument(
"--global-cache-hit-threshold",
**scheduler_kwargs["global_cache_hit_threshold"],
)
scheduler_group.add_argument(
"--disable-hybrid-kv-cache-manager",
**scheduler_kwargs["disable_hybrid_kv_cache_manager"],
Expand Down Expand Up @@ -1601,6 +1606,7 @@ def create_engine_config(
long_prefill_token_threshold=self.long_prefill_token_threshold,
disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
async_scheduling=self.async_scheduling,
global_cache_hit_threshold=self.global_cache_hit_threshold,
)

if not model_config.is_multimodal_model and self.default_mm_loras:
Expand Down
2 changes: 2 additions & 0 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def generate(
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
data_parallel_rank: int | None = None,
cache_hit_threshold: float | None = None,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request."""
...
Expand All @@ -79,6 +80,7 @@ def encode(
priority: int = 0,
truncate_prompt_tokens: int | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
cache_hit_threshold: float | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model."""
...
Expand Down
5 changes: 4 additions & 1 deletion vllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,14 @@ async def generate(request: Request) -> Response:
async def _generate(request_dict: dict, raw_request: Request) -> Response:
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
cache_hit_threshold = request_dict.pop("cache_hit_threshold", None)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()

assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id)
results_generator = engine.generate(
prompt, sampling_params, request_id, cache_hit_threshold=cache_hit_threshold
)

# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
Expand Down
7 changes: 5 additions & 2 deletions vllm/entrypoints/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def log_inputs(
prompt_embeds: torch.Tensor | None,
params: SamplingParams | PoolingParams | BeamSearchParams | None,
lora_request: LoRARequest | None,
cache_hit_threshold: float | None = None,
) -> None:
max_log_len = self.max_log_len
if max_log_len is not None:
Expand All @@ -37,18 +38,20 @@ def log_inputs(
logger.debug(
"Request %s details: prompt: %r, "
"prompt_token_ids: %s, "
"prompt_embeds shape: %s.",
"prompt_embeds shape: %s, ",
request_id,
prompt,
prompt_token_ids,
prompt_embeds.shape if prompt_embeds is not None else None,
)

logger.info(
"Received request %s: params: %s, lora_request: %s.",
"Received request %s: params: %s, lora_request: %s ",
"cache_hit_threshold: %s.",
request_id,
params,
lora_request,
cache_hit_threshold,
)

def log_outputs(
Expand Down
Loading