@@ -76,18 +76,21 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
7676 seed.
7777 - Keep max_tokens and max_model_len bounded for speed and memory use.
7878 """
79- random .seed (12345 )
79+ seed = int (os .getenv ("VLLM_TEST_SEED" , "12345" ))
80+ random .seed (seed )
8081
8182 # Allow overrides from environment (useful for CI tuning)
8283 # "facebook/opt-125m" is too small, doesn't reliably test determinism
8384 model = os .getenv ("VLLM_TEST_MODEL" , "Qwen/Qwen3-1.7B" )
8485 num_trials = int (os .getenv ("VLLM_NEEDLE_TRIALS" , "5" ))
85- batch_size = int (os .getenv ("VLLM_NEEDLE_BATCH_SIZE" , "64" ))
86- assert batch_size >= 2 , "Batch size should be >= 2 to mix needle."
86+ max_batch_size = int (os .getenv ("VLLM_NEEDLE_BATCH_SIZE" , "128" ))
87+ min_random_prompt = int (os .getenv ("VLLM_MIN_PROMPT" , "1024" ))
88+ max_random_prompt = int (os .getenv ("VLLM_MAX_PROMPT" , "2048" ))
89+ assert max_batch_size >= 2 , "Batch size should be >= 2 to mix needle."
8790
8891 # Keep GPU memory usage low to avoid startup allocation failures.
89- gpu_mem_util = float (os .getenv ("VLLM_GPU_MEMORY_UTILIZATION" , "0.3 " ))
90- max_model_len = int (os .getenv ("VLLM_MAX_MODEL_LEN" , "4096 " ))
92+ gpu_mem_util = float (os .getenv ("VLLM_GPU_MEMORY_UTILIZATION" , "0.4 " ))
93+ max_model_len = int (os .getenv ("VLLM_MAX_MODEL_LEN" , "5120 " ))
9194 swap_space_gb = int (os .getenv ("VLLM_SWAP_SPACE_GB" , "4" ))
9295
9396 # Sampling parameters: longer outputs with a more random-sounding
@@ -111,7 +114,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
111114 # Engine with bs=1 behavior
112115 llm_bs1 = LLM_with_max_seqs (
113116 model = model ,
114- max_num_seqs = 1 ,
117+ max_num_seqs = max_batch_size ,
115118 gpu_memory_utilization = gpu_mem_util ,
116119 max_model_len = max_model_len ,
117120 swap_space = swap_space_gb ,
@@ -126,7 +129,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
126129 # Engine with larger batch limit (e.g., 64)
127130 llm_bsN = LLM_with_max_seqs (
128131 model = model ,
129- max_num_seqs = batch_size ,
132+ max_num_seqs = max_batch_size ,
130133 gpu_memory_utilization = gpu_mem_util ,
131134 max_model_len = max_model_len ,
132135 swap_space = swap_space_gb ,
@@ -135,15 +138,16 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
135138 mismatches = 0
136139
137140 for trial in range (num_trials ):
138- # Create a batch of size `batch_size ` and insert the needle at
141+ # Create a batch of size `max_batch_size ` and insert the needle at
139142 # a random index
140143 prompts : list [str ] = []
144+ batch_size = random .randint (max_batch_size // 2 , max_batch_size )
141145 needle_pos = random .randint (0 , batch_size - 1 )
142146 for i in range (batch_size ):
143147 if i == needle_pos :
144148 prompts .append (needle_prompt )
145149 else :
146- prompts .append (_random_prompt ())
150+ prompts .append (_random_prompt (min_random_prompt , max_random_prompt ))
147151
148152 # Generate with the larger-batch engine
149153 outputs = llm_bsN .generate (prompts , sampling )
@@ -154,19 +158,20 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
154158 text = needle_output .outputs [0 ].text
155159
156160 if text != baseline_text :
161+ print (f"{ text } \n \n == Not the same as ==\n \n { baseline_text } \n \n " )
157162 mismatches += 1
158163
159164 passes = num_trials - mismatches
160165 # Dump how many passed vs failed
161166 print (
162167 f"[determinism] total={ num_trials } , passed={ passes } , "
163- f"failed={ mismatches } , batch_size= { batch_size } "
168+ f"failed={ mismatches } , max_batch_size= { max_batch_size } "
164169 )
165170
166171 if mismatches > 0 :
167172 pytest .fail (
168173 f"Nondeterministic outputs detected: { mismatches } failed out "
169- f"of { num_trials } trials (batch_size= { batch_size } )."
174+ f"of { num_trials } trials (max_batch_size= { max_batch_size } )."
170175 )
171176
172177 finally :
@@ -199,25 +204,28 @@ def _extract_step_logprobs(request_output):
199204 not torch .cuda .is_available (),
200205 reason = "Requires CUDA to match production inference path." ,
201206)
202- def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2 ():
203- # model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m")
207+ @pytest .mark .parametrize ("backend" , ["FLEX_ATTENTION" , "FLASHINFER" ])
208+ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN (backend ):
209+ backend = os .getenv ("VLLM_ATTENTION_BACKEND" , backend )
210+ os .environ ["VLLM_ATTENTION_BACKEND" ] = backend
211+
212+ seed = int (os .getenv ("VLLM_TEST_SEED" , "12345" ))
213+ random .seed (seed )
204214 model_name = os .getenv ("VLLM_TEST_MODEL" , "Qwen/Qwen3-1.7B" )
205215 tp_size = int (os .getenv ("VLLM_TEST_TP_SIZE" , "1" ))
206216
207217 # Force float32 to avoid precision-induced differences.
208218 llm = LLM (
209219 model = model_name ,
210220 tensor_parallel_size = tp_size ,
211- enforce_eager = True , # helps reduce nondeterminism from some backends
221+ enforce_eager = True ,
222+ enable_prefix_caching = False ,
212223 )
213224
214- prompts = [
215- "The capital of France is" ,
216- "The capital of Germany is" ,
217- ]
225+ prompts = [_random_prompt (10 , 1024 ) for i in range (100 )]
218226
219227 sp = SamplingParams (
220- temperature = 0.0 ,
228+ temperature = 0.6 ,
221229 top_p = 1.0 ,
222230 max_tokens = 8 ,
223231 # Seed shouldn't matter at temperature=0, but keeping it stable anyway.
@@ -238,29 +246,29 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
238246 )
239247 bs1_logprobs_per_prompt .append (step_logprobs )
240248
241- # BS=2 : run prompts in a batch and collect logprobs per step for each
249+ # BS=N : run prompts in a batch and collect logprobs per step for each
242250 # prompt.
243251 outs_batched = llm .generate (prompts , sp , use_tqdm = False )
244252 assert len (outs_batched ) == len (prompts )
245- bs2_logprobs_per_prompt = []
253+ bsN_logprobs_per_prompt = []
246254 for o in outs_batched :
247255 step_logprobs = _extract_step_logprobs (o )
248256 if step_logprobs is None :
249257 pytest .skip (
250258 "Logits are not available on RequestOutput; "
251259 "enable logprobs return to run this test."
252260 )
253- bs2_logprobs_per_prompt .append (step_logprobs )
261+ bsN_logprobs_per_prompt .append (step_logprobs )
254262
255- # Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs.
256- for i , (logprobs_bs1 , logprobs_bs2 ) in enumerate (
257- zip (bs1_logprobs_per_prompt , bs2_logprobs_per_prompt )
263+ # Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
264+ for i , (logprobs_bs1 , logprobs_bsN ) in enumerate (
265+ zip (bs1_logprobs_per_prompt , bsN_logprobs_per_prompt )
258266 ):
259- assert len (logprobs_bs1 ) == len (logprobs_bs2 ), (
267+ assert len (logprobs_bs1 ) == len (logprobs_bsN ), (
260268 f"Different number of generation steps for prompt index { i } : "
261- f"{ len (logprobs_bs1 )} (BS=1) vs { len (logprobs_bs2 )} (BS=2 )"
269+ f"{ len (logprobs_bs1 )} (BS=1) vs { len (logprobs_bsN )} (BS=N )"
262270 )
263- for t , (a , b ) in enumerate (zip (logprobs_bs1 , logprobs_bs2 )):
271+ for t , (a , b ) in enumerate (zip (logprobs_bs1 , logprobs_bsN )):
264272 assert a .shape == b .shape , (
265273 f"Logits shape mismatch at prompt { i } , step { t } : { a .shape } vs { b .shape } "
266274 )
@@ -297,6 +305,7 @@ def LLM_with_max_seqs(
297305 tensor_parallel_size = int (os .getenv ("VLLM_TP_SIZE" , "1" )),
298306 trust_remote_code = os .getenv ("VLLM_TRUST_REMOTE_CODE" , "0" ) == "1" ,
299307 enable_prefix_caching = False ,
308+ enforce_eager = True ,
300309 # Enable for MOE models
301310 # enable_expert_parallel=True,
302311 )
0 commit comments