Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 50 additions & 26 deletions tests/e2e/singlecard/core/test_ascend_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,7 @@ def create_scheduler(
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
**({
"tensors": {}
} if vllm_version_is("0.9.0") else {
"kv_cache_tensors": []
}),
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
Expand Down Expand Up @@ -145,8 +141,8 @@ def create_requests(num_requests: int,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
**({
"arrival_time": 0.0
} if vllm_version_is("0.9.0") else {}),
"pooling_params": None
} if not vllm_version_is("0.9.1") else {}),
)
requests.append(request)
return requests
Expand Down Expand Up @@ -262,7 +258,9 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(output, model_runner_output)

# Schedule the next step. All three requests are running.
Expand All @@ -286,7 +284,10 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))

scheduler.update_from_output(output1, model_runner_output)
output2 = scheduler.schedule()
assert len(scheduler.running) == 3
Expand Down Expand Up @@ -337,7 +338,10 @@ def test_stop_via_update_from_output():
11]], # First request hits EOS, second continues
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))

scheduler.update_from_output(scheduler_output, model_output)

Expand Down Expand Up @@ -385,7 +389,10 @@ def test_stop_via_update_from_output():
[13, 14]], # First request hits stop token
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))

scheduler.update_from_output(scheduler_output, model_output)

Expand Down Expand Up @@ -432,7 +439,10 @@ def test_stop_via_update_from_output():
[13]], # First request exceeds max_tokens
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))

scheduler.update_from_output(scheduler_output, model_output)

Expand Down Expand Up @@ -474,7 +484,10 @@ def test_stop_via_update_from_output():
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))

scheduler.update_from_output(scheduler_output, model_output)

Expand Down Expand Up @@ -524,7 +537,10 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))

scheduler.update_from_output(scheduler_output0, model_runner_output)

# Schedule the next step.
Expand All @@ -541,7 +557,10 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))

scheduler.update_from_output(scheduler_output1, model_runner_output)


Expand All @@ -565,8 +584,6 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
1. Speculated tokens get scheduled correctly
2. Spec decoding stats properly count number of draft and accepted tokens
"""
if vllm_version_is("0.9.0"):
return
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
Expand All @@ -593,7 +610,10 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
spec_token_ids=spec_tokens,
logprobs=None,
prompt_logprobs_dict={},
)
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))

engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)

Expand Down Expand Up @@ -632,7 +652,10 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))

engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)

Expand Down Expand Up @@ -727,7 +750,9 @@ def make_output(scheduler: AscendScheduler):
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))


def assert_scheduler_empty(scheduler: AscendScheduler):
Expand All @@ -744,11 +769,10 @@ def assert_scheduler_empty(scheduler: AscendScheduler):
assert len(scheduler.encoder_cache_manager.cached) == 0

# KVCache Manager.
if not vllm_version_is("0.9.0"):
assert len(scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.coordinator.
single_type_managers[0].num_cached_block) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block) == 0
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
Expand Down Expand Up @@ -789,4 +813,4 @@ def test_memory_leak():
scheduler.update_from_output(scheduler_output, model_runner_output)

# Confirm no memory leak.
assert_scheduler_empty(scheduler)
assert_scheduler_empty(scheduler)
24 changes: 20 additions & 4 deletions tests/e2e/singlecard/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from vllm.v1.structured_output import StructuredOutputManager

from vllm_ascend.core.scheduler import AscendScheduler
from vllm_ascend.utils import vllm_version_is

EOS_TOKEN_ID = 50256

Expand Down Expand Up @@ -130,6 +131,9 @@ def create_requests(num_requests: int,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
**({
"pooling_params": None
} if not vllm_version_is("0.9.1") else {}),
)
requests.append(request)
return requests
Expand Down Expand Up @@ -237,7 +241,10 @@ def test_stop_via_update_from_output():
11]], # First request hits EOS, second continues
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))

scheduler.update_from_output(scheduler_output, model_output)

Expand Down Expand Up @@ -283,7 +290,10 @@ def test_stop_via_update_from_output():
[13, 14]], # First request hits stop token
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))

scheduler.update_from_output(scheduler_output, model_output)

Expand Down Expand Up @@ -328,7 +338,10 @@ def test_stop_via_update_from_output():
[13]], # First request exceeds max_tokens
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))

scheduler.update_from_output(scheduler_output, model_output)

Expand Down Expand Up @@ -369,7 +382,10 @@ def test_stop_via_update_from_output():
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))

scheduler.update_from_output(scheduler_output, model_output)

Expand Down
Loading
Loading