1212
1313
1414def _random_prompt (min_words : int = 1024 , max_words : int = 1024 * 2 ) -> str :
15- # Lightweight random prompt generator to vary prompt lengths and content.
16- vocab = [
17- "alpha" ,
18- "bravo" ,
19- "charlie" ,
20- "delta " ,
21- "echo " ,
22- "foxtrot " ,
23- "golf" ,
24- "hotel" ,
25- "india " ,
26- "juliet " ,
27- "kilo " ,
28- "lima" ,
29- "mike" ,
30- "november " ,
31- "oscar " ,
32- "papa " ,
33- "quebec" ,
34- "romeo" ,
35- "sierra " ,
36- "tango " ,
37- "uniform " ,
38- "victor" ,
39- "whiskey" ,
40- "xray " ,
41- "yankee " ,
42- "zulu " ,
15+ # Generate more realistic prompts that will actually produce varied tokens
16+ # Use a mix of common English text patterns
17+
18+ prompt_templates = [
19+ # Question-answer style
20+ "Question: What is the capital of France? \n Answer: The capital of France is " ,
21+ "Q: How does photosynthesis work? \n A: Photosynthesis is the process by which " ,
22+ "User: Can you explain quantum mechanics? \n Assistant: Quantum mechanics is " ,
23+
24+ # Story/narrative style
25+ "Once upon a time in a distant galaxy, there lived " ,
26+ "The old man walked slowly down the street, remembering " ,
27+ "In the year 2157, humanity finally discovered " ,
28+
29+ # Technical/code style
30+ "To implement a binary search tree in Python, first we need to " ,
31+ "The algorithm works by iterating through the array and " ,
32+ "Here's how to optimize database queries using indexing: " ,
33+
34+ # Factual/informative style
35+ "The Renaissance was a period in European history that " ,
36+ "Climate change is caused by several factors including " ,
37+ "The human brain contains approximately 86 billion neurons which " ,
38+
39+ # Conversational style
40+ "I've been thinking about getting a new laptop because " ,
41+ "Yesterday I went to the store and bought " ,
42+ "My favorite thing about summer is definitely " ,
4343 ]
44- n = random .randint (min_words , max_words )
45- words = random .choices (vocab , k = n )
4644
47- # Add some noise and punctuation variability
48- if random .random () < 0.5 :
49- words [0 ] = words [0 ].capitalize ()
50- if random .random () < 0.2 :
51- words .append ("" .join (random .choices (string .ascii_lowercase , k = 5 )))
52- punct = random .choice (["." , "?" , "!" , "..." , "" ])
53- return " " .join (words ) + punct
45+ # Pick a random template
46+ base_prompt = random .choice (prompt_templates )
47+
48+ # Add some padding to vary the length if needed
49+ if min_words > 50 :
50+ # For longer prompts, repeat context
51+ padding_text = " This is an interesting topic that deserves more explanation. " * (min_words // 50 )
52+ base_prompt = base_prompt + padding_text
53+
54+ return base_prompt
5455
5556
5657@pytest .mark .timeout (1000 )
@@ -91,7 +92,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
9192 # Keep GPU memory usage low to avoid startup allocation failures.
9293 gpu_mem_util = float (os .getenv ("VLLM_GPU_MEMORY_UTILIZATION" , "0.4" ))
9394 max_model_len = int (os .getenv ("VLLM_MAX_MODEL_LEN" , "5120" ))
94- swap_space_gb = int (os .getenv ("VLLM_SWAP_SPACE_GB" , "4" ))
9595
9696 # Sampling parameters: longer outputs with a more random-sounding
9797 # continuation,but still deterministic due to fixed seed.
@@ -117,7 +117,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
117117 max_num_seqs = max_batch_size ,
118118 gpu_memory_utilization = gpu_mem_util ,
119119 max_model_len = max_model_len ,
120- swap_space = swap_space_gb ,
121120 )
122121
123122 # Baseline generation for the needle prompt alone.
@@ -132,7 +131,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
132131 max_num_seqs = max_batch_size ,
133132 gpu_memory_utilization = gpu_mem_util ,
134133 max_model_len = max_model_len ,
135- swap_space = swap_space_gb ,
136134 )
137135
138136 mismatches = 0
@@ -195,16 +193,16 @@ def _extract_step_logprobs(request_output):
195193 ],
196194 dtype = torch .float32 ,
197195 )
198- return t
196+ return t , inner . token_ids
199197
200- return None
198+ return None , None
201199
202200
203201@pytest .mark .skipif (
204202 not torch .cuda .is_available (),
205203 reason = "Requires CUDA to match production inference path." ,
206204)
207- @pytest .mark .parametrize ("backend" , ["FLEX_ATTENTION" , "FLASHINFER" ])
205+ @pytest .mark .parametrize ("backend" , ["FLEX_ATTENTION" , "FLASH_ATTN" , " FLASHINFER" ])
208206def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN (backend ):
209207 backend = os .getenv ("VLLM_ATTENTION_BACKEND" , backend )
210208 os .environ ["VLLM_ATTENTION_BACKEND" ] = backend
@@ -214,77 +212,226 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
214212 model_name = os .getenv ("VLLM_TEST_MODEL" , "Qwen/Qwen3-1.7B" )
215213 tp_size = int (os .getenv ("VLLM_TEST_TP_SIZE" , "1" ))
216214
217- # Force float32 to avoid precision-induced differences.
215+ # For batch invariance, disable custom all-reduce to ensure deterministic
216+ # all-reduce operations (custom all-reduce may not be deterministic)
217+ from vllm .model_executor .layers .batch_invariant import (
218+ vllm_kernel_override_batch_invariant ,
219+ )
220+ disable_custom_ar = vllm_kernel_override_batch_invariant ()
221+
222+ if disable_custom_ar :
223+ print (f"\n { '=' * 80 } " )
224+ print (f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={ tp_size } )" )
225+ print (f"{ '=' * 80 } \n " )
226+
218227 llm = LLM (
219228 model = model_name ,
220229 tensor_parallel_size = tp_size ,
221- enforce_eager = True ,
222230 enable_prefix_caching = False ,
231+ max_num_seqs = 32 ,
232+ max_model_len = 8192 ,
233+ dtype = "bfloat16" , # not everything is supported
223234 )
224235
225- prompts = [_random_prompt (10 , 1024 ) for i in range (100 )]
236+ # Use more realistic prompts for better token generation
237+ prompts = [_random_prompt (10 , 50 ) for i in range (3 )]
226238
227239 sp = SamplingParams (
228240 temperature = 0.6 ,
229241 top_p = 1.0 ,
230242 max_tokens = 8 ,
231- # Seed shouldn't matter at temperature=0, but keeping it stable anyway.
232243 seed = 1234 ,
233244 logprobs = 5 ,
234245 )
235246
236247 # BS=1: run prompts individually and collect logprobs per step.
248+ print ("\n " + "=" * 80 )
249+ print ("STARTING BS=1 RUNS (each prompt individually)" )
250+ print ("=" * 80 + "\n " )
251+
237252 bs1_logprobs_per_prompt = []
238- for p in prompts :
253+ bs1_tokens_per_prompt = []
254+ for idx , p in enumerate (prompts ):
255+ print (f"\n [BS=1] Running prompt { idx } /{ len (prompts )} - Preview: { p [:80 ]} ..." )
239256 outs = llm .generate ([p ], sp , use_tqdm = False )
240257 assert len (outs ) == 1
241- step_logprobs = _extract_step_logprobs (outs [0 ])
258+ step_logprobs , token_ids = _extract_step_logprobs (outs [0 ])
242259 if step_logprobs is None :
243260 pytest .skip (
244261 "Logits are not available on RequestOutput; "
245262 "enable logprobs return to run this test."
246263 )
247264 bs1_logprobs_per_prompt .append (step_logprobs )
265+ bs1_tokens_per_prompt .append (token_ids )
266+ print (f"[BS=1] Prompt { idx } generated tokens: { token_ids } " )
248267
249268 # BS=N: run prompts in a batch and collect logprobs per step for each
250269 # prompt.
270+ print ("\n " + "=" * 80 )
271+ print (f"STARTING BS={ len (prompts )} RUN (all prompts batched)" )
272+ print ("=" * 80 + "\n " )
273+
251274 outs_batched = llm .generate (prompts , sp , use_tqdm = False )
252275 assert len (outs_batched ) == len (prompts )
253276 bsN_logprobs_per_prompt = []
254- for o in outs_batched :
255- step_logprobs = _extract_step_logprobs (o )
277+ bsN_tokens_per_prompt = []
278+
279+ print (f"\n [BS={ len (prompts )} ] Processing batched outputs..." )
280+ for idx , o in enumerate (outs_batched ):
281+ print (f"[BS={ len (prompts )} ] Prompt { idx } generated tokens: { o .outputs [0 ].token_ids if o .outputs else 'N/A' } " )
282+ step_logprobs , token_ids = _extract_step_logprobs (o )
256283 if step_logprobs is None :
257284 pytest .skip (
258285 "Logits are not available on RequestOutput; "
259286 "enable logprobs return to run this test."
260287 )
261288 bsN_logprobs_per_prompt .append (step_logprobs )
289+ bsN_tokens_per_prompt .append (token_ids )
262290
263291 # Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
264- for i , (logprobs_bs1 , logprobs_bsN ) in enumerate (
265- zip (bs1_logprobs_per_prompt , bsN_logprobs_per_prompt )
292+ failed_prompts = []
293+ for i , (logprobs_bs1 , logprobs_bsN , tokens_bs1 , tokens_bsN ) in enumerate (
294+ zip (bs1_logprobs_per_prompt , bsN_logprobs_per_prompt ,
295+ bs1_tokens_per_prompt , bsN_tokens_per_prompt )
266296 ):
267- assert len (logprobs_bs1 ) == len (logprobs_bsN ), (
268- f"Different number of generation steps for prompt index { i } : "
269- f"{ len (logprobs_bs1 )} (BS=1) vs { len (logprobs_bsN )} (BS=N)"
270- )
297+ if len (logprobs_bs1 ) != len (logprobs_bsN ):
298+ failed_prompts .append ({
299+ "prompt_idx" : i ,
300+ "step" : "all" ,
301+ "reason" : f"Different number of steps: { len (logprobs_bs1 )} (BS=1) vs { len (logprobs_bsN )} (BS=N)" ,
302+ "prompt_preview" : prompts [i ][:100 ],
303+ "bs1_tokens" : tokens_bs1 ,
304+ "bsN_tokens" : tokens_bsN ,
305+ })
306+ continue
307+
308+ # Check if tokens match first
309+ if tokens_bs1 != tokens_bsN :
310+ failed_prompts .append ({
311+ "prompt_idx" : i ,
312+ "step" : "sampling" ,
313+ "reason" : "Different tokens sampled" ,
314+ "prompt_preview" : prompts [i ][:100 ],
315+ "bs1_tokens" : tokens_bs1 ,
316+ "bsN_tokens" : tokens_bsN ,
317+ "bs1_all_logprobs" : [logprobs_bs1 [s ].tolist () for s in range (len (logprobs_bs1 ))],
318+ "bsN_all_logprobs" : [logprobs_bsN [s ].tolist () for s in range (len (logprobs_bsN ))],
319+ })
320+ continue
321+
271322 for t , (a , b ) in enumerate (zip (logprobs_bs1 , logprobs_bsN )):
272- assert a .shape == b .shape , (
273- f"Logits shape mismatch at prompt { i } , step { t } : { a .shape } vs { b .shape } "
274- )
275- # Bitwise exact equality.
276- assert torch .equal (a , b ), (
277- f"Bitwise logprobs mismatch at prompt { i } , step { t } "
278- f"(dtype={ a .dtype } , shape={ a .shape } )."
279- )
323+ if a .shape != b .shape :
324+ failed_prompts .append ({
325+ "prompt_idx" : i ,
326+ "step" : t ,
327+ "reason" : f"Shape mismatch: { a .shape } vs { b .shape } " ,
328+ "prompt_preview" : prompts [i ][:100 ],
329+ "bs1_tokens" : tokens_bs1 ,
330+ "bsN_tokens" : tokens_bsN ,
331+ })
332+ break
333+
334+ if not torch .equal (a , b ):
335+ max_diff = torch .abs (a - b ).max ().item ()
336+ # Print which token failed
337+ print (f"\n [DIVERGENCE] Prompt { i } , Token { t } : max_diff={ max_diff :.6e} " )
338+ print (f" Token IDs: bs1={ tokens_bs1 [t ] if t < len (tokens_bs1 ) else 'N/A' } , bsN={ tokens_bsN [t ] if t < len (tokens_bsN ) else 'N/A' } " )
339+ print (f" BS=1 logprob: { a .tolist ()} " )
340+ print (f" BS=N logprob: { b .tolist ()} " )
341+ failed_prompts .append ({
342+ "prompt_idx" : i ,
343+ "step" : t ,
344+ "reason" : f"Bitwise mismatch (max_diff={ max_diff :.6e} )" ,
345+ "prompt_preview" : prompts [i ][:100 ],
346+ "bs1_tokens" : tokens_bs1 ,
347+ "bsN_tokens" : tokens_bsN ,
348+ "bs1_all_logprobs" : [logprobs_bs1 [s ].tolist () for s in range (len (logprobs_bs1 ))],
349+ "bsN_all_logprobs" : [logprobs_bsN [s ].tolist () for s in range (len (logprobs_bsN ))],
350+ })
351+ break
352+
353+ # Print summary of all failures
354+ if failed_prompts :
355+ print (f"\n { '=' * 80 } " )
356+ print (f"BATCH INVARIANCE FAILURES: { len (failed_prompts )} /{ len (prompts )} prompts failed" )
357+ print (f"{ '=' * 80 } " )
358+ for fail in failed_prompts :
359+ print (f"\n Prompt { fail ['prompt_idx' ]} (step { fail ['step' ]} ):" )
360+ print (f" Reason: { fail ['reason' ]} " )
361+ print (f" Preview: { fail ['prompt_preview' ]} ..." )
362+
363+ # Always show the tokens
364+ if "bs1_tokens" in fail :
365+ print (f" BS=1 tokens: { fail ['bs1_tokens' ]} " )
366+ if "bsN_tokens" in fail :
367+ print (f" BS=N tokens: { fail ['bsN_tokens' ]} " )
368+
369+ if "bs1_all_logprobs" in fail :
370+ print (f" BS=1 logprobs for all { len (fail ['bs1_all_logprobs' ])} steps:" )
371+ for step_idx , logprobs in enumerate (fail ['bs1_all_logprobs' ]):
372+ print (f" Step { step_idx } : { logprobs } " )
373+ print (f" BS=N logprobs for all { len (fail ['bsN_all_logprobs' ])} steps:" )
374+ for step_idx , logprobs in enumerate (fail ['bsN_all_logprobs' ]):
375+ print (f" Step { step_idx } : { logprobs } " )
376+ print (f"{ '=' * 80 } \n " )
377+
378+ # Fail the test with summary
379+ pytest .fail (
380+ f"Batch invariance violated in { len (failed_prompts )} /{ len (prompts )} prompts. "
381+ f"See output above for details."
382+ )
383+
384+
385+ def test_simple_generation ():
386+ """
387+ Simple test that runs the model with a basic prompt and prints the output.
388+ Useful for quick smoke testing and debugging.
389+ """
390+ model = os .getenv ("VLLM_TEST_MODEL" , "Qwen/Qwen3-1.7B" )
391+
392+ llm = LLM (
393+ model = model ,
394+ max_num_seqs = 1 ,
395+ tensor_parallel_size = int (os .getenv ("VLLM_TP_SIZE" , "1" )),
396+ enforce_eager = True ,
397+ gpu_memory_utilization = 0.9 ,
398+ max_model_len = 2048 ,
399+ dtype = "bfloat16" ,
400+ enable_prefix_caching = False ,
401+ )
402+
403+ prompt = "the capital of france is"
404+ sampling_params = SamplingParams (
405+ temperature = 0.0 ,
406+ max_tokens = 20 ,
407+ )
408+
409+ print (f"\n { '=' * 80 } " )
410+ print (f"Running simple generation test" )
411+ print (f"Prompt: '{ prompt } '" )
412+ print (f"{ '=' * 80 } \n " )
413+
414+ try :
415+ outputs = llm .generate ([prompt ], sampling_params )
416+
417+ assert len (outputs ) == 1
418+ output_text = outputs [0 ].outputs [0 ].text
419+
420+ print (f"Output: '{ output_text } '" )
421+ print (f"\n { '=' * 80 } " )
422+ print (f"Full completion: '{ prompt } { output_text } '" )
423+ print (f"{ '=' * 80 } \n " )
424+
425+ finally :
426+ with contextlib .suppress (Exception ):
427+ llm .shutdown ()
280428
281429
282430def LLM_with_max_seqs (
283431 model : str ,
284432 max_num_seqs : int ,
285433 gpu_memory_utilization : float ,
286434 max_model_len : int ,
287- swap_space : int ,
288435) -> LLM :
289436 """
290437 Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
@@ -293,17 +440,9 @@ def LLM_with_max_seqs(
293440 return LLM (
294441 model = model ,
295442 max_num_seqs = max_num_seqs ,
296- # Constrain GPU memory pool so test can run even on busy GPUs.
297443 gpu_memory_utilization = gpu_memory_utilization ,
298- # Keep KV cache footprint small while allowing longer outputs.
299444 max_model_len = max_model_len ,
300- # Allow some CPU offload if needed.
301- swap_space = swap_space ,
302- # Keep things lean and CI-friendly.
303- dtype = "float16" ,
304- # Single-GPU by default; override externally if desired.
445+ dtype = "bfloat16" ,
305446 tensor_parallel_size = int (os .getenv ("VLLM_TP_SIZE" , "1" )),
306- trust_remote_code = os .getenv ("VLLM_TRUST_REMOTE_CODE" , "0" ) == "1" ,
307- enforce_eager = True ,
308447 enable_prefix_caching = False ,
309448 )
0 commit comments