33tensor parallelism.
44"""
55
6+ import json
67from typing import Optional
78
89import pytest
2829@pytest .mark .parametrize ("test_llm_kwargs" , [
2930 [
3031 "--speculative_config" ,
31- str ({
32+ json . dumps ({
3233 "model" : "JackFram/llama-68m" ,
3334 "num_speculative_tokens" : 3 ,
3435 }),
3536 ],
3637 [
3738 "--speculative_config" ,
38- str ({
39+ json . dumps ({
3940 "model" : "ngram" ,
4041 "num_speculative_tokens" : 5 ,
4142 "prompt_lookup_max" : 3 ,
@@ -88,15 +89,15 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
8889 "model, test_llm_kwargs" ,
8990 [("JackFram/llama-68m" , [
9091 "--speculative_config" ,
91- str ({
92+ json . dumps ({
9293 "model" : "JackFram/llama-68m" ,
9394 "num_speculative_tokens" : 5 ,
9495 "draft_tensor_parallel_size" : 1 ,
9596 }),
9697 ]),
9798 ("ibm-granite/granite-3b-code-instruct" , [
9899 "--speculative_config" ,
99- str ({
100+ json . dumps ({
100101 "model" : "ibm-granite/granite-3b-code-instruct" ,
101102 "num_speculative_tokens" : 5 ,
102103 "draft_tensor_parallel_size" : 1 ,
@@ -147,20 +148,20 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
147148@pytest .mark .parametrize ("model, test_llm_kwargs" ,
148149 [("JackFram/llama-68m" , [
149150 "--speculative_config" ,
150- str ({
151+ json . dumps ({
151152 "model" : "JackFram/llama-68m" ,
152153 "num_speculative_tokens" : 3 ,
153154 }),
154155 ]),
155156 ("JackFram/llama-68m" , [
156157 "--speculative_config" ,
157- str ({
158+ json . dumps ({
158159 "model" : "JackFram/llama-68m" ,
159160 "num_speculative_tokens" : 3 ,
160161 "draft_tensor_parallel_size" : 1 ,
161162 }),
162163 ])])
163- @pytest .mark .parametrize ("logprobs" , [None , 2 ])
164+ @pytest .mark .parametrize ("logprobs" , [None ])
164165@pytest .mark .parametrize ("batch_size" , [2 ])
165166@pytest .mark .parametrize ("seed" , [1 ])
166167def test_spec_decode_chunked_prefill_tp2 (model , common_llm_kwargs ,
@@ -171,9 +172,68 @@ def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
171172 """Verify spec decode works well with same and different TP size for
172173 the draft model with chunked prefill.
173174 """
174- if logprobs :
175- test_llm_kwargs .extend (
176- ["--disable_logprobs_during_spec_decoding" , "False" ])
175+ run_equality_correctness_test_tp (model ,
176+ common_llm_kwargs ,
177+ per_test_common_llm_kwargs ,
178+ baseline_llm_kwargs ,
179+ test_llm_kwargs ,
180+ batch_size ,
181+ max_output_len = 32 ,
182+ seed = seed ,
183+ temperature = 0.0 ,
184+ logprobs = logprobs )
185+
186+
187+ @pytest .mark .skipif (torch .cuda .device_count () < 2 ,
188+ reason = "Need at least 2 GPUs to run the test." )
189+ @pytest .mark .parametrize (
190+ "common_llm_kwargs" ,
191+ [[
192+ # Skip cuda graph recording for fast test.
193+ "--enforce-eager" ,
194+ "--tensor_parallel_size" ,
195+ "2" ,
196+
197+ # precision
198+ "--dtype" ,
199+ "bfloat16" ,
200+ ]])
201+ @pytest .mark .parametrize (
202+ "per_test_common_llm_kwargs" ,
203+ [["--enable-chunked-prefill" , "False" ],
204+ [
205+ "--enable-chunked-prefill" , "True" , "--max-num-batched-tokens" , "4" ,
206+ "--max-num-seqs" , "4"
207+ ]])
208+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [[]])
209+ @pytest .mark .parametrize ("model, test_llm_kwargs" ,
210+ [("JackFram/llama-68m" , [
211+ "--speculative_config" ,
212+ json .dumps ({
213+ "model" : "JackFram/llama-68m" ,
214+ "num_speculative_tokens" : 3 ,
215+ "disable_logprobs" : False ,
216+ }),
217+ ]),
218+ ("JackFram/llama-68m" , [
219+ "--speculative_config" ,
220+ json .dumps ({
221+ "model" : "JackFram/llama-68m" ,
222+ "num_speculative_tokens" : 3 ,
223+ "draft_tensor_parallel_size" : 1 ,
224+ "disable_logprobs" : False ,
225+ }),
226+ ])])
227+ @pytest .mark .parametrize ("logprobs" , [2 ])
228+ @pytest .mark .parametrize ("batch_size" , [2 ])
229+ @pytest .mark .parametrize ("seed" , [1 ])
230+ def test_spec_decode_chunked_prefill_tp2_with_logprobs (
231+ model , common_llm_kwargs , per_test_common_llm_kwargs ,
232+ baseline_llm_kwargs , test_llm_kwargs , logprobs : Optional [int ],
233+ batch_size : int , seed : int ):
234+ """Verify spec decode works well with same and different TP size for
235+ the draft model with chunked prefill.
236+ """
177237 run_equality_correctness_test_tp (model ,
178238 common_llm_kwargs ,
179239 per_test_common_llm_kwargs ,
0 commit comments