Skip to content

Commit eeae693

Browse files
Kfir Wolfsonkfirwolfson
authored andcommitted
Add cache hit threshold. Squash of 11 commits including:
Fix Gemini CR comments Add unit tests Move from SamplingParams to request unit test remake fix static code analysis rejects Fix unit test fix after local CR fix pre-commit reject add threshold to request logger and fix some calls to encode fix ruff Signed-off-by: Kfir Wolfson <kfirw@pliops.com>
1 parent f231e5b commit eeae693

File tree

19 files changed

+358
-18
lines changed

19 files changed

+358
-18
lines changed

tests/v1/core/test_scheduler.py

Lines changed: 171 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,176 @@ def test_kv_connector_handles_preemption():
11431143
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
11441144

11451145

1146+
def _iterate_until_done(scheduler: Scheduler):
1147+
while True:
1148+
scheduler_output = scheduler.schedule()
1149+
if len(scheduler.running) == 0:
1150+
break
1151+
model_runner_output = make_output(scheduler)
1152+
scheduler.update_from_output(scheduler_output, model_runner_output)
1153+
1154+
1155+
@pytest.mark.parametrize(
1156+
"global_threshold,"
1157+
"request_num_tokens,"
1158+
"request_local_hit_blocks,"
1159+
"request_external_hit_blocks,"
1160+
"request_thresholds,"
1161+
"request_expected_scehduled",
1162+
[
1163+
(
1164+
0.0,
1165+
[57, 34, 28],
1166+
[1, 1, 0],
1167+
[0, 1, 0],
1168+
# expected hit ratio: [0.281, 0.941, 0.0]
1169+
# calculated as (local + external) * BLOCK_SIZE / tokens
1170+
[None, 0.4, 0.1],
1171+
[True, True, False],
1172+
),
1173+
(
1174+
0.3,
1175+
[157, 134, 128, 20, 150],
1176+
[4, 1, 0, 0, 1],
1177+
[2, 4, 0, 1, 0],
1178+
# expected hit ratio: [0.611, 0.597, 0.0, 0.8, 0.106]
1179+
[0.8, 0.4, 0.1, None, None],
1180+
[False, True, False, True, False],
1181+
),
1182+
],
1183+
)
1184+
def test_cache_hit_threshold(
1185+
# we validate global_threshold is used when request threshold is None
1186+
global_threshold: float,
1187+
# number of tokens in each request
1188+
request_num_tokens: list[int],
1189+
# number of blocks hit in local cache per request
1190+
request_local_hit_blocks: list[int],
1191+
# number of blocks hit in external cache per request
1192+
request_external_hit_blocks: list[int],
1193+
# optional cache_hit_threshold for each request
1194+
request_thresholds: list[Optional[float]],
1195+
# bool per request indicating if it is expected to be scheduled
1196+
request_expected_scehduled: list[bool],
1197+
):
1198+
assert (
1199+
len(request_num_tokens)
1200+
== len(request_thresholds)
1201+
== len(request_local_hit_blocks)
1202+
== len(request_external_hit_blocks)
1203+
== len(request_expected_scehduled)
1204+
)
1205+
1206+
scheduler = create_scheduler(
1207+
enable_prefix_caching=True,
1208+
global_cache_hit_threshold=global_threshold,
1209+
use_kv_connector=True,
1210+
)
1211+
1212+
_insert_to_local_cache(request_local_hit_blocks, scheduler)
1213+
_mock_external_cache_hit(request_external_hit_blocks, scheduler)
1214+
1215+
requests, scheduler_output = _create_and_schedule_requests(
1216+
request_num_tokens, request_thresholds, scheduler
1217+
)
1218+
1219+
# assert all requests expected to be scheduled are indeed scheduled
1220+
assert [
1221+
r.request_id
1222+
for r, expected in zip(requests, request_expected_scehduled)
1223+
if expected
1224+
] == [s.req_id for s in scheduler_output.scheduled_new_reqs]
1225+
1226+
# assert other requests are "finished" due to cache threshold
1227+
requests_expected_not_scheduled = [
1228+
r for r, expected in zip(requests, request_expected_scehduled) if not expected
1229+
]
1230+
assert all(
1231+
r.status == RequestStatus.FINISHED_CACHE_HIT_BELOW_THRESHOLD
1232+
for r in requests_expected_not_scheduled
1233+
)
1234+
1235+
_iterate_until_done(scheduler)
1236+
assert_scheduler_empty(scheduler)
1237+
1238+
1239+
def _create_and_schedule_requests(
1240+
request_num_tokens: list[int],
1241+
request_thresholds: list[Optional[float]],
1242+
scheduler: Scheduler,
1243+
):
1244+
num_requests = len(request_num_tokens)
1245+
requests = create_requests(
1246+
num_requests=num_requests,
1247+
num_tokens=request_num_tokens,
1248+
block_size=scheduler.cache_config.block_size,
1249+
cache_hit_thresholds=request_thresholds,
1250+
)
1251+
1252+
for request in requests:
1253+
scheduler.add_request(request)
1254+
1255+
scheduler_output = scheduler.schedule()
1256+
model_runner_output = make_output(scheduler)
1257+
scheduler.update_from_output(scheduler_output, model_runner_output)
1258+
return requests, scheduler_output
1259+
1260+
1261+
def _mock_external_cache_hit(request_external_hit_blocks, scheduler: Scheduler):
1262+
BLOCK_SIZE = scheduler.cache_config.block_size
1263+
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
1264+
scheduler.connector.get_num_new_matched_tokens.side_effect = [
1265+
(i * BLOCK_SIZE, False) for i in request_external_hit_blocks
1266+
]
1267+
1268+
1269+
def _insert_to_local_cache(request_local_hit_blocks, scheduler: Scheduler):
1270+
"""Schedule requests to fill in the local cache"""
1271+
BLOCK_SIZE = scheduler.cache_config.block_size
1272+
num_total_requests = len(request_local_hit_blocks)
1273+
1274+
requests_to_schedule = [
1275+
i for i, hit_blocks in enumerate(request_local_hit_blocks) if hit_blocks > 0
1276+
]
1277+
1278+
num_requests_to_schedule = len(requests_to_schedule)
1279+
if num_requests_to_schedule == 0:
1280+
# nothing to do
1281+
return
1282+
1283+
# Mock no external Cache Hit for this cache-warmup phase
1284+
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
1285+
scheduler.connector.get_num_new_matched_tokens.return_value = (0, False)
1286+
1287+
# set threshold to 0.0 to ensure all are scheduled
1288+
zero_thresholds: list[Optional[float]] = [0.0] * num_total_requests
1289+
1290+
# Only requests with local hits should run and populate the cache
1291+
# We create all requests to make sure the correct tokens are cached
1292+
# (since the tokens are generated according to request id)
1293+
requests = create_requests(
1294+
num_requests=num_total_requests,
1295+
num_tokens=[x * BLOCK_SIZE for x in request_local_hit_blocks],
1296+
block_size=BLOCK_SIZE,
1297+
cache_hit_thresholds=zero_thresholds,
1298+
)
1299+
1300+
# Only schedule the request we want to run and populate the cache
1301+
for i in requests_to_schedule:
1302+
scheduler.add_request(requests[i])
1303+
1304+
scheduler_output = scheduler.schedule()
1305+
1306+
# verify all were indeed scheduled
1307+
assert len(scheduler_output.scheduled_new_reqs) == num_requests_to_schedule
1308+
1309+
# iterate until all scheduled requests are done
1310+
model_runner_output = make_output(scheduler)
1311+
scheduler.update_from_output(scheduler_output, model_runner_output)
1312+
_iterate_until_done(scheduler)
1313+
assert_scheduler_empty(scheduler)
1314+
1315+
11461316
def make_output(scheduler: Scheduler):
11471317
return ModelRunnerOutput(
11481318
req_ids=[req.request_id for req in scheduler.running],
@@ -1215,13 +1385,7 @@ def test_memory_leak():
12151385
model_runner_output = make_output(scheduler)
12161386
scheduler.update_from_output(scheduler_output, model_runner_output)
12171387

1218-
# Iterate until done.
1219-
while True:
1220-
scheduler_output = scheduler.schedule()
1221-
if len(scheduler.running) == 0:
1222-
break
1223-
model_runner_output = make_output(scheduler)
1224-
scheduler.update_from_output(scheduler_output, model_runner_output)
1388+
_iterate_until_done(scheduler)
12251389

12261390
# Confirm no memory leak.
12271391
assert_scheduler_empty(scheduler)

tests/v1/core/utils.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def create_scheduler(
4747
num_speculative_tokens: Optional[int] = None,
4848
skip_tokenizer_init: bool = False,
4949
async_scheduling: bool = False,
50+
global_cache_hit_threshold: float = 0.0,
5051
) -> Union[Scheduler, AsyncScheduler]:
5152
"""Create scheduler under test.
5253
@@ -71,6 +72,7 @@ def create_scheduler(
7172
disable_chunked_mm_input=disable_chunked_mm_input,
7273
enable_chunked_prefill=True,
7374
async_scheduling=async_scheduling,
75+
global_cache_hit_threshold=global_cache_hit_threshold,
7476
)
7577
model_config = ModelConfig(
7678
model=model,
@@ -139,19 +141,21 @@ def create_scheduler(
139141

140142
def create_requests(
141143
num_requests: int,
142-
num_tokens: int = 10,
144+
num_tokens: Union[int, list[int]] = 10,
143145
mm_positions: Optional[list[list[PlaceholderRange]]] = None,
144146
max_tokens: int = 16,
145147
stop_token_ids: Optional[list[int]] = None,
146148
prompt_logprobs: Optional[int] = None,
147149
same_prompt: bool = False,
148150
block_size: int = 16,
151+
cache_hit_thresholds: Optional[list[Optional[float]]] = None,
149152
) -> list[Request]:
150153
global _none_hash_initialized
151154
if not _none_hash_initialized:
152155
init_none_hash(sha256)
153156
_none_hash_initialized = True
154-
157+
if cache_hit_thresholds is not None:
158+
assert len(cache_hit_thresholds) == num_requests
155159
block_hasher = get_request_block_hasher(block_size, sha256)
156160
sampling_params = SamplingParams(
157161
ignore_eos=False,
@@ -175,8 +179,16 @@ def create_requests(
175179
modality="image",
176180
)
177181
mm_features.append(mm_feature)
178-
179-
prompt_token_ids = [0] * num_tokens if same_prompt else [i] * num_tokens
182+
request_num_tokens: int = (
183+
num_tokens[i] if isinstance(num_tokens, list) else num_tokens
184+
)
185+
prompt_token_ids = (
186+
[0] * request_num_tokens if same_prompt else [i] * request_num_tokens
187+
)
188+
if cache_hit_thresholds is not None:
189+
cache_hit_threshold = cache_hit_thresholds[i]
190+
else:
191+
cache_hit_threshold = None
180192
request = Request(
181193
request_id=f"{i}",
182194
prompt_token_ids=prompt_token_ids,
@@ -185,6 +197,7 @@ def create_requests(
185197
mm_features=mm_features if mm_features else None,
186198
eos_token_id=EOS_TOKEN_ID,
187199
block_hasher=block_hasher,
200+
cache_hit_threshold=cache_hit_threshold,
188201
)
189202
requests.append(request)
190203
return requests

vllm/config/scheduler.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,14 @@ class SchedulerConfig:
151151
structured outputs, speculative decoding, and pipeline parallelism.
152152
"""
153153

154+
global_cache_hit_threshold: float = 0.0
155+
"""The threshold for cache hit ratio to handle all requests,
156+
except for requests which override it using the "cache_hit_threshold" field.
157+
This feature enables Decode-first optimization in P/D disaggregation:
158+
Decode nodes can avoide remote Prefill in case of high cache hit ratio.
159+
If set to 0.0, the optimization is disabled.
160+
"""
161+
154162
def compute_hash(self) -> str:
155163
"""
156164
WARNING: Whenever a new field is added to this config,
@@ -316,4 +324,13 @@ def _verify_args(self) -> Self:
316324
f"max_num_partial_prefills ({self.max_num_partial_prefills})."
317325
)
318326

327+
if (self.global_cache_hit_threshold < 0.0) or (
328+
self.global_cache_hit_threshold > 1.0
329+
):
330+
raise ValueError(
331+
"global_cache_hit_threshold "
332+
f"({self.global_cache_hit_threshold}) "
333+
"must be between 0.0 and 1.0, inclusive."
334+
)
335+
319336
return self

vllm/config/vllm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,8 @@ def __str__(self):
773773
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
774774
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
775775
f"pooler_config={self.model_config.pooler_config!r}, "
776-
f"compilation_config={self.compilation_config!r}"
776+
f"compilation_config={self.compilation_config!r}",
777+
f"global_cache_hit_threshold={self.scheduler_config.global_cache_hit_threshold}",
777778
)
778779

779780

vllm/engine/arg_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,8 @@ class EngineArgs:
516516

517517
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
518518

519+
global_cache_hit_threshold: float = SchedulerConfig.global_cache_hit_threshold
520+
519521
def __post_init__(self):
520522
# support `EngineArgs(compilation_config={...})`
521523
# without having to manually construct a
@@ -987,6 +989,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
987989
scheduler_group.add_argument(
988990
"--scheduler-cls", **scheduler_kwargs["scheduler_cls"]
989991
)
992+
scheduler_group.add_argument(
993+
"--global-cache-hit-threshold",
994+
**scheduler_kwargs["global_cache_hit_threshold"],
995+
)
990996
scheduler_group.add_argument(
991997
"--disable-hybrid-kv-cache-manager",
992998
**scheduler_kwargs["disable_hybrid_kv_cache_manager"],
@@ -1485,6 +1491,7 @@ def create_engine_config(
14851491
long_prefill_token_threshold=self.long_prefill_token_threshold,
14861492
disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
14871493
async_scheduling=self.async_scheduling,
1494+
global_cache_hit_threshold=self.global_cache_hit_threshold,
14881495
)
14891496

14901497
if not model_config.is_multimodal_model and self.default_mm_loras:

vllm/engine/protocol.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def generate(
5757
trace_headers: Optional[Mapping[str, str]] = None,
5858
priority: int = 0,
5959
data_parallel_rank: Optional[int] = None,
60+
cache_hit_threshold: Optional[float] = None,
6061
) -> AsyncGenerator[RequestOutput, None]:
6162
"""Generate outputs for a request."""
6263
...
@@ -245,6 +246,7 @@ def encode(
245246
trace_headers: Optional[Mapping[str, str]] = None,
246247
priority: int = 0,
247248
tokenization_kwargs: Optional[dict[str, Any]] = None,
249+
cache_hit_threshold: Optional[float] = None,
248250
) -> AsyncGenerator[PoolingRequestOutput, None]:
249251
"""Generate outputs for a request from a pooling model."""
250252
...

vllm/entrypoints/api_server.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,14 @@ async def generate(request: Request) -> Response:
5858
async def _generate(request_dict: dict, raw_request: Request) -> Response:
5959
prompt = request_dict.pop("prompt")
6060
stream = request_dict.pop("stream", False)
61+
cache_hit_threshold = request_dict.pop("cache_hit_threshold", None)
6162
sampling_params = SamplingParams(**request_dict)
6263
request_id = random_uuid()
6364

6465
assert engine is not None
65-
results_generator = engine.generate(prompt, sampling_params, request_id)
66+
results_generator = engine.generate(
67+
prompt, sampling_params, request_id, cache_hit_threshold=cache_hit_threshold
68+
)
6669

6770
# Streaming case
6871
async def stream_results() -> AsyncGenerator[bytes, None]:

vllm/entrypoints/logger.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def log_inputs(
2626
prompt_embeds: Optional[torch.Tensor],
2727
params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]],
2828
lora_request: Optional[LoRARequest],
29+
cache_hit_threshold: Optional[float],
2930
) -> None:
3031
max_log_len = self.max_log_len
3132
if max_log_len is not None:
@@ -39,13 +40,15 @@ def log_inputs(
3940
"Received request %s: prompt: %r, "
4041
"params: %s, prompt_token_ids: %s, "
4142
"prompt_embeds shape: %s, "
42-
"lora_request: %s.",
43+
"lora_request: %s, "
44+
"cache_hit_threshold: %s.",
4345
request_id,
4446
prompt,
4547
params,
4648
prompt_token_ids,
4749
prompt_embeds.shape if prompt_embeds is not None else None,
4850
lora_request,
51+
cache_hit_threshold,
4952
)
5053

5154
def log_outputs(

0 commit comments

Comments
 (0)