1+ #
2+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+ # This file is a part of the vllm-ascend project.
4+ # Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_eagle_correctness.py
5+ # Copyright 2023 The vLLM team.
6+ #
7+ # Licensed under the Apache License, Version 2.0 (the "License");
8+ # you may not use this file except in compliance with the License.
9+ # You may obtain a copy of the License at
10+ #
11+ # http://www.apache.org/licenses/LICENSE-2.0
12+ #
13+ # Unless required by applicable law or agreed to in writing, software
14+ # distributed under the License is distributed on an "AS IS" BASIS,
15+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+ # See the License for the specific language governing permissions and
17+ # limitations under the License.
18+ #
19+ """This docstring details important information on the testing methodology.
20+
21+ Most of the tests rely on "greedy equality", where we expect the output of
22+ speculative decoding on a sequence to exactly match the output of normal non-
23+ speculative decoding.
24+
25+ Since speculative decoding with rejection sampling guarantees that the output
26+ distribution matches the target model's output distribution (up to hardware
27+ numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
28+ equality.
29+
30+ However, we still need to verify below scenario could be passed:
31+ * Batch size 1 greedy equality
32+ * Batch size >1 greedy equality
33+ * Test greedy equality under preemption
34+ * Test greedy equality under various number of speculative tokens.
35+
36+ With those tests, we can say at least, EAGLE would not break the
37+ correctness for the target model outputs.
38+ """
39+
40+ import pytest
41+
42+ from tests .e2e .long_term .spec_decode_v0 .e2e .conftest import \
43+ run_equality_correctness_test
44+
45+ # main model
46+ MAIN_MODEL = "JackFram/llama-68m"
47+
48+ # speculative model
49+ SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random"
50+
51+ # max. number of speculative tokens: this corresponds to
52+ # num_heads in the config.json of the speculator model.
53+ MAX_SPEC_TOKENS = 4
54+
55+ # precision
56+ # TODO The vLLM here uses float32, but some op on the vllm-ascend
57+ # do not support float32, such as ROPE, When it is fixed, it is
58+ # recommended to change this to float32.
59+ PRECISION = "float16"
60+
61+
62+ @pytest .mark .parametrize (
63+ "common_llm_kwargs" ,
64+ [{
65+ # Skip cuda graph recording for fast test.
66+ "enforce_eager" : True ,
67+
68+ # Print spec metrics.
69+ "disable_log_stats" : False ,
70+
71+ # Precision
72+ "dtype" : PRECISION ,
73+
74+ # Main model
75+ "model_name" : MAIN_MODEL ,
76+ }])
77+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
78+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
79+ @pytest .mark .parametrize ("test_llm_kwargs" , [
80+ {
81+ "speculative_config" : {
82+ "model" : SPEC_MODEL ,
83+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
84+ },
85+ },
86+ ])
87+ @pytest .mark .parametrize ("output_len" , [
88+ 128 ,
89+ ])
90+ @pytest .mark .parametrize ("batch_size" , [1 , 32 ])
91+ @pytest .mark .parametrize ("seed" , [1 ])
92+ def test_eagle_e2e_greedy_correctness (vllm_runner , common_llm_kwargs ,
93+ per_test_common_llm_kwargs ,
94+ baseline_llm_kwargs , test_llm_kwargs ,
95+ batch_size : int , output_len : int ,
96+ seed : int ):
97+
98+ run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
99+ per_test_common_llm_kwargs ,
100+ baseline_llm_kwargs , test_llm_kwargs ,
101+ batch_size , output_len , seed )
102+
103+
104+ @pytest .mark .parametrize (
105+ "common_llm_kwargs" ,
106+ [{
107+ # Skip cuda graph recording for fast test.
108+ "enforce_eager" : True ,
109+
110+ # Print spec metrics.
111+ "disable_log_stats" : False ,
112+
113+ # Precision
114+ "dtype" : PRECISION ,
115+
116+ # Main model
117+ "model_name" : MAIN_MODEL ,
118+ }])
119+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
120+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
121+ @pytest .mark .parametrize ("test_llm_kwargs" , [{
122+ "speculative_config" : {
123+ "model" : SPEC_MODEL ,
124+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
125+ "disable_logprobs" : False ,
126+ },
127+ }, {
128+ "speculative_config" : {
129+ "model" : SPEC_MODEL ,
130+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
131+ "disable_logprobs" : True ,
132+ },
133+ }])
134+ @pytest .mark .parametrize ("output_len" , [
135+ 128 ,
136+ ])
137+ @pytest .mark .parametrize ("batch_size" , [8 ])
138+ @pytest .mark .parametrize ("seed" , [1 ])
139+ @pytest .mark .parametrize ("logprobs" , [1 , 6 ])
140+ def test_eagle_e2e_greedy_logprobs (vllm_runner , common_llm_kwargs ,
141+ per_test_common_llm_kwargs ,
142+ baseline_llm_kwargs , test_llm_kwargs ,
143+ batch_size : int , output_len : int , seed : int ,
144+ logprobs : int ):
145+
146+ run_equality_correctness_test (
147+ vllm_runner ,
148+ common_llm_kwargs ,
149+ per_test_common_llm_kwargs ,
150+ baseline_llm_kwargs ,
151+ test_llm_kwargs ,
152+ batch_size ,
153+ output_len ,
154+ seed ,
155+ logprobs = logprobs ,
156+ prompt_logprobs = logprobs ,
157+ disable_logprobs = test_llm_kwargs ["speculative_config" ]
158+ ["disable_logprobs" ])
159+
160+
161+ @pytest .mark .skipif (True , reason = "Open it when graph mode ready." )
162+ @pytest .mark .parametrize (
163+ "common_llm_kwargs" ,
164+ [{
165+ "enforce_eager" : False ,
166+
167+ # Print spec metrics.
168+ "disable_log_stats" : False ,
169+
170+ # Precision
171+ "dtype" : PRECISION ,
172+
173+ # Main model
174+ "model_name" : MAIN_MODEL ,
175+ }])
176+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
177+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
178+ @pytest .mark .parametrize ("test_llm_kwargs" , [
179+ {
180+ "speculative_config" : {
181+ "model" : SPEC_MODEL ,
182+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
183+ },
184+ },
185+ ])
186+ @pytest .mark .parametrize ("output_len" , [
187+ 128 ,
188+ ])
189+ @pytest .mark .parametrize ("batch_size" , [1 , 32 ])
190+ @pytest .mark .parametrize ("seed" , [1 ])
191+ def test_eagle_e2e_greedy_correctness_cuda_graph (
192+ vllm_runner , common_llm_kwargs , per_test_common_llm_kwargs ,
193+ baseline_llm_kwargs , test_llm_kwargs , batch_size : int , output_len : int ,
194+ seed : int ):
195+ """Verify greedy equality with cuda graph enabled and different
196+ batch sizes."""
197+ run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
198+ per_test_common_llm_kwargs ,
199+ baseline_llm_kwargs , test_llm_kwargs ,
200+ batch_size , output_len , seed )
201+
202+
203+ @pytest .mark .skipif (True , reason = "Open it when preempt ready." )
204+ @pytest .mark .parametrize (
205+ "common_llm_kwargs" ,
206+ [{
207+ "block_size" : 8 ,
208+ # 2 for small prompt, 256//8 for generated.
209+ "num_gpu_blocks_override" : 2 + 256 // 8 ,
210+ "max_model_len" : (2 + 256 // 8 ) * 8 ,
211+
212+ # Skip cuda graph recording for fast test.
213+ "enforce_eager" : True ,
214+
215+ # Precision
216+ "dtype" : PRECISION ,
217+
218+ # Main model
219+ "model_name" : MAIN_MODEL ,
220+ }])
221+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
222+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
223+ @pytest .mark .parametrize ("test_llm_kwargs" , [
224+ {
225+ "speculative_config" : {
226+ "model" : SPEC_MODEL ,
227+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
228+ },
229+ },
230+ ])
231+ @pytest .mark .parametrize (
232+ "output_len" ,
233+ [
234+ # Use small output len for fast test.
235+ 128 ,
236+ ])
237+ @pytest .mark .parametrize ("batch_size" , [4 ])
238+ @pytest .mark .parametrize ("seed" , [1 ])
239+ def test_eagle_e2e_greedy_correctness_with_preemption (
240+ vllm_runner , common_llm_kwargs , per_test_common_llm_kwargs ,
241+ baseline_llm_kwargs , test_llm_kwargs , batch_size : int , output_len : int ,
242+ seed : int ):
243+ """Verify greedy equality, even when some sequences are preempted mid-
244+ generation.
245+ """
246+ run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
247+ per_test_common_llm_kwargs ,
248+ baseline_llm_kwargs , test_llm_kwargs ,
249+ batch_size , output_len , seed )
250+
251+
252+ @pytest .mark .parametrize (
253+ "common_llm_kwargs" ,
254+ [{
255+ # Skip cuda graph recording for fast test.
256+ "enforce_eager" : True ,
257+
258+ # Precision
259+ "dtype" : PRECISION ,
260+
261+ # Main model
262+ "model_name" : MAIN_MODEL ,
263+ }])
264+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
265+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
266+ @pytest .mark .parametrize (
267+ "test_llm_kwargs" ,
268+ [
269+ {
270+ "speculative_config" : {
271+ "model" : SPEC_MODEL ,
272+ "num_speculative_tokens" : k ,
273+ },
274+ }
275+ # Try a range of num. speculative tokens
276+ for k in range (1 , 1 + MAX_SPEC_TOKENS )
277+ ])
278+ @pytest .mark .parametrize ("batch_size" , [2 ])
279+ @pytest .mark .parametrize (
280+ "output_len" ,
281+ [
282+ # Use smaller output len for fast test.
283+ 32 ,
284+ ])
285+ @pytest .mark .parametrize ("seed" , [1 ])
286+ def test_eagle_different_k (vllm_runner , common_llm_kwargs ,
287+ per_test_common_llm_kwargs , baseline_llm_kwargs ,
288+ test_llm_kwargs , batch_size : int , output_len : int ,
289+ seed : int ):
290+ """Verify that eagle speculative decoding produces exact equality
291+ to without spec decode with different values of num_speculative_tokens.
292+ """
293+ run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
294+ per_test_common_llm_kwargs ,
295+ baseline_llm_kwargs , test_llm_kwargs ,
296+ batch_size , output_len , seed )
297+
298+
299+ @pytest .mark .parametrize (
300+ "common_llm_kwargs" ,
301+ [{
302+ # Skip cuda graph recording for fast test.
303+ "enforce_eager" : True ,
304+
305+ # Precision
306+ "dtype" : PRECISION ,
307+
308+ # Main model
309+ "model_name" : MAIN_MODEL ,
310+ }])
311+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
312+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
313+ @pytest .mark .parametrize ("test_llm_kwargs" , [{
314+ "speculative_config" : {
315+ "model" : SPEC_MODEL ,
316+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
317+ "disable_by_batch_size" : 4 ,
318+ },
319+ }])
320+ @pytest .mark .parametrize ("batch_size" , [1 , 5 ])
321+ @pytest .mark .parametrize (
322+ "output_len" ,
323+ [
324+ # Use smaller output len for fast test.
325+ 32 ,
326+ ])
327+ @pytest .mark .parametrize ("seed" , [1 ])
328+ def test_eagle_disable_queue (vllm_runner , common_llm_kwargs ,
329+ per_test_common_llm_kwargs , baseline_llm_kwargs ,
330+ test_llm_kwargs , batch_size : int , output_len : int ,
331+ seed : int ):
332+ """Verify that eagle speculative decoding produces exact equality
333+ to without spec decode when speculation is disabled for large
334+ batch sizes.
335+ """
336+ run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
337+ per_test_common_llm_kwargs ,
338+ baseline_llm_kwargs , test_llm_kwargs ,
339+ batch_size , output_len , seed )
340+
341+
342+ if __name__ == "__main__" :
343+ import pytest
344+ pytest .main ([__file__ ])
0 commit comments