11from  __future__ import  annotations 
22
3+ import  os 
4+ 
35import  pytest 
46from  vllm  import  SamplingParams 
57from  vllm .config  import  CompilationConfig , CUDAGraphMode 
68
79from  tests .e2e .conftest  import  VllmRunner 
810
11+ os .environ ["VLLM_WORKER_MULTIPROC_METHOD" ] =  "spawn" 
12+ 
913
1014@pytest .fixture  
1115def  sampling_config ():
@@ -17,12 +21,11 @@ def model_name():
1721    return  "wemaster/deepseek_mtp_main_random_bf16" 
1822
1923
20- def  mtp_correctness (
21-     sampling_config : SamplingParams ,
22-     model_name : str ,
23-     num_speculative_tokens : int ,
24-     graph_mode : CUDAGraphMode  =  CUDAGraphMode .PIECEWISE ,
25- ):
24+ def  mtp_correctness (sampling_config : SamplingParams ,
25+                     model_name : str ,
26+                     num_speculative_tokens : int ,
27+                     graph_mode : CUDAGraphMode  =  CUDAGraphMode .PIECEWISE ,
28+                     disable_padded_drafter_batch = True ):
2629    example_prompts  =  [
2730        "Hello, my name is" ,
2831        "The president of the United States is" ,
@@ -54,6 +57,7 @@ def mtp_correctness(
5457            speculative_config = {
5558                "method" : "deepseek_mtp" ,
5659                "num_speculative_tokens" : num_speculative_tokens ,
60+                 "disable_padded_drafter_batch" : disable_padded_drafter_batch ,
5761            },
5862            enforce_eager = False ,
5963            max_model_len = 2000 ,
@@ -108,3 +112,23 @@ def test_mtp2_correctness_full_graph(
108112    model_name : str ,
109113):
110114    mtp_correctness (sampling_config , model_name , 2 , CUDAGraphMode .FULL )
115+ 
116+ 
117+ def  test_mtp1_correctness_piecewise_graph_with_pad (
118+     sampling_config : SamplingParams ,
119+     model_name : str ,
120+ ):
121+     mtp_correctness (sampling_config ,
122+                     model_name ,
123+                     1 ,
124+                     disable_padded_drafter_batch = False )
125+ 
126+ 
127+ def  test_mtp2_correctness_piecewise_graph_with_pad (
128+     sampling_config : SamplingParams ,
129+     model_name : str ,
130+ ):
131+     mtp_correctness (sampling_config ,
132+                     model_name ,
133+                     2 ,
134+                     disable_padded_drafter_batch = False )
0 commit comments