@@ -305,6 +305,150 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
305305 batch_size , output_len , seed )
306306
307307
308+ @pytest .mark .parametrize (
309+ "common_llm_kwargs" ,
310+ [{
311+ # Skip cuda graph recording for fast test.
312+ "enforce_eager" : True ,
313+
314+ # Print spec metrics.
315+ "disable_log_stats" : False ,
316+
317+ # Precision
318+ "dtype" : "float16" ,
319+
320+ # Main model
321+ "model_name" : "meta-llama/Llama-2-7b-chat-hf" ,
322+ }])
323+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
324+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
325+ @pytest .mark .parametrize ("test_llm_kwargs" , [
326+ {
327+ "speculative_model" : "yuhuili/EAGLE-llama2-chat-7B" ,
328+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
329+ },
330+ ])
331+ @pytest .mark .parametrize (
332+ "output_len" ,
333+ [
334+ # Use smaller output len for fast test.
335+ 32 ,
336+ ])
337+ @pytest .mark .parametrize ("batch_size" , [1 , 5 ])
338+ @pytest .mark .parametrize ("seed" , [1 ])
339+ def test_llama2_eagle_e2e_greedy_correctness (vllm_runner , common_llm_kwargs ,
340+ per_test_common_llm_kwargs ,
341+ baseline_llm_kwargs ,
342+ test_llm_kwargs , batch_size : int ,
343+ output_len : int , seed : int ):
344+
345+ run_equality_correctness_test (vllm_runner ,
346+ common_llm_kwargs ,
347+ per_test_common_llm_kwargs ,
348+ baseline_llm_kwargs ,
349+ test_llm_kwargs ,
350+ batch_size ,
351+ output_len ,
352+ seed ,
353+ temperature = 0.0 )
354+
355+
356+ @pytest .mark .parametrize (
357+ "common_llm_kwargs" ,
358+ [{
359+ # Skip cuda graph recording for fast test.
360+ "enforce_eager" : True ,
361+
362+ # Print spec metrics.
363+ "disable_log_stats" : False ,
364+
365+ # Precision
366+ "dtype" : "float16" ,
367+
368+ # Main model
369+ "model_name" : "meta-llama/Meta-Llama-3-8B-Instruct" ,
370+ }])
371+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
372+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
373+ @pytest .mark .parametrize ("test_llm_kwargs" , [
374+ {
375+ "speculative_model" : "yuhuili/EAGLE-LLaMA3-Instruct-8B" ,
376+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
377+ },
378+ ])
379+ @pytest .mark .parametrize (
380+ "output_len" ,
381+ [
382+ # Use smaller output len for fast test.
383+ 32 ,
384+ ])
385+ @pytest .mark .parametrize ("batch_size" , [1 , 5 ])
386+ @pytest .mark .parametrize ("seed" , [1 ])
387+ def test_llama3_eagle_e2e_greedy_correctness (vllm_runner , common_llm_kwargs ,
388+ per_test_common_llm_kwargs ,
389+ baseline_llm_kwargs ,
390+ test_llm_kwargs , batch_size : int ,
391+ output_len : int , seed : int ):
392+
393+ run_equality_correctness_test (vllm_runner ,
394+ common_llm_kwargs ,
395+ per_test_common_llm_kwargs ,
396+ baseline_llm_kwargs ,
397+ test_llm_kwargs ,
398+ batch_size ,
399+ output_len ,
400+ seed ,
401+ temperature = 0.0 )
402+
403+
404+ @pytest .mark .parametrize (
405+ "common_llm_kwargs" ,
406+ [{
407+ # Skip cuda graph recording for fast test.
408+ "enforce_eager" : True ,
409+
410+ # Print spec metrics.
411+ "disable_log_stats" : False ,
412+
413+ # Precision
414+ "dtype" : "float16" ,
415+
416+ # Main model
417+ "model_name" : "Qwen/Qwen2-7B-Instruct" ,
418+ }])
419+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
420+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
421+ @pytest .mark .parametrize ("test_llm_kwargs" , [
422+ {
423+ "speculative_model" : "yuhuili/EAGLE-Qwen2-7B-Instruct" ,
424+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
425+ },
426+ ])
427+ @pytest .mark .parametrize (
428+ "output_len" ,
429+ [
430+ # Use smaller output len for fast test.
431+ 32 ,
432+ ])
433+ @pytest .mark .parametrize ("batch_size" , [1 , 5 ])
434+ @pytest .mark .parametrize ("seed" , [1 ])
435+ def test_qwen2_eagle_e2e_greedy_correctness (vllm_runner , common_llm_kwargs ,
436+ per_test_common_llm_kwargs ,
437+ baseline_llm_kwargs ,
438+ test_llm_kwargs , batch_size : int ,
439+ output_len : int , seed : int ):
440+
441+ run_equality_correctness_test (vllm_runner ,
442+ common_llm_kwargs ,
443+ per_test_common_llm_kwargs ,
444+ baseline_llm_kwargs ,
445+ test_llm_kwargs ,
446+ batch_size ,
447+ output_len ,
448+ seed ,
449+ temperature = 0.0 )
450+
451+
308452if __name__ == "__main__" :
309453 import pytest
310454 pytest .main ([__file__ ])
0 commit comments