@@ -76,18 +76,21 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
7676      seed. 
7777    - Keep max_tokens and max_model_len bounded for speed and memory use. 
7878    """ 
79-     random .seed (12345 )
79+     seed  =  int (os .getenv ("VLLM_TEST_SEED" , "12345" ))
80+     random .seed (seed )
8081
8182    # Allow overrides from environment (useful for CI tuning) 
8283    # "facebook/opt-125m" is too small, doesn't reliably test determinism 
8384    model  =  os .getenv ("VLLM_TEST_MODEL" , "Qwen/Qwen3-1.7B" )
8485    num_trials  =  int (os .getenv ("VLLM_NEEDLE_TRIALS" , "5" ))
85-     batch_size  =  int (os .getenv ("VLLM_NEEDLE_BATCH_SIZE" , "64" ))
86-     assert  batch_size  >=  2 , "Batch size should be >= 2 to mix needle." 
86+     max_batch_size  =  int (os .getenv ("VLLM_NEEDLE_BATCH_SIZE" , "128" ))
87+     min_random_prompt  =  int (os .getenv ("VLLM_MIN_PROMPT" , "1024" ))
88+     max_random_prompt  =  int (os .getenv ("VLLM_MAX_PROMPT" , "2048" ))
89+     assert  max_batch_size  >=  2 , "Batch size should be >= 2 to mix needle." 
8790
8891    # Keep GPU memory usage low to avoid startup allocation failures. 
89-     gpu_mem_util  =  float (os .getenv ("VLLM_GPU_MEMORY_UTILIZATION" , "0.3 " ))
90-     max_model_len  =  int (os .getenv ("VLLM_MAX_MODEL_LEN" , "4096 " ))
92+     gpu_mem_util  =  float (os .getenv ("VLLM_GPU_MEMORY_UTILIZATION" , "0.4 " ))
93+     max_model_len  =  int (os .getenv ("VLLM_MAX_MODEL_LEN" , "5120 " ))
9194    swap_space_gb  =  int (os .getenv ("VLLM_SWAP_SPACE_GB" , "4" ))
9295
9396    # Sampling parameters: longer outputs with a more random-sounding 
@@ -111,7 +114,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
111114        # Engine with bs=1 behavior 
112115        llm_bs1  =  LLM_with_max_seqs (
113116            model = model ,
114-             max_num_seqs = 1 ,
117+             max_num_seqs = 128 ,
115118            gpu_memory_utilization = gpu_mem_util ,
116119            max_model_len = max_model_len ,
117120            swap_space = swap_space_gb ,
@@ -126,7 +129,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
126129        # Engine with larger batch limit (e.g., 64) 
127130        llm_bsN  =  LLM_with_max_seqs (
128131            model = model ,
129-             max_num_seqs = batch_size ,
132+             max_num_seqs = 128 ,
130133            gpu_memory_utilization = gpu_mem_util ,
131134            max_model_len = max_model_len ,
132135            swap_space = swap_space_gb ,
@@ -135,15 +138,17 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
135138        mismatches  =  0 
136139
137140        for  trial  in  range (num_trials ):
138-             # Create a batch of size `batch_size ` and insert the needle at 
141+             # Create a batch of size `max_batch_size ` and insert the needle at 
139142            # a random index 
140143            prompts : list [str ] =  []
144+             batch_size  =  random .randint (max_batch_size  //  2 , max_batch_size )
141145            needle_pos  =  random .randint (0 , batch_size  -  1 )
142146            for  i  in  range (batch_size ):
143147                if  i  ==  needle_pos :
144148                    prompts .append (needle_prompt )
145149                else :
146-                     prompts .append (_random_prompt ())
150+                     prompts .append (
151+                         _random_prompt (min_random_prompt , max_random_prompt ))
147152
148153            # Generate with the larger-batch engine 
149154            outputs  =  llm_bsN .generate (prompts , sampling )
@@ -154,17 +159,19 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
154159            text  =  needle_output .outputs [0 ].text 
155160
156161            if  text  !=  baseline_text :
162+                 print (
163+                     f"{ text } \n \n == Not the same as ==\n \n { baseline_text } \n \n " )
157164                mismatches  +=  1 
158165
159166        passes  =  num_trials  -  mismatches 
160167        # Dump how many passed vs failed 
161168        print (f"[determinism] total={ num_trials } { passes }  
162-               f"failed={ mismatches } batch_size= { batch_size }  )
169+               f"failed={ mismatches } max_batch_size= { max_batch_size }  )
163170
164171        if  mismatches  >  0 :
165172            pytest .fail (
166173                f"Nondeterministic outputs detected: { mismatches }  
167-                 f"of { num_trials } batch_size= { batch_size }  )
174+                 f"of { num_trials } max_batch_size= { max_batch_size }  )
168175
169176    finally :
170177        # Ensure engines are shutdown to free GPU/VRAM across test sessions 
@@ -196,9 +203,14 @@ def _extract_step_logprobs(request_output):
196203    not  torch .cuda .is_available (), 
197204    reason = "Requires CUDA to match production inference path." , 
198205) 
199- def  test_logprobs_bitwise_batch_invariance_bs1_vs_bs2 ():
206+ @pytest .mark .parametrize ("backend" , ["FLEX_ATTENTION" , "FLASHINFER" ]) 
207+ def  test_logprobs_bitwise_batch_invariance_bs1_vs_bsN (backend ):
200208
201-     #model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m") 
209+     backend  =  os .getenv ("VLLM_ATTENTION_BACKEND" , backend )
210+     os .environ ["VLLM_ATTENTION_BACKEND" ] =  backend 
211+ 
212+     seed  =  int (os .getenv ("VLLM_TEST_SEED" , "12345" ))
213+     random .seed (seed )
202214    model_name  =  os .getenv ("VLLM_TEST_MODEL" , "Qwen/Qwen3-1.7B" )
203215    tp_size  =  int (os .getenv ("VLLM_TEST_TP_SIZE" , "1" ))
204216
@@ -212,10 +224,15 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
212224    prompts  =  [
213225        "The capital of France is" ,
214226        "The capital of Germany is" ,
227+         _random_prompt (10 , 1024 ),
228+         _random_prompt (10 , 1024 ),
229+         _random_prompt (10 , 1024 ),
230+         _random_prompt (10 , 1024 ),
231+         _random_prompt (10 , 1024 ),
215232    ]
216233
217234    sp  =  SamplingParams (
218-         temperature = 0.0  ,
235+         temperature = 0.6  ,
219236        top_p = 1.0 ,
220237        max_tokens = 8 ,
221238        # Seed shouldn't matter at temperature=0, but keeping it stable anyway. 
@@ -234,25 +251,25 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
234251                        "enable logprobs return to run this test." )
235252        bs1_logprobs_per_prompt .append (step_logprobs )
236253
237-     # BS=2 : run prompts in a batch and collect logprobs per step for each 
254+     # BS=N : run prompts in a batch and collect logprobs per step for each 
238255    # prompt. 
239256    outs_batched  =  llm .generate (prompts , sp , use_tqdm = False )
240257    assert  len (outs_batched ) ==  len (prompts )
241-     bs2_logprobs_per_prompt  =  []
258+     bsN_logprobs_per_prompt  =  []
242259    for  o  in  outs_batched :
243260        step_logprobs  =  _extract_step_logprobs (o )
244261        if  step_logprobs  is  None :
245262            pytest .skip ("Logits are not available on RequestOutput; " 
246263                        "enable logprobs return to run this test." )
247-         bs2_logprobs_per_prompt .append (step_logprobs )
264+         bsN_logprobs_per_prompt .append (step_logprobs )
248265
249-     # Compare step-by-step logprobs for each prompt between BS=1 and BS=2  runs. 
250-     for  i , (logprobs_bs1 , logprobs_bs2 ) in  enumerate (
251-             zip (bs1_logprobs_per_prompt , bs2_logprobs_per_prompt )):
252-         assert  len (logprobs_bs1 ) ==  len (logprobs_bs2 ), (
266+     # Compare step-by-step logprobs for each prompt between BS=1 and BS=N  runs. 
267+     for  i , (logprobs_bs1 , logprobs_bsN ) in  enumerate (
268+             zip (bs1_logprobs_per_prompt , bsN_logprobs_per_prompt )):
269+         assert  len (logprobs_bs1 ) ==  len (logprobs_bsN ), (
253270            f"Different number of generation steps for prompt index { i }  
254-             f"{ len (logprobs_bs1 )} { len (logprobs_bs2 )} 2 )" )
255-         for  t , (a , b ) in  enumerate (zip (logprobs_bs1 , logprobs_bs2 )):
271+             f"{ len (logprobs_bs1 )} { len (logprobs_bsN )} N )" )
272+         for  t , (a , b ) in  enumerate (zip (logprobs_bs1 , logprobs_bsN )):
256273            assert  a .shape  ==  b .shape , (
257274                f"Logits shape mismatch at prompt { i } { t }  
258275                f"{ a .shape } { b .shape }  )
0 commit comments