Skip to content

Commit 9338a9b

Browse files
committed
complete e2e test
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
1 parent f3dddfb commit 9338a9b

File tree

1 file changed

+101
-63
lines changed

1 file changed

+101
-63
lines changed

tests/v1/e2e/test_async_scheduling.py

Lines changed: 101 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_with_spec_decoding(self, monkeypatch: pytest.MonkeyPatch):
3939
monkeypatch,
4040
MTP_MODEL,
4141
[{}],
42-
spec_config={"method": "mtp", "num_speculative_tokens": 1},
42+
spec_configs=[{"method": "mtp", "num_speculative_tokens": 1}, None],
4343
)
4444

4545
def test_without_spec_decoding(
@@ -62,7 +62,7 @@ def test_without_spec_decoding(
6262
),
6363
]
6464
self.preempt_and_async_scheduling_e2e(
65-
monkeypatch, MODEL, sampling_param_tests, None
65+
monkeypatch, MODEL, sampling_param_tests, [None]
6666
)
6767

6868
@dynamo_config.patch(cache_size_limit=16)
@@ -71,83 +71,98 @@ def preempt_and_async_scheduling_e2e(
7171
monkeypatch: pytest.MonkeyPatch,
7272
model: str,
7373
sampling_param_tests: list[dict[str, Any]],
74-
spec_config: dict | None,
74+
spec_configs: list[dict | None],
7575
):
7676
"""Test consistency of combos of async scheduling, preemption,
7777
uni/multiproc executor with spec decoding."""
7878

