1+ #
2+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+ # This file is a part of the vllm-ascend project.
4+ # Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_eagle_correctness.py
5+ # Copyright 2023 The vLLM team.
6+ #
7+ # Licensed under the Apache License, Version 2.0 (the "License");
8+ # you may not use this file except in compliance with the License.
9+ # You may obtain a copy of the License at
10+ #
11+ # http://www.apache.org/licenses/LICENSE-2.0
12+ #
13+ # Unless required by applicable law or agreed to in writing, software
14+ # distributed under the License is distributed on an "AS IS" BASIS,
15+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+ # See the License for the specific language governing permissions and
17+ # limitations under the License.
18+ #
19+ """This docstring details important information on the testing methodology.
20+
21+ Most of the tests rely on "greedy equality", where we expect the output of
22+ speculative decoding on a sequence to exactly match the output of normal non-
23+ speculative decoding.
24+
25+ Since speculative decoding with rejection sampling guarantees that the output
26+ distribution matches the target model's output distribution (up to hardware
27+ numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
28+ equality.
29+
30+ However, we still need to verify below scenario could be passed:
31+ * Batch size 1 greedy equality
32+ * Batch size >1 greedy equality
33+ * Test greedy equality under preemption
34+ * Test greedy equality under various number of speculative tokens.
35+
36+ With those tests, we can say at least, EAGLE would not break the
37+ correctness for the target model outputs.
38+ """
39+
40+ import pytest
41+
42+ from tests .long_term .spec_decode_v0 .e2e .conftest import \
43+ run_equality_correctness_test
44+
45+ # main model
46+ MAIN_MODEL = "JackFram/llama-68m"
47+
48+ # speculative model
49+ SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random"
50+
51+ # max. number of speculative tokens: this corresponds to
52+ # num_heads in the config.json of the speculator model.
53+ MAX_SPEC_TOKENS = 4
54+
55+ # precision
56+ PRECISION = "float32"
57+
58+
59+ @pytest .mark .parametrize (
60+ "common_llm_kwargs" ,
61+ [{
62+ # Skip cuda graph recording for fast test.
63+ "enforce_eager" : True ,
64+
65+ # Print spec metrics.
66+ "disable_log_stats" : False ,
67+
68+ # Precision
69+ "dtype" : PRECISION ,
70+
71+ # Main model
72+ "model_name" : MAIN_MODEL ,
73+ }])
74+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
75+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
76+ @pytest .mark .parametrize ("test_llm_kwargs" , [
77+ {
78+ "speculative_config" : {
79+ "model" : SPEC_MODEL ,
80+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
81+ },
82+ },
83+ ])
84+ @pytest .mark .parametrize ("output_len" , [
85+ 128 ,
86+ ])
87+ @pytest .mark .parametrize ("batch_size" , [1 , 32 ])
88+ @pytest .mark .parametrize ("seed" , [1 ])
89+ def test_eagle_e2e_greedy_correctness (vllm_runner , common_llm_kwargs ,
90+ per_test_common_llm_kwargs ,
91+ baseline_llm_kwargs , test_llm_kwargs ,
92+ batch_size : int , output_len : int ,
93+ seed : int ):
94+
95+ run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
96+ per_test_common_llm_kwargs ,
97+ baseline_llm_kwargs , test_llm_kwargs ,
98+ batch_size , output_len , seed )
99+
100+
101+ @pytest .mark .parametrize (
102+ "common_llm_kwargs" ,
103+ [{
104+ # Skip cuda graph recording for fast test.
105+ "enforce_eager" : True ,
106+
107+ # Print spec metrics.
108+ "disable_log_stats" : False ,
109+
110+ # Precision
111+ "dtype" : PRECISION ,
112+
113+ # Main model
114+ "model_name" : MAIN_MODEL ,
115+ }])
116+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
117+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
118+ @pytest .mark .parametrize ("test_llm_kwargs" , [{
119+ "speculative_config" : {
120+ "model" : SPEC_MODEL ,
121+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
122+ "disable_logprobs" : False ,
123+ },
124+ }, {
125+ "speculative_config" : {
126+ "model" : SPEC_MODEL ,
127+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
128+ "disable_logprobs" : True ,
129+ },
130+ }])
131+ @pytest .mark .parametrize ("output_len" , [
132+ 128 ,
133+ ])
134+ @pytest .mark .parametrize ("batch_size" , [8 ])
135+ @pytest .mark .parametrize ("seed" , [1 ])
136+ @pytest .mark .parametrize ("logprobs" , [1 , 6 ])
137+ def test_eagle_e2e_greedy_logprobs (vllm_runner , common_llm_kwargs ,
138+ per_test_common_llm_kwargs ,
139+ baseline_llm_kwargs , test_llm_kwargs ,
140+ batch_size : int , output_len : int , seed : int ,
141+ logprobs : int ):
142+
143+ run_equality_correctness_test (
144+ vllm_runner ,
145+ common_llm_kwargs ,
146+ per_test_common_llm_kwargs ,
147+ baseline_llm_kwargs ,
148+ test_llm_kwargs ,
149+ batch_size ,
150+ output_len ,
151+ seed ,
152+ logprobs = logprobs ,
153+ prompt_logprobs = logprobs ,
154+ disable_logprobs = test_llm_kwargs ["speculative_config" ]
155+ ["disable_logprobs" ])
156+
157+
158+ @pytest .mark .skipif (True , reason = "Open it when graph mode ready." )
159+ @pytest .mark .parametrize (
160+ "common_llm_kwargs" ,
161+ [{
162+ "enforce_eager" : False ,
163+
164+ # Print spec metrics.
165+ "disable_log_stats" : False ,
166+
167+ # Precision
168+ "dtype" : PRECISION ,
169+
170+ # Main model
171+ "model_name" : MAIN_MODEL ,
172+ }])
173+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
174+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
175+ @pytest .mark .parametrize ("test_llm_kwargs" , [
176+ {
177+ "speculative_config" : {
178+ "model" : SPEC_MODEL ,
179+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
180+ },
181+ },
182+ ])
183+ @pytest .mark .parametrize ("output_len" , [
184+ 128 ,
185+ ])
186+ @pytest .mark .parametrize ("batch_size" , [1 , 32 ])
187+ @pytest .mark .parametrize ("seed" , [1 ])
188+ def test_eagle_e2e_greedy_correctness_cuda_graph (
189+ vllm_runner , common_llm_kwargs , per_test_common_llm_kwargs ,
190+ baseline_llm_kwargs , test_llm_kwargs , batch_size : int , output_len : int ,
191+ seed : int ):
192+ """Verify greedy equality with cuda graph enabled and different
193+ batch sizes."""
194+ run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
195+ per_test_common_llm_kwargs ,
196+ baseline_llm_kwargs , test_llm_kwargs ,
197+ batch_size , output_len , seed )
198+
199+
200+ @pytest .mark .skipif (True , reason = "Open it when preempt ready." )
201+ @pytest .mark .parametrize (
202+ "common_llm_kwargs" ,
203+ [{
204+ "block_size" : 8 ,
205+ # 2 for small prompt, 256//8 for generated.
206+ "num_gpu_blocks_override" : 2 + 256 // 8 ,
207+ "max_model_len" : (2 + 256 // 8 ) * 8 ,
208+
209+ # Skip cuda graph recording for fast test.
210+ "enforce_eager" : True ,
211+
212+ # Precision
213+ "dtype" : PRECISION ,
214+
215+ # Main model
216+ "model_name" : MAIN_MODEL ,
217+ }])
218+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
219+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
220+ @pytest .mark .parametrize ("test_llm_kwargs" , [
221+ {
222+ "speculative_config" : {
223+ "model" : SPEC_MODEL ,
224+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
225+ },
226+ },
227+ ])
228+ @pytest .mark .parametrize (
229+ "output_len" ,
230+ [
231+ # Use small output len for fast test.
232+ 128 ,
233+ ])
234+ @pytest .mark .parametrize ("batch_size" , [4 ])
235+ @pytest .mark .parametrize ("seed" , [1 ])
236+ def test_eagle_e2e_greedy_correctness_with_preemption (
237+ vllm_runner , common_llm_kwargs , per_test_common_llm_kwargs ,
238+ baseline_llm_kwargs , test_llm_kwargs , batch_size : int , output_len : int ,
239+ seed : int ):
240+ """Verify greedy equality, even when some sequences are preempted mid-
241+ generation.
242+ """
243+ run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
244+ per_test_common_llm_kwargs ,
245+ baseline_llm_kwargs , test_llm_kwargs ,
246+ batch_size , output_len , seed )
247+
248+
249+ @pytest .mark .parametrize (
250+ "common_llm_kwargs" ,
251+ [{
252+ # Skip cuda graph recording for fast test.
253+ "enforce_eager" : True ,
254+
255+ # Precision
256+ "dtype" : PRECISION ,
257+
258+ # Main model
259+ "model_name" : MAIN_MODEL ,
260+ }])
261+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
262+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
263+ @pytest .mark .parametrize (
264+ "test_llm_kwargs" ,
265+ [
266+ {
267+ "speculative_config" : {
268+ "model" : SPEC_MODEL ,
269+ "num_speculative_tokens" : k ,
270+ },
271+ }
272+ # Try a range of num. speculative tokens
273+ for k in range (1 , 1 + MAX_SPEC_TOKENS )
274+ ])
275+ @pytest .mark .parametrize ("batch_size" , [2 ])
276+ @pytest .mark .parametrize (
277+ "output_len" ,
278+ [
279+ # Use smaller output len for fast test.
280+ 32 ,
281+ ])
282+ @pytest .mark .parametrize ("seed" , [1 ])
283+ def test_eagle_different_k (vllm_runner , common_llm_kwargs ,
284+ per_test_common_llm_kwargs , baseline_llm_kwargs ,
285+ test_llm_kwargs , batch_size : int , output_len : int ,
286+ seed : int ):
287+ """Verify that eagle speculative decoding produces exact equality
288+ to without spec decode with different values of num_speculative_tokens.
289+ """
290+ run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
291+ per_test_common_llm_kwargs ,
292+ baseline_llm_kwargs , test_llm_kwargs ,
293+ batch_size , output_len , seed )
294+
295+
296+ @pytest .mark .parametrize (
297+ "common_llm_kwargs" ,
298+ [{
299+ # Skip cuda graph recording for fast test.
300+ "enforce_eager" : True ,
301+
302+ # Precision
303+ "dtype" : PRECISION ,
304+
305+ # Main model
306+ "model_name" : MAIN_MODEL ,
307+ }])
308+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
309+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
310+ @pytest .mark .parametrize ("test_llm_kwargs" , [{
311+ "speculative_config" : {
312+ "model" : SPEC_MODEL ,
313+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
314+ "disable_by_batch_size" : 4 ,
315+ },
316+ }])
317+ @pytest .mark .parametrize ("batch_size" , [1 , 5 ])
318+ @pytest .mark .parametrize (
319+ "output_len" ,
320+ [
321+ # Use smaller output len for fast test.
322+ 32 ,
323+ ])
324+ @pytest .mark .parametrize ("seed" , [1 ])
325+ def test_eagle_disable_queue (vllm_runner , common_llm_kwargs ,
326+ per_test_common_llm_kwargs , baseline_llm_kwargs ,
327+ test_llm_kwargs , batch_size : int , output_len : int ,
328+ seed : int ):
329+ """Verify that eagle speculative decoding produces exact equality
330+ to without spec decode when speculation is disabled for large
331+ batch sizes.
332+ """
333+ run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
334+ per_test_common_llm_kwargs ,
335+ baseline_llm_kwargs , test_llm_kwargs ,
336+ batch_size , output_len , seed )
337+
338+
339+ if __name__ == "__main__" :
340+ import pytest
341+ pytest .main ([__file__ ])
0 commit comments