|
93 | 93 | ]) |
94 | 94 | @pytest.mark.parametrize("batch_size", [1, 32]) |
95 | 95 | @pytest.mark.parametrize("seed", [1]) |
96 | | -def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, |
97 | | - per_test_common_llm_kwargs, |
98 | | - baseline_llm_kwargs, test_llm_kwargs, |
99 | | - batch_size: int, output_len: int, |
100 | | - seed: int): |
| 96 | +def test_mtp_e2e_greedy_correctness( |
| 97 | + enable_modelscope_env, |
| 98 | + vllm_runner, |
| 99 | + common_llm_kwargs, |
| 100 | + per_test_common_llm_kwargs, |
| 101 | + baseline_llm_kwargs, |
| 102 | + test_llm_kwargs, |
| 103 | + batch_size: int, |
| 104 | + output_len: int, |
| 105 | + seed: int, |
| 106 | +): |
101 | 107 |
|
102 | 108 | run_equality_correctness_test(vllm_runner, common_llm_kwargs, |
103 | 109 | per_test_common_llm_kwargs, |
@@ -138,12 +144,17 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, |
138 | 144 | ]) |
139 | 145 | @pytest.mark.parametrize("batch_size", [1, 32]) |
140 | 146 | @pytest.mark.parametrize("seed", [1]) |
141 | | -def test_mtp_e2e_quant_greedy_correctness(vllm_runner, common_llm_kwargs, |
142 | | - per_test_common_llm_kwargs, |
143 | | - baseline_llm_kwargs, test_llm_kwargs, |
144 | | - batch_size: int, output_len: int, |
145 | | - seed: int): |
146 | | - |
| 147 | +def test_mtp_e2e_quant_greedy_correctness( |
| 148 | + enable_modelscope_env, |
| 149 | + vllm_runner, |
| 150 | + common_llm_kwargs, |
| 151 | + per_test_common_llm_kwargs, |
| 152 | + baseline_llm_kwargs, |
| 153 | + test_llm_kwargs, |
| 154 | + batch_size: int, |
| 155 | + output_len: int, |
| 156 | + seed: int, |
| 157 | +): |
147 | 158 | run_equality_correctness_test(vllm_runner, common_llm_kwargs, |
148 | 159 | per_test_common_llm_kwargs, |
149 | 160 | baseline_llm_kwargs, test_llm_kwargs, |
@@ -192,12 +203,18 @@ def test_mtp_e2e_quant_greedy_correctness(vllm_runner, common_llm_kwargs, |
192 | 203 | @pytest.mark.parametrize("batch_size", [8]) |
193 | 204 | @pytest.mark.parametrize("seed", [1]) |
194 | 205 | @pytest.mark.parametrize("logprobs", [1, 6]) |
195 | | -def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, |
196 | | - per_test_common_llm_kwargs, |
197 | | - baseline_llm_kwargs, test_llm_kwargs, |
198 | | - batch_size: int, output_len: int, seed: int, |
199 | | - logprobs: int): |
200 | | - |
| 206 | +def test_mtp_e2e_greedy_logprobs( |
| 207 | + enable_modelscope_env, |
| 208 | + vllm_runner, |
| 209 | + common_llm_kwargs, |
| 210 | + per_test_common_llm_kwargs, |
| 211 | + baseline_llm_kwargs, |
| 212 | + test_llm_kwargs, |
| 213 | + batch_size: int, |
| 214 | + output_len: int, |
| 215 | + seed: int, |
| 216 | + logprobs: int, |
| 217 | +): |
201 | 218 | run_equality_correctness_test( |
202 | 219 | vllm_runner, |
203 | 220 | common_llm_kwargs, |
@@ -246,9 +263,16 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, |
246 | 263 | @pytest.mark.parametrize("batch_size", [1, 32]) |
247 | 264 | @pytest.mark.parametrize("seed", [1]) |
248 | 265 | def test_mtp_e2e_greedy_correctness_torchair_graph( |
249 | | - vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, |
250 | | - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, |
251 | | - seed: int): |
| 266 | + enable_modelscope_env, |
| 267 | + vllm_runner, |
| 268 | + common_llm_kwargs, |
| 269 | + per_test_common_llm_kwargs, |
| 270 | + baseline_llm_kwargs, |
| 271 | + test_llm_kwargs, |
| 272 | + batch_size: int, |
| 273 | + output_len: int, |
| 274 | + seed: int, |
| 275 | +): |
252 | 276 | """Verify greedy equality with torchair graph enabled and different |
253 | 277 | batch sizes using bfloat16 weights.""" |
254 | 278 | run_equality_correctness_test(vllm_runner, common_llm_kwargs, |
@@ -290,9 +314,16 @@ def test_mtp_e2e_greedy_correctness_torchair_graph( |
290 | 314 | @pytest.mark.parametrize("batch_size", [1, 32]) |
291 | 315 | @pytest.mark.parametrize("seed", [1]) |
292 | 316 | def test_mtp_e2e_quant_greedy_correctness_torchair_graph( |
293 | | - vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, |
294 | | - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, |
295 | | - seed: int): |
| 317 | + enable_modelscope_env, |
| 318 | + vllm_runner, |
| 319 | + common_llm_kwargs, |
| 320 | + per_test_common_llm_kwargs, |
| 321 | + baseline_llm_kwargs, |
| 322 | + test_llm_kwargs, |
| 323 | + batch_size: int, |
| 324 | + output_len: int, |
| 325 | + seed: int, |
| 326 | +): |
296 | 327 | """Verify greedy equality with torchair graph enabled and different |
297 | 328 | batch sizes using quant weights.""" |
298 | 329 | run_equality_correctness_test(vllm_runner, common_llm_kwargs, |
@@ -341,9 +372,16 @@ def test_mtp_e2e_quant_greedy_correctness_torchair_graph( |
341 | 372 | @pytest.mark.parametrize("batch_size", [4]) |
342 | 373 | @pytest.mark.parametrize("seed", [1]) |
343 | 374 | def test_mtp_e2e_greedy_correctness_with_preemption( |
344 | | - vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, |
345 | | - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, |
346 | | - seed: int): |
| 375 | + enable_modelscope_env, |
| 376 | + vllm_runner, |
| 377 | + common_llm_kwargs, |
| 378 | + per_test_common_llm_kwargs, |
| 379 | + baseline_llm_kwargs, |
| 380 | + test_llm_kwargs, |
| 381 | + batch_size: int, |
| 382 | + output_len: int, |
| 383 | + seed: int, |
| 384 | +): |
347 | 385 | """Verify greedy equality, even when some sequences are preempted mid- |
348 | 386 | generation. |
349 | 387 | """ |
@@ -391,10 +429,17 @@ def test_mtp_e2e_greedy_correctness_with_preemption( |
391 | 429 | 32, |
392 | 430 | ]) |
393 | 431 | @pytest.mark.parametrize("seed", [1]) |
394 | | -def test_mtp_different_k(vllm_runner, common_llm_kwargs, |
395 | | - per_test_common_llm_kwargs, baseline_llm_kwargs, |
396 | | - test_llm_kwargs, batch_size: int, output_len: int, |
397 | | - seed: int): |
| 432 | +def test_mtp_different_k( |
| 433 | + enable_modelscope_env, |
| 434 | + vllm_runner, |
| 435 | + common_llm_kwargs, |
| 436 | + per_test_common_llm_kwargs, |
| 437 | + baseline_llm_kwargs, |
| 438 | + test_llm_kwargs, |
| 439 | + batch_size: int, |
| 440 | + output_len: int, |
| 441 | + seed: int, |
| 442 | +): |
398 | 443 | """Verify that mtp speculative decoding produces exact equality |
399 | 444 | to without spec decode with different values of num_speculative_tokens. |
400 | 445 | """ |
@@ -437,10 +482,17 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs, |
437 | 482 | 32, |
438 | 483 | ]) |
439 | 484 | @pytest.mark.parametrize("seed", [1]) |
440 | | -def test_mtp_disable_queue(vllm_runner, common_llm_kwargs, |
441 | | - per_test_common_llm_kwargs, baseline_llm_kwargs, |
442 | | - test_llm_kwargs, batch_size: int, output_len: int, |
443 | | - seed: int): |
| 485 | +def test_mtp_disable_queue( |
| 486 | + enable_modelscope_env, |
| 487 | + vllm_runner, |
| 488 | + common_llm_kwargs, |
| 489 | + per_test_common_llm_kwargs, |
| 490 | + baseline_llm_kwargs, |
| 491 | + test_llm_kwargs, |
| 492 | + batch_size: int, |
| 493 | + output_len: int, |
| 494 | + seed: int, |
| 495 | +): |
444 | 496 | """Verify that mtp speculative decoding produces exact equality |
445 | 497 | to without spec decode when speculation is disabled for large |
446 | 498 | batch sizes. |
|
0 commit comments