Skip to content

Commit 2665ff6

Browse files
committed
[Bugfix] Fix num_tokens_to_schedule & Add unit test
Signed-off-by: herotai214 <herotai214@gmail.com>
1 parent 5e79b0b commit 2665ff6

File tree

3 files changed

+180
-6
lines changed

3 files changed

+180
-6
lines changed

tests/v1/core/test_scheduler.py

Lines changed: 172 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
2424
from vllm.utils.hashing import sha256
25+
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
2526
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
2627
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
2728
from vllm.v1.core.sched.scheduler import Scheduler
@@ -1518,6 +1519,7 @@ def create_requests_with_priority(
15181519
starting_idx: int = 0,
15191520
same_prompt: bool = False,
15201521
block_size: int = 16,
1522+
req_ids: list[int] | None = None,
15211523
):
15221524
"""Create requests with specified priorities and arrival times."""
15231525
assert len(priorities) == num_requests
@@ -1553,6 +1555,11 @@ def create_requests_with_priority(
15531555
# Verify mm items with identical identifier are having mm_position.length
15541556
seen_hashes: dict[str, int] = {}
15551557

1558+
if req_ids:
1559+
assert len(req_ids) == num_requests
1560+
else:
1561+
req_ids = [f"{i + starting_idx}" for i in range(num_requests)]
1562+
15561563
for i in range(num_requests):
15571564
mm_features = []
15581565

@@ -1589,7 +1596,7 @@ def create_requests_with_priority(
15891596
else [i + starting_idx] * num_tokens
15901597
)
15911598
request = Request(
1592-
request_id=f"{i + starting_idx}",
1599+
request_id=req_ids[i],
15931600
prompt_token_ids=prompt_token_ids,
15941601
sampling_params=sampling_params,
15951602
pooling_params=None,
@@ -2273,6 +2280,7 @@ def _validate_chunked_prefill_settings_for_encoder_decoder(
22732280

22742281
def _assert_right_encoder_cache_allocated(
22752282
scheduler: Scheduler,
2283+
hashes_to_check: list[str] | None = None,
22762284
requests: list[Request] | None = None,
22772285
expected_total_allocated: int | None = None,
22782286
):
@@ -2291,6 +2299,13 @@ def _assert_right_encoder_cache_allocated(
22912299
# Verify each request with MM data is in cache
22922300
cached_hashes = set(encoder_cache_manager.cached.keys())
22932301

2302+
if hashes_to_check:
2303+
missed_hashes = set(hashes_to_check) - cached_hashes
2304+
assert not missed_hashes, (
2305+
f"Miss hashes: {missed_hashes} "
2306+
f"Existing encoder cache: {encoder_cache_manager.cached}"
2307+
)
2308+
22942309
for req in requests if requests is not None else []:
22952310
if req.mm_features:
22962311
mm_hashes = [f.identifier for f in req.mm_features]
@@ -2572,7 +2587,7 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
25722587
# Encoder cache should contain all mm items from request
25732588
_assert_right_encoder_cache_allocated(scheduler, requests=[request1])
25742589

2575-
# Should call update_state_after_alloc for external load
2590+
# Should have called update_state_after_alloc for external load
25762591
scheduler.ec_connector.update_state_after_alloc.assert_called()
25772592
scheduler.ec_connector.update_state_after_alloc.reset_mock()
25782593

@@ -2716,7 +2731,7 @@ def test_ec_connector_schedule_multiple_requests(cache_exist, use_kv_connector):
27162731

27172732
## Encoder-cache-specific checks:
27182733
# mm_hashes of requests exist in cache after scheduling for all scenario
2719-
_assert_right_encoder_cache_allocated(scheduler, requests)
2734+
_assert_right_encoder_cache_allocated(scheduler, requests=requests)
27202735

27212736
# Should only call update_state_after_alloc when loaded externally
27222737
if cache_exist == "connector_only":
@@ -2814,7 +2829,7 @@ def test_ec_connector_unable_to_allocate(use_kv_connector):
28142829
assert len(scheduler.running) == 1
28152830
assert len(scheduler.waiting) == 1
28162831

2817-
# Should call update_state_after_alloc for external load
2832+
# Should have called update_state_after_alloc for external load
28182833
scheduler.ec_connector.update_state_after_alloc.assert_called_with(
28192834
scheduler.running[0], 0
28202835
)
@@ -3051,7 +3066,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
30513066

30523067
## Encoder-cache-specific checks:
30533068
# mm_hash of request_low exists in cache after scheduling for all scenario
3054-
_assert_right_encoder_cache_allocated(scheduler, [request_low])
3069+
_assert_right_encoder_cache_allocated(scheduler, requests=[request_low])
30553070

30563071
# Should only call update_state_after_alloc when loaded externally
30573072
if cache_exist == "connector_only":
@@ -3080,6 +3095,158 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
30803095
_assert_right_encoder_inputs(output, expected_total_reqs=0)
30813096

30823097

3098+
@pytest.mark.parametrize("use_kv_connector", [False, True])
3099+
def test_ec_connector_allocate_encoder_tokens_with_external_load(use_kv_connector):
3100+
"""
3101+
Scenario:
3102+
- Encoder cache size: 32
3103+
- Request A: 1 feature (12 tokens) → NOT cached remotely.
3104+
- Request B: 3 features (3 x 10 tokens) → ALL cached remotely.
3105+
3106+
Steps:
3107+
1. Schedule Request A (locally uses 12 tokens).
3108+
2. Schedule Request B (remote cache — no local tokens used) - only schedule 1st and 2nd
3109+
3. Free A's cache, then schedule B again (continuation) - schedule 3rd image
3110+
"""
3111+
scheduler = create_scheduler(
3112+
model="llava-hf/llava-1.5-7b-hf",
3113+
max_num_batched_tokens=1024,
3114+
enable_prefix_caching=True,
3115+
use_kv_connector=use_kv_connector,
3116+
block_size=16,
3117+
num_blocks=11, # Can hold 160 tokens (first block is null)
3118+
use_ec_connector=True,
3119+
ec_role="ec_consumer",
3120+
disable_hybrid_kv_cache_manager=use_kv_connector,
3121+
)
3122+
3123+
# Limit the
3124+
scheduler.encoder_cache_manager = EncoderCacheManager(cache_size=32)
3125+
3126+
# Create MM request1
3127+
NUM_TOKENS_1 = 50 # NOTE: includes mm tokens
3128+
NUM_ENCODER_TOKENS_1 = 12
3129+
mm_hashes_list_1 = [["hash1_1"]]
3130+
mm_positions_1 = [[PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_1)]]
3131+
3132+
request1 = create_requests(
3133+
num_requests=1,
3134+
num_tokens=NUM_TOKENS_1,
3135+
mm_hashes_list=mm_hashes_list_1,
3136+
mm_positions=mm_positions_1,
3137+
max_tokens=1, # For simplicity
3138+
req_ids=["req1"],
3139+
)[0]
3140+
3141+
# Create MM request1 with 3 MM items
3142+
NUM_TOKENS_2 = 40
3143+
NUM_ENCODER_TOKENS_2 = 10
3144+
mm_hashes_list_2 = [["hash2_1", "hash2_2", "hash2_3"]]
3145+
mm_positions_2 = [
3146+
[
3147+
PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_2),
3148+
PlaceholderRange(offset=12, length=NUM_ENCODER_TOKENS_2),
3149+
PlaceholderRange(offset=24, length=NUM_ENCODER_TOKENS_2),
3150+
]
3151+
]
3152+
3153+
request2 = create_requests(
3154+
num_requests=1,
3155+
num_tokens=NUM_TOKENS_2,
3156+
mm_hashes_list=mm_hashes_list_2,
3157+
mm_positions=mm_positions_2,
3158+
max_tokens=10,
3159+
req_ids=["req2"],
3160+
)[0]
3161+
3162+
# Mock cache hit: MM of request1 NOT cached remotely, request2 cached remotely
3163+
scheduler.ec_connector.has_caches = Mock(
3164+
side_effect=lambda req: [True, True, True] if req == request2 else [False]
3165+
)
3166+
scheduler.ec_connector.update_state_after_alloc = Mock(
3167+
wraps=scheduler.ec_connector.update_state_after_alloc
3168+
)
3169+
3170+
scheduler.add_request(request1)
3171+
scheduler.add_request(request2)
3172+
output = scheduler.schedule()
3173+
3174+
# Now, since encoder cache manager can only store 32 tokens
3175+
# It should allocated mm item hash1_1, hash2_1 and hash2_2
3176+
scheduled_tokens = output.num_scheduled_tokens[request1.request_id]
3177+
assert scheduled_tokens == NUM_TOKENS_1
3178+
assert scheduler.get_num_unfinished_requests() == 2
3179+
3180+
# Encoder cache should contain mm item from request1
3181+
_assert_right_encoder_cache_allocated(
3182+
scheduler, hashes_to_check=['hash1_1', 'hash2_1', 'hash2_2']
3183+
)
3184+
3185+
# request2's 2nd mm item is the last call of update_state_after_alloc
3186+
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request2, 1)
3187+
scheduler.ec_connector.update_state_after_alloc.reset_mock()
3188+
3189+
# ECConnector should carry metadata of hash2_1 and hash2_2 ONLY
3190+
_assert_right_ec_connector_metadata(
3191+
output, mm_features_list=[request2.mm_features[0], request2.mm_features[1]]
3192+
)
3193+
3194+
# Should schedule ONLY 1 encoder input
3195+
_assert_right_encoder_inputs(
3196+
output,
3197+
requests=[request1],
3198+
expected_encoder_inputs=[[0]], # index 0 of the mm item of request1
3199+
expected_total_reqs=1,
3200+
)
3201+
3202+
# Simulate model execution 1 step
3203+
model_output = ModelRunnerOutput(
3204+
req_ids=[request1.request_id, request2.request_id],
3205+
req_id_to_index={request1.request_id: 0, request2.request_id: 1},
3206+
sampled_token_ids=[[100], [121]],
3207+
# spec_token_ids=None,
3208+
logprobs=None,
3209+
prompt_logprobs_dict={},
3210+
pooler_output=[],
3211+
)
3212+
scheduler.update_from_output(output, model_output)
3213+
3214+
# request1 is finished after outputing 1 token
3215+
# Finish request
3216+
scheduler.finish_requests(request1.request_id, RequestStatus.FINISHED_LENGTH_CAPPED)
3217+
assert scheduler.get_num_unfinished_requests() == 1
3218+
3219+
# Schedule again
3220+
# Now request1's encoder cache should be freed -> hash2_3 can be scheduled and allocated
3221+
output = scheduler.schedule()
3222+
3223+
# Check
3224+
# Should schedule all tokens
3225+
scheduled_tokens = output.num_scheduled_tokens[request2.request_id]
3226+
print(f"Hero: scheduled_tokens for req2: {scheduled_tokens}")
3227+
print(f"hero: num_scheduled_tokens 2: {output.num_scheduled_tokens}")
3228+
3229+
# Encoder cache should contain all mm items from request2
3230+
_assert_right_encoder_cache_allocated(scheduler, requests=[request2])
3231+
3232+
# request2's 3rd mm item is the ONLY call of update_state_after_alloc
3233+
scheduler.ec_connector.update_state_after_alloc.assert_called_with(request2, 2)
3234+
scheduler.ec_connector.update_state_after_alloc.assert_called_once()
3235+
3236+
scheduler.ec_connector.update_state_after_alloc.reset_mock()
3237+
3238+
# ECConnector should carry metadata for hash2_3 ONLY
3239+
_assert_right_ec_connector_metadata(
3240+
output, mm_features_list=[request2.mm_features[2]]
3241+
)
3242+
3243+
# Should schedule no encoder input
3244+
_assert_right_encoder_inputs(
3245+
output,
3246+
expected_total_reqs=0,
3247+
)
3248+
3249+
30833250
# ==============================================================================
30843251
# EPD (Encoder-Prefill-Decode) Encoder-cache-specific tests end
30853252
# ==============================================================================

tests/v1/core/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def create_requests(
163163
prompt_logprobs: int | None = None,
164164
same_prompt: bool = False,
165165
block_size: int = 16,
166+
req_ids: list[int] | None = None,
166167
) -> list[Request]:
167168
global _none_hash_initialized
168169
if not _none_hash_initialized:
@@ -191,6 +192,11 @@ def create_requests(
191192
# Verify mm items with identical identifier are having mm_position.length
192193
seen_hashes: dict[str, int] = {}
193194

195+
if req_ids:
196+
assert len(req_ids) == num_requests
197+
else:
198+
req_ids = [f"{i}" for i in range(num_requests)]
199+
194200
for i in range(num_requests):
195201
mm_features = []
196202

@@ -223,7 +229,7 @@ def create_requests(
223229

224230
prompt_token_ids = [0] * num_tokens if same_prompt else [i] * num_tokens
225231
request = Request(
226-
request_id=f"{i}",
232+
request_id=req_ids[i],
227233
prompt_token_ids=prompt_token_ids,
228234
sampling_params=sampling_params,
229235
pooling_params=None,

vllm/v1/core/sched/scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,7 @@ def _try_schedule_encoder_inputs(
898898
if self.ec_connector is not None and remote_cache_has_item[i]:
899899
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
900900
external_load_encoder_input.append(i)
901+
num_tokens_to_schedule += num_encoder_tokens
901902
continue
902903

903904
num_tokens_to_schedule += num_encoder_tokens

0 commit comments

Comments
 (0)