@@ -39,7 +39,7 @@ def test_with_spec_decoding(self, monkeypatch: pytest.MonkeyPatch):
3939 monkeypatch ,
4040 MTP_MODEL ,
4141 [{}],
42- spec_config = {"method" : "mtp" , "num_speculative_tokens" : 1 },
42+ spec_configs = [ {"method" : "mtp" , "num_speculative_tokens" : 1 }, None ] ,
4343 )
4444
4545 def test_without_spec_decoding (
@@ -62,7 +62,7 @@ def test_without_spec_decoding(
6262 ),
6363 ]
6464 self .preempt_and_async_scheduling_e2e (
65- monkeypatch , MODEL , sampling_param_tests , None
65+ monkeypatch , MODEL , sampling_param_tests , [ None ]
6666 )
6767
6868 @dynamo_config .patch (cache_size_limit = 16 )
@@ -71,83 +71,98 @@ def preempt_and_async_scheduling_e2e(
7171 monkeypatch : pytest .MonkeyPatch ,
7272 model : str ,
7373 sampling_param_tests : list [dict [str , Any ]],
74- spec_config : dict | None ,
74+ spec_configs : list [ dict | None ] ,
7575 ):
7676 """Test consistency of combos of async scheduling, preemption,
7777 uni/multiproc executor with spec decoding."""
7878
7979 with monkeypatch .context () as m :
8080 m .setenv ("VLLM_ATTENTION_BACKEND" , "FLEX_ATTENTION" )
8181 # m.setenv("VLLM_BATCH_INVARIANT", "1")
82- spec_decoding = False
83- if spec_config :
84- spec_decoding = True
8582 outputs : list [tuple [str , list ]] = []
8683 for test_preemption in [False , True ]:
8784 for executor in ["mp" , "uni" ]:
8885 for async_scheduling in [False , True ]:
89- cache_arg : dict [str , Any ] = (
90- dict (num_gpu_blocks_override = 32 )
91- if test_preemption
92- else dict (gpu_memory_utilization = 0.7 )
93- )
94- test_config = (
95- f"executor={ executor } , preemption={ test_preemption } , "
96- f"async_sched={ async_scheduling } , "
97- f"spec_decoding={ spec_decoding } "
98- )
99- print ("-" * 80 )
100- print (f"---- TESTING: { test_config } " )
101- print ("-" * 80 )
102- with VllmRunner (
103- model ,
104- max_model_len = 512 ,
105- enforce_eager = True ,
106- async_scheduling = async_scheduling ,
107- distributed_executor_backend = executor ,
108- dtype = "float32" , # avoid precision errors
109- speculative_config = spec_config ,
110- ** cache_arg ,
111- ) as vllm_model :
112- results = []
113- for override_params in sampling_param_tests :
114- print (f"----------- RUNNING PARAMS: { override_params } " )
115- results .append (
116- vllm_model .generate (
117- self .example_prompts ,
118- sampling_params = SamplingParams (
119- ** self .default_params , ** override_params
120- ),
121- return_logprobs = True ,
86+ for spec_config in spec_configs :
87+ spec_decoding = spec_config is not None
88+ cache_arg : dict [str , Any ] = (
89+ dict (num_gpu_blocks_override = 32 )
90+ if test_preemption
91+ else dict (gpu_memory_utilization = 0.7 )
92+ )
93+ test_config = (
94+ f"executor={ executor } , preemption={ test_preemption } , "
95+ f"async_sched={ async_scheduling } , "
96+ f"spec_decoding={ spec_decoding } "
97+ )
98+ print ("-" * 80 )
99+ print (f"---- TESTING: { test_config } " )
100+ print ("-" * 80 )
101+ with VllmRunner (
102+ model ,
103+ max_model_len = 512 ,
104+ enforce_eager = True ,
105+ async_scheduling = async_scheduling ,
106+ distributed_executor_backend = executor ,
107+ dtype = "float32" , # avoid precision errors
108+ speculative_config = spec_config ,
109+ ** cache_arg ,
110+ ) as vllm_model :
111+ results = []
112+ acceptance_rates = []
113+ for override_params in sampling_param_tests :
114+ print (
115+ f"----------- RUNNING PARAMS: { override_params } "
122116 )
123- )
124-
125- if not outputs and len (results ) > 1 :
126- # First check that the different parameter configs
127- # actually result in different output.
128- for (
129- other_test_outs ,
130- other_test_logprobs ,
131- ), params in zip (results [1 :], sampling_param_tests [1 :]):
132- with pytest .raises (AssertionError ):
133- check_outputs_equal (
134- outputs_0_lst = results [0 ][0 ],
135- outputs_1_lst = other_test_outs ,
136- name_0 = f"baseline params={ params } " ,
137- name_1 = f"other params={ params } " ,
138- )
139- assert _all_logprobs_match (
140- results [0 ][1 ], other_test_logprobs
117+ results .append (
118+ vllm_model .generate (
119+ self .example_prompts ,
120+ sampling_params = SamplingParams (
121+ ** self .default_params ,
122+ ** override_params ,
123+ ),
124+ return_logprobs = True ,
141125 )
126+ )
127+ acceptance_rates .append (
128+ _calc_average_acceptance_rate (vllm_model )
129+ )
142130
143- outputs .append ((test_config , results ))
144-
145- baseline_config , baseline_tests = outputs [0 ]
131+ if not outputs and len (results ) > 1 :
132+ # First check that the different parameter configs
133+ # actually result in different output.
134+ for (
135+ other_test_outs ,
136+ other_test_logprobs ,
137+ ), params in zip (
138+ results [1 :], sampling_param_tests [1 :]
139+ ):
140+ with pytest .raises (AssertionError ):
141+ check_outputs_equal (
142+ outputs_0_lst = results [0 ][0 ],
143+ outputs_1_lst = other_test_outs ,
144+ name_0 = f"baseline params={ params } " ,
145+ name_1 = f"other params={ params } " ,
146+ )
147+ assert _all_logprobs_match (
148+ results [0 ][1 ], other_test_logprobs
149+ )
150+
151+ outputs .append ((test_config , results , acceptance_rates ))
152+
153+ baseline_config , baseline_tests , base_acceptance_rates = outputs [0 ]
146154
147155 failure = None
148- for test_config , test_outputs in outputs [1 :]:
149- for (base_outs , base_logprobs ), (test_outs , test_logprobs ), params in zip (
150- baseline_tests , test_outputs , sampling_param_tests
156+ for test_config , test_outputs , test_acceptance_rates in outputs [1 :]:
157+ for (base_outs , base_logprobs ), base_acceptance_rate , (
158+ test_outs ,
159+ test_logprobs ,
160+ ), test_acceptance_rate , params in zip (
161+ baseline_tests ,
162+ base_acceptance_rates ,
163+ test_outputs ,
164+ test_acceptance_rates ,
165+ sampling_param_tests ,
151166 ):
152167 try :
153168 check_outputs_equal (
@@ -159,6 +174,13 @@ def preempt_and_async_scheduling_e2e(
159174
160175 assert _all_logprobs_match (base_logprobs , test_logprobs )
161176
177+ # only check acceptance rate if spec decoding is used.
178+ if base_acceptance_rate > 0 :
179+ assert (
180+ pytest .approx (test_acceptance_rate , rel = 5e-2 )
181+ == base_acceptance_rate
182+ )
183+
162184 print (f"PASSED: config=[{ test_config } ], params={ params } " )
163185 except AssertionError as e :
164186 print (f"FAILED: config=[{ test_config } ], params={ params } " )
@@ -188,3 +210,19 @@ def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> boo
188210 and a .logprob == pytest .approx (b .logprob , rel = 1e-3 , abs = 1e-6 )
189211 for a , b in ((lps_a [x ], lps_b [x ]) for x in lps_a )
190212 )
213+
214+
215+ def _calc_average_acceptance_rate (vllm_model : VllmRunner ) -> float :
216+ metrics = vllm_model .llm .get_metrics ()
217+ num_draft = []
218+ num_accept = []
219+ for metric in metrics :
220+ if metric .name == "vllm:spec_decode_num_draft_tokens" :
221+ num_draft .append (metric .value )
222+ if metric .name == "vllm:spec_decode_num_accepted_tokens" :
223+ num_accept .append (metric .value )
224+ acceptance_rates = []
225+ for draft , accept in zip (num_draft , num_accept ):
226+ acceptance_rate = accept / draft if draft > 0 else 0
227+ acceptance_rates .append (acceptance_rate )
228+ return sum (acceptance_rates ) / len (acceptance_rates ) if acceptance_rates else 0
0 commit comments