7979
with monkeypatch.context() as m:
8080
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
8181
# m.setenv("VLLM_BATCH_INVARIANT", "1")
82-
spec_decoding = False
83-
if spec_config:
84-
spec_decoding = True
8582
outputs: list[tuple[str, list]] = []
8683
for test_preemption in [False, True]:
8784
for executor in ["mp", "uni"]:
8885
for async_scheduling in [False, True]:
89-
cache_arg: dict[str, Any] = (
90-
dict(num_gpu_blocks_override=32)
91-
if test_preemption
92-
else dict(gpu_memory_utilization=0.7)
93-
)
94-
test_config = (
95-
f"executor={executor}, preemption={test_preemption}, "
96-
f"async_sched={async_scheduling}, "
97-
f"spec_decoding={spec_decoding}"
98-
)
99-
print("-" * 80)
100-
print(f"---- TESTING: {test_config}")
101-
print("-" * 80)
102-
with VllmRunner(
103-
model,
104-
max_model_len=512,
105-
enforce_eager=True,
106-
async_scheduling=async_scheduling,
107-
distributed_executor_backend=executor,
108-
dtype="float32", # avoid precision errors
109-
speculative_config=spec_config,
110-
**cache_arg,
111-
) as vllm_model:
112-
results = []
113-
for override_params in sampling_param_tests:
114-
print(f"----------- RUNNING PARAMS: {override_params}")
115-
results.append(
116-
vllm_model.generate(
117-
self.example_prompts,
118-
sampling_params=SamplingParams(
119-
**self.default_params, **override_params
120-
),
121-
return_logprobs=True,
86+
for spec_config in spec_configs:
87+
spec_decoding = spec_config is not None
88+
cache_arg: dict[str, Any] = (
89+
dict(num_gpu_blocks_override=32)
90+
if test_preemption
91+
else dict(gpu_memory_utilization=0.7)
92+
)
93+
test_config = (
94+
f"executor={executor}, preemption={test_preemption}, "
95+
f"async_sched={async_scheduling}, "
96+
f"spec_decoding={spec_decoding}"
97+
)
98+
print("-" * 80)
99+
print(f"---- TESTING: {test_config}")
100+
print("-" * 80)
101+
with VllmRunner(
102+
model,
103+
max_model_len=512,
104+
enforce_eager=True,
105+
async_scheduling=async_scheduling,
106+
distributed_executor_backend=executor,
107+
dtype="float32", # avoid precision errors
108+
speculative_config=spec_config,
109+
**cache_arg,
110+
) as vllm_model:
111+
results = []
112+
acceptance_rates = []
113+
for override_params in sampling_param_tests:
114+
print(
115+
f"----------- RUNNING PARAMS: {override_params}"
122116
)
123-
)
124-
125-
if not outputs and len(results) > 1:
126-
# First check that the different parameter configs
127-
# actually result in different output.
128-
for (
129-
other_test_outs,
130-
other_test_logprobs,
131-
), params in zip(results[1:], sampling_param_tests[1:]):
132-
with pytest.raises(AssertionError):
133-
check_outputs_equal(
134-
outputs_0_lst=results[0][0],
135-
outputs_1_lst=other_test_outs,
136-
name_0=f"baseline params={params}",
137-
name_1=f"other params={params}",
138-
)
139-
assert _all_logprobs_match(
140-
results[0][1], other_test_logprobs
117+
results.append(
118+
vllm_model.generate(
119+
self.example_prompts,
120+
sampling_params=SamplingParams(
121+
**self.default_params,
122+
**override_params,
123+
),
124+
return_logprobs=True,
141125
)
126+
)
127+
acceptance_rates.append(
128+
_calc_average_acceptance_rate(vllm_model)
129+
)
142130

143-
outputs.append((test_config, results))
144-
145-
baseline_config, baseline_tests = outputs[0]
131+
if not outputs and len(results) > 1:
132+
# First check that the different parameter configs
133+
# actually result in different output.
134+
for (
135+
other_test_outs,
136+
other_test_logprobs,
137+
), params in zip(
138+
results[1:], sampling_param_tests[1:]
139+
):
140+
with pytest.raises(AssertionError):
141+
check_outputs_equal(
142+
outputs_0_lst=results[0][0],
143+
outputs_1_lst=other_test_outs,
144+
name_0=f"baseline params={params}",
145+
name_1=f"other params={params}",
146+
)
147+
assert _all_logprobs_match(
148+
results[0][1], other_test_logprobs
149+
)
150+
151+
outputs.append((test_config, results, acceptance_rates))
152+
153+
baseline_config, baseline_tests, base_acceptance_rates = outputs[0]
146154

147155
failure = None
148-
for test_config, test_outputs in outputs[1:]:
149-
for (base_outs, base_logprobs), (test_outs, test_logprobs), params in zip(
150-
baseline_tests, test_outputs, sampling_param_tests
156+
for test_config, test_outputs, test_acceptance_rates in outputs[1:]:
157+
for (base_outs, base_logprobs), base_acceptance_rate, (
158+
test_outs,
159+
test_logprobs,
160+
), test_acceptance_rate, params in zip(
161+
baseline_tests,
162+
base_acceptance_rates,
163+
test_outputs,
164+
test_acceptance_rates,
165+
sampling_param_tests,
151166
):
152167
try:
153168
check_outputs_equal(
@@ -159,6 +174,13 @@ def preempt_and_async_scheduling_e2e(
159174

160175
assert _all_logprobs_match(base_logprobs, test_logprobs)
161176

177+
# only check acceptance rate if spec decoding is used.
178+
if base_acceptance_rate > 0:
179+
assert (
180+
pytest.approx(test_acceptance_rate, rel=5e-2)
181+
== base_acceptance_rate
182+
)
183+
162184
print(f"PASSED: config=[{test_config}], params={params}")
163185
except AssertionError as e:
164186
print(f"FAILED: config=[{test_config}], params={params}")
@@ -188,3 +210,19 @@ def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> boo
188210
and a.logprob == pytest.approx(b.logprob, rel=1e-3, abs=1e-6)
189211
for a, b in ((lps_a[x], lps_b[x]) for x in lps_a)
190212
)
213+
214+
215+
def _calc_average_acceptance_rate(vllm_model: VllmRunner) -> float:
216+
metrics = vllm_model.llm.get_metrics()
217+
num_draft = []
218+
num_accept = []
219+
for metric in metrics:
220+
if metric.name == "vllm:spec_decode_num_draft_tokens":
221+
num_draft.append(metric.value)
222+
if metric.name == "vllm:spec_decode_num_accepted_tokens":
223+
num_accept.append(metric.value)
224+
acceptance_rates = []
225+
for draft, accept in zip(num_draft, num_accept):
226+
acceptance_rate = accept / draft if draft > 0 else 0
227+
acceptance_rates.append(acceptance_rate)
228+
return sum(acceptance_rates) / len(acceptance_rates) if acceptance_rates else 0

0 commit comments

Comments
 (0